Conversation
78c6657 to
5764274
Compare
|
This PR has been inactive for 10 days and is now marked as stale. |
|
This PR has been inactive for 10 days and is now marked as stale. |
5d949a3 to
a99e964
Compare
89b9bca to
e779dbb
Compare
|
bugbot run |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 4 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
Comment @cursor review or bugbot run to trigger another review on this PR
sharpenb
left a comment
There was a problem hiding this comment.
The overall structure is clear.
- I did not check if the detailed furnciotns could be factorized differently for more compact code.
| if is_moe_lm(model) | ||
| else model_config.moe_topk[0] | ||
| ) | ||
| # qwen_moe can use different intermediate size compared to mixtral. |
There was a problem hiding this comment.
What do we mean by intermediate size?
There was a problem hiding this comment.
I agree that this is what the comment should explain instead
| batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] | ||
|
|
||
| # use ray to parallelize the tuning | ||
| ray.init(ignore_reinit_error=True) |
There was a problem hiding this comment.
The utilisation of ray makes me think of the optimization agent we worked on.
gsprochette
left a comment
There was a problem hiding this comment.
Super nice feature, I have basically no comment on the content but left comments on the form. I left close to no comment on the form of benchmark_config since this is imported, i get the value of keeping it as is.
I do have 4 questions/suggestions about the general structure of the code:
- ray is not declared as a dependency, should it be in an extra e.g. vllm? Could we import it inside the import_algorithm_packages by isolating everything except the
MoeKernelTuner(PrunaAlgorithmBase)in a utils.py and import it only in the import_algorithm_packages method? - Should the tuning be done again if the model is loaded on a setup with a different triton version? If so we can use the reapply saving function and check at the beginning of apply if the artifact already exists and matches the setup, in which case we skip the tuning, but still tune otherwise.
- The _apply method is very long. Below is some suggestion for splitting/simplifying them and making it more readable and also more type-friendly.
- The moe_kernel_tuner.py file is very long, the utils split in question 1. would also make this lighter, WDYT?
For (i) the in _apply method, I think the code should be made clearer. Currently most of the logic is a series of if..else checking whether we are in the general is_moe_lm case, or the HunyuanImage3ForCausalMM exception, and extracting hyperparameters: nb_experts, topk, intermediate_size, hidden_size, shard_intermediate_size.
I think it would be clearer if:
- in (i): a check for HunyuanImage3ForCausalMM -> call _extract_hunyuan_dimensions whose output is nb_experts, shard_intermediate_size, hidden_size and topk, and in the general case call _extract_transformers_moe_dimensions that has the same output
- in each of these functions, get the config and make an actual typing check so we know the attributes exist. The docstring of these functions, or the comment in (i) when collecting these functions can explain what these different variables represent in the moe operations
| from .base_tester import AlgorithmTesterBase | ||
|
|
||
|
|
||
| class TestMoeKernelTuner(AlgorithmTesterBase): |
There was a problem hiding this comment.
I'm not sure that AlgorithmTesterBase is really able to test this algorithm, since it's not changing the model. Two options that I would be satisfied with:
- the minimal option would be to manually smash the model and check that the artifacts exist
- the more advanced options would be to implement post a smash hook checking for the artifacts and an other hook after loading to make sure the artifact_loader reloaded the missing config files (e.g. by deleting the artifacts not in the saved_model_dir and checking they are restored after loading).
| "imageio-ffmpeg", | ||
| "jaxtyping", | ||
| "peft>=0.17.1", | ||
| "vllm>=0.11.0", |
There was a problem hiding this comment.
Should vllm be a main dependency? Could we add it into a vllm extra and add it to the list of extras in the CI?
| processor_required: bool = False | ||
| runs_on: list[str] = ["cuda", "accelerate"] | ||
| dataset_required: bool = False | ||
| compatible_before: Iterable[str] = [ |
There was a problem hiding this comment.
can we get this list of tags as the result of a function/call on tags so it's automatically expanded if we add new tags?
|
|
||
| class MoeKernelTuner(PrunaAlgorithmBase): | ||
| """ | ||
| Tune the MoE Triton kernel for the model. |
There was a problem hiding this comment.
Maybe this is a good place to state that MoE stands for Mixture of Experts
| model_config = model.config | ||
| if model_config is None: | ||
| raise ValueError(f"Model {model.__class__.__name__} has no config.") | ||
| nb_experts = model_config.num_experts # number of experts |
There was a problem hiding this comment.
the comment in uninformative and incorrectly formatted, can we remove it?
There was a problem hiding this comment.
For typing reasons, there should be either a protection nb_experts = getattr(model_config, "num_experts", None") followed by a check if nb_experts is None, or better if available, the check above should check for the type of model_config isinstance(model_config, BaseConfigClass) where BaseConfigClass declares the num_experts attribute.
| filename_vllm = path_to_vllm_configs / filename | ||
| if not pathlib.Path(filename_vllm).exists(): | ||
| pruna_logger.info(f"Writing best config to {filename_vllm}...") | ||
| with open(filename_vllm, "w") as f: | ||
| json.dump({"triton_version": imported_packages["triton"].__version__, **configs}, f, indent=4) | ||
| f.write("\n") |
There was a problem hiding this comment.
This block feels redundant with the first one, can we factorize this and or is there specificities I'm not getting?
| keys, values = zip(*param_ranges.items()) | ||
| for config_values in product(*values): | ||
| config = dict(zip(keys, config_values)) | ||
| configs.append(config) | ||
|
|
There was a problem hiding this comment.
| keys, values = zip(*param_ranges.items()) | |
| for config_values in product(*values): | |
| config = dict(zip(keys, config_values)) | |
| configs.append(config) | |
| configs = [ | |
| dict(zip(param_ranges.keys(), config_values)) | |
| for config_values in product(*param_ranges.values()) | |
| ] |
| imported_packages = MoeKernelTuner().import_algorithm_packages() | ||
| save_dir = Path(model_path) / "moe_kernel_tuned_configs" | ||
| with open(save_dir / "moe_kernel_tuner.json") as f: | ||
| payload = json.load(f) |
There was a problem hiding this comment.
why payload? could the name be more explicit?
| save_configs( | ||
| best_configs, | ||
| nb_experts, | ||
| shard_intermediate_size, | ||
| dtype, | ||
| use_fp8_w8a8, | ||
| use_int8_w8a16, | ||
| None, | ||
| smash_config["path_to_huggingface_hub_cache"], | ||
| smash_config["path_to_vllm_cache"], | ||
| imported_packages, | ||
| ) | ||
| # results to be saved for later loading |
There was a problem hiding this comment.
Can the pyload name be improved?
| # store artifacts in pruna cache for later saving/loading | ||
| save_dir = smash_config.cache_dir / "moe_kernel_tuned_configs" | ||
| save_dir.mkdir(parents=True, exist_ok=True) | ||
| with open(save_dir / "moe_kernel_tuner.json", "w") as f: |
There was a problem hiding this comment.
Why not placing it directly in smash_config.cache_dir if only one file will be contained in moe_kernel_tuned_configs?
Description
This PR is inspired from vLLM benchmarks (the benchmark_config fn is copied from here) and enable one to tune the MoE (triton) kernel used in vllm.
This new algorithm
MoeKernelTunerdoes not modify the model. It generates a tuned configuration that is saved in:kernelslib will make use of the optimized config);moe_kernel_tuned_configsin the model directory (to be later re-used without waiting for tuning, when loading the model with pruna).The core modifications are in:
rayfor parallelization; (iv) the best configurations are saved in hf and vllm caches (so that after smashing, hf cache and vllm cache are already populated with optimal configs that the user can use), and in the pruna cache (similar to what we do withsave_before_apply);Related Issue
Fixes #(issue number)
Type of Change
How Has This Been Tested?
Checklist
Additional Notes
Notebook for testing with vllm is available here. On H100 for qwen3Coder-30B, latency goes from 6.43ms (before tuning) to 5.83 ms (after tuning) while using
vllm.