Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 2a19442

Browse files
authored
fix issues in configure_module_default_qconfigs (#359)
1 parent c4062fb commit 2a19442

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

src/sparseml/pytorch/utils/quantization/helpers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,12 @@ def configure_module_qat_wrappers(module: Module):
275275
276276
:param module: module to potentially wrap the submodules of
277277
"""
278-
for submodule in module.modules():
279-
for child_name, child_module in module.named_children():
280-
if hasattr(child_module, "wrap_qat") and child_module.wrap_qat:
281-
setattr(submodule, child_name, QATWrapper.from_module(child_module))
278+
# wrap any children of the given module as a QATWrapper if required
279+
for child_name, child_module in module.named_children():
280+
if hasattr(child_module, "wrap_qat") and child_module.wrap_qat:
281+
setattr(module, child_name, QATWrapper.from_module(child_module))
282+
# recurse on child module
283+
configure_module_qat_wrappers(child_module)
282284

283285

284286
def configure_module_default_qconfigs(module: Module):

tests/sparseml/pytorch/utils/quantization/test_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,13 @@ def _count_submodule_instances(module, clazz):
8282
reason="torch quantization not available",
8383
)
8484
def test_configure_module_qat_wrappers():
85-
module = _ModuleWrapper(_QATMatMul())
85+
module = _ModuleWrapper(_ModuleWrapper(_QATMatMul()))
8686

8787
assert not _module_has_quant_stubs(module)
8888

8989
configure_module_qat_wrappers(module)
9090

91-
qat_matmul = module.module
91+
qat_matmul = module.module.module
9292

9393
assert isinstance(qat_matmul, QATWrapper)
9494
assert _module_has_quant_stubs(module)

0 commit comments

Comments
 (0)