Skip to content

make ScatterMoe more easily patchable for HF MoE modeling#82

Open
winglian wants to merge 1 commit intohuggingface:mainfrom
winglian:hf-moe
Open

make ScatterMoe more easily patchable for HF MoE modeling#82
winglian wants to merge 1 commit intohuggingface:mainfrom
winglian:hf-moe

Conversation

@winglian
Copy link

@winglian winglian commented Dec 5, 2025

This PR makes it easier to simply use something like the snippet below for a drop in replacement for all the MoEs based on qwen2 in transformers v5. The original ScatterMoEGatedMLP isn't quite usable since it relies on router instead of gate and input_linear, output_linear instead of gate_up_proj, down_proj in the v5 modeling.

register_kernel_mapping({
    "HFScatterMoEParallelExperts": {
        "cuda": {
            Mode.TRAINING: LayerRepository(
                repo_id="axolotl-ai-co/scattermoe",
                layer_name="HFScatterMoEGatedMLP",
            ),
            Mode.INFERENCE: LayerRepository(
                repo_id="axolotl-ai-co/scattermoe",
                layer_name="HFScatterMoEGatedMLP",
            ),
        },
    }
})
replace_kernel_forward_from_hub(Qwen2MoeSparseMoeBlock, "HFScatterMoEParallelExperts")

@winglian winglian requested a review from MekkCyber as a code owner December 5, 2025 14:45
@MekkCyber
Copy link
Contributor

cc @shawntan

@danieldk
Copy link
Member

danieldk commented Dec 5, 2025

Hi, @shawntan, any chance you could review this PR?

@winglian
Copy link
Author

winglian commented Dec 5, 2025

Here's a quick SFT test of olmoe.

Screenshot 2025-12-05 at 9 46 36 AM

@shawntan
Copy link
Contributor

shawntan commented Dec 8, 2025

Yeah this looks good and is a good idea, but have you tested it end-to-end with use_kernels=True?

I have an issue with another community kernel: #76

@winglian
Copy link
Author

yes, this was tested with use_kernels=True

@shawntan
Copy link
Contributor

shawntan commented Dec 12, 2025

Okay. I'm trying to test GraniteMoEHybrid with use_kernels=True to make sure these kernels work, but it seems the SiLU kernel does not work with non-contiguous tensors, and it seems both GraniteMoEHybrid and Qwen2MoE will be affected by this. See this comment: #76 (comment)

The decision seems to be to assert hidden_states.is_contiguous(). Which breaks both models, as far as I understand.

UPDATE: confirmed that it breaks Qwen2MoE as well.

@winglian Not fully sure how the error didn't show up when testing with use_kernels=True, but I'm interested to know how you get around it.

@MekkCyber FYI.

@MekkCyber
Copy link
Contributor

Hey @shawntan could you share a code snippet please ?

@winglian
Copy link
Author

I'm testing with transformers v5/main since that's really the only way these will work with MoE. Are you also using v5 or v4.57.x?

@shawntan
Copy link
Contributor

shawntan commented Dec 18, 2025

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

import inspect


device = torch.device('cuda:0')
# model_path = "ibm-granite/granite-4.0-h-tiny"
model_path = "Qwen/Qwen1.5-MoE-A2.7B"
tokenizer = AutoTokenizer.from_pretrained(model_path)


model = AutoModelForCausalLM.from_pretrained(
    model_path,
    dtype=torch.bfloat16, device_map=device,
    use_kernels=True # comment and uncomment this to see the issue.
)

print("Model class path:", inspect.getfile(model.__class__))
# change input text as desired
chat = "Please list one IBM Research laboratory located in the United States. You should only output its name and location."
chat = [{ "role": "user", "content": chat }]
chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
print("Prompt:")
print(chat)
print()
# tokenize the text
input_tokens = tokenizer(chat, return_tensors="pt").to(device)
# generate output tokens
output = model.generate(**input_tokens, max_new_tokens=100)
# decode output tokens into text
output = tokenizer.batch_decode(output)
# print output
print("Output:")
print(output[0])

use_kernels=True:

use_kernel_func_from_hub is not available in the installed kernels version. Please upgrade kernels to use this feature.
Loading weights: 100%|█████████████████████████████████| 387/387 [00:10<00:00, 36.60it/s, Materializing param=model.norm.weight]
Fetching 7 files: 100%|█████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 3268.41it/s]
Fetching 17 files: 100%|██████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 4678.07it/s]
Download complete: : 0.00B [00:00, ?B/s] [00:00, ?B/s]
Download complete: : 0.00B [00:00, ?B/s]                                                                 | 0/17 [00:00<?, ?it/s]
Model class path: /proj/checkpoints/shawntan/transformers/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Prompt:
<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Please list one IBM Research laboratory located in the United States. You should only output its name and location.<|im_end|>
<|im_start|>assistant


Output:
<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Please list one IBM Research laboratory located in the United States. You should only output its name and location.<|im_end|>
<|im_start|>assistant
866-5999.<|endoftext|>

use_kernels=False:

use_kernel_func_from_hub is not available in the installed kernels version. Please upgrade kernels to use this feature.
Loading weights: 100%|█████████████████████████████████| 387/387 [00:10<00:00, 36.75it/s, Materializing param=model.norm.weight]
Model class path: /proj/checkpoints/shawntan/transformers/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Prompt:
<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Please list one IBM Research laboratory located in the United States. You should only output its name and location.<|im_end|>
<|im_start|>assistant


Output:
<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Please list one IBM Research laboratory located in the United States. You should only output its name and location.<|im_end|>
<|im_start|>assistant
You are a helpful assistant. Please list one IBM Research laboratory located in the United States. You should only output its name and location. I'm sorry, as an AI language model, I don't have access to current information. However, you can visit the IBM Research website to find a list of their laboratories located in the United States.<|endoftext|>

@winglian v5, but as of now I'm not sure if this error occurs only with generation workflow.

@sayakpaul
Copy link
Member

@danieldk anything pending for this merge?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants