Skip to content

Commit b6935ad

Browse files
authored
NXP backend: Add QAT support for NeutronQuantizer (#15692)
### Summary Adds quantization aware training support for NeutronQuantizer. ### Test plan New test cases covering QAT mode were added + dedicated test for training a simple NN in QAT mode.
1 parent c4878e3 commit b6935ad

34 files changed

+891
-504
lines changed

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 113 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,13 @@
5454
)
5555
from torch import fx
5656
from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
57-
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
57+
from torchao.quantization.pt2e import (
58+
FakeQuantize,
59+
FusedMovingAvgObsFakeQuantize,
60+
HistogramObserver,
61+
MinMaxObserver,
62+
MovingAverageMinMaxObserver,
63+
)
5864
from torchao.quantization.pt2e.quantizer import (
5965
ComposableQuantizer,
6066
DerivedQuantizationSpec,
@@ -154,78 +160,120 @@ def get_supported_operators(cls) -> list[OperatorConfig]:
154160

155161

156162
# Quantization Specification used by Neutron NPU
157-
act_qspec = QuantizationSpec(
158-
dtype=torch.int8,
159-
quant_min=-128,
160-
quant_max=127,
161-
qscheme=torch.per_tensor_affine,
162-
is_dynamic=False,
163-
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
164-
)
165-
166-
wgt_qspec = QuantizationSpec(
167-
dtype=torch.int8,
168-
quant_min=-127,
169-
quant_max=127,
170-
qscheme=torch.per_tensor_symmetric,
171-
is_dynamic=False,
172-
observer_or_fake_quant_ctr=MinMaxObserver,
173-
ch_axis=0,
174-
)
163+
def act_qspec(is_qat: bool):
164+
eps = 2**-12
165+
observer_or_fake_quant_ctr = (
166+
FusedMovingAvgObsFakeQuantize.with_args(
167+
observer=MovingAverageMinMaxObserver, eps=eps
168+
)
169+
if is_qat
170+
else HistogramObserver.with_args(eps=eps)
171+
)
172+
173+
return QuantizationSpec(
174+
dtype=torch.int8,
175+
quant_min=-128,
176+
quant_max=127,
177+
qscheme=torch.per_tensor_affine,
178+
is_dynamic=False,
179+
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
180+
)
181+
182+
183+
def wgt_qspec(is_qat: bool):
184+
observer_or_fake_quant_ctr = (
185+
FakeQuantize.with_args(observer=MovingAverageMinMaxObserver)
186+
if is_qat
187+
else MinMaxObserver
188+
)
189+
190+
return QuantizationSpec(
191+
dtype=torch.int8,
192+
quant_min=-127,
193+
quant_max=127,
194+
qscheme=torch.per_tensor_symmetric,
195+
is_dynamic=False,
196+
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
197+
ch_axis=0,
198+
)
199+
200+
201+
def wgt_fc_qspec(is_qat: bool):
202+
observer_or_fake_quant_ctr = (
203+
FakeQuantize.with_args(observer=MovingAverageMinMaxObserver)
204+
if is_qat
205+
else MinMaxObserver
206+
)
207+
208+
return QuantizationSpec(
209+
dtype=torch.int8,
210+
quant_min=-127,
211+
quant_max=127,
212+
qscheme=torch.per_tensor_symmetric,
213+
is_dynamic=False,
214+
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
215+
)
175216

176-
wgt_fc_qspec = QuantizationSpec(
177-
dtype=torch.int8,
178-
quant_min=-127,
179-
quant_max=127,
180-
qscheme=torch.per_tensor_symmetric,
181-
is_dynamic=False,
182-
observer_or_fake_quant_ctr=MinMaxObserver,
183-
)
184217

185218
# Is set by the *PatternQuantizer directly.
186219
bias_qspec = None
187220

188221

189222
class NeutronQuantizer(ComposableQuantizer):
190-
def __init__(self, neutron_target_spec: NeutronTargetSpec):
223+
def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False):
191224
self.neutron_target_spec = neutron_target_spec
192-
static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None)
193-
static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None)
225+
self.is_qat = is_qat
226+
227+
static_qconfig = QuantizationConfig(
228+
act_qspec(is_qat=is_qat),
229+
act_qspec(is_qat=is_qat),
230+
wgt_qspec(is_qat=is_qat),
231+
None,
232+
)
233+
static_fc_qconfig = QuantizationConfig(
234+
act_qspec(is_qat=is_qat),
235+
act_qspec(is_qat=is_qat),
236+
wgt_fc_qspec(is_qat=is_qat),
237+
None,
238+
)
239+
240+
OpQuantizer = NeutronAtenQuantizer
194241
super().__init__(
195242
[
196-
NeutronAtenQuantizer(AbsPattern(), static_qconfig),
197-
NeutronAtenQuantizer(AdaptiveAvgPoolPattern(), static_qconfig),
198-
NeutronAtenQuantizer(AddTensorPattern(), static_qconfig),
199-
NeutronAtenQuantizer(AddmmPattern(self), static_fc_qconfig),
200-
NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig),
201-
NeutronAtenQuantizer(CatPattern(), static_qconfig),
202-
NeutronAtenQuantizer(Conv1dPattern(), static_qconfig),
203-
NeutronAtenQuantizer(Conv2dPattern(self), static_qconfig),
204-
NeutronAtenQuantizer(ConvTranspose2dPattern(), static_qconfig),
205-
NeutronAtenQuantizer(DropoutPattern(), static_qconfig),
206-
NeutronAtenQuantizer(FlattenPattern(), static_qconfig),
207-
NeutronAtenQuantizer(HardTanhPattern(), static_qconfig),
208-
NeutronAtenQuantizer(HardTanhInPlacePattern(), static_qconfig),
209-
NeutronAtenQuantizer(LinearPattern(self), static_fc_qconfig),
210-
NeutronAtenQuantizer(MaxPoolPattern(), static_qconfig),
211-
NeutronAtenQuantizer(MeanDimPattern(), static_qconfig),
212-
NeutronAtenQuantizer(MmPattern(self), static_qconfig),
213-
NeutronAtenQuantizer(MulTensorPattern(), static_qconfig),
214-
NeutronAtenQuantizer(PadPattern(), static_qconfig),
215-
NeutronAtenQuantizer(PermutePattern(), static_qconfig),
216-
NeutronAtenQuantizer(ReluPattern(), static_qconfig),
217-
NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig),
218-
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
219-
NeutronAtenQuantizer(SigmoidPattern(), static_qconfig),
220-
NeutronAtenQuantizer(SliceTensorPattern(), static_qconfig),
221-
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
222-
NeutronAtenQuantizer(SubTensorPattern(), static_qconfig),
223-
NeutronAtenQuantizer(TanhPattern(), static_qconfig),
224-
NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig),
225-
NeutronAtenQuantizer(TransposeIntPattern(), static_qconfig),
226-
NeutronAtenQuantizer(ViewPattern(), static_qconfig),
243+
OpQuantizer(AbsPattern(is_qat=is_qat), static_qconfig),
244+
OpQuantizer(AdaptiveAvgPoolPattern(is_qat=is_qat), static_qconfig),
245+
OpQuantizer(AddTensorPattern(is_qat=is_qat), static_qconfig),
246+
OpQuantizer(AddmmPattern(self, is_qat=is_qat), static_fc_qconfig),
247+
OpQuantizer(AvgPoolPattern(is_qat=is_qat), static_qconfig),
248+
OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig),
249+
OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig),
250+
OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig),
251+
OpQuantizer(ConvTranspose2dPattern(is_qat=is_qat), static_qconfig),
252+
OpQuantizer(DropoutPattern(is_qat=is_qat), static_qconfig),
253+
OpQuantizer(FlattenPattern(is_qat=is_qat), static_qconfig),
254+
OpQuantizer(HardTanhPattern(is_qat=is_qat), static_qconfig),
255+
OpQuantizer(HardTanhInPlacePattern(is_qat=is_qat), static_qconfig),
256+
OpQuantizer(LinearPattern(self, is_qat=is_qat), static_fc_qconfig),
257+
OpQuantizer(MaxPoolPattern(is_qat=is_qat), static_qconfig),
258+
OpQuantizer(MeanDimPattern(is_qat=is_qat), static_qconfig),
259+
OpQuantizer(MmPattern(self, is_qat=is_qat), static_qconfig),
260+
OpQuantizer(MulTensorPattern(is_qat=is_qat), static_qconfig),
261+
OpQuantizer(PadPattern(is_qat=is_qat), static_qconfig),
262+
OpQuantizer(PermutePattern(is_qat=is_qat), static_qconfig),
263+
OpQuantizer(ReluPattern(is_qat=is_qat), static_qconfig),
264+
OpQuantizer(ReluInPlacePattern(is_qat=is_qat), static_qconfig),
265+
OpQuantizer(ReshapePattern(is_qat=is_qat), static_qconfig),
266+
OpQuantizer(SigmoidPattern(is_qat=is_qat), static_qconfig),
267+
OpQuantizer(SliceTensorPattern(is_qat=is_qat), static_qconfig),
268+
OpQuantizer(SoftMaxPattern(is_qat=is_qat), static_qconfig),
269+
OpQuantizer(SubTensorPattern(is_qat=is_qat), static_qconfig),
270+
OpQuantizer(TanhPattern(is_qat=is_qat), static_qconfig),
271+
OpQuantizer(TanhInPlacePattern(is_qat=is_qat), static_qconfig),
272+
OpQuantizer(TransposeIntPattern(is_qat=is_qat), static_qconfig),
273+
OpQuantizer(ViewPattern(is_qat=is_qat), static_qconfig),
227274
]
228275
)
276+
229277
# Mapping ops defined in quantizer partition types to its quantizer
230278
self.op_to_quantizer = {
231279
pt: q for q in self.quantizers for pt in q.pattern.partition_types()
@@ -235,7 +283,9 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec):
235283
pt: False for q in self.quantizers for pt in q.pattern.partition_types()
236284
}
237285
self.cluster_quantizers = [
238-
NeutronAtenQuantizer(ActivationsConcatClusterPattern(self), static_qconfig)
286+
NeutronAtenQuantizer(
287+
ActivationsConcatClusterPattern(self, is_qat=is_qat), static_qconfig
288+
)
239289
]
240290

