Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 3ccb428

Browse files
authored
Efficientnet fixes (#1035)
* Quantization improvements * Replace Swish by SiLU. Fix padding of input conv * Update image size * Quantization improvements * Replace Swish by SiLU. Fix padding of input conv * Update image size * Style and quality fixes
1 parent 4bf5d02 commit 3ccb428

File tree

2 files changed

+88
-31
lines changed

2 files changed

+88
-31
lines changed

src/sparseml/pytorch/models/classification/efficientnet.py

Lines changed: 87 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919

2020
import math
2121
from collections import OrderedDict
22-
from typing import List, Tuple, Union
22+
from typing import List, Optional, Tuple, Union
2323

24+
import torch
2425
from torch import Tensor
2526
from torch.nn import (
2627
AdaptiveAvgPool2d,
@@ -31,11 +32,18 @@
3132
Module,
3233
Sequential,
3334
Sigmoid,
35+
SiLU,
3436
Softmax,
3537
)
3638

39+
40+
try:
41+
from torch.nn.quantized import FloatFunctional
42+
except Exception:
43+
FloatFunctional = None
44+
3745
from sparseml.pytorch.models.registry import ModelRegistry
38-
from sparseml.pytorch.nn import SqueezeExcite, Swish
46+
from sparseml.pytorch.nn import SqueezeExcite
3947

4048

4149
__all__ = [
@@ -52,6 +60,36 @@
5260
]
5361

5462

63+
class _Add(Module):
64+
def __init__(self):
65+
super().__init__()
66+
67+
if FloatFunctional:
68+
self.functional = FloatFunctional()
69+
self.wrap_qat = True
70+
self.qat_wrapper_kwargs = {
71+
"num_inputs": 2,
72+
"num_outputs": 0,
73+
}
74+
75+
def forward(self, a: Tensor, b: Tensor):
76+
if FloatFunctional:
77+
return self.functional.add(a, b)
78+
else:
79+
return torch.add(a, b)
80+
81+
82+
class QATSiLU(SiLU):
83+
def __init__(self, *args, **kwargs):
84+
super().__init__(*args, **kwargs)
85+
86+
self.wrap_qat = True
87+
self.qat_wrapper_kwargs = {
88+
"num_inputs": 1,
89+
"num_outputs": 1,
90+
}
91+
92+
5593
class _InvertedBottleneckBlock(Module):
5694
def __init__(
5795
self,
@@ -68,7 +106,13 @@ def __init__(
68106
self._out_channels = out_channels
69107
self._stride = stride
70108
self._se_mod = se_mod
71-
expanded_channels = int(in_channels * expansion_ratio)
109+
expanded_channels = _scale_num_channels(in_channels, expansion_ratio)
110+
squeezed_channels = (
111+
max(1, int(in_channels * se_ratio))
112+
if se_ratio and 0 < se_ratio <= 1
113+
else None
114+
)
115+
72116
self.expand = (
73117
Sequential(
74118
OrderedDict(
@@ -83,7 +127,10 @@ def __init__(
83127
),
84128
),
85129
("bn", BatchNorm2d(num_features=expanded_channels)),
86-
("act", Swish(num_channels=expanded_channels)),
130+
(
131+
"act",
132+
QATSiLU() if squeezed_channels else SiLU(),
133+
),
87134
]
88135
)
89136
)
@@ -108,30 +155,32 @@ def __init__(
108155
),
109156
),
110157
("bn", BatchNorm2d(num_features=expanded_channels)),
111-
("act", Swish(num_channels=expanded_channels)),
158+
(
159+
"act",
160+
QATSiLU() if squeezed_channels else SiLU(),
161+
),
112162
]
113163
)
114164
)
115165

116-
squeezed_channels = (
117-
max(1, int(in_channels * se_ratio))
118-
if se_ratio and 0 < se_ratio <= 1
119-
else None
120-
)
121-
122166
if self._se_mod:
123167
self.se = (
124-
SqueezeExcite(out_channels, squeezed_channels)
168+
SqueezeExcite(out_channels, squeezed_channels, "silu")
125169
if squeezed_channels
126170
else None
127171
)
128172
else:
129173
self.se = (
130-
SqueezeExcite(expanded_channels, squeezed_channels)
174+
SqueezeExcite(expanded_channels, squeezed_channels, "silu")
131175
if squeezed_channels
132176
else None
133177
)
134178

