make ScatterMoe more easily patchable for HF MoE modeling#82
make ScatterMoe more easily patchable for HF MoE modeling#82winglian wants to merge 1 commit intohuggingface:mainfrom
Conversation
|
cc @shawntan |
|
Hi, @shawntan, any chance you could review this PR? |
|
Yeah this looks good and is a good idea, but have you tested it end-to-end with I have an issue with another community kernel: #76 |
|
yes, this was tested with |
|
Okay. I'm trying to test GraniteMoEHybrid with The decision seems to be to UPDATE: confirmed that it breaks Qwen2MoE as well. @winglian Not fully sure how the error didn't show up when testing with @MekkCyber FYI. |
|
Hey @shawntan could you share a code snippet please ? |
|
I'm testing with transformers v5/main since that's really the only way these will work with MoE. Are you also using v5 or v4.57.x? |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import inspect
device = torch.device('cuda:0')
# model_path = "ibm-granite/granite-4.0-h-tiny"
model_path = "Qwen/Qwen1.5-MoE-A2.7B"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
dtype=torch.bfloat16, device_map=device,
use_kernels=True # comment and uncomment this to see the issue.
)
print("Model class path:", inspect.getfile(model.__class__))
# change input text as desired
chat = "Please list one IBM Research laboratory located in the United States. You should only output its name and location."
chat = [{ "role": "user", "content": chat }]
chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
print("Prompt:")
print(chat)
print()
# tokenize the text
input_tokens = tokenizer(chat, return_tensors="pt").to(device)
# generate output tokens
output = model.generate(**input_tokens, max_new_tokens=100)
# decode output tokens into text
output = tokenizer.batch_decode(output)
# print output
print("Output:")
print(output[0])
@winglian v5, but as of now I'm not sure if this error occurs only with generation workflow. |
|
@danieldk anything pending for this merge? |

This PR makes it easier to simply use something like the snippet below for a drop in replacement for all the MoEs based on qwen2 in transformers v5. The original
ScatterMoEGatedMLPisn't quite usable since it relies onrouterinstead ofgateandinput_linear,output_linearinstead ofgate_up_proj,down_projin the v5 modeling.