Skip to content

Commit 4542cf0

Browse files
committed
Add features_only, other bits to mambaout, define different base alternatives
1 parent c2da12c commit 4542cf0

File tree

1 file changed

+95
-24
lines changed

1 file changed

+95
-24
lines changed

timm/models/mambaout.py

Lines changed: 95 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
MetaFormer (https://github.com/sail-sg/metaformer),
66
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
77
"""
8+
from collections import OrderedDict
89
from typing import Optional
910

1011
import torch
@@ -120,7 +121,7 @@ class MlpHead(nn.Module):
120121

121122
def __init__(
122123
self,
123-
dim,
124+
in_features,
124125
num_classes=1000,
125126
pool_type='avg',
126127
act_layer=nn.GELU,
@@ -130,27 +131,47 @@ def __init__(
130131
bias=True,
131132
):
132133
super().__init__()
133-
hidden_features = int(mlp_ratio * dim)
134+
if mlp_ratio is not None:
135+
hidden_size = int(mlp_ratio * in_features)
136+
else:
137+
hidden_size = None
134138
self.pool_type = pool_type
139+
self.in_features = in_features
140+
self.hidden_size = hidden_size or in_features
141+
142+
self.norm = norm_layer(in_features)
143+
if hidden_size:
144+
self.pre_logits = nn.Sequential(OrderedDict([
145+
('fc', nn.Linear(in_features, hidden_size)),
146+
('act', act_layer()),
147+
('norm', norm_layer(hidden_size))
148+
]))
149+
self.num_features = hidden_size
150+
else:
151+
self.num_features = in_features
152+
self.pre_logits = nn.Identity()
135153

136-
self.norm1 = norm_layer(dim)
137-
self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
138-
self.act = act_layer()
139-
self.norm2 = norm_layer(hidden_features)
140-
self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
154+
self.fc = nn.Linear(hidden_size, num_classes, bias=bias)
141155
self.head_dropout = nn.Dropout(drop_rate)
142156

157+
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
158+
if pool_type is not None:
159+
self.pool_type = pool_type
160+
if reset_other:
161+
self.norm = nn.Identity()
162+
self.pre_logits = nn.Identity()
163+
self.num_features = self.in_features
164+
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
165+
143166
def forward(self, x, pre_logits: bool = False):
144167
if self.pool_type == 'avg':
145168
x = x.mean((1, 2))
146-
x = self.norm1(x)
147-
x = self.fc1(x)
148-
x = self.act(x)
149-
x = self.norm2(x)
169+
x = self.norm(x)
170+
x = self.pre_logits(x)
150171
x = self.head_dropout(x)
151172
if pre_logits:
152173
return x
153-
x = self.fc2(x)
174+
x = self.fc(x)
154175
return x
155176

156177

@@ -284,6 +305,7 @@ def __init__(
284305
norm_layer=LayerNorm,
285306
act_layer=nn.GELU,
286307
conv_ratio=1.0,
308+
expansion_ratio=8/3,
287309
kernel_size=7,
288310
stem_mid_norm=True,
289311
ls_init_value=None,
@@ -303,6 +325,7 @@ def __init__(
303325

304326
num_stage = len(depths)
305327
self.num_stage = num_stage
328+
self.feature_info = []
306329

307330
self.stem = Stem(
308331
in_chans,
@@ -313,16 +336,20 @@ def __init__(
313336
)
314337
prev_dim = dims[0]
315338
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
316-
self.stages = nn.ModuleList()
317339
cur = 0
340+
curr_stride = 4
341+
self.stages = nn.Sequential()
318342
for i in range(num_stage):
319343
dim = dims[i]
344+
stride = 2 if curr_stride == 2 or i > 0 else 1
345+
curr_stride *= stride
320346
stage = MambaOutStage(
321347
dim=prev_dim,
322348
dim_out=dim,
323349
depth=depths[i],
324350
kernel_size=kernel_size,
325351
conv_ratio=conv_ratio,
352+
expansion_ratio=expansion_ratio,
326353
downsample=downsample if i > 0 else '',
327354
ls_init_value=ls_init_value,
328355
norm_layer=norm_layer,
@@ -331,6 +358,8 @@ def __init__(
331358
)
332359
self.stages.append(stage)
333360
prev_dim = dim
361+
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
362+
self.feature_info += [dict(num_chs=prev_dim, reduction=curr_stride, module=f'stages.{i}')]
334363
cur += depths[i]
335364

336365
if head_fn == 'default':
@@ -352,6 +381,8 @@ def __init__(
352381
norm_layer=norm_layer,
353382
drop_rate=drop_rate,
354383
)
384+
self.num_features = prev_dim
385+
self.hidden_size = self.head.num_features
355386

356387
self.apply(self._init_weights)
357388

@@ -362,13 +393,31 @@ def _init_weights(self, m):
362393
nn.init.constant_(m.bias, 0)
363394

364395
@torch.jit.ignore
365-
def no_weight_decay(self):
366-
return {}
396+
def group_matcher(self, coarse=False):
397+
return dict(
398+
stem=r'^stem',
399+
blocks=r'^stages\.(\d+)' if coarse else [
400+
(r'^stages\.(\d+)\.downsample', (0,)), # blocks
401+
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
402+
]
403+
)
404+
405+
@torch.jit.ignore
406+
def set_grad_checkpointing(self, enable=True):
407+
for s in self.stages:
408+
s.grad_checkpointing = enable
409+
410+
@torch.jit.ignore
411+
def get_classifier(self) -> nn.Module:
412+
return self.head.fc
413+
414+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
415+
self.num_classes = num_classes
416+
self.head.reset(num_classes, global_pool)
367417

368418
def forward_features(self, x):
369419
x = self.stem(x)
370-
for s in self.stages:
371-
x = s(x)
420+
x = self.stages(x)
372421
return x
373422

374423
def forward_head(self, x, pre_logits: bool = False):
@@ -391,10 +440,14 @@ def checkpoint_filter_fn(state_dict, model):
391440
k = k.replace('downsample_layers.0.', 'stem.')
392441
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
393442
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
443+
# remap head names
394444
if k.startswith('norm.'):
395-
k = k.replace('norm.', 'head.norm1.')
396-
elif k.startswith('head.norm.'):
397-
k = k.replace('head.norm.', 'head.norm2.')
445+
# this is moving to head since it's after the pooling
446+
k = k.replace('norm.', 'head.norm.')
447+
elif k.startswith('head.'):
448+
k = k.replace('head.fc1.', 'head.pre_logits.fc.')
449+
k = k.replace('head.norm.', 'head.pre_logits.norm.')
450+
k = k.replace('head.fc2.', 'head.fc.')
398451
out_dict[k] = v
399452

400453
return out_dict
@@ -405,7 +458,7 @@ def _cfg(url='', **kwargs):
405458
'url': url,
406459
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
407460
'crop_pct': 1.0, 'interpolation': 'bicubic',
408-
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
461+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head.fc',
409462
**kwargs
410463
}
411464

@@ -422,7 +475,8 @@ def _cfg(url='', **kwargs):
422475
'mambaout_base': _cfg(
423476
url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_base.pth'),
424477
'mambaout_small_rw': _cfg(),
425-
'mambaout_base_rw': _cfg(),
478+
'mambaout_base_slim_rw': _cfg(),
479+
'mambaout_base_plus_rw': _cfg(),
426480
}
427481

428482

@@ -480,12 +534,29 @@ def mambaout_small_rw(pretrained=False, **kwargs):
480534

481535

482536
@register_model
483-
def mambaout_base_rw(pretrained=False, **kwargs):
537+
def mambaout_base_slim_rw(pretrained=False, **kwargs):
484538
model_args = dict(
485539
depths=(3, 4, 27, 3),
486540
dims=(128, 256, 512, 768),
541+
expansion_ratio=2.5,
542+
conv_ratio=1.25,
487543
stem_mid_norm=False,
544+
downsample='conv_nf',
545+
ls_init_value=1e-6,
546+
head_fn='norm_mlp',
547+
)
548+
return _create_mambaout('mambaout_base_slim_rw', pretrained=pretrained, **dict(model_args, **kwargs))
549+
550+
551+
@register_model
552+
def mambaout_base_plus_rw(pretrained=False, **kwargs):
553+
model_args = dict(
554+
depths=(3, 4, 27, 3),
555+
dims=(128, 256, 512, 768),
556+
expansion_ratio=3.0,
557+
stem_mid_norm=False,
558+
downsample='conv_nf',
488559
ls_init_value=1e-6,
489560
head_fn='norm_mlp',
490561
)
491-
return _create_mambaout('mambaout_base_rw', pretrained=pretrained, **dict(model_args, **kwargs))
562+
return _create_mambaout('mambaout_base_plus_rw', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)