add scattermoe kernel for fast MoE training#40365
add scattermoe kernel for fast MoE training#40365mayank31398 wants to merge 5 commits intohuggingface:mainfrom
Conversation
ArthurZucker
left a comment
There was a problem hiding this comment.
Hello @mayank31398 ! Nice pr, happy to add something like that, do you mind using kernels like what we do for GPT_OSS?!
This way we keep a slow path, compatible with all torch, all hardwares etc and don't have code changes for the core modeling, and just have the kernel on the hub!
WDYT? 🤗
|
@ArthurZucker scattermoe doesnt support bias for now, I will add this soon! |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: granitemoe |
Is there an existing triton kernel you could point to that I could follow? |
|
This one: https://huggingface.co/kernels-community/megablocks/tree/main/torch-ext/megablocks (might not be triton) and otherwise https://huggingface.co/kernels-community/triton_kernels fully triton! |
|
Sorry it took a while. I've tried to piece together the various guides for community kernels to package https://huggingface.co/shawntan/scattermoe This is what I have so far, and the following seems to work: from kernels import (
LocalLayerRepository,
use_kernel_mapping,
Mode,
use_kernel_forward_from_hub,
kernelize
)
from transformers import AutoTokenizer, AutoConfig
from transformers.activations import ACT2FN
from pathlib import Path
import torch
from torch import nn
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
GraniteMoeHybridConfig,
GraniteMoeHybridParallelExperts,
GraniteMoeHybridTopKGating,
)
@use_kernel_forward_from_hub('ScatterMoEGatedMLP')
class GraniteMoeHybridMoE(nn.Module):
"""
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
Args:
config:
Configuration object with model hyperparameters.
"""
def __init__(self, config: GraniteMoeHybridConfig):
super().__init__()
self.input_size = config.hidden_size
self.hidden_size = config.intermediate_size
self.activation = ACT2FN[config.hidden_act]
self.input_linear = GraniteMoeHybridParallelExperts(
config.num_local_experts, self.input_size, self.hidden_size * 2
)
self.output_linear = GraniteMoeHybridParallelExperts(
config.num_local_experts, self.hidden_size, self.input_size
)
self.router = GraniteMoeHybridTopKGating(
input_size=self.input_size,
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
)
def forward(self, layer_input):
"""
Forward pass of the mixture of experts layer.
Args:
layer_input (Tensor):
Input tensor.
Returns:
Tensor:
Output tensor.
Tensor:
Router logits.
"""
bsz, length, emb_size = layer_input.size()
layer_input = layer_input.reshape(-1, emb_size)
_, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)
expert_inputs = layer_input[batch_index]
hidden_states = self.input_linear(expert_inputs, expert_size)
chunked_hidden_states = hidden_states.chunk(2, dim=-1)
hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
expert_outputs = self.output_linear(hidden_states, expert_size)
expert_outputs = expert_outputs * batch_gates[:, None]
zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
layer_output = zeros.index_add(0, batch_index, expert_outputs)
layer_output = layer_output.view(bsz, length, self.input_size)
return layer_output, router_logits
model_path = "ibm-granite/granite-4.0-h-tiny-base"
device = torch.device("cuda")
kernel_layer_mapping = {
"ScatterMoEGatedMLP": {
"cuda": LocalLayerRepository(
repo_path=Path("/u/shawntan/hf_scattermoe"),
package_name='scattermoe',
layer_name="ScatterMoEGatedMLP"
)
# "cuda": LayerRepository(
# repo_id='shawntan/scattermoe',
# layer_name='ScatterMoEGatedMLP'
# )
}
}
# scattermoe = get_kernel("shawntan/scattermoe")
config = AutoConfig.from_pretrained(model_path, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = GraniteMoeHybridMoE(config).to(device)
for p in model.parameters():
torch.nn.init.normal_(p, std=0.02)
x = torch.randn(4, 4096, 1536, device=device)
out_reference, _ = model(x)
with use_kernel_mapping(kernel_layer_mapping):
model = kernelize(model, mode=Mode.TRAINING)
out_kernel, _ = model(x)
print((out_reference - out_kernel).abs().max())What further steps need to be done for submitting it to community kernels? And also to include it in the |
|
@ArthurZucker any thoughts? |
|
Hey @shawntan thanks for this contribution, the kernel looks good to me! We can have it in |
|
Alright! What do you need specifically? I can come up with some benchmarks for the Granite models. |
MekkCyber
left a comment
There was a problem hiding this comment.
Yes that would be great! we need some latency & memory benchmarks with and without the kernel for different seq len sizes to see if we have some speedups or improved memory consumption
|
https://arxiv.org/pdf/2403.08245 The report for ScatterMoE has some benchmarks on previous models compared against Megablocks. It includes both speedup and memory usage comparisons.
|
|
Very nice performance @shawntan ! Thanks for sharing ! So the kernel is used for training not inference ? |
|
It can be used for both. But gains will mainly come for training, or in cases of prefill. I've made the draft changes for |
|
Started another PR here #41458 because it is significantly different from the current one. |
|
ciao |

No description provided.