From 81cdc759e68002e01685017b99f91ea8613d7bae Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 22 May 2026 20:49:55 +0900 Subject: [PATCH] [Frontend] Fix index/i64 type mismatch in expert-mask codegen (issue #228) The second clause of the index-cast guard at mlir_ops.py:268 compared the whole [tile_size, dtype] list against the string "index", which is never true. With i64 on the lhs and index on the rhs the code fell through to the same-bit-width branch (MLIR_TO_BIT["i64"] == MLIR_TO_BIT["index"] == 64) and emitted arith.cmpi between vector and vector, which mlir-opt rejects. This blocked any MoE model whose (i64_buf == arange) expert-mask pattern landed with that operand orientation -- first observed on deepseek_v3. Fix replaces the dead clause with op_type2[1] == "index" so the operand2-side index_cast at lines 285-288 is reachable, normalizing the rhs to i64 before the cmpi. Add tests/test_expert_mask.py as a focused regression covering the (expert_idx_i64.unsqueeze(-1) == arange(N)) -> torch.where pattern. --- PyTorchSimFrontend/mlir/mlir_ops.py | 2 +- tests/test_expert_mask.py | 46 +++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 tests/test_expert_mask.py diff --git a/PyTorchSimFrontend/mlir/mlir_ops.py b/PyTorchSimFrontend/mlir/mlir_ops.py index 58e8b73b..217129e8 100644 --- a/PyTorchSimFrontend/mlir/mlir_ops.py +++ b/PyTorchSimFrontend/mlir/mlir_ops.py @@ -265,7 +265,7 @@ def binary_elementwise_common(operand1, operand2): # Data type check if op_type1[1] != op_type2[1]: - if op_type1[1] == "index" or op_type1 == "index": + if op_type1[1] == "index" or op_type2[1] == "index": if op_type1[1] == "index": # index -> target type: 2-step casting if target is float if op_type2[1][0] == "f": diff --git a/tests/test_expert_mask.py b/tests/test_expert_mask.py new file mode 100644 index 00000000..4d240206 --- /dev/null +++ b/tests/test_expert_mask.py @@ -0,0 +1,46 @@ +import torch +import torch._dynamo +import torch.utils.cpp_extension + + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + + +def test_expert_mask(device, batch=4, num_experts=8): + # Regression test for issue #228: + # (i64_buf == arange) was emitting arith.cmpi with mismatched + # vector / vector operands because the operand2-side + # index-cast branch in binary_elementwise_common was guarded by a + # typo'd condition. + def expert_mask(expert_idx, scores): + j = torch.arange(num_experts, device=expert_idx.device, dtype=torch.int64) + mask = expert_idx.unsqueeze(-1) == j.unsqueeze(0) + return torch.where(mask, scores, torch.zeros_like(scores)) + + expert_idx = torch.randint(0, num_experts, (batch,), dtype=torch.int64) + scores = torch.randn(batch, num_experts, dtype=torch.float32) + + cpu_out = expert_mask(expert_idx, scores) + + opt_fn = torch.compile(dynamic=False)(expert_mask) + npu_out = opt_fn(expert_idx.to(device=device), scores.to(device=device)) + + test_result("ExpertMask (i64 == arange)", npu_out, cpu_out) + + +if __name__ == "__main__": + device = torch.device("npu:0") + test_expert_mask(device)