Skip to content
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ ipython_config.py
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# uv lock file
uv.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

Expand Down Expand Up @@ -163,4 +166,4 @@ cython_debug/
# ignore llama repository in resources
/resources/llama.cpp/

tests/openai
tests/openai
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ dev = [
"twine",
"pyc-wheel",
"ruff",
"numpydoc>=1.9.0",
"numpydoc-validation",
"pytest",
"pytest-cov",
Expand Down
136 changes: 136 additions & 0 deletions src/pruna/algorithms/reduce_noe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import json
import tempfile
from collections.abc import Iterable
from pathlib import Path
from typing import Any

from ConfigSpace import UniformIntegerHyperparameter
from transformers import AutoModelForCausalLM

from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
from pruna.algorithms.base.tags import AlgorithmTag as tags
from pruna.config.hyperparameters import UnconstrainedHyperparameter
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.engine.model_checks import is_moe_lm, is_transformers_pipeline_with_moe_lm
from pruna.engine.utils import get_device_map, move_to_device, safe_memory_cleanup


class ReduceNOE(PrunaAlgorithmBase):
"""
Implement ReduceNOE for LMs and diffusers pipelines with MoE blocks.

ReduceNOE is a method to Reduce the Number Of Experts per token.
"""

algorithm_name: str = "reduce_noe"
group_tags: list[str] = [tags.PRUNER]
references: dict[str, str] = {}
tokenizer_required: bool = False
processor_required: bool = False
dataset_required: bool = False
runs_on: list[str] = ["cuda", "accelerate"]
save_fn: None = None
compatible_after: Iterable[str] = ["*"]

def get_hyperparameters(self) -> list:
"""
Configure all algorithm-specific hyperparameters with ConfigSpace.

Returns
-------
list
The hyperparameters.
"""
return [
UniformIntegerHyperparameter(
name="num_experts_per_token",
lower=1,
upper=256,
default_value=2,
meta=dict(desc="Number of experts triggered per token."),
),
UnconstrainedHyperparameter(
name="target_name",
default_value="num_experts_per_tok",
meta=dict(
desc="Name of of the parameter in the config.json file to be modified, "
"e.g. 'num_experts_per_tok' for mixtral models. "
),
),
]

def model_check_fn(self, model: Any) -> bool:
"""
Check if the model is a causal language model or a diffusers pipeline with a MoE block.

Parameters
----------
model : Any
The model to check.

Returns
-------
bool
True if the model is a MoE LM or a transformers pipeline with a MoE block, False otherwise.
"""
# Hunyuan3-image is a MoE model, but not depending on mixtral
if model.__class__.__name__ == "HunyuanImage3ForCausalMM":
return True
else:
return is_moe_lm(model) or is_transformers_pipeline_with_moe_lm(model)

def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
"""
Reduce the number of experts per token in the config.

Parameters
----------
model : Any
The model to reduce the number of experts per token in.
smash_config : SmashConfigPrefixWrapper
The configuration for the reduction of the number of experts per token.

Returns
-------
Any
The model with the reduced number of experts per token.
"""
if is_transformers_pipeline_with_moe_lm(model):
return self._apply_to_model_within_transformers_pipeline(model, smash_config)

device_map = get_device_map(model)
# we need to save and reload with the new config, because immutable object.
with tempfile.TemporaryDirectory() as temp_dir:
move_to_device(model, "cpu")
model.save_pretrained(temp_dir)
config_path = Path(temp_dir) / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"Config file not found at {config_path}")
else:
with config_path.open("r", encoding="utf-8") as f:
config_json = json.load(f)
target_name = smash_config["target_name"]
if target_name not in config_json:
raise KeyError(f"Target name '{target_name}' not found in config file at {config_path}")
config_json[target_name] = smash_config["num_experts_per_token"]
with config_path.open("w", encoding="utf-8") as f:
json.dump(config_json, f, indent=2)
safe_memory_cleanup()
model = AutoModelForCausalLM.from_pretrained(temp_dir, device_map=device_map)
return model
36 changes: 36 additions & 0 deletions src/pruna/engine/model_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,25 @@ def is_speech_seq2seq_model(model: Any) -> bool:
return False


def is_moe_lm(model: Any) -> bool:
"""
Check if the model is a MoE LM.

Currently all MoE LMs are based on Mixtral in transformers.

Parameters
----------
model : Any
The model to check.

Returns
-------
bool
True if the model is a MoE LM, False otherwise.
"""
return hasattr(model, "num_experts")


def is_transformers_pipeline_with_causal_lm(model: Any) -> bool:
"""
Check if the model is a transformers pipeline (for tasks like text generation, classification, etc.).
Expand Down Expand Up @@ -158,6 +177,23 @@ def is_transformers_pipeline_with_speech_recognition(model: Any) -> bool:
)


def is_transformers_pipeline_with_moe_lm(model: Any) -> bool:
"""
Check if the model is a transformers pipeline with a MoE LM.

Parameters
----------
model : Any
The model to check.

Returns
-------
bool
True if the model is a transformers pipeline with a MoE LM, False otherwise.
"""
return isinstance(model, TextGenerationPipeline) and is_moe_lm(getattr(model, "model", None))


def is_diffusers_pipeline(model: Any, include_video: bool = False) -> bool:
"""
Check if the model is a diffusers pipeline.
Expand Down
13 changes: 13 additions & 0 deletions tests/algorithms/testers/reduce_noe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pruna.algorithms.reduce_noe import ReduceNOE

from .base_tester import AlgorithmTesterBase


class TestReduceNOE(AlgorithmTesterBase):
"""Test the ReduceNOE algorithm."""

models = ["qwen3_next_moe_tiny_random"]
reject_models = ["sd_tiny_random"]
allow_pickle_files = False
algorithm_class = ReduceNOE
metrics = ["perplexity"]
3 changes: 3 additions & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,7 @@ def get_autoregressive_text_to_image_model(model_id: str) -> tuple[Any, SmashCon
"wan_tiny_random": partial(get_diffusers_model, "pruna-test/wan-t2v-tiny-random", torch_dtype=torch.bfloat16),
"flux_tiny": partial(get_diffusers_model, "pruna-test/tiny_flux", torch_dtype=torch.float16),
"tiny_llama": partial(get_automodel_transformers, "pruna-test/tiny_llama", torch_dtype=torch.bfloat16),
"qwen3_next_moe_tiny_random": partial(
get_automodel_transformers, "tiny-random/qwen3-next-moe", torch_dtype=torch.bfloat16
),
}