Skip to content

Commit b24885d

Browse files
authored
Fix TPP scalar mul fusion issue (#3397)
1 parent aa0ea39 commit b24885d

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

intel_extension_for_pytorch/transformers/models/reference/modules/decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def MllamaVisionEncoderLayer_forward(
9797
hidden_state = self.mlp.fc2(hidden_state)
9898
hidden_state = self.gate_ffn.tanh() * hidden_state
9999
else:
100-
hidden_state = self.mlp_linear_mul(hidden_state, self.gate_ffn.tanh())
100+
hidden_state = self.mlp_linear_mul(
101+
hidden_state, self.gate_ffn.tanh().expand_as(residual)
102+
)
101103
hidden_state = residual + hidden_state
102104
else:
103105
if self.distributed:

tests/cpu/test_ipex_optimize_transformers_nightly.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,13 @@
195195
lambda m: m.model.layers[0].self_attn.__class__,
196196
lambda m: m.model.layers[0].__class__,
197197
),
198-
# TODO: uncomment when TPP issue is fixed
199-
# model_info(
200-
# "mllama",
201-
# transformers.models.mllama.modeling_mllama.MllamaForConditionalGeneration,
202-
# True,
203-
# lambda m: m.language_model.model.layers[0].self_attn.__class__,
204-
# lambda m: m.language_model.model.layers[0].__class__,
205-
# ),
198+
model_info(
199+
"mllama",
200+
transformers.models.mllama.modeling_mllama.MllamaForConditionalGeneration,
201+
True,
202+
lambda m: m.language_model.model.layers[0].self_attn.__class__,
203+
lambda m: m.language_model.model.layers[0].__class__,
204+
),
206205
model_info(
207206
"maira2",
208207
Maira2ForConditionalGeneration,

0 commit comments

Comments
 (0)