Skip to content

Commit 2f501c5

Browse files
authored
Decompose after export in export_llama
Differential Revision: D87826410 Pull Request resolved: #15951
1 parent 8577a02 commit 2f501c5

File tree

4 files changed

+35
-17
lines changed

4 files changed

+35
-17
lines changed

backends/xnnpack/test/passes/test_propagate_custom_meta_pass.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,15 @@
2020
)
2121
from executorch.backends.xnnpack.test.tester import Quantize as XNNPackQuantize, Tester
2222
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
23+
24+
from executorch.exir import ExecutorchProgramManager
25+
from executorch.exir._serialize import _deserialize_pte_binary
2326
from executorch.exir.passes.external_constants_pass import (
2427
delegate_external_constants_pass_unlifted,
2528
)
29+
from executorch.extension.flat_tensor.serialize.serialize import (
30+
_deserialize_to_flat_tensor,
31+
)
2632

2733
from torchao.quantization.granularity import PerGroup
2834
from torchao.quantization.quant_api import Int8DynamicActivationIntxWeightConfig
@@ -87,7 +93,7 @@ def _test_linear(
8793
self,
8894
partitioner: XnnpackPartitioner,
8995
quantization_stage: Union[BaseStages.Quantize, BaseStages.Quantize_],
90-
):
96+
) -> ExecutorchProgramManager:
9197
eager_model = self.ModuleLinear(
9298
in_size=1,
9399
input_channels=32,
@@ -106,7 +112,7 @@ def _test_linear(
106112
exec = tester.get_artifact()
107113
program_buffer = exec.buffer
108114
self.assertEqual(len(exec._tensor_data), 1)
109-
data_buffer = bytes(exec._tensor_data.pop("model"))
115+
data_buffer = bytes(exec._tensor_data["model"])
110116
self.assertTrue(len(data_buffer) > 200)
111117
from executorch.extension.pybindings import portable_lib as runtime
112118

@@ -122,6 +128,8 @@ def _test_linear(
122128
# test_inputs
123129
# )
124130

131+
return exec
132+
125133
def test_quantize_(self):
126134
# Quantize with torchao quantize_ API.
127135
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
@@ -132,9 +140,16 @@ def test_quantize_(self):
132140
weight_dtype=torch.int4,
133141
weight_granularity=PerGroup(32),
134142
)
135-
self._test_linear(
143+
exec = self._test_linear(
136144
DynamicallyQuantizedPartitioner, BaseStages.Quantize_(config=linear_config)
137145
)
146+
# PTE file has no named data.
147+
pte_file = _deserialize_pte_binary(exec.buffer)
148+
self.assertEqual(pte_file.named_data, None)
149+
150+
# PTD file contains quantized weight and scale.
151+
ptd_file = _deserialize_to_flat_tensor(bytes(exec._tensor_data["model"]))
152+
self.assertEqual(len(ptd_file.named_data), 2)
138153

139154
def test_pt2e_quantize(self):
140155
# Quantize with pt2e quantize.
@@ -156,6 +171,15 @@ def test_pt2e_quantize(self):
156171
partitioner = XnnpackPartitioner(
157172
config_precisions=precision, per_op_mode=per_op_mode
158173
)
159-
self._test_linear(
174+
exec = self._test_linear(
160175
partitioner, XNNPackQuantize(quantization_config=quant_config)
161176
)
177+
# PTE file has no named data.
178+
pte_file = _deserialize_pte_binary(exec.buffer)
179+
self.assertEqual(pte_file.named_data, None)
180+
181+
# PTD file contains quantized weight, and potentially scale.
182+
ptd_file = _deserialize_to_flat_tensor(
183+
bytes(exec._tensor_data["model"])
184+
)
185+
self.assertTrue(len(ptd_file.named_data) >= 1)

examples/models/llama/source_transformation/quantize.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,6 @@ def filter_fn(m, fqn):
179179
),
180180
filter_fn=filter_fn,
181181
)
182-
183-
model = unwrap_tensor_subclass(model)
184-
185182
# TODO: deal with checkpoint / computation dtype decoupling.
186183

187184
if verbose:

extension/llm/export/builder.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from torch.nn.attention import SDPBackend
3939
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4040
from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer
41-
from torchao.utils import unwrap_tensor_subclass
4241

4342
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
4443
logging.basicConfig(level=logging.INFO, format=FORMAT)
@@ -137,15 +136,15 @@ def __init__(
137136
if not self.dynamic_shapes and self.enable_dynamic_shape:
138137
if not self.use_kv_cache:
139138
# Only one input argument: tokens
140-
# Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad
139+
# Here we use -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad
141140
self.dynamic_shapes = (
142141
{1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)},
143142
)
144143
else:
145144
# Two input arguments: tokens and input_pos but input_pos is static shape.
146-
145+
# Here we use -1 due to export limitation (same as non-kv-cache case above).
147146
self.dynamic_shapes = (
148-
{1: torch.export.Dim("token_dim", max=self.max_seq_len)},
147+
{1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)},
149148
{"input_pos": {0: 1}},
150149
)
151150

@@ -203,11 +202,6 @@ def _get_edge_config(self) -> EdgeCompileConfig:
203202
return edge_config
204203

205204
def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
206-
if module is not None:
207-
unwrap_tensor_subclass(module)
208-
else:
209-
unwrap_tensor_subclass(self.model)
210-
211205
dynamic_shape = self._get_dynamic_shape()
212206
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
213207
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
@@ -226,6 +220,8 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
226220
dynamic_shapes=dynamic_shape,
227221
strict=True,
228222
)
223+
# Functionalize the graph, and decompose subclasses from torchao quantize.
224+
exported_module = exported_module.run_decompositions({})
229225
return exported_module
230226

231227
def export(self) -> "LLMEdgeManager":

extension/llm/export/test/test_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> Non
8888
# Check first element (tokens dimension)
8989
self.assertIsInstance(result[0], dict)
9090
self.assertIn(1, result[0])
91-
self.assertEqual(result[0][1].max, self.max_seq_len)
91+
# max is max_seq_len - 1 due to export limitation
92+
self.assertEqual(result[0][1].max, self.max_seq_len - 1)
9293

9394
# Check second element (input_pos dimension)
9495
self.assertIsInstance(result[1], dict)

0 commit comments

Comments
 (0)