241291
def transform_for_annotation(
@@ -288,7 +338,7 @@ def _annotate_inputs(self, model: fx.GraphModule):
288338
continue
289339

290340
if node.op == "placeholder" and len(node.users) > 0:
291-
_annotate_output_qspec(node, act_qspec)
341+
_annotate_output_qspec(node, act_qspec(self.is_qat))
292342
self._mark_input_node_as_annotated(node)
293343

294344
def validate(self, model: torch.fx.GraphModule) -> None:

backends/nxp/quantizer/patterns.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
from torch import fx
1515
from torch._ops import OpOverload
1616
from torch.fx import Node
17-
from torchao.quantization.pt2e import PerChannelMinMaxObserver
17+
from torchao.quantization.pt2e import (
18+
FakeQuantize,
19+
MovingAveragePerChannelMinMaxObserver,
20+
PerChannelMinMaxObserver,
21+
)
1822
from torchao.quantization.pt2e.quantizer import (
1923
DerivedQuantizationSpec,
2024
FixedQParamsQuantizationSpec,
@@ -59,7 +63,8 @@ class PartitionAnchors:
5963
| tuple[fx.Node, NodeArgsIdx, SharedQuantizationSpec],
6064
] = field(default_factory=list)
6165
weights: list[
62-
tuple[fx.Node, NodeArgsIdx] | tuple[fx.Node, NodeArgsIdx, QuantizationSpec],
66+
tuple[fx.Node, NodeArgsIdx]
67+
| tuple[fx.Node, NodeArgsIdx, QuantizationSpec | FakeQuantize],
6368
] = field(default_factory=list)
6469
biases: list[
6570
tuple[fx.Node, NodeArgsIdx]
@@ -69,12 +74,18 @@ class PartitionAnchors:
6974
literals: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list)
7075
output: list[
7176
tuple[fx.Node]
72-
| tuple[fx.Node, FixedQParamsQuantizationSpec | SharedQuantizationSpec],
77+
| tuple[
78+
fx.Node,
79+
FixedQParamsQuantizationSpec | SharedQuantizationSpec,
80+
],
7381
] = field(default_factory=list)
7482
empty: bool = False
7583

7684

7785
class QuantizationPattern(ABC):
86+
def __init__(self, is_qat: bool = False):
87+
self.is_qat = is_qat
88+
7889
@abstractmethod
7990
def partition_types(self) -> list[OpOverload]:
8091
"""
@@ -148,11 +159,12 @@ def get_anchors_for_fixed_quant_specs(
148159
zero_point: int,
149160
quant_min: int = -128,
150161
quant_max: int = 127,
162+
is_qat: bool = False,
151163
) -> PartitionAnchors:
152164
node = fused_partition[0].nodes[-1]
153165
assert len(fused_partition[0].input_nodes) == 1
154166

155-
qspec = FixedQParamsQuantizationSpec(
167+
qspec_or_fake_quantize = FixedQParamsQuantizationSpec(
156168
dtype=torch.int8,
157169
scale=scale,
158170
zero_point=zero_point,
@@ -166,7 +178,7 @@ def get_anchors_for_fixed_quant_specs(
166178
weights=[],
167179
biases=[],
168180
output=[
169-
(node, qspec),
181+
(node, qspec_or_fake_quantize),
170182
],
171183
)
172184

@@ -190,7 +202,9 @@ def partition_types(self):
190202

191203

192204
class AddmmPattern(QuantizationPattern):
193-
def __init__(self, neutron_quantizer):
205+
def __init__(self, neutron_quantizer, is_qat: bool):
206+
super().__init__(is_qat=is_qat)
207+
194208
self.neutron_quantizer = neutron_quantizer
195209
self.neutron_target_info = (
196210
self.neutron_quantizer.neutron_target_spec.neutron_target_info
@@ -365,7 +379,11 @@ def get_anchors(
365379
ch_axis=0,
366380
)
367381

368-
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
382+
weight_observer_or_fake_quant_ctr = (
383+
FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver)
384+
if self.is_qat
385+
else PerChannelMinMaxObserver
386+
)
369387
weight_quantization_spec = QuantizationSpec(
370388
dtype=torch.int8,
371389
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
@@ -399,7 +417,9 @@ def partition_types(self) -> list[OpOverload]:
399417

400418

401419
class Conv2dPattern(ConvPattern):
402-
def __init__(self, neutron_quantizer):
420+
def __init__(self, neutron_quantizer, is_qat: bool = False):
421+
super().__init__(is_qat=is_qat)
422+
403423
self.neutron_quantizer = neutron_quantizer
404424
self.neutron_target_info = (
405425
self.neutron_quantizer.neutron_target_spec.neutron_target_info
@@ -426,7 +446,11 @@ def get_anchors(
426446
ch_axis=0,
427447
)
428448

429-
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
449+
weight_observer_or_fake_quant_ctr = (
450+
FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver)
451+
if self.is_qat
452+
else PerChannelMinMaxObserver
453+
)
430454
weight_quantization_spec = QuantizationSpec(
431455
dtype=torch.int8,
432456
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr,
@@ -563,7 +587,9 @@ def replacement_op(self):
563587

564588

565589
class LinearPattern(QuantizationPattern):
566-
def __init__(self, neutron_quantizer):
590+
def __init__(self, neutron_quantizer, is_qat: bool = False):
591+
super().__init__(is_qat=is_qat)
592+
567593
self.neutron_quantizer = neutron_quantizer
568594
self.neutron_target_info = (
569595
self.neutron_quantizer.neutron_target_spec.neutron_target_info
@@ -637,7 +663,9 @@ def partition_types(self):
637663

638664

639665
class MmPattern(QuantizationPattern):
640-
def __init__(self, neutron_quantizer):
666+
def __init__(self, neutron_quantizer, is_qat: bool = False):
667+
super().__init__(is_qat=is_qat)
668+
641669
self.neutron_quantizer = neutron_quantizer
642670
self.neutron_target_info = (
643671
self.neutron_quantizer.neutron_target_spec.neutron_target_info
@@ -802,7 +830,7 @@ def get_anchors(
802830
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
803831
) -> PartitionAnchors:
804832
return get_anchors_for_fixed_quant_specs(
805-
fused_partition, scale=1.0 / 256.0, zero_point=-128
833+
fused_partition, scale=1.0 / 256.0, zero_point=-128, is_qat=self.is_qat
806834
)
807835

808836

@@ -820,7 +848,7 @@ def get_anchors(
820848
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
821849
) -> PartitionAnchors:
822850
return get_anchors_for_fixed_quant_specs(
823-
fused_partition, scale=1.0 / 256.0, zero_point=-128
851+
fused_partition, scale=1.0 / 256.0, zero_point=-128, is_qat=self.is_qat
824852
)
825853

826854

@@ -838,7 +866,7 @@ def get_anchors(
838866
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
839867
) -> PartitionAnchors:
840868
return get_anchors_for_fixed_quant_specs(
841-
fused_partition, scale=1.0 / 128.0, zero_point=0
869+
fused_partition, scale=1.0 / 128.0, zero_point=0, is_qat=self.is_qat
842870
)
843871

844872

@@ -856,7 +884,7 @@ def get_anchors(
856884
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
857885
) -> PartitionAnchors:
858886
return get_anchors_for_fixed_quant_specs(
859-
fused_partition, scale=1.0 / 128.0, zero_point=0
887+
fused_partition, scale=1.0 / 128.0, zero_point=0, is_qat=self.is_qat
860888
)
861889

862890

@@ -884,7 +912,9 @@ class ActivationsConcatClusterPattern(QuantizationPattern):
884912
885913
"""
886914

887-
def __init__(self, neutron_quantizer):
915+
def __init__(self, neutron_quantizer, is_qat: bool = False):
916+
super().__init__(is_qat=is_qat)
917+
888918
self.neutron_quantizer = neutron_quantizer
889919
self.neutron_target_info = (
890920
self.neutron_quantizer.neutron_target_spec.neutron_target_info

0 commit comments

Comments
 (0)