diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 098e5f03506..6e98989874e 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -5,7 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - +import logging from collections import defaultdict from collections.abc import Sequence @@ -110,8 +110,13 @@ UnsqueezeBeforeRepeatPass, UnsqueezeScalarPlaceholdersPass, ) - from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.common.pipeline_config import ( + ArmPassPipelineConfig, + FuseDuplicateUsersConfig, + SoftmaxDecompositionConfig, +) from executorch.backends.arm.tosa.specification import ( tosa_spec_in_set, TosaLoweringContext, @@ -124,11 +129,45 @@ from torch.fx.passes.infra.pass_base import PassResult from torch.nn.modules import Module +logger = logging.getLogger(__name__) + class ArmPassManager(PassManager): - def __init__(self, tosa_spec: TosaSpecification) -> None: - self.tosa_spec = tosa_spec + def __init__(self, compile_spec: ArmCompileSpec) -> None: + self.compile_spec = compile_spec + self.tosa_spec = compile_spec.tosa_spec + self._skip_pass_types: tuple[type, ...] = () super().__init__() + self.configure_skip_passes() + + def configure_skip_passes( + self, + override_config: ArmPassPipelineConfig | None = None, + ) -> tuple[type, ...]: + """ + Configures the pass manager to skip certain passes based on the ArmPassPipelineConfig class + found in the compile spec. + """ + skip_set: set[type] = set() + + config = override_config or self.compile_spec.get_pass_pipeline_config() + logger.debug(f"Skip Config: {config}") + + match config.softmax: + case SoftmaxDecompositionConfig.MASKED: + skip_set.add(DecomposeSoftmaxUnstablePass) + case SoftmaxDecompositionConfig.UNSTABLE: + skip_set.add(DecomposeSoftmaxPass) + skip_set.add(DecomposeMaskedFillPass) + + if config.fuse_duplicate_users is FuseDuplicateUsersConfig.DISABLED: + skip_set.add(FuseDuplicateUsersPass) + + self._skip_pass_types = tuple(skip_set) + skip_names = [skipped_pass.__name__ for skipped_pass in self._skip_pass_types] + logger.debug(f"Passes in skip list: {skip_names}") + + return self._skip_pass_types def validate_constraints_mandatory(self): """ @@ -165,6 +204,11 @@ def _transform(self, graph_module: GraphModule): with TosaLoweringContext(self.tosa_spec): return self(graph_module).graph_module + def add_pass(self, pipeline_pass): + if type(pipeline_pass) in self._skip_pass_types: + return + super().add_pass(pipeline_pass) + def _tosa_pipeline( self, exported_program: ExportedProgram, graph_module: GraphModule ) -> GraphModule: @@ -373,11 +417,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): DecomposeSqrtPass(), DecomposeSiluPass(), DecomposeAvgPool2dPass(), - ( - DecomposeSoftmaxUnstablePass() - if self.tosa_spec.is_U55_subset - else DecomposeSoftmaxPass() - ), + DecomposeSoftmaxUnstablePass(), + DecomposeSoftmaxPass(), ConvertMinMaxPass(), ] ) @@ -386,7 +427,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_passes( [ ReplaceInfAndLimitValuesPass(), - DecomposeMaskedFillPass() if not self.tosa_spec.is_U55_subset else None, + DecomposeMaskedFillPass(), ] ) diff --git a/backends/arm/common/arm_compile_spec.py b/backends/arm/common/arm_compile_spec.py index 1075594e901..dda2930b306 100644 --- a/backends/arm/common/arm_compile_spec.py +++ b/backends/arm/common/arm_compile_spec.py @@ -10,10 +10,12 @@ # JIT compiler flows. # +import json from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum +from executorch.backends.arm.common.pipeline_config import ArmPassPipelineConfig from executorch.backends.arm.tosa import TosaSpecification from executorch.exir.backend.compile_spec_schema import CompileSpec @@ -36,6 +38,7 @@ class DebugMode(Enum): _DEBUG_ARTIFACT_KEY = "debug_artifact_path" _DEBUG_MODE_KEY = "dump_debug_info" _OUTPUT_REORDER_KEY = "ouput_reorder_workaround" + _TRANSFORM_PIPELINE_CONFIG_KEY = "transform_pipeline_config" def _set_compile_specs( self, @@ -44,6 +47,7 @@ def _set_compile_specs( path_for_intermediates: str | None = None, tosa_debug_mode: DebugMode | None = None, output_order_workaround: bool = True, + pipeline_config: ArmPassPipelineConfig | None = None, ): """Set all values of dataclass directly.""" self.tosa_spec = tosa_spec @@ -51,6 +55,7 @@ def _set_compile_specs( self.path_for_intermediates = path_for_intermediates self.tosa_debug_mode = tosa_debug_mode self.output_order_workaround = output_order_workaround + self._pipeline_config = pipeline_config @classmethod def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 @@ -60,6 +65,7 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 path_for_intermediates: str | None = None tosa_debug_mode: ArmCompileSpec.DebugMode | None = None output_order_workaround: bool = True + pipeline_config: ArmPassPipelineConfig | None = None unknown_specs: dict[str, str] = {} for spec in compile_specs: key = spec.key @@ -98,6 +104,12 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 tosa_debug_mode = ArmCompileSpec.DebugMode[val] elif key == ArmCompileSpec._OUTPUT_REORDER_KEY: output_order_workaround = val # type: ignore[assignment] + elif key == ArmCompileSpec._TRANSFORM_PIPELINE_CONFIG_KEY: + if pipeline_config is not None: + raise ValueError( + "More than one transform pipeline entry in compile spec." + ) + pipeline_config = ArmPassPipelineConfig.from_dict(json.loads(val)) else: unknown_specs[key] = val @@ -120,6 +132,7 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 path_for_intermediates=path_for_intermediates, tosa_debug_mode=tosa_debug_mode, output_order_workaround=output_order_workaround, + pipeline_config=pipeline_config, ) cls.from_list_hook(compile_spec, unknown_specs) compile_spec.validate() @@ -189,8 +202,33 @@ def to_list(self): ) ) + if self._pipeline_config is not None and not self._pipeline_config.is_default(): + compile_spec.append( + CompileSpec( + ArmCompileSpec._TRANSFORM_PIPELINE_CONFIG_KEY, + self._pipeline_config.serialize(), + ) + ) return compile_spec + def get_pass_pipeline_config(self) -> ArmPassPipelineConfig: + """ + Returns configuration that controls how the Arm pass pipeline should behave. + Subclasses may override to tweak defaults for specific targets. + """ + if self._pipeline_config is None: + self._pipeline_config = self._create_default_pipeline_config() + return self._pipeline_config + + def set_pass_pipeline_config(self, config: ArmPassPipelineConfig) -> None: + self._pipeline_config = config + + def _create_default_pipeline_config(self) -> ArmPassPipelineConfig: + config = ArmPassPipelineConfig() + if self.tosa_spec.is_U55_subset: + config.disable_masked_softmax() + return config + def get_intermediate_path(self) -> str | None: """ Gets the path used for dumping intermediate results such as tosa and pte. diff --git a/backends/arm/common/pipeline_config.py b/backends/arm/common/pipeline_config.py new file mode 100644 index 00000000000..bbceb3c0c60 --- /dev/null +++ b/backends/arm/common/pipeline_config.py @@ -0,0 +1,59 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +from dataclasses import dataclass, fields +from enum import auto, Enum +from typing import Any + + +class SoftmaxDecompositionConfig(Enum): + MASKED = auto() + UNSTABLE = auto() + + +class FuseDuplicateUsersConfig(Enum): + ENABLED = auto() + DISABLED = auto() + + +@dataclass +class ArmPassPipelineConfig: + softmax: SoftmaxDecompositionConfig = SoftmaxDecompositionConfig.MASKED + fuse_duplicate_users: FuseDuplicateUsersConfig = FuseDuplicateUsersConfig.ENABLED + + def disable_masked_softmax(self) -> None: + self.softmax = SoftmaxDecompositionConfig.UNSTABLE + + def disable_fuse_duplicate_users(self) -> None: + self.fuse_duplicate_users = FuseDuplicateUsersConfig.DISABLED + + def is_default(self) -> bool: + return ( + self.softmax is SoftmaxDecompositionConfig.MASKED + and self.fuse_duplicate_users is FuseDuplicateUsersConfig.ENABLED + ) + + def to_dict(self) -> dict[str, str]: + return {f.name: getattr(self, f.name).name for f in fields(self)} + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ArmPassPipelineConfig": + config = cls() + for f in fields(cls): + raw_value = data.get(f.name) + if raw_value is None: + continue + enum_type = f.type + setattr(config, f.name, enum_type[raw_value]) + return config + + def serialize(self) -> bytes: + """Return a serialized representation of this config.""" + return json.dumps(self.to_dict()).encode() + + def __repr__(self): + fields = ", ".join(f"{name}={value!r}" for name, value in self.__dict__.items()) + return f"({fields})" diff --git a/backends/arm/ethosu/compile_spec.py b/backends/arm/ethosu/compile_spec.py index b53034c365e..e2c49840f80 100644 --- a/backends/arm/ethosu/compile_spec.py +++ b/backends/arm/ethosu/compile_spec.py @@ -4,14 +4,13 @@ # LICENSE file in the root directory of this source tree. from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec - +from executorch.backends.arm.common.pipeline_config import ( # noqa: unused + ArmPassPipelineConfig, +) from executorch.backends.arm.tosa import ( # type: ignore[import-not-found] TosaSpecification, ) - -from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found] - CompileSpec, -) +from executorch.exir.backend.compile_spec_schema import CompileSpec class EthosUCompileSpec(ArmCompileSpec): @@ -43,7 +42,6 @@ def __init__( """ self.target = target - # Set vela compiler flags if config_ini is None: config_ini = "Arm/vela.ini" @@ -57,25 +55,26 @@ def __init__( ] ) # default system config and memory mode - if "ethos-u55" in self.target: + target_lower = self.target.lower() + if "ethos-u55" in target_lower: if system_config is None: system_config = "Ethos_U55_High_End_Embedded" if memory_mode is None: memory_mode = "Shared_Sram" - elif "ethos-u85" in self.target: + elif "ethos-u85" in target_lower: if system_config is None: system_config = "Ethos_U85_SYS_DRAM_Mid" if memory_mode is None: memory_mode = "Sram_Only" else: - raise RuntimeError(f"Unknown ethos target: {self.target}") + raise RuntimeError(f"Unknown ethos target: {target}") compiler_flags.append(f"--system-config={system_config}") compiler_flags.append(f"--memory-mode={memory_mode}") # Set TOSA version. base_tosa_version = "TOSA-1.0+INT+int16" - if "u55" in self.target: + if "u55" in target_lower: # Add the Ethos-U55 extension marker base_tosa_version += "+u55" tosa_spec = TosaSpecification.create_from_string(base_tosa_version) @@ -109,3 +108,8 @@ def validate(self): def get_output_format(cls) -> str: """Return the artifact format emitted by this compile spec.""" return "vela" + + def _create_default_pipeline_config(self) -> ArmPassPipelineConfig: + # Any u55 subset passes are treated as tosa specification configs + # As such, they should be added to the base class default. + return super()._create_default_pipeline_config() diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 467209fcb75..a383f44890f 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -339,11 +339,13 @@ class TOSAQuantizer(Quantizer): def __init__( self, compile_spec_or_tosa_spec: TosaSpecification | ArmCompileSpec ) -> None: - super().__init__() + self.compile_spec: ArmCompileSpec if isinstance(compile_spec_or_tosa_spec, TosaSpecification): - self.tosa_spec = compile_spec_or_tosa_spec - self.compile_spec = None + from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec + + self.compile_spec = TosaCompileSpec(compile_spec_or_tosa_spec) + self.tosa_spec = self.compile_spec.tosa_spec elif isinstance(compile_spec_or_tosa_spec, ArmCompileSpec): self.compile_spec = compile_spec_or_tosa_spec self.tosa_spec = self.compile_spec.tosa_spec @@ -432,9 +434,8 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule: # TODO: Fix the need to lazily import this. from executorch.backends.arm._passes import ArmPassManager - return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline( - graph_module=model - ) + pass_manager = ArmPassManager(self.compile_spec) + return pass_manager.transform_for_annotation_pipeline(graph_module=model) def annotate(self, model: GraphModule) -> GraphModule: """Annotate the graph with the configured quantization settings. diff --git a/backends/arm/test/misc/test_call_operator_submodule.py b/backends/arm/test/misc/test_call_operator_submodule.py index 799c546e24e..03201c86f59 100644 --- a/backends/arm/test/misc/test_call_operator_submodule.py +++ b/backends/arm/test/misc/test_call_operator_submodule.py @@ -9,7 +9,7 @@ from executorch.backends.arm._passes.arm_pass import ArmPass from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager -from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from torch.fx import GraphModule from torch.fx.passes.infra.pass_base import PassResult @@ -58,7 +58,7 @@ def test_call_operator_runs_once_for_cond_submodules() -> None: graph_module = exported.graph_module recording_pass = _DepthRecordingPass(graph_module) - pass_manager = ArmPassManager(TosaSpecification.create_from_string("TOSA-1.00+FP")) + pass_manager = ArmPassManager(TosaCompileSpec("TOSA-1.00+FP")) pass_manager.add_pass(recording_pass) pass_manager._transform(graph_module) diff --git a/backends/arm/test/misc/test_pass_pipeline_config.py b/backends/arm/test/misc/test_pass_pipeline_config.py new file mode 100644 index 00000000000..e89a235ae9a --- /dev/null +++ b/backends/arm/test/misc/test_pass_pipeline_config.py @@ -0,0 +1,35 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.arm._passes import ( + DecomposeSoftmaxUnstablePass, + FuseDuplicateUsersPass, +) +from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager +from executorch.backends.arm.common.pipeline_config import ArmPassPipelineConfig +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.tosa.specification import TosaSpecification + + +def test_pipeline_config_override_outside_compile_spec(): + compile_spec = TosaCompileSpec( + TosaSpecification.create_from_string("TOSA-1.00+INT") + ) + default_manager = ArmPassManager(compile_spec) + default_skip_passes = default_manager._skip_pass_types + assert FuseDuplicateUsersPass not in default_skip_passes + assert DecomposeSoftmaxUnstablePass in default_skip_passes + + override_compile_spec = TosaCompileSpec( + TosaSpecification.create_from_string("TOSA-1.00+INT") + ) + override_config = ArmPassPipelineConfig() + override_config.disable_fuse_duplicate_users() + override_compile_spec.set_pass_pipeline_config(override_config) + override_manager = ArmPassManager(override_compile_spec) + skip_passes = override_manager._skip_pass_types + + assert FuseDuplicateUsersPass in skip_passes + assert DecomposeSoftmaxUnstablePass in skip_passes diff --git a/backends/arm/test/misc/test_pass_required_order.py b/backends/arm/test/misc/test_pass_required_order.py index 2745d25a498..694e1997d0f 100644 --- a/backends/arm/test/misc/test_pass_required_order.py +++ b/backends/arm/test/misc/test_pass_required_order.py @@ -8,6 +8,7 @@ import pytest from executorch.backends.arm._passes.arm_pass_manager import ArmPass, ArmPassManager +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.tosa.specification import TosaSpecification from executorch.exir.pass_base import ExportPass @@ -30,7 +31,8 @@ class IndependentPass(ArmPass): def _setup_pass_manager(passes: List[ArmPass] | None = None): tosa_spec = TosaSpecification.create_from_string("TOSA-1.00+INT") - pass_manager = ArmPassManager(tosa_spec) + compile_spec = TosaCompileSpec(tosa_spec) + pass_manager = ArmPassManager(compile_spec) if passes is not None: for p in passes: pass_manager.add_pass(p) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 8cd1d8db5af..bb344c369a2 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -204,7 +204,6 @@ def run_artifact(self, inputs): class RunPasses(tester.RunPasses): - @no_type_check def __init__( self, @@ -868,7 +867,7 @@ def run_transform_for_annotation_pipeline( artifact = self.get_artifact(stage) if self.cur == StageType.EXPORT: new_gm = ArmPassManager( - self.compile_spec.tosa_spec + self.compile_spec ).transform_for_annotation_pipeline(graph_module=artifact.graph_module) else: raise RuntimeError("Can only run passes on Export stage.") diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 5df9668a540..b92656c007a 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -129,7 +129,6 @@ def __init__( Union[Sequence[PassType], Dict[str, Sequence[PassType]]] ] = None, ): - self.tester = ArmTester( module, example_inputs=test_data, @@ -314,7 +313,6 @@ def run(self): class TOSAPipelineMaker(BasePipelineMaker, Generic[T]): - @staticmethod def is_tosa_ref_model_available(): """Checks if the TOSA reference model is available.""" @@ -992,7 +990,10 @@ def __init__( tosa_spec = tosa_profiles[tosa_version] - compile_spec = common.get_tosa_compile_spec(tosa_spec, custom_path=custom_path) + compile_spec: ArmCompileSpec = common.get_tosa_compile_spec( + tosa_spec, + custom_path=custom_path, + ) super().__init__( module, test_data, @@ -1063,7 +1064,6 @@ def __init__( ] = None, tosa_extensions: Optional[List[str]] = None, ): - if tosa_extensions is None: tosa_extensions = [] tosa_spec = TosaSpecification.create_from_string( diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 38cc6e255de..69d16a2a708 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -288,7 +288,7 @@ def _preprocess_module( # noqa: C901 # TODO: Fix the need to lazily import this. from executorch.backends.arm._passes import ArmPassManager - graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline( # type: ignore + graph_module = ArmPassManager(compile_spec).transform_to_backend_pipeline( # type: ignore exported_program=edge_program, graph_module=graph_module ) @@ -377,9 +377,14 @@ def filter_tosa_compile_specs( ``TOSABackend.preprocess``. """ + + pipeline_config = compile_spec.get_pass_pipeline_config() + tosa_compile_spec = TosaCompileSpec(compile_spec.tosa_spec) + tosa_compile_spec.set_pass_pipeline_config(pipeline_config) return ( - TosaCompileSpec(compile_spec.tosa_spec) - .dump_intermediate_artifacts_to(compile_spec.get_intermediate_path()) + tosa_compile_spec.dump_intermediate_artifacts_to( + compile_spec.get_intermediate_path() + ) .dump_debug_info(compile_spec.tosa_debug_mode) .set_output_order_workaround(compile_spec.output_order_workaround) ) diff --git a/backends/arm/tosa/compile_spec.py b/backends/arm/tosa/compile_spec.py index 98671031e3d..5cd72ce04b3 100644 --- a/backends/arm/tosa/compile_spec.py +++ b/backends/arm/tosa/compile_spec.py @@ -4,6 +4,9 @@ # LICENSE file in the root directory of this source tree. from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.common.pipeline_config import ( # noqa: unused + ArmPassPipelineConfig, +) from executorch.backends.arm.tosa import TosaSpecification @@ -21,6 +24,7 @@ def __init__(self, tosa_spec: TosaSpecification | str): if isinstance(tosa_spec, str): tosa_spec = TosaSpecification.create_from_string(tosa_spec) self._set_compile_specs(tosa_spec, []) + self.validate() def validate(self): """Ensure that no unsupported compiler flags were supplied.""" @@ -34,3 +38,11 @@ def validate(self): def get_output_format(cls) -> str: """Return the artifact format emitted by this compile spec.""" return "tosa" + + @classmethod + def from_list_hook(cls, compile_spec, specs: dict[str, str]): + super().from_list_hook(compile_spec, specs) + + def _create_default_pipeline_config(self): + config = super()._create_default_pipeline_config() + return config diff --git a/backends/arm/vgf/compile_spec.py b/backends/arm/vgf/compile_spec.py index 0e160492a9e..b5b13f59939 100644 --- a/backends/arm/vgf/compile_spec.py +++ b/backends/arm/vgf/compile_spec.py @@ -6,6 +6,9 @@ import logging from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec +from executorch.backends.arm.common.pipeline_config import ( # noqa: unused + ArmPassPipelineConfig, +) from executorch.backends.arm.tosa import ( # type: ignore[import-not-found] TosaSpecification, ) @@ -62,3 +65,9 @@ def validate(self): def get_output_format(cls) -> str: """Return the artifact format emitted by this compile spec.""" return "vgf" + + def _create_default_pipeline_config(self) -> ArmPassPipelineConfig: + config = super()._create_default_pipeline_config() + # GRPHCOMP-3140 / MLETORCH-1529 + config.disable_fuse_duplicate_users() + return config