179+
if self._stride == 1 and self._in_channels == self._out_channels:
180+
self.add = _Add()
181+
else:
182+
self.add = None
183+
135184
self.project = Sequential(
136185
OrderedDict(
137186
[
@@ -165,8 +214,8 @@ def forward(self, inp: Tensor):
165214
if self.se is not None and self._se_mod:
166215
out = out * self.se(out)
167216

168-
if self._stride == 1 and self._in_channels == self._out_channels:
169-
out = out + inp
217+
if self.add is not None:
218+
out = self.add(out, inp)
170219

171220
return out
172221

@@ -188,7 +237,7 @@ def __init__(
188237
bias=False,
189238
)
190239
self.bn = BatchNorm2d(num_features=out_channels)
191-
self.act = Swish(out_channels)
240+
self.act = SiLU()
192241
self.pool = AdaptiveAvgPool2d(1)
193242
self.dropout = Dropout(p=dropout)
194243
self.fc = Linear(out_channels, classes)
@@ -283,11 +332,12 @@ def __init__(
283332
out_channels=sec_settings[0].in_channels,
284333
kernel_size=3,
285334
stride=2,
335+
padding=1,
286336
bias=False,
287337
),
288338
),
289339
("bn", BatchNorm2d(num_features=sec_settings[0].in_channels)),
290-
("act", Swish(sec_settings[0].in_channels)),
340+
("act", SiLU()),
291341
]
292342
)
293343
)
@@ -317,33 +367,40 @@ def create_section(settings: EfficientNetSectionSettings) -> Sequential:
317367
stride = settings.stride
318368
blocks = []
319369

320-
for _ in range(settings.num_blocks):
370+
for block in range(settings.num_blocks):
321371
blocks.append(
322372
_InvertedBottleneckBlock(
323-
in_channels=in_channels,
373+
in_channels=in_channels if block == 0 else settings.out_channels,
324374
out_channels=settings.out_channels,
325375
kernel_size=settings.kernel_size,
326376
expansion_ratio=settings.expansion_ratio,
327-
stride=stride,
377+
stride=stride if block == 0 else 1,
328378
se_ratio=settings.se_ratio,
329379
se_mod=settings.se_mod,
330380
)
331381
)
332-
in_channels = settings.out_channels
333382

334383
return Sequential(*blocks)
335384

336385

337-
def _scale_num_channels(channels: int, width_mult: float) -> int:
338-
divisor = 8
339-
scaled = channels * width_mult
340-
scaled = max(divisor, int(scaled + divisor / 2) // divisor * divisor)
386+
def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
387+
"""
388+
This function is taken from the original tf repo.
389+
It ensures that all layers have a channel number that is divisible by 8
390+
It can be seen here:
391+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
392+
"""
393+
if min_value is None:
394+
min_value = divisor
395+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
396+
# Make sure that round down does not go down by more than 10%.
397+
if new_v < 0.9 * v:
398+
new_v += divisor
399+
return int(math.ceil(new_v))
341400

342-
if scaled < 0.9 * channels:
343-
# prevent rounding by more than 10%
344-
scaled += divisor
345401

346-
return int(scaled)
402+
def _scale_num_channels(channels: int, width_mult: float) -> int:
403+
return _make_divisible(channels * width_mult, 8)
347404

348405

349406
def _scale_num_blocks(blocks: int, depth_mult: float) -> int:
@@ -438,7 +495,7 @@ def _efficient_net_params(model_name):
438495
params_dict = {
439496
"efficientnet_b0": (1.0, 1.0, 0.2, 224),
440497
"efficientnet_b1": (1.0, 1.1, 0.2, 240),
441-
"efficientnet_b2": (1.1, 1.2, 0.3, 260),
498+
"efficientnet_b2": (1.1, 1.2, 0.3, 288),
442499
"efficientnet_b3": (1.2, 1.4, 0.3, 300),
443500
"efficientnet_b4": (1.4, 1.8, 0.4, 380),
444501
"efficientnet_b5": (1.6, 2.2, 0.4, 456),

src/sparseml/pytorch/nn/se.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class SqueezeExcite(Module):
3636
:param expanded_channels: the number of channels to expand to in the SE layer
3737
:param squeezed_channels: the number of channels to squeeze down to in the SE layer
3838
:param act_type: the activation type to use in the SE layer; options:
39-
[relu, relu6, prelu, lrelu, swish]
39+
[relu, relu6, prelu, lrelu, swish, silu]
4040
"""
4141

4242
def __init__(

0 commit comments

Comments
 (0)