Skip to content

Commit 6cd8b2d

Browse files
authored
Fix issue in IPEX TP (#3409)
1 parent a6f288b commit 6cd8b2d

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

intel_extension_for_pytorch/transformers/optimize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def lowering_class_cpu(m, target_m, new_class, config, tpp=False, woq=False):
5353

5454

5555
distributed = False
56+
is_deepspeed = False
5657

5758

5859
def is_distributed(m, ds_layers):
@@ -62,6 +63,8 @@ def is_distributed(m, ds_layers):
6263
ds_layers,
6364
):
6465
global distributed
66+
global is_deepspeed
67+
is_deepspeed = True
6568
distributed = True
6669
return
6770
is_distributed(sub_m, ds_layers)
@@ -465,7 +468,8 @@ def model_convert_reference(_model):
465468
rank = ipex_comm.get_rank() if ipex_comm.has_ccl else 0
466469
if world_size > 1:
467470
global distributed
468-
if distributed:
471+
global is_deepspeed
472+
if is_deepspeed:
469473
need_ipex_tp = False
470474
else:
471475
need_ipex_tp = True

intel_extension_for_pytorch/transformers/tensor_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def shard_mha_weights(
561561
rank,
562562
world_size,
563563
shard_by_head=True,
564-
value_with_share_qk=True,
564+
value_with_share_qk=value_with_share_qk,
565565
)
566566
# del sub_m.__dict__["_modules"][l_name]
567567
setattr(sub_m, l_name, TPLinear)

tests/cpu/test_ipex_tensor_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def tensor_parallel_with_optimize_transformers(self, model):
119119
input_dict["position_ids"] = position_ids.unsqueeze(0)
120120
ref_m = copy.deepcopy(model)
121121
for dtype in [torch.float32, torch.bfloat16]:
122-
ipex_model = ipex.optimize_transformers(model, dtype=dtype)
122+
ipex_model = ipex.llm.optimize(model, dtype=dtype)
123123
with torch.no_grad(), torch.cpu.amp.autocast(
124124
enabled=True if dtype is torch.bfloat16 else False
125125
):

0 commit comments

Comments
 (0)