Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 1e5b27b

Browse files
anmarquesdbogunowiczrahul-tuli
authored
Feature custom qat module (#999)
* Extend quantizable layer types (#881) * Creates support to adding modules to the list of quantizable modules via a modifier flag. * Passed layer_class_argument to recursive call. * Changed flag name to have better contrast to exclude_module_types * Style and quality fixes * Update src/sparseml/pytorch/sparsification/quantization/helpers.py Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> * Dea-activate conversions that remove Q/DQ after matmul * Remove unused code * Use public property name instead of alias * Replace conditional by dictionary * Style and quality fixes * Changed flag name to custom_quantizable_module_types to make it a bit less confusing when contrasted w/ existing flag exclude_module_types (sometimes the same module needs to be listed in both flags). * Fix of calls to set _exclude_module_types Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Co-authored-by: Rahul Tuli <rahul@neuralmagic.com> * Revert dtype conversion. Dictionary was breaking in some cases (#989) Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
1 parent e7c5009 commit 1e5b27b

File tree

3 files changed

+40
-35
lines changed

3 files changed

+40
-35
lines changed

src/sparseml/pytorch/sparsification/quantization/helpers.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -481,20 +481,25 @@ def configure_module_default_qconfigs(module: Module):
481481
submodule.configure_qconfig()
482482

483483

484-
def add_quant_dequant(module, name=None, parent_module=None):
484+
def add_quant_dequant(
485+
module: torch.nn.Module, name=None, parent_module=None, layer_class_names=None
486+
):
485487
"""
486488
Wraps all Conv and Linear submodule with a qconfig with a QuantWrapper
487489
:param module: the module to modify
488490
:param name: name of the module to modify; default to None
489491
:param parent_module: parent module containing the module to modify; default to None
492+
:param layer_class_names: list of module class names to be added to the
493+
list of quantizable modules
490494
:return: the modified module
491495
"""
492496
named_children = module.named_children()
493-
if (
494-
type(module) in _QUANTIZABLE_MODULE_TYPES
495-
and hasattr(module, "qconfig")
496-
and module.qconfig
497-
):
497+
is_quantizable = type(module) in _QUANTIZABLE_MODULE_TYPES
498+
if layer_class_names:
499+
is_quantizable = (
500+
is_quantizable or module.__class__.__name__ in layer_class_names
501+
)
502+
if is_quantizable and hasattr(module, "qconfig") and module.qconfig:
498503
module = torch_quantization.QuantWrapper(module)
499504
if parent_module is not None and len(list(named_children)) <= 0:
500505
if "." in name:
@@ -508,7 +513,11 @@ def add_quant_dequant(module, name=None, parent_module=None):
508513
setattr(parent_module, name, module)
509514
else:
510515
for name, child in named_children:
511-
setattr(module, name, add_quant_dequant(child))
516+
setattr(
517+
module,
518+
name,
519+
add_quant_dequant(child, layer_class_names=layer_class_names),
520+
)
512521
return module
513522

514523

src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ class QuantizationModifier(ScheduledModifier):
135135
batch-normalization modules
136136
:param exclude_module_types: optional list of module class names
137137
to not propagate quantization configs to. Default is None
138+
:param custom_quantizable_module_types: optional list of module class names
139+
to be added to the list of quantizable modules. Default is None
138140
:param activation_qconfig_kwargs: Additional kwargs for quantization of
139141
activations.
140142
:param weight_qconfig_kwargs: Additional kwargs for quantization of
@@ -162,6 +164,7 @@ def __init__(
162164
num_calibration_steps: Optional[int] = None,
163165
exclude_batchnorm: bool = True,
164166
exclude_module_types: Optional[List[str]] = None,
167+
custom_quantizable_module_types: Optional[List[str]] = None,
165168
activation_qconfig_kwargs: Optional[Dict[str, Any]] = None,
166169
weight_qconfig_kwargs: Optional[Dict[str, Any]] = None,
167170
tensorrt: bool = False,
@@ -195,6 +198,7 @@ def __init__(
195198
self._weight_bits = weight_bits
196199
self._exclude_batchnorm = exclude_batchnorm
197200
self._exclude_module_types = exclude_module_types
201+
self._custom_quantizable_module_types = custom_quantizable_module_types
198202

199203
self._modules_to_quantize = None
200204
self._qat_enabled = False
@@ -389,6 +393,14 @@ def quantize_embedding_activations(self) -> bool:
389393
else:
390394
return self._quantize_embedding_activations
391395

396+
@ModifierProp()
397+
def custom_quantizable_module_types(self) -> Union[List[str], None]:
398+
"""
399+
:return: optional list of module class names to be included
400+
in list of quantizable modules. Default is None
401+
"""
402+
return self._custom_quantizable_module_types
403+
392404
@ModifierProp()
393405
def exclude_module_types(self) -> Union[List[str], None]:
394406
"""
@@ -651,8 +663,9 @@ def _enable_module_qat(self, module: Module):
651663
# wrap all conv / linear blocks in with quantization observers
652664
torch_quantization.propagate_qconfig_(quant_module)
653665
configure_module_default_qconfigs(quant_module)
654-
655-
add_quant_dequant(quant_module, name, module)
666+
add_quant_dequant(
667+
quant_module, name, module, self.custom_quantizable_module_types
668+
)
656669

657670
# Remove output quantization from appropriate modules
658671
remove_activation_qat_by_layer_name(
@@ -661,16 +674,16 @@ def _enable_module_qat(self, module: Module):
661674

662675
# remove qconfigs for module types in exclude_module_types
663676
to_exclude = []
664-
if self._exclude_module_types:
665-
to_exclude.extend(self._exclude_module_types)
677+
if self.exclude_module_types:
678+
to_exclude.extend(self.exclude_module_types)
666679

667680
# if exclude_batchnorm flag is used, add batch norm layers to list of
668681
# modules to exclude qconfig
669-
if self._exclude_batchnorm:
682+
if self.exclude_batchnorm:
670683
to_exclude.extend(["BatchNorm1d", "BatchNorm2d", "BatchNorm3d"])
671684

672685
self._exclude_module_types = to_exclude
673-
if self._exclude_module_types:
686+
if self.exclude_module_types:
674687
self._strip_excluded_module_qconfigs(module)
675688

676689
# set modules with proper qconfigs to QAT mode
@@ -753,9 +766,9 @@ def _freeze_bn_stats_update_ready(self, epoch: float) -> bool:
753766
)
754767

755768
def _strip_excluded_module_qconfigs(self, module: Module):
756-
if not self._exclude_module_types:
769+
if not self.exclude_module_types:
757770
return
758-
excluded_classes = set(self._exclude_module_types)
771+
excluded_classes = set(self.exclude_module_types)
759772
for submodule in module.modules():
760773
if submodule.__class__.__name__ in excluded_classes and hasattr(
761774
submodule, "qconfig"

src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def _attribute_to_kwarg(attribute: onnx.AttributeProto):
323323
def _quantize_array(
324324
array: numpy.ndarray, scale: float, zero_point: int, dtype: Any = numpy.uint8
325325
) -> numpy.ndarray:
326+
326327
if dtype == numpy.uint8:
327328
tensor_dtype = torch.quint8
328329
elif dtype == numpy.int8:
@@ -1060,25 +1061,8 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
10601061
if not bias_add_node or bias_add_node.op_type != "Add":
10611062
continue
10621063

1063-
# Optionally find output QDQ block which will be deleted
1064-
output_quantize_node = graph.get_node_single_child(bias_add_node)
1065-
if (
1066-
not output_quantize_node
1067-
or output_quantize_node.op_type not in _QUANTIZE_OP_NAMES
1068-
):
1069-
output_quantize_node = None
1070-
1071-
output_dequantize_node = (
1072-
graph.get_node_single_child(output_quantize_node)
1073-
if output_quantize_node
1074-
else None
1075-
)
1076-
if (
1077-
not output_dequantize_node
1078-
or output_dequantize_node.op_type not in _QUANTIZE_OP_NAMES
1079-
):
1080-
output_quantize_node = None
1081-
output_dequantize_node = None
1064+
output_quantize_node = None
1065+
output_dequantize_node = None
10821066

10831067
input_quantize_params = get_quantization_params(
10841068
model, input_quantize_node, include_target=False
@@ -1587,7 +1571,6 @@ def quantize_torch_qat_export(
15871571
_convert_quantizable_gemm_no_activations(model)
15881572
quantize_resnet_identity_add_inputs(model)
15891573
_remove_duplicate_quantize_ops(model)
1590-
_cleanup_unused_quants(model)
15911574

15921575
graph = ONNXGraph(model)
15931576
graph.sort_nodes_topologically()

0 commit comments

Comments
 (0)