Skip to content

Commit eeafd42

Browse files
authored
Skip triton sdpa in gemma3 (#190)
* init * init * lint fix * Revert "init" This reverts commit 6938f59. * Revert "lint fix" This reverts commit 10f8a6e. * lint fix - 2
1 parent 8967fe9 commit eeafd42

File tree

1 file changed

+15
-14
lines changed
  • optimum/exporters/executorch/recipes

1 file changed

+15
-14
lines changed

optimum/exporters/executorch/recipes/cuda.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ExecutorchProgram,
2626
to_edge_transform_and_lower,
2727
)
28+
from executorch.exir.backend.compile_spec_schema import CompileSpec
2829
from optimum.executorch.passes.remove_padding_idx_embedding_pass import (
2930
RemovePaddingIdxEmbeddingPass,
3031
)
@@ -65,8 +66,6 @@ def export_to_executorch_with_cuda(
6566
For encoder-decoder models or multimodal models, it may generate multiple programs.
6667
"""
6768
# Import here to avoid version conflicts.
68-
from torch._inductor.decomposition import conv1d_to_conv2d
69-
7069
from executorch.backends.cuda.cuda_backend import CudaBackend
7170
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
7271

@@ -80,18 +79,20 @@ def _lower_to_executorch(
8079
if len(exported_programs) == 1:
8180
exported_programs = {"forward": next(iter(exported_programs.values()))}
8281

83-
# CUDA backend compile spec with method name.
84-
partitioners = {
85-
key: [CudaPartitioner([CudaBackend.generate_method_name_compile_spec(key)])]
86-
for key in exported_programs.keys()
87-
}
88-
# Add decompositions for triton to generate kernels.
89-
for key, ep in exported_programs.items():
90-
exported_programs[key] = ep.run_decompositions(
91-
{
92-
aten.conv1d.default: conv1d_to_conv2d,
93-
}
94-
)
82+
# Check if this is a Gemma3 model and prepare appropriate compile specs
83+
model_type = getattr(model.config, "model_type", None)
84+
85+
# For Gemma3 we don't want to use triton sdpa kernels for better performance
86+
partitioners = {}
87+
for key in exported_programs.keys():
88+
compile_specs = [CudaBackend.generate_method_name_compile_spec(key)]
89+
90+
# Add Gemma3-specific compile spec if needed
91+
if model_type == "gemma3":
92+
compile_specs.append(CompileSpec(key="triton_kernel_mode", value=b"OFF"))
93+
94+
partitioners[key] = [CudaPartitioner(compile_specs)]
95+
9596
et_prog = to_edge_transform_and_lower(
9697
exported_programs,
9798
partitioner=partitioners,

0 commit comments

Comments
 (0)