From eb9a9112f6230df57d21c8281208fc3350d01a7d Mon Sep 17 00:00:00 2001 From: llcnt Date: Mon, 17 Nov 2025 13:33:44 +0000 Subject: [PATCH 01/11] feat: draft rednoe --- src/pruna/algorithms/red_noe.py | 119 +++++++++++++++++++++++++++++++ src/pruna/engine/model_checks.py | 37 ++++++++++ 2 files changed, 156 insertions(+) create mode 100644 src/pruna/algorithms/red_noe.py diff --git a/src/pruna/algorithms/red_noe.py b/src/pruna/algorithms/red_noe.py new file mode 100644 index 00000000..e398f60d --- /dev/null +++ b/src/pruna/algorithms/red_noe.py @@ -0,0 +1,119 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import tempfile +from collections.abc import Iterable +from pathlib import Path +from typing import Any + +from ConfigSpace import UniformIntegerHyperparameter +import diffusers + +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag as tags +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.model_checks import is_moe_lm, is_transformers_pipeline_with_moe_lm +from pruna.engine.utils import load_json_config, move_to_device, safe_memory_cleanup + + +class RedNOE(PrunaAlgorithmBase): + """ + Implement RedNOE for LMs and diffusers pipelines with MoE blocks. + + RedNOE is a method to Reduce the Number Of Experts per token. + """ + + algorithm_name: str = "red_noe" + group_tags: list[str] = [tags.PRUNER] + references: dict[str, str] = {} + tokenizer_required: bool = False + processor_required: bool = False + dataset_required: bool = False + runs_on: list[str] = ["cuda", "accelerate"] + save_fn: None = None + compatible_after: Iterable[str] = ["*"] + + def get_hyperparameters(self) -> list: + """ + Configure all algorithm-specific hyperparameters with ConfigSpace. + + Returns + ------- + list + The hyperparameters. + """ + return [ + UniformIntegerHyperparameter( + name="num_experts_per_token", + lower=1, + upper=256, + default_value=2, + meta=dict(desc="Number of experts triggered per token."), + ) + ] + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model is a causal language model or a diffusers pipeline with a MoE block. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a causal language model or a diffusers pipeline with a MoE block, False otherwise. + """ + return is_moe_lm(model) or is_transformers_pipeline_with_moe_lm(model) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Reduce the number of experts per token in the config. + + Parameters + ---------- + model : Any + The model to reduce the number of experts per token in. + smash_config : SmashConfigPrefixWrapper + The configuration for the reduction of the number of experts per token. + + Returns + ------- + Any + The model with the reduced number of experts per token. + """ + model_name_or_path = getattr(model, "name_or_path", None) + config_path = Path(model_name_or_path) / "config.json" + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found at {config_path}") + else: + with config_path.open("r", encoding="utf-8") as f: + config_json = json.load(f) + config_json["num_experts_per_tok"] = smash_config["num_experts_per_token"] + with config_path.open("w", encoding="utf-8") as f: + json.dump(config_json, f, indent=2) + with tempfile.TemporaryDirectory() as temp_dir: + move_to_device(model, "cpu") + model.save_pretrained(temp_dir) + # Get the pipeline class name + model_index = load_json_config(temp_dir, "model_index.json") + cls = getattr(diffusers, model_index["_class_name"]) + safe_memory_cleanup() + model = cls.from_pretrained(temp_dir) + return model diff --git a/src/pruna/engine/model_checks.py b/src/pruna/engine/model_checks.py index 3ea41321..5905a78a 100644 --- a/src/pruna/engine/model_checks.py +++ b/src/pruna/engine/model_checks.py @@ -24,6 +24,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, ) +from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM from transformers.pipelines.automatic_speech_recognition import AutomaticSpeechRecognitionPipeline from transformers.pipelines.text2text_generation import Text2TextGenerationPipeline from transformers.pipelines.text_generation import TextGenerationPipeline @@ -105,6 +106,25 @@ def is_speech_seq2seq_model(model: Any) -> bool: return False +def is_moe_lm(model: Any) -> bool: + """ + Check if the model is a MoE LM. + + Currently all MoE LMs are based on Mixtral in transformers. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a MoE LM, False otherwise. + """ + return isinstance(model, MixtralForCausalLM) + + def is_transformers_pipeline_with_causal_lm(model: Any) -> bool: """ Check if the model is a transformers pipeline (for tasks like text generation, classification, etc.). @@ -158,6 +178,23 @@ def is_transformers_pipeline_with_speech_recognition(model: Any) -> bool: ) +def is_transformers_pipeline_with_moe_lm(model: Any) -> bool: + """ + Check if the model is a transformers pipeline with a MoE LM. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a transformers pipeline with a MoE LM, False otherwise. + """ + return isinstance(model, TextGenerationPipeline) and is_moe_lm(getattr(model, "model", None)) + + def is_diffusers_pipeline(model: Any, include_video: bool = False) -> bool: """ Check if the model is a diffusers pipeline. From 592174bfeede2f2e3644bce805c7f5833207f2d6 Mon Sep 17 00:00:00 2001 From: llcnt Date: Mon, 1 Dec 2025 16:14:17 +0000 Subject: [PATCH 02/11] feat: use tmpdir to save load with modified config --- src/pruna/algorithms/red_noe.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/pruna/algorithms/red_noe.py b/src/pruna/algorithms/red_noe.py index e398f60d..595b069c 100644 --- a/src/pruna/algorithms/red_noe.py +++ b/src/pruna/algorithms/red_noe.py @@ -21,13 +21,14 @@ from typing import Any from ConfigSpace import UniformIntegerHyperparameter -import diffusers +from transformers import AutoModelForCausalLM from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag as tags from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.config.target_modules import TargetModules from pruna.engine.model_checks import is_moe_lm, is_transformers_pipeline_with_moe_lm -from pruna.engine.utils import load_json_config, move_to_device, safe_memory_cleanup +from pruna.engine.utils import get_device_map, move_to_device, safe_memory_cleanup class RedNOE(PrunaAlgorithmBase): @@ -63,7 +64,15 @@ def get_hyperparameters(self) -> list: upper=256, default_value=2, meta=dict(desc="Number of experts triggered per token."), - ) + ), + TargetModules( + name="target_name", + default_value="num_experts_per_tok", + meta=dict( + desc="Name of of the parameter in the config.json file to be modified, " + "e.g. 'num_experts_per_tok' for mixtral models. " + ), + ), ] def model_check_fn(self, model: Any) -> bool: @@ -78,9 +87,13 @@ def model_check_fn(self, model: Any) -> bool: Returns ------- bool - True if the model is a causal language model or a diffusers pipeline with a MoE block, False otherwise. + True if the model is a MoE LM or a transformers pipeline with a MoE block, False otherwise. """ - return is_moe_lm(model) or is_transformers_pipeline_with_moe_lm(model) + # Hunyuan3-image is a MoE model, but not depending on mixtral + if model.__class__.__name__ == "HunyuanImage3ForCausalMM": + return True + else: + return is_moe_lm(model) or is_transformers_pipeline_with_moe_lm(model) def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: """ @@ -105,15 +118,14 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: else: with config_path.open("r", encoding="utf-8") as f: config_json = json.load(f) - config_json["num_experts_per_tok"] = smash_config["num_experts_per_token"] + config_json[smash_config["target_name"]] = smash_config["num_experts_per_token"] with config_path.open("w", encoding="utf-8") as f: json.dump(config_json, f, indent=2) + device_map = get_device_map(model) + # we need to save and reload with the new config, because immutable object. with tempfile.TemporaryDirectory() as temp_dir: move_to_device(model, "cpu") model.save_pretrained(temp_dir) - # Get the pipeline class name - model_index = load_json_config(temp_dir, "model_index.json") - cls = getattr(diffusers, model_index["_class_name"]) safe_memory_cleanup() - model = cls.from_pretrained(temp_dir) + model = AutoModelForCausalLM.from_pretrained(temp_dir, device_map=device_map) return model From 855d3289a7d3686bc5ebf76b613aa211c2d9ce30 Mon Sep 17 00:00:00 2001 From: llcnt Date: Mon, 1 Dec 2025 16:44:00 +0000 Subject: [PATCH 03/11] feat: add unit test --- tests/algorithms/testers/red_noe.py | 13 +++++++++++++ tests/fixtures.py | 3 +++ 2 files changed, 16 insertions(+) create mode 100644 tests/algorithms/testers/red_noe.py diff --git a/tests/algorithms/testers/red_noe.py b/tests/algorithms/testers/red_noe.py new file mode 100644 index 00000000..d6f32552 --- /dev/null +++ b/tests/algorithms/testers/red_noe.py @@ -0,0 +1,13 @@ +from pruna.algorithms.red_noe import RedNOE + +from .base_tester import AlgorithmTesterBase + + +class TestRedNOE(AlgorithmTesterBase): + """Test the RedNOE algorithm.""" + + models = ["qwen3_next_moe_tiny_random"] + reject_models = ["sd_tiny_random"] + allow_pickle_files = False + algorithm_class = RedNOE + metrics = ["perplexity"] diff --git a/tests/fixtures.py b/tests/fixtures.py index d5bd55d7..1c2f24c1 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -197,4 +197,7 @@ def get_autoregressive_text_to_image_model(model_id: str) -> tuple[Any, SmashCon "wan_tiny_random": partial(get_diffusers_model, "pruna-test/wan-t2v-tiny-random", torch_dtype=torch.bfloat16), "flux_tiny": partial(get_diffusers_model, "pruna-test/tiny_flux", torch_dtype=torch.float16), "tiny_llama": partial(get_automodel_transformers, "pruna-test/tiny_llama", torch_dtype=torch.bfloat16), + "qwen3_next_moe_tiny_random": partial( + get_automodel_transformers, "tiny-random/qwen3-next-moe", torch_dtype=torch.bfloat16 + ), } From 3aea1ecdbf7eb1d0dca7d0697723186523e82149 Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 2 Dec 2025 15:34:42 +0000 Subject: [PATCH 04/11] feat: make check fn more general and fix device --- src/pruna/algorithms/red_noe.py | 31 +++++++++++++++---------------- src/pruna/engine/model_checks.py | 3 +-- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/pruna/algorithms/red_noe.py b/src/pruna/algorithms/red_noe.py index 595b069c..ebfd8b44 100644 --- a/src/pruna/algorithms/red_noe.py +++ b/src/pruna/algorithms/red_noe.py @@ -67,7 +67,7 @@ def get_hyperparameters(self) -> list: ), TargetModules( name="target_name", - default_value="num_experts_per_tok", + default_value={"include": ["num_experts_per_tok"], "exclude": []}, meta=dict( desc="Name of of the parameter in the config.json file to be modified, " "e.g. 'num_experts_per_tok' for mixtral models. " @@ -111,21 +111,20 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: Any The model with the reduced number of experts per token. """ - model_name_or_path = getattr(model, "name_or_path", None) - config_path = Path(model_name_or_path) / "config.json" - if not config_path.exists(): - raise FileNotFoundError(f"Config file not found at {config_path}") - else: - with config_path.open("r", encoding="utf-8") as f: - config_json = json.load(f) - config_json[smash_config["target_name"]] = smash_config["num_experts_per_token"] - with config_path.open("w", encoding="utf-8") as f: - json.dump(config_json, f, indent=2) - device_map = get_device_map(model) - # we need to save and reload with the new config, because immutable object. - with tempfile.TemporaryDirectory() as temp_dir: - move_to_device(model, "cpu") - model.save_pretrained(temp_dir) + device_map = get_device_map(model) + # we need to save and reload with the new config, because immutable object. + with tempfile.TemporaryDirectory() as temp_dir: + move_to_device(model, "cpu") + model.save_pretrained(temp_dir) + config_path = Path(temp_dir) / "config.json" + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found at {config_path}") + else: + with config_path.open("r", encoding="utf-8") as f: + config_json = json.load(f) + config_json[smash_config["target_name"]["include"][0]] = smash_config["num_experts_per_token"] + with config_path.open("w", encoding="utf-8") as f: + json.dump(config_json, f, indent=2) safe_memory_cleanup() model = AutoModelForCausalLM.from_pretrained(temp_dir, device_map=device_map) return model diff --git a/src/pruna/engine/model_checks.py b/src/pruna/engine/model_checks.py index 5905a78a..6a4bbb75 100644 --- a/src/pruna/engine/model_checks.py +++ b/src/pruna/engine/model_checks.py @@ -24,7 +24,6 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, ) -from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM from transformers.pipelines.automatic_speech_recognition import AutomaticSpeechRecognitionPipeline from transformers.pipelines.text2text_generation import Text2TextGenerationPipeline from transformers.pipelines.text_generation import TextGenerationPipeline @@ -122,7 +121,7 @@ def is_moe_lm(model: Any) -> bool: bool True if the model is a MoE LM, False otherwise. """ - return isinstance(model, MixtralForCausalLM) + return hasattr(model, "num_experts") def is_transformers_pipeline_with_causal_lm(model: Any) -> bool: From 8964d9b36db8829095c7043b484b1c66ace973e6 Mon Sep 17 00:00:00 2001 From: llcnt Date: Wed, 3 Dec 2025 16:45:12 +0000 Subject: [PATCH 05/11] feat: del uv.lock to avoid transformers pined version --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 3c364e4d..e713c7d4 100644 --- a/.gitignore +++ b/.gitignore @@ -109,6 +109,9 @@ ipython_config.py # https://pdm.fming.dev/#use-with-ide .pdm.toml +# uv lock file +uv.lock + # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ @@ -163,4 +166,4 @@ cython_debug/ # ignore llama repository in resources /resources/llama.cpp/ -tests/openai \ No newline at end of file +tests/openai From 3e71ec62633de1de90a670971f37ce17550f5779 Mon Sep 17 00:00:00 2001 From: llcnt Date: Wed, 3 Dec 2025 17:04:38 +0000 Subject: [PATCH 06/11] feat: upd numpydoc version to avoid sphinx version errors --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 44df84bd..9fbd6b8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -173,6 +173,7 @@ dev = [ "twine", "pyc-wheel", "ruff", + "numpydoc>=1.9.0", "numpydoc-validation", "pytest", "pytest-cov", From f0e48b0a259297504b0186fe8f16fbc7b92d4431 Mon Sep 17 00:00:00 2001 From: llcnt Date: Fri, 12 Dec 2025 11:09:38 +0000 Subject: [PATCH 07/11] fix: add _apply for pipeleines --- src/pruna/algorithms/red_noe.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/pruna/algorithms/red_noe.py b/src/pruna/algorithms/red_noe.py index ebfd8b44..d18edf0c 100644 --- a/src/pruna/algorithms/red_noe.py +++ b/src/pruna/algorithms/red_noe.py @@ -111,6 +111,9 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: Any The model with the reduced number of experts per token. """ + if is_transformers_pipeline_with_moe_lm(model): + return self._apply_to_model_within_transformers_pipeline(model, smash_config) + device_map = get_device_map(model) # we need to save and reload with the new config, because immutable object. with tempfile.TemporaryDirectory() as temp_dir: @@ -122,7 +125,12 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: else: with config_path.open("r", encoding="utf-8") as f: config_json = json.load(f) - config_json[smash_config["target_name"]["include"][0]] = smash_config["num_experts_per_token"] + target_names = smash_config["target_name"]["include"] + if not target_names: + raise ValueError( + "The 'include' list in 'target_name' is empty. Please provide at least one config parameter name to modify." + ) + config_json[target_names[0]] = smash_config["num_experts_per_token"] with config_path.open("w", encoding="utf-8") as f: json.dump(config_json, f, indent=2) safe_memory_cleanup() From 69389d970235599bc074d9c953de0c4f9372b4ef Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 23 Dec 2025 16:36:57 +0000 Subject: [PATCH 08/11] feat: change name to reduceNOE --- src/pruna/algorithms/{red_noe.py => reduce_noe.py} | 6 +++--- tests/algorithms/testers/{red_noe.py => reduce_noe.py} | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) rename src/pruna/algorithms/{red_noe.py => reduce_noe.py} (96%) rename tests/algorithms/testers/{red_noe.py => reduce_noe.py} (54%) diff --git a/src/pruna/algorithms/red_noe.py b/src/pruna/algorithms/reduce_noe.py similarity index 96% rename from src/pruna/algorithms/red_noe.py rename to src/pruna/algorithms/reduce_noe.py index d18edf0c..634db7b5 100644 --- a/src/pruna/algorithms/red_noe.py +++ b/src/pruna/algorithms/reduce_noe.py @@ -31,11 +31,11 @@ from pruna.engine.utils import get_device_map, move_to_device, safe_memory_cleanup -class RedNOE(PrunaAlgorithmBase): +class ReduceNOE(PrunaAlgorithmBase): """ - Implement RedNOE for LMs and diffusers pipelines with MoE blocks. + Implement ReduceNOE for LMs and diffusers pipelines with MoE blocks. - RedNOE is a method to Reduce the Number Of Experts per token. + ReduceNOE is a method to Reduce the Number Of Experts per token. """ algorithm_name: str = "red_noe" diff --git a/tests/algorithms/testers/red_noe.py b/tests/algorithms/testers/reduce_noe.py similarity index 54% rename from tests/algorithms/testers/red_noe.py rename to tests/algorithms/testers/reduce_noe.py index d6f32552..aa5dd76b 100644 --- a/tests/algorithms/testers/red_noe.py +++ b/tests/algorithms/testers/reduce_noe.py @@ -1,13 +1,13 @@ -from pruna.algorithms.red_noe import RedNOE +from pruna.algorithms.red_noe import ReduceNOE from .base_tester import AlgorithmTesterBase -class TestRedNOE(AlgorithmTesterBase): - """Test the RedNOE algorithm.""" +class TestReduceNOE(AlgorithmTesterBase): + """Test the ReduceNOE algorithm.""" models = ["qwen3_next_moe_tiny_random"] reject_models = ["sd_tiny_random"] allow_pickle_files = False - algorithm_class = RedNOE + algorithm_class = ReduceNOE metrics = ["perplexity"] From 04ee0cfd479ab299abc8130dcf0c1b8faf2fbe96 Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 23 Dec 2025 16:46:15 +0000 Subject: [PATCH 09/11] feat: simplify usage with unconstrained hyperparameters --- src/pruna/algorithms/reduce_noe.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/pruna/algorithms/reduce_noe.py b/src/pruna/algorithms/reduce_noe.py index 634db7b5..bdfcea09 100644 --- a/src/pruna/algorithms/reduce_noe.py +++ b/src/pruna/algorithms/reduce_noe.py @@ -25,8 +25,8 @@ from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag as tags +from pruna.config.hyperparameters import UnconstrainedHyperparameter from pruna.config.smash_config import SmashConfigPrefixWrapper -from pruna.config.target_modules import TargetModules from pruna.engine.model_checks import is_moe_lm, is_transformers_pipeline_with_moe_lm from pruna.engine.utils import get_device_map, move_to_device, safe_memory_cleanup @@ -65,9 +65,9 @@ def get_hyperparameters(self) -> list: default_value=2, meta=dict(desc="Number of experts triggered per token."), ), - TargetModules( + UnconstrainedHyperparameter( name="target_name", - default_value={"include": ["num_experts_per_tok"], "exclude": []}, + default_value="num_experts_per_tok", meta=dict( desc="Name of of the parameter in the config.json file to be modified, " "e.g. 'num_experts_per_tok' for mixtral models. " @@ -125,12 +125,10 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: else: with config_path.open("r", encoding="utf-8") as f: config_json = json.load(f) - target_names = smash_config["target_name"]["include"] - if not target_names: - raise ValueError( - "The 'include' list in 'target_name' is empty. Please provide at least one config parameter name to modify." - ) - config_json[target_names[0]] = smash_config["num_experts_per_token"] + target_names = smash_config["target_name"] + if target_names not in config_json: + raise KeyError(f"Target name '{target_names}' not found in config file at {config_path}") + config_json[target_names] = smash_config["num_experts_per_token"] with config_path.open("w", encoding="utf-8") as f: json.dump(config_json, f, indent=2) safe_memory_cleanup() From 61a954a794c00318f7894a303828f55529bc6471 Mon Sep 17 00:00:00 2001 From: llcnt Date: Tue, 23 Dec 2025 16:51:18 +0000 Subject: [PATCH 10/11] fix: adapt path to new name --- tests/algorithms/testers/reduce_noe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/algorithms/testers/reduce_noe.py b/tests/algorithms/testers/reduce_noe.py index aa5dd76b..84344123 100644 --- a/tests/algorithms/testers/reduce_noe.py +++ b/tests/algorithms/testers/reduce_noe.py @@ -1,4 +1,4 @@ -from pruna.algorithms.red_noe import ReduceNOE +from pruna.algorithms.reduce_noe import ReduceNOE from .base_tester import AlgorithmTesterBase From a4053375a1ecc497a458f5e49e14dfa9e2b5bd8a Mon Sep 17 00:00:00 2001 From: llcnt Date: Mon, 19 Jan 2026 14:33:02 +0000 Subject: [PATCH 11/11] fix: adjust name and singular vs plural --- src/pruna/algorithms/reduce_noe.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pruna/algorithms/reduce_noe.py b/src/pruna/algorithms/reduce_noe.py index bdfcea09..82e7c411 100644 --- a/src/pruna/algorithms/reduce_noe.py +++ b/src/pruna/algorithms/reduce_noe.py @@ -38,7 +38,7 @@ class ReduceNOE(PrunaAlgorithmBase): ReduceNOE is a method to Reduce the Number Of Experts per token. """ - algorithm_name: str = "red_noe" + algorithm_name: str = "reduce_noe" group_tags: list[str] = [tags.PRUNER] references: dict[str, str] = {} tokenizer_required: bool = False @@ -125,10 +125,10 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: else: with config_path.open("r", encoding="utf-8") as f: config_json = json.load(f) - target_names = smash_config["target_name"] - if target_names not in config_json: - raise KeyError(f"Target name '{target_names}' not found in config file at {config_path}") - config_json[target_names] = smash_config["num_experts_per_token"] + target_name = smash_config["target_name"] + if target_name not in config_json: + raise KeyError(f"Target name '{target_name}' not found in config file at {config_path}") + config_json[target_name] = smash_config["num_experts_per_token"] with config_path.open("w", encoding="utf-8") as f: json.dump(config_json, f, indent=2) safe_memory_cleanup()