Skip to content

feat: moe kernel tuning#482

Open
llcnt wants to merge 21 commits intomainfrom
feat/moe_kernel_tuning
Open

feat: moe kernel tuning#482
llcnt wants to merge 21 commits intomainfrom
feat/moe_kernel_tuning

Conversation

@llcnt
Copy link
Collaborator

@llcnt llcnt commented Dec 23, 2025

Description

This PR is inspired from vLLM benchmarks (the benchmark_config fn is copied from here) and enable one to tune the MoE (triton) kernel used in vllm.
This new algorithm MoeKernelTuner does not modify the model. It generates a tuned configuration that is saved in:

  • the vllm configs folder (so that using the model on the same gpu afterward makes vllm use this optimized config);
  • the RedhatAI kernel folder in the hf cache (so that using the moe kernels from the kernels lib will make use of the optimized config);
  • a folder moe_kernel_tuned_configs in the model directory (to be later re-used without waiting for tuning, when loading the model with pruna).

The core modifications are in:

  • the new moe_kernel_tuner.py file ((i) it does not modify the model, so it is compatible with every other algorithm before/after; (ii) the user can select dtypes but also size of parameters gridsearch; (iii) the kernel is tuned for batch sizes(ie the input dimension M) from 1 to 8192 using ray for parallelization; (iv) the best configurations are saved in hf and vllm caches (so that after smashing, hf cache and vllm cache are already populated with optimal configs that the user can use), and in the pruna cache (similar to what we do with save_before_apply);
  • the save_artifacts.py file (we move the tuned config from the pruna cache to the saved path);
  • the load_artifacts.py file (for re-saving the tuned config inside vllm/hf cache when loading a smashed model).

Related Issue

Fixes #(issue number)

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

Notebook for testing with vllm is available here. On H100 for qwen3Coder-30B, latency goes from 6.43ms (before tuning) to 5.83 ms (after tuning) while using vllm.

@llcnt llcnt force-pushed the feat/moe_kernel_tuning branch from 78c6657 to 5764274 Compare December 23, 2025 17:02
@github-actions
Copy link

github-actions bot commented Jan 6, 2026

This PR has been inactive for 10 days and is now marked as stale.

@github-actions github-actions bot added the stale label Jan 6, 2026
@llcnt llcnt removed the stale label Jan 7, 2026
@llcnt llcnt marked this pull request as ready for review January 7, 2026 14:21
@github-actions
Copy link

This PR has been inactive for 10 days and is now marked as stale.

@github-actions github-actions bot added the stale label Jan 18, 2026
@llcnt llcnt mentioned this pull request Jan 19, 2026
10 tasks
@github-actions github-actions bot closed this Jan 25, 2026
@llcnt llcnt removed the stale label Jan 26, 2026
@llcnt llcnt reopened this Jan 26, 2026
@llcnt llcnt force-pushed the feat/moe_kernel_tuning branch 2 times, most recently from 5d949a3 to a99e964 Compare February 9, 2026 14:52
@llcnt llcnt force-pushed the feat/moe_kernel_tuning branch from 89b9bca to e779dbb Compare February 9, 2026 15:09
@llcnt
Copy link
Collaborator Author

llcnt commented Feb 9, 2026

bugbot run

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 4 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

Copy link
Member

@sharpenb sharpenb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall structure is clear.

  • I did not check if the detailed furnciotns could be factorized differently for more compact code.

if is_moe_lm(model)
else model_config.moe_topk[0]
)
# qwen_moe can use different intermediate size compared to mixtral.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we mean by intermediate size?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this is what the comment should explain instead

batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]

