1919
2020import math
2121from collections import OrderedDict
22- from typing import List , Tuple , Union
22+ from typing import List , Optional , Tuple , Union
2323
24+ import torch
2425from torch import Tensor
2526from torch .nn import (
2627 AdaptiveAvgPool2d ,
3132 Module ,
3233 Sequential ,
3334 Sigmoid ,
35+ SiLU ,
3436 Softmax ,
3537)
3638
39+
40+ try :
41+ from torch .nn .quantized import FloatFunctional
42+ except Exception :
43+ FloatFunctional = None
44+
3745from sparseml .pytorch .models .registry import ModelRegistry
38- from sparseml .pytorch .nn import SqueezeExcite , Swish
46+ from sparseml .pytorch .nn import SqueezeExcite
3947
4048
4149__all__ = [
5260]
5361
5462
63+ class _Add (Module ):
64+ def __init__ (self ):
65+ super ().__init__ ()
66+
67+ if FloatFunctional :
68+ self .functional = FloatFunctional ()
69+ self .wrap_qat = True
70+ self .qat_wrapper_kwargs = {
71+ "num_inputs" : 2 ,
72+ "num_outputs" : 0 ,
73+ }
74+
75+ def forward (self , a : Tensor , b : Tensor ):
76+ if FloatFunctional :
77+ return self .functional .add (a , b )
78+ else :
79+ return torch .add (a , b )
80+
81+
82+ class QATSiLU (SiLU ):
83+ def __init__ (self , * args , ** kwargs ):
84+ super ().__init__ (* args , ** kwargs )
85+
86+ self .wrap_qat = True
87+ self .qat_wrapper_kwargs = {
88+ "num_inputs" : 1 ,
89+ "num_outputs" : 1 ,
90+ }
91+
92+
5593class _InvertedBottleneckBlock (Module ):
5694 def __init__ (
5795 self ,
@@ -68,7 +106,13 @@ def __init__(
68106 self ._out_channels = out_channels
69107 self ._stride = stride
70108 self ._se_mod = se_mod
71- expanded_channels = int (in_channels * expansion_ratio )
109+ expanded_channels = _scale_num_channels (in_channels , expansion_ratio )
110+ squeezed_channels = (
111+ max (1 , int (in_channels * se_ratio ))
112+ if se_ratio and 0 < se_ratio <= 1
113+ else None
114+ )
115+
72116 self .expand = (
73117 Sequential (
74118 OrderedDict (
@@ -83,7 +127,10 @@ def __init__(
83127 ),
84128 ),
85129 ("bn" , BatchNorm2d (num_features = expanded_channels )),
86- ("act" , Swish (num_channels = expanded_channels )),
130+ (
131+ "act" ,
132+ QATSiLU () if squeezed_channels else SiLU (),
133+ ),
87134 ]
88135 )
89136 )
@@ -108,30 +155,32 @@ def __init__(
108155 ),
109156 ),
110157 ("bn" , BatchNorm2d (num_features = expanded_channels )),
111- ("act" , Swish (num_channels = expanded_channels )),
158+ (
159+ "act" ,
160+ QATSiLU () if squeezed_channels else SiLU (),
161+ ),
112162 ]
113163 )
114164 )
115165
116- squeezed_channels = (
117- max (1 , int (in_channels * se_ratio ))
118- if se_ratio and 0 < se_ratio <= 1
119- else None
120- )
121-
122166 if self ._se_mod :
123167 self .se = (
124- SqueezeExcite (out_channels , squeezed_channels )
168+ SqueezeExcite (out_channels , squeezed_channels , "silu" )
125169 if squeezed_channels
126170 else None
127171 )
128172 else :
129173 self .se = (
130- SqueezeExcite (expanded_channels , squeezed_channels )
174+ SqueezeExcite (expanded_channels , squeezed_channels , "silu" )
131175 if squeezed_channels
132176 else None
133177 )
134178
179+ if self ._stride == 1 and self ._in_channels == self ._out_channels :
180+ self .add = _Add ()
181+ else :
182+ self .add = None
183+
135184 self .project = Sequential (
136185 OrderedDict (
137186 [
@@ -165,8 +214,8 @@ def forward(self, inp: Tensor):
165214 if self .se is not None and self ._se_mod :
166215 out = out * self .se (out )
167216
168- if self ._stride == 1 and self . _in_channels == self . _out_channels :
169- out = out + inp
217+ if self .add is not None :
218+ out = self . add ( out , inp )
170219
171220 return out
172221
@@ -188,7 +237,7 @@ def __init__(
188237 bias = False ,
189238 )
190239 self .bn = BatchNorm2d (num_features = out_channels )
191- self .act = Swish ( out_channels )
240+ self .act = SiLU ( )
192241 self .pool = AdaptiveAvgPool2d (1 )
193242 self .dropout = Dropout (p = dropout )
194243 self .fc = Linear (out_channels , classes )
@@ -283,11 +332,12 @@ def __init__(
283332 out_channels = sec_settings [0 ].in_channels ,
284333 kernel_size = 3 ,
285334 stride = 2 ,
335+ padding = 1 ,
286336 bias = False ,
287337 ),
288338 ),
289339 ("bn" , BatchNorm2d (num_features = sec_settings [0 ].in_channels )),
290- ("act" , Swish ( sec_settings [ 0 ]. in_channels )),
340+ ("act" , SiLU ( )),
291341 ]
292342 )
293343 )
@@ -317,33 +367,40 @@ def create_section(settings: EfficientNetSectionSettings) -> Sequential:
317367 stride = settings .stride
318368 blocks = []
319369
320- for _ in range (settings .num_blocks ):
370+ for block in range (settings .num_blocks ):
321371 blocks .append (
322372 _InvertedBottleneckBlock (
323- in_channels = in_channels ,
373+ in_channels = in_channels if block == 0 else settings . out_channels ,
324374 out_channels = settings .out_channels ,
325375 kernel_size = settings .kernel_size ,
326376 expansion_ratio = settings .expansion_ratio ,
327- stride = stride ,
377+ stride = stride if block == 0 else 1 ,
328378 se_ratio = settings .se_ratio ,
329379 se_mod = settings .se_mod ,
330380 )
331381 )
332- in_channels = settings .out_channels
333382
334383 return Sequential (* blocks )
335384
336385
337- def _scale_num_channels (channels : int , width_mult : float ) -> int :
338- divisor = 8
339- scaled = channels * width_mult
340- scaled = max (divisor , int (scaled + divisor / 2 ) // divisor * divisor )
386+ def _make_divisible (v : float , divisor : int , min_value : Optional [int ] = None ) -> int :
387+ """
388+ This function is taken from the original tf repo.
389+ It ensures that all layers have a channel number that is divisible by 8
390+ It can be seen here:
391+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
392+ """
393+ if min_value is None :
394+ min_value = divisor
395+ new_v = max (min_value , int (v + divisor / 2 ) // divisor * divisor )
396+ # Make sure that round down does not go down by more than 10%.
397+ if new_v < 0.9 * v :
398+ new_v += divisor
399+ return int (math .ceil (new_v ))
341400
342- if scaled < 0.9 * channels :
343- # prevent rounding by more than 10%
344- scaled += divisor
345401
346- return int (scaled )
402+ def _scale_num_channels (channels : int , width_mult : float ) -> int :
403+ return _make_divisible (channels * width_mult , 8 )
347404
348405
349406def _scale_num_blocks (blocks : int , depth_mult : float ) -> int :
@@ -438,7 +495,7 @@ def _efficient_net_params(model_name):
438495 params_dict = {
439496 "efficientnet_b0" : (1.0 , 1.0 , 0.2 , 224 ),
440497 "efficientnet_b1" : (1.0 , 1.1 , 0.2 , 240 ),
441- "efficientnet_b2" : (1.1 , 1.2 , 0.3 , 260 ),
498+ "efficientnet_b2" : (1.1 , 1.2 , 0.3 , 288 ),
442499 "efficientnet_b3" : (1.2 , 1.4 , 0.3 , 300 ),
443500 "efficientnet_b4" : (1.4 , 1.8 , 0.4 , 380 ),
444501 "efficientnet_b5" : (1.6 , 2.2 , 0.4 , 456 ),
0 commit comments