Skip to content

Commit f1aa2de

Browse files
authored
E5M2 KVCache for Pure IPEX (#3375)
* add ut; ut pass w/o mul_attenion_weights_and_value_of_head optimization * enable mul_attenion_weights_and_value_of_head optimization * Change frontend interface partially and need to be continued * Change beam_search related frontend interfaces * refine code * fix flake * refine code * Improve code style * clean comment * fix ut * Compatible with old code * improve code style * get kv_cache_dtype by model.config * temp remove duplicate deq * enable e5m2 on deepspeed * add assert * delete kv_cache_dtype on ipex.llm.optimize interface * improve code style; revert 6df5ed5007487b52b8366681746b3b22827e8940 * improve code style
1 parent 3e27750 commit f1aa2de

File tree

12 files changed

+1290
-129
lines changed

12 files changed

+1290
-129
lines changed

csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp

Lines changed: 387 additions & 58 deletions
Large diffs are not rendered by default.

examples/cpu/llm/inference/distributed/run_generation_with_deepspeed.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,17 @@
197197
help="Quantize weight symmetrically for weight only quantization. It usually brings better latency at"
198198
" the cost of accuracy. It has not effect if you are loading low-precision checkpoints.",
199199
)
200+
parser.add_argument(
201+
"--kv-cache-dtype",
202+
type=str,
203+
choices=[
204+
"auto",
205+
"fp8_e5m2",
206+
],
207+
default="auto",
208+
help='Data type for kv cache storage. If "auto", will use model '
209+
"data type. fp8 type now supports e5m2.",
210+
)
200211
parser.add_argument(
201212
"--low-precision-checkpoint",
202213
default="",
@@ -206,6 +217,7 @@
206217
" quantization with INT4 weight.",
207218
)
208219

220+
209221
args = parser.parse_args()
210222

211223

@@ -350,6 +362,13 @@ def get_checkpoint_files(model_name_or_path):
350362
config = AutoConfig.from_pretrained(
351363
args.config_file, torchscript=True, trust_remote_code=True
352364
)
365+
366+
if args.kv_cache_dtype == "auto":
367+
kv_cache_dtype = None
368+
elif args.kv_cache_dtype == "fp8_e5m2":
369+
kv_cache_dtype = torch.float8_e5m2
370+
config.kv_cache_dtype = kv_cache_dtype
371+
353372
if not hasattr(config, "text_max_length") and args.prompt is None:
354373
config.text_max_length = int(args.input_tokens) + int(args.max_new_tokens)
355374
if model_type == "mpt" and args.prompt is None:

examples/cpu/llm/inference/run.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,17 @@ def main(args_in: Optional[List[str]] = None) -> None:
300300
" In other cases, this feature is always turned on regardless of this argument and it does not"
301301
" conflict with the accuracy test.",
302302
)
303+
parser.add_argument(
304+
"--kv-cache-dtype",
305+
type=str,
306+
choices=[
307+
"auto",
308+
"fp8_e5m2",
309+
],
310+
default="auto",
311+
help='Data type for kv cache storage. If "auto", will use model '
312+
"data type. fp8 type now supports e5m2.",
313+
)
303314
args = parser.parse_args(args_in)
304315

305316
parent_path = Path(__file__).parent.absolute()
@@ -335,6 +346,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
335346
infer_cmd.extend(["--num-iter", str(args.num_iter)])
336347
infer_cmd.extend(["--num-warmup", str(args.num_warmup)])
337348
infer_cmd.extend(["--batch-size", str(args.batch_size)])
349+
infer_cmd.extend(["--kv-cache-dtype", args.kv_cache_dtype])
338350
if args.vision_text_model:
339351
infer_cmd.extend(["--vision-text-model"])
340352
if args.greedy:
@@ -630,6 +642,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
630642
infer_cmd.extend(["--num-iter", str(args.num_iter)])
631643
infer_cmd.extend(["--num-warmup", str(args.num_warmup)])
632644
infer_cmd.extend(["--batch-size", str(args.batch_size)])
645+
infer_cmd.extend(["--kv-cache-dtype", args.kv_cache_dtype])
633646
if args.local_rank is not None:
634647
infer_cmd.extend(["--local_rank", str(args.local_rank)])
635648
if args.greedy:

