Skip to content

Commit c2da12c

Browse files
committed
Update rw models, fix heads
1 parent f2086f5 commit c2da12c

File tree

1 file changed

+49
-17
lines changed

1 file changed

+49
-17
lines changed

timm/models/mambaout.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from typing import Optional
99

1010
import torch
11-
import torch.nn as nn
11+
from torch import nn
1212

1313
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
14-
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale
14+
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead
1515
from ._builder import build_model_with_cfg
1616
from ._manipulate import checkpoint_seq
1717
from ._registry import register_model
@@ -122,6 +122,7 @@ def __init__(
122122
self,
123123
dim,
124124
num_classes=1000,
125+
pool_type='avg',
125126
act_layer=nn.GELU,
126127
mlp_ratio=4,
127128
norm_layer=LayerNorm,
@@ -130,17 +131,25 @@ def __init__(
130131
):
131132
super().__init__()
132133
hidden_features = int(mlp_ratio * dim)
134+
self.pool_type = pool_type
135+
136+
self.norm1 = norm_layer(dim)
133137
self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
134138
self.act = act_layer()
135-
self.norm = norm_layer(hidden_features)
139+
self.norm2 = norm_layer(hidden_features)
136140
self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
137141
self.head_dropout = nn.Dropout(drop_rate)
138142

139-
def forward(self, x):
143+
def forward(self, x, pre_logits: bool = False):
144+
if self.pool_type == 'avg':
145+
x = x.mean((1, 2))
146+
x = self.norm1(x)
140147
x = self.fc1(x)
141148
x = self.act(x)
142-
x = self.norm(x)
149+
x = self.norm2(x)
143150
x = self.head_dropout(x)
151+
if pre_logits:
152+
return x
144153
x = self.fc2(x)
145154
return x
146155

@@ -208,7 +217,7 @@ def __init__(
208217
expansion_ratio=8 / 3,
209218
kernel_size=7,
210219
conv_ratio=1.0,
211-
downsample: bool = False,
220+
downsample: str = '',
212221
ls_init_value: Optional[float] = None,
213222
norm_layer=LayerNorm,
214223
act_layer=nn.GELU,
@@ -218,8 +227,10 @@ def __init__(
218227
dim_out = dim_out or dim
219228
self.grad_checkpointing = False
220229

221-
if downsample:
230+
if downsample == 'conv':
222231
self.downsample = Downsample(dim, dim_out, norm_layer=norm_layer)
232+
elif downsample == 'conv_nf':
233+
self.downsample = DownsampleNormFirst(dim, dim_out, norm_layer=norm_layer)
223234
else:
224235
assert dim == dim_out
225236
self.downsample = nn.Identity()
@@ -276,10 +287,10 @@ def __init__(
276287
kernel_size=7,
277288
stem_mid_norm=True,
278289
ls_init_value=None,
290+
downsample='conv',
279291
drop_path_rate=0.,
280292
drop_rate=0.,
281-
output_norm=LayerNorm,
282-
head_fn=MlpHead,
293+
head_fn='default',
283294
**kwargs,
284295
):
285296
super().__init__()
@@ -312,7 +323,7 @@ def __init__(
312323
depth=depths[i],
313324
kernel_size=kernel_size,
314325
conv_ratio=conv_ratio,
315-
downsample=i > 0,
326+
downsample=downsample if i > 0 else '',
316327
ls_init_value=ls_init_value,
317328
norm_layer=norm_layer,
318329
act_layer=act_layer,
@@ -322,9 +333,25 @@ def __init__(
322333
prev_dim = dim
323334
cur += depths[i]
324335

325-
self.norm = output_norm(prev_dim)
326-
327-
self.head = head_fn(prev_dim, num_classes, drop_rate=drop_rate)
336+
if head_fn == 'default':
337+
# specific to this model, unusual norm -> pool -> fc -> act -> norm -> fc combo
338+
self.head = MlpHead(
339+
prev_dim,
340+
num_classes,
341+
pool_type='avg',
342+
drop_rate=drop_rate,
343+
norm_layer=norm_layer,
344+
)
345+
else:
346+
# more typical norm -> pool -> fc -> act -> fc
347+
self.head = ClNormMlpClassifierHead(
348+
prev_dim,
349+
num_classes,
350+
hidden_size=int(prev_dim * 4),
351+
pool_type='avg',
352+
norm_layer=norm_layer,
353+
drop_rate=drop_rate,
354+
)
328355

329356
self.apply(self._init_weights)
330357

@@ -336,7 +363,7 @@ def _init_weights(self, m):
336363

337364
@torch.jit.ignore
338365
def no_weight_decay(self):
339-
return {'norm'}
366+
return {}
340367

341368
def forward_features(self, x):
342369
x = self.stem(x)
@@ -345,9 +372,7 @@ def forward_features(self, x):
345372
return x
346373

347374
def forward_head(self, x, pre_logits: bool = False):
348-
x = x.mean((1, 2))
349-
x = self.norm(x)
350-
x = self.head(x)
375+
x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
351376
return x
352377

353378
def forward(self, x):
@@ -366,6 +391,10 @@ def checkpoint_filter_fn(state_dict, model):
366391
k = k.replace('downsample_layers.0.', 'stem.')
367392
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
368393
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
394+
if k.startswith('norm.'):
395+
k = k.replace('norm.', 'head.norm1.')
396+
elif k.startswith('head.norm.'):
397+
k = k.replace('head.norm.', 'head.norm2.')
369398
out_dict[k] = v
370399

371400
return out_dict
@@ -443,7 +472,9 @@ def mambaout_small_rw(pretrained=False, **kwargs):
443472
depths=[3, 4, 27, 3],
444473
dims=[96, 192, 384, 576],
445474
stem_mid_norm=False,
475+
downsample='conv_nf',
446476
ls_init_value=1e-6,
477+
head_fn='norm_mlp',
447478
)
448479
return _create_mambaout('mambaout_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))
449480

@@ -455,5 +486,6 @@ def mambaout_base_rw(pretrained=False, **kwargs):
455486
dims=(128, 256, 512, 768),
456487
stem_mid_norm=False,
457488
ls_init_value=1e-6,
489+
head_fn='norm_mlp',
458490
)
459491
return _create_mambaout('mambaout_base_rw', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)