1717"""
1818
1919from copy import deepcopy
20- from typing import Any , Callable , Dict , List , Optional , Union
20+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
2121
2222import torch
2323from torch .nn import BatchNorm2d , Conv2d , Embedding , Module , ReLU
3333from dataclasses import dataclass , field
3434
3535from sparseml .pytorch .nn import ReLU as ReLU_nm
36+ from sparseml .pytorch .utils import get_layer
3637
3738
3839__all__ = [
9192 else None
9293)
9394
95+ _FUSED_MODULE_TYPES = (
96+ (
97+ # Conv based layers
98+ nni .ConvBn1d ,
99+ nni .ConvBn2d ,
100+ nni .ConvBn3d ,
101+ nni .ConvReLU1d ,
102+ nni .ConvReLU2d ,
103+ nni .ConvReLU3d ,
104+ nni .ConvBnReLU1d ,
105+ nni .ConvBnReLU2d ,
106+ nni .ConvBnReLU3d ,
107+ # Linear Layers
108+ nni .LinearReLU ,
109+ )
110+ if nni # nni will always import if torch.quantization is available
111+ else tuple ()
112+ )
113+
94114
95115@dataclass
96116class QConfigProperties :
@@ -675,7 +695,16 @@ def fuse_module_conv_bn_relus(
675695 if len (current_block ) > 1 :
676696 conv_blocks .append (current_block )
677697 if conv_blocks :
698+ # manually save and move hooks surrounding fused blocks into new fused modules
699+ # due to torch.quantization error when a module has more than one hook
700+ block_hooks = _delete_get_block_hooks (module , conv_blocks )
701+
702+ # run torch fusion
678703 torch_quantization .fuse_modules (module , conv_blocks , inplace = True )
704+
705+ # add hooks back
706+ _add_fused_block_hooks (module , block_hooks )
707+
679708 return module
680709
681710
@@ -701,6 +730,52 @@ def prepare_embeddings_qat(
701730 _prepare_qat_embedding (submodule , qconfig )
702731
703732
733+ def _delete_get_block_hooks (
734+ module : Module ,
735+ fuse_blocks : List [str ],
736+ ) -> List [Tuple [Any , Any ]]:
737+ block_hooks = []
738+ for block in fuse_blocks :
739+ pre_hooks = []
740+ post_hooks = []
741+
742+ # get first and last Module objects in block by their names
743+ block_head = get_layer (block [0 ], module )
744+ block_tail = get_layer (block [- 1 ], module )
745+
746+ for handle_id , pre_hook_fn in list (block_head ._forward_pre_hooks .items ()):
747+ pre_hooks .append (pre_hook_fn )
748+ del block_head ._forward_pre_hooks [handle_id ]
749+
750+ for handle_id , hook_fn in list (block_tail ._forward_hooks .items ()):
751+ post_hooks .append (hook_fn )
752+ del block_tail ._forward_hooks [handle_id ]
753+
754+ block_hooks .append ((pre_hooks , post_hooks ))
755+
756+ return block_hooks
757+
758+
759+ def _add_fused_block_hooks (module : Module , block_hooks : List [Tuple [Any , Any ]]):
760+ fused_modules = [
761+ mod for mod in module .modules () if isinstance (mod , _FUSED_MODULE_TYPES )
762+ ]
763+
764+ if len (fused_modules ) != len (block_hooks ):
765+ raise RuntimeError (
766+ f"Number of fused modules ({ len (fused_modules )} ) after layer fusion in "
767+ f"module { module .__class__ .__name__ } . does not match expected "
768+ f"({ len (block_hooks )} ). Module may have already been fused or block "
769+ "skipped during torch.quantization.fuse_modules"
770+ )
771+
772+ for fused_module , (pre_hooks , post_hooks ) in zip (fused_modules , block_hooks ):
773+ for pre_hook in pre_hooks :
774+ fused_module .register_forward_pre_hook (pre_hook )
775+ for post_hook in post_hooks :
776+ fused_module .register_forward_hook (post_hook )
777+
778+
704779def _prepare_qat_embedding (embedding : Module , qconfig : "torch.quantization.QConfig" ):
705780 embedding .weight_fake_quant = qconfig .weight ()
706781
0 commit comments