examples/cpu/llm/inference/single_instance/run_generation.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,17 @@
105105
action="store_true",
106106
help="whether or not it is vision-text multi-model structure",
107107
)
108+
parser.add_argument(
109+
"--kv-cache-dtype",
110+
type=str,
111+
choices=[
112+
"auto",
113+
"fp8_e5m2",
114+
],
115+
default="auto",
116+
help='Data type for kv cache storage. If "auto", will use model '
117+
"data type. fp8 type now supports e5m2.",
118+
)
108119

109120
args = parser.parse_args()
110121
print(args)
@@ -154,6 +165,13 @@
154165
trust_remote_code=True,
155166
torch_dtype=amp_dtype,
156167
)
168+
169+
if args.kv_cache_dtype == "auto":
170+
kv_cache_dtype = None
171+
elif args.kv_cache_dtype == "fp8_e5m2":
172+
kv_cache_dtype = torch.float8_e5m2
173+
config.kv_cache_dtype = kv_cache_dtype
174+
157175
if not hasattr(config, "text_max_length") and args.prompt is None:
158176
config.text_max_length = int(args.input_tokens) + int(args.max_new_tokens)
159177
if model_type == "mpt" and args.prompt is None:

intel_extension_for_pytorch/transformers/generation/beam_sample.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ def _beam_sample(
178178
"Maira2ForConditionalGeneration",
179179
]:
180180
first_token = False
181+
if hasattr(self.config, "kv_cache_dtype"):
182+
kv_cache_dtype = self.config.kv_cache_dtype
183+
elif hasattr(self, "dtype"):
184+
kv_cache_dtype = self.dtype
185+
else:
186+
kv_cache_dtype = torch.float
181187
if model_inputs["past_key_values"] is None:
182188
first_token = True
183189
if self.model_backbone == "T5ForConditionalGeneration":
@@ -189,8 +195,12 @@ def _beam_sample(
189195
[
190196
(
191197
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
192-
torch.zeros([1, 1, 1, 1]).contiguous(),
193-
torch.zeros([1, 1, 1, 1]).contiguous(),
198+
torch.zeros([1, 1, 1, 1])
199+
.contiguous()
200+
.to(kv_cache_dtype),
201+
torch.zeros([1, 1, 1, 1])
202+
.contiguous()
203+
.to(kv_cache_dtype),
194204
beam_idx_tmp,
195205
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
196206
self.decoder.block[i]
@@ -247,10 +257,14 @@ def _beam_sample(
247257
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
248258
torch.zeros(
249259
[int(batch_size * num_beams), num_head, 1, head_dim]
250-
).contiguous(),
260+
)
261+
.contiguous()
262+
.to(kv_cache_dtype),
251263
torch.zeros(
252264
[int(batch_size * num_beams), num_head, 1, head_dim]
253-
).contiguous(),
265+
)
266+
.contiguous()
267+
.to(kv_cache_dtype),
254268
beam_idx_tmp,
255269
)
256270
for i in range(self.config.num_hidden_layers)
@@ -265,8 +279,12 @@ def _beam_sample(
265279
[
266280
(
267281
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
268-
torch.zeros([1, 1, 1, 1]).contiguous(),
269-
torch.zeros([1, 1, 1, 1]).contiguous(),
282+
torch.zeros([1, 1, 1, 1])
283+
.contiguous()
284+
.to(kv_cache_dtype),
285+
torch.zeros([1, 1, 1, 1])
286+
.contiguous()
287+
.to(kv_cache_dtype),
270288
beam_idx_tmp,
271289
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
272290
self.model.decoder.layers[i]
@@ -324,8 +342,12 @@ def _beam_sample(
324342
torch.zeros(
325343
1, 0, 0, 1, dtype=torch.long
326344
).contiguous(),
327-
torch.zeros([1, 1, 1, 1]).contiguous(),
328-
torch.zeros([1, 1, 1, 1]).contiguous(),
345+
torch.zeros([1, 1, 1, 1])
346+
.contiguous()
347+
.to(kv_cache_dtype),
348+
torch.zeros([1, 1, 1, 1])
349+
.contiguous()
350+
.to(kv_cache_dtype),
329351
beam_idx_tmp,
330352
)
331353
if i
@@ -343,8 +365,12 @@ def _beam_sample(
343365
[
344366
(
345367
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
346-
torch.zeros([1, 1, 1, 1]).contiguous(),
347-
torch.zeros([1, 1, 1, 1]).contiguous(),
368+
torch.zeros([1, 1, 1, 1])
369+
.contiguous()
370+
.to(kv_cache_dtype),
371+
torch.zeros([1, 1, 1, 1])
372+
.contiguous()
373+
.to(kv_cache_dtype),
348374
beam_idx_tmp,
349375
)
350376
for i in range(num_hidden_layers)

intel_extension_for_pytorch/transformers/generation/beam_search.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,12 @@ def _beam_search(
206206
]:
207207
first_token = False
208208
has_position_id = model_inputs.get("position_ids", None) is not None
209+
if hasattr(self.config, "kv_cache_dtype"):
210+
kv_cache_dtype = self.config.kv_cache_dtype
211+
elif hasattr(self, "dtype"):
212+
kv_cache_dtype = self.dtype
213+
else:
214+
kv_cache_dtype = torch.float
209215
if model_inputs["past_key_values"] is None:
210216
first_token = True
211217
if self.model_backbone == "T5ForConditionalGeneration":
@@ -217,8 +223,12 @@ def _beam_search(
217223
[
218224
(
219225
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
220-
torch.zeros([1, 1, 1, 1]).contiguous(),
221-
torch.zeros([1, 1, 1, 1]).contiguous(),
226+
torch.zeros([1, 1, 1, 1])
227+
.contiguous()
228+
.to(kv_cache_dtype),
229+
torch.zeros([1, 1, 1, 1])
230+
.contiguous()
231+
.to(kv_cache_dtype),
222232
beam_idx_tmp,
223233
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
224234
self.decoder.block[i]
@@ -275,10 +285,14 @@ def _beam_search(
275285
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
276286
torch.zeros(
277287
[int(batch_size * num_beams), num_head, 1, head_dim]
278-
).contiguous(),
288+
)
289+
.contiguous()
290+
.to(kv_cache_dtype),
279291
torch.zeros(
280292
[int(batch_size * num_beams), num_head, 1, head_dim]
281-
).contiguous(),
293+
)
294+
.contiguous()
295+
.to(kv_cache_dtype),
282296
beam_idx_tmp,
283297
)
284298
for i in range(self.config.num_hidden_layers)
@@ -293,8 +307,12 @@ def _beam_search(
293307
[
294308
(
295309
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
296-
torch.zeros([1, 1, 1, 1]).contiguous(),
297-
torch.zeros([1, 1, 1, 1]).contiguous(),
310+
torch.zeros([1, 1, 1, 1])
311+
.contiguous()
312+
.to(kv_cache_dtype),
313+
torch.zeros([1, 1, 1, 1])
314+
.contiguous()
315+
.to(kv_cache_dtype),
298316
beam_idx_tmp,
299317
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
300318
self.model.decoder.layers[i]
@@ -353,15 +371,23 @@ def _beam_search(
353371
torch.zeros(
354372
1, 0, 0, 1, dtype=torch.long
355373
).contiguous(),
356-
torch.zeros([1, 1, 1, 1]).contiguous(),
357-
torch.zeros([1, 1, 1, 1]).contiguous(),
374+
torch.zeros([1, 1, 1, 1])
375+
.contiguous()
376+
.to(kv_cache_dtype),
377+
torch.zeros([1, 1, 1, 1])
378+
.contiguous()
379+
.to(kv_cache_dtype),
358380
beam_idx_tmp,
359381
)
360382
if i
361383
not in self.config.text_config.cross_attention_layers
362384
else (
363-
torch.zeros([1, 1, 1, head_dim]).contiguous(),
364-
torch.zeros([1, 1, 1, head_dim]).contiguous(),
385+
torch.zeros([1, 1, 1, head_dim])
386+
.contiguous()
387+
.to(kv_cache_dtype),
388+
torch.zeros([1, 1, 1, head_dim])
389+
.contiguous()
390+
.to(kv_cache_dtype),
365391
)
366392
)
367393
for i in range(num_hidden_layers)
@@ -372,8 +398,12 @@ def _beam_search(
372398
[
373399
(
374400
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
375-
torch.zeros([1, 1, 1, 1]).contiguous(),
376-
torch.zeros([1, 1, 1, 1]).contiguous(),
401+
torch.zeros([1, 1, 1, 1])
402+
.contiguous()
403+
.to(kv_cache_dtype),
404+
torch.zeros([1, 1, 1, 1])
405+
.contiguous()
406+
.to(kv_cache_dtype),
377407
beam_idx_tmp,
378408
)
379409
for i in range(num_hidden_layers)

intel_extension_for_pytorch/transformers/generation/greedy_search.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ def _greedy_search(
170170
"Maira2ForConditionalGeneration",
171171
]:
172172
first_token = False
173+
if hasattr(self.config, "kv_cache_dtype"):
174+
kv_cache_dtype = self.config.kv_cache_dtype
175+
elif hasattr(self, "dtype"):
176+
kv_cache_dtype = self.dtype
177+
else:
178+
kv_cache_dtype = torch.float
173179
input_bs = input_ids.size()[0]
174180
if model_inputs["past_key_values"] is None:
175181
first_token = True
@@ -182,8 +188,12 @@ def _greedy_search(
182188
[
183189
(
184190
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
185-
torch.zeros([1, 1, 1, 1]).contiguous(),
186-
torch.zeros([1, 1, 1, 1]).contiguous(),
191+
torch.zeros([1, 1, 1, 1])
192+
.contiguous()
193+
.to(kv_cache_dtype),
194+
torch.zeros([1, 1, 1, 1])
195+
.contiguous()
196+
.to(kv_cache_dtype),
187197
beam_idx_tmp,
188198
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
189199
self.decoder.block[i]
@@ -232,8 +242,12 @@ def _greedy_search(
232242
[
233243
(
234244
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
235-
torch.zeros([1, 1, 1, 1]).contiguous(),
236-
torch.zeros([1, 1, 1, 1]).contiguous(),
245+
torch.zeros([1, 1, 1, 1])
246+
.contiguous()
247+
.to(kv_cache_dtype),
248+
torch.zeros([1, 1, 1, 1])
249+
.contiguous()
250+
.to(kv_cache_dtype),
237251
beam_idx_tmp,
238252
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
239253
self.model.decoder.layers[i]
@@ -291,12 +305,12 @@ def _greedy_search(
291305
[
292306
(
293307
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
294-
torch.zeros(
295-
[input_bs, num_head, 1, head_dim]
296-
).contiguous(),
297-
torch.zeros(
298-
[input_bs, num_head, 1, head_dim]
299-
).contiguous(),
308+
torch.zeros([input_bs, num_head, 1, head_dim])
309+
.contiguous()
310+
.to(kv_cache_dtype),
311+
torch.zeros([input_bs, num_head, 1, head_dim])
312+
.contiguous()
313+
.to(kv_cache_dtype),
300314
beam_idx_tmp,
301315
)
302316
for i in range(num_hidden_layers)
@@ -314,8 +328,12 @@ def _greedy_search(
314328
torch.zeros(
315329
1, 0, 0, 1, dtype=torch.long
316330
).contiguous(),
317-
torch.zeros([1, 1, 1, 1]).contiguous(),
318-
torch.zeros([1, 1, 1, 1]).contiguous(),
331+
torch.zeros([1, 1, 1, 1])
332+
.contiguous()
333+
.to(kv_cache_dtype),
334+
torch.zeros([1, 1, 1, 1])
335+
.contiguous()
336+
.to(kv_cache_dtype),
319337
beam_idx_tmp,
320338
)
321339
if i
@@ -333,8 +351,12 @@ def _greedy_search(
333351
[
334352
(
335353
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
336-
torch.zeros([1, 1, 1, 1]).contiguous(),
337-
torch.zeros([1, 1, 1, 1]).contiguous(),
354+
torch.zeros([1, 1, 1, 1])
355+
.contiguous()
356+
.to(kv_cache_dtype),
357+
torch.zeros([1, 1, 1, 1])
358+
.contiguous()
359+
.to(kv_cache_dtype),
338360
beam_idx_tmp,
339361
)
340362
for i in range(num_hidden_layers)

0 commit comments

Comments
 (0)