Skip to content

Commit 344551f

Browse files
authored
fix bug for support deepspeed tp load int4 checkpoint (#3396)
* fix bug. * flake format.
1 parent a0c5475 commit 344551f

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

examples/cpu/llm/inference/distributed/run_accuracy_with_deepspeed.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,13 +404,13 @@ def write_checkpoints_json():
404404
self.model = self.model.module
405405
import pathlib
406406

407+
low_precision_checkpoint = None
407408
if args.low_precision_checkpoint != "":
408409
pathname = args.low_precision_checkpoint
409410
assert os.path.exists(
410411
pathname
411412
), f"Checkpoint file does not exist: {pathname}"
412413
if os.path.isfile(pathname):
413-
low_precision_checkpoint = None
414414
if pathname.endswith(".pt") or pathname.endswith(".pth"):
415415
low_precision_checkpoint = torch.load(pathname, weights_only=True)
416416
elif pathname.endswith(".safetensors"):
@@ -625,9 +625,12 @@ def write_checkpoints_json():
625625
low_precision_checkpoint_dict[key] = data[
626626
:, q_head_start * dim : q_head_end * dim
627627
]
628-
low_precision_dict = (low_precision_checkpoint_dict, quant_method)
628+
low_precision_checkpoint = (
629+
low_precision_checkpoint_dict,
630+
quant_method,
631+
)
629632
else:
630-
low_precision_dict = None
633+
low_precision_checkpoint = None
631634

632635
self.model = ipex.llm.optimize(
633636
self.model.eval(),
@@ -636,7 +639,7 @@ def write_checkpoints_json():
636639
inplace=True,
637640
deployment_mode=False,
638641
cache_weight_for_large_batch=args.cache_weight_for_large_batch,
639-
low_precision_checkpoint=low_precision_dict,
642+
low_precision_checkpoint=low_precision_checkpoint,
640643
)
641644

642645
self.base_model = self.model

examples/cpu/llm/inference/distributed/run_generation_with_deepspeed.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -489,14 +489,14 @@ def write_checkpoints_json():
489489
# to ipex
490490
if use_ipex:
491491
ipex_woq_enabled = args.ipex_weight_only_quantization
492+
low_precision_checkpoint = None
492493
if ipex_woq_enabled:
493494
if args.low_precision_checkpoint != "":
494495
pathname = args.low_precision_checkpoint
495496
assert os.path.exists(
496497
pathname
497498
), f"Checkpoint file does not exist: {pathname}"
498499
if os.path.isfile(pathname):
499-
low_precision_checkpoint = None
500500
if pathname.endswith(".pt") or pathname.endswith(".pth"):
501501
low_precision_checkpoint = torch.load(pathname, weights_only=True)
502502
elif pathname.endswith(".safetensors"):
@@ -703,10 +703,10 @@ def write_checkpoints_json():
703703
low_precision_checkpoint_dict[key] = data[
704704
:, q_head_start * dim : q_head_end * dim
705705
]
706-
low_precision_dict = (low_precision_checkpoint_dict, quant_method)
706+
low_precision_checkpoint = (low_precision_checkpoint_dict, quant_method)
707707

708708
else:
709-
low_precision_dict = None
709+
low_precision_checkpoint = None
710710

711711
model = ipex.llm.optimize(
712712
model.eval(),
@@ -715,7 +715,7 @@ def write_checkpoints_json():
715715
inplace=True,
716716
deployment_mode=args.deployment_mode,
717717
cache_weight_for_large_batch=args.cache_weight_for_large_batch,
718-
low_precision_checkpoint=low_precision_dict,
718+
low_precision_checkpoint=low_precision_checkpoint,
719719
)
720720
# Generate
721721
print_rank0(f"*** Starting to generate {num_tokens} tokens with bs={args.batch_size}")

0 commit comments

Comments
 (0)