Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 7d37cae

Browse files
committed
protect fused modules with multiple hooks from dict mutate exception (#872)
1 parent e0524d6 commit 7d37cae

File tree

1 file changed

+76
-1
lines changed
  • src/sparseml/pytorch/sparsification/quantization

1 file changed

+76
-1
lines changed

src/sparseml/pytorch/sparsification/quantization/helpers.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""
1818

1919
from copy import deepcopy
20-
from typing import Any, Callable, Dict, List, Optional, Union
20+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2121

2222
import torch
2323
from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU
@@ -33,6 +33,7 @@
3333
from dataclasses import dataclass, field
3434

3535
from sparseml.pytorch.nn import ReLU as ReLU_nm
36+
from sparseml.pytorch.utils import get_layer
3637

3738

3839
__all__ = [
@@ -91,6 +92,25 @@
9192
else None
9293
)
9394

95+
_FUSED_MODULE_TYPES = (
96+
(
97+
# Conv based layers
98+
nni.ConvBn1d,
99+
nni.ConvBn2d,
100+
nni.ConvBn3d,
101+
nni.ConvReLU1d,
102+
nni.ConvReLU2d,
103+
nni.ConvReLU3d,
104+
nni.ConvBnReLU1d,
105+
nni.ConvBnReLU2d,
106+
nni.ConvBnReLU3d,
107+
# Linear Layers
108+
nni.LinearReLU,
109+
)
110+
if nni # nni will always import if torch.quantization is available
111+
else tuple()
112+
)
113+
94114

95115
@dataclass
96116
class QConfigProperties:
@@ -675,7 +695,16 @@ def fuse_module_conv_bn_relus(
675695
if len(current_block) > 1:
676696
conv_blocks.append(current_block)
677697
if conv_blocks:
698+
# manually save and move hooks surrounding fused blocks into new fused modules
699+
# due to torch.quantization error when a module has more than one hook
700+
block_hooks = _delete_get_block_hooks(module, conv_blocks)
701+
702+
# run torch fusion
678703
torch_quantization.fuse_modules(module, conv_blocks, inplace=True)
704+
705+
# add hooks back
706+
_add_fused_block_hooks(module, block_hooks)
707+
679708
return module
680709

681710

@@ -701,6 +730,52 @@ def prepare_embeddings_qat(
701730
_prepare_qat_embedding(submodule, qconfig)
702731

703732

733+
def _delete_get_block_hooks(
734+
module: Module,
735+
fuse_blocks: List[str],
736+
) -> List[Tuple[Any, Any]]:
737+
block_hooks = []
738+
for block in fuse_blocks:
739+
pre_hooks = []
740+
post_hooks = []
741+
742+
# get first and last Module objects in block by their names
743+
block_head = get_layer(block[0], module)
744+
block_tail = get_layer(block[-1], module)
745+
746+
for handle_id, pre_hook_fn in list(block_head._forward_pre_hooks.items()):
747+
pre_hooks.append(pre_hook_fn)
748+
del block_head._forward_pre_hooks[handle_id]
749+
750+
for handle_id, hook_fn in list(block_tail._forward_hooks.items()):
751+
post_hooks.append(hook_fn)
752+
del block_tail._forward_hooks[handle_id]
753+
754+
block_hooks.append((pre_hooks, post_hooks))
755+
756+
return block_hooks
757+
758+
759+
def _add_fused_block_hooks(module: Module, block_hooks: List[Tuple[Any, Any]]):
760+
fused_modules = [
761+
mod for mod in module.modules() if isinstance(mod, _FUSED_MODULE_TYPES)
762+
]
763+
764+
if len(fused_modules) != len(block_hooks):
765+
raise RuntimeError(
766+
f"Number of fused modules ({len(fused_modules)}) after layer fusion in "
767+
f"module {module.__class__.__name__}. does not match expected "
768+
f"({len(block_hooks)}). Module may have already been fused or block "
769+
"skipped during torch.quantization.fuse_modules"
770+
)
771+
772+
for fused_module, (pre_hooks, post_hooks) in zip(fused_modules, block_hooks):
773+
for pre_hook in pre_hooks:
774+
fused_module.register_forward_pre_hook(pre_hook)
775+
for post_hook in post_hooks:
776+
fused_module.register_forward_hook(post_hook)
777+
778+
704779
def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"):
705780
embedding.weight_fake_quant = qconfig.weight()
706781

0 commit comments

Comments
 (0)