# use ray to parallelize the tuning
ray.init(ignore_reinit_error=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The utilisation of ray makes me think of the optimization agent we worked on.

Copy link
Collaborator

@gsprochette gsprochette left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice feature, I have basically no comment on the content but left comments on the form. I left close to no comment on the form of benchmark_config since this is imported, i get the value of keeping it as is.

I do have 4 questions/suggestions about the general structure of the code:

  1. ray is not declared as a dependency, should it be in an extra e.g. vllm? Could we import it inside the import_algorithm_packages by isolating everything except the MoeKernelTuner(PrunaAlgorithmBase) in a utils.py and import it only in the import_algorithm_packages method?
  2. Should the tuning be done again if the model is loaded on a setup with a different triton version? If so we can use the reapply saving function and check at the beginning of apply if the artifact already exists and matches the setup, in which case we skip the tuning, but still tune otherwise.
  3. The _apply method is very long. Below is some suggestion for splitting/simplifying them and making it more readable and also more type-friendly.
  4. The moe_kernel_tuner.py file is very long, the utils split in question 1. would also make this lighter, WDYT?

For (i) the in _apply method, I think the code should be made clearer. Currently most of the logic is a series of if..else checking whether we are in the general is_moe_lm case, or the HunyuanImage3ForCausalMM exception, and extracting hyperparameters: nb_experts, topk, intermediate_size, hidden_size, shard_intermediate_size.
I think it would be clearer if:

  • in (i): a check for HunyuanImage3ForCausalMM -> call _extract_hunyuan_dimensions whose output is nb_experts, shard_intermediate_size, hidden_size and topk, and in the general case call _extract_transformers_moe_dimensions that has the same output
  • in each of these functions, get the config and make an actual typing check so we know the attributes exist. The docstring of these functions, or the comment in (i) when collecting these functions can explain what these different variables represent in the moe operations

from .base_tester import AlgorithmTesterBase


class TestMoeKernelTuner(AlgorithmTesterBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that AlgorithmTesterBase is really able to test this algorithm, since it's not changing the model. Two options that I would be satisfied with:

  • the minimal option would be to manually smash the model and check that the artifacts exist
  • the more advanced options would be to implement post a smash hook checking for the artifacts and an other hook after loading to make sure the artifact_loader reloaded the missing config files (e.g. by deleting the artifacts not in the saved_model_dir and checking they are restored after loading).

"imageio-ffmpeg",
"jaxtyping",
"peft>=0.17.1",
"vllm>=0.11.0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should vllm be a main dependency? Could we add it into a vllm extra and add it to the list of extras in the CI?

processor_required: bool = False
runs_on: list[str] = ["cuda", "accelerate"]
dataset_required: bool = False
compatible_before: Iterable[str] = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we get this list of tags as the result of a function/call on tags so it's automatically expanded if we add new tags?


class MoeKernelTuner(PrunaAlgorithmBase):
"""
Tune the MoE Triton kernel for the model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is a good place to state that MoE stands for Mixture of Experts

model_config = model.config
if model_config is None:
raise ValueError(f"Model {model.__class__.__name__} has no config.")
nb_experts = model_config.num_experts # number of experts
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment in uninformative and incorrectly formatted, can we remove it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For typing reasons, there should be either a protection nb_experts = getattr(model_config, "num_experts", None") followed by a check if nb_experts is None, or better if available, the check above should check for the type of model_config isinstance(model_config, BaseConfigClass) where BaseConfigClass declares the num_experts attribute.

Comment on lines +791 to +796
filename_vllm = path_to_vllm_configs / filename
if not pathlib.Path(filename_vllm).exists():
pruna_logger.info(f"Writing best config to {filename_vllm}...")
with open(filename_vllm, "w") as f:
json.dump({"triton_version": imported_packages["triton"].__version__, **configs}, f, indent=4)
f.write("\n")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block feels redundant with the first one, can we factorize this and or is there specificities I'm not getting?

Comment on lines +504 to +508
keys, values = zip(*param_ranges.items())
for config_values in product(*values):
config = dict(zip(keys, config_values))
configs.append(config)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
keys, values = zip(*param_ranges.items())
for config_values in product(*values):
config = dict(zip(keys, config_values))
configs.append(config)
configs = [
dict(zip(param_ranges.keys(), config_values))
for config_values in product(*param_ranges.values())
]

imported_packages = MoeKernelTuner().import_algorithm_packages()
save_dir = Path(model_path) / "moe_kernel_tuned_configs"
with open(save_dir / "moe_kernel_tuner.json") as f:
payload = json.load(f)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why payload? could the name be more explicit?

Comment on lines +267 to +279
save_configs(
best_configs,
nb_experts,
shard_intermediate_size,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
None,
smash_config["path_to_huggingface_hub_cache"],
smash_config["path_to_vllm_cache"],
imported_packages,
)
# results to be saved for later loading
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the pyload name be improved?

# store artifacts in pruna cache for later saving/loading
save_dir = smash_config.cache_dir / "moe_kernel_tuned_configs"
save_dir.mkdir(parents=True, exist_ok=True)
with open(save_dir / "moe_kernel_tuner.json", "w") as f:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not placing it directly in smash_config.cache_dir if only one file will be contained in moe_kernel_tuned_configs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants