Skip to content

Commit f2086f5

Browse files
committed
Add mambaout builder support, pretrained weight remap
1 parent c6ef54e commit f2086f5

File tree

1 file changed

+55
-76
lines changed

1 file changed

+55
-76
lines changed

timm/models/mambaout.py

Lines changed: 55 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,16 @@
55
MetaFormer (https://github.com/sail-sg/metaformer),
66
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
77
"""
8-
from functools import partial
98
from typing import Optional
109

1110
import torch
1211
import torch.nn as nn
13-
import torch.nn.functional as F
14-
from timm.models.layers import trunc_normal_, DropPath, LayerNorm
15-
from .vision_transformer import LayerScale
16-
from ._manipulate import checkpoint_seq
17-
from timm.models.registry import register_model
12+
1813
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
14+
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale
15+
from ._builder import build_model_with_cfg
16+
from ._manipulate import checkpoint_seq
17+
from ._registry import register_model
1918

2019

2120
class Stem(nn.Module):
@@ -275,6 +274,7 @@ def __init__(
275274
act_layer=nn.GELU,
276275
conv_ratio=1.0,
277276
kernel_size=7,
277+
stem_mid_norm=True,
278278
ls_init_value=None,
279279
drop_path_rate=0.,
280280
drop_rate=0.,
@@ -293,7 +293,13 @@ def __init__(
293293
num_stage = len(depths)
294294
self.num_stage = num_stage
295295

296-
self.stem = Stem(in_chans, dims[0], act_layer=act_layer, norm_layer=norm_layer)
296+
self.stem = Stem(
297+
in_chans,
298+
dims[0],
299+
mid_norm=stem_mid_norm,
300+
act_layer=act_layer,
301+
norm_layer=norm_layer,
302+
)
297303
prev_dim = dims[0]
298304
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
299305
self.stages = nn.ModuleList()
@@ -338,7 +344,7 @@ def forward_features(self, x):
338344
x = s(x)
339345
return x
340346

341-
def forward_head(self, x):
347+
def forward_head(self, x, pre_logits: bool = False):
342348
x = x.mean((1, 2))
343349
x = self.norm(x)
344350
x = self.head(x)
@@ -350,6 +356,21 @@ def forward(self, x):
350356
return x
351357

352358

359+
def checkpoint_filter_fn(state_dict, model):
360+
if 'model' in state_dict:
361+
state_dict = state_dict['model']
362+
363+
import re
364+
out_dict = {}
365+
for k, v in state_dict.items():
366+
k = k.replace('downsample_layers.0.', 'stem.')
367+
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
368+
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
369+
out_dict[k] = v
370+
371+
return out_dict
372+
373+
353374
def _cfg(url='', **kwargs):
354375
return {
355376
'url': url,
@@ -376,105 +397,63 @@ def _cfg(url='', **kwargs):
376397
}
377398

378399

400+
def _create_mambaout(variant, pretrained=False, **kwargs):
401+
model = build_model_with_cfg(
402+
MambaOut, variant, pretrained,
403+
pretrained_filter_fn=checkpoint_filter_fn,
404+
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
405+
**kwargs,
406+
)
407+
return model
408+
409+
379410
# a series of MambaOut models
380411
@register_model
381412
def mambaout_femto(pretrained=False, **kwargs):
382-
model = MambaOut(
383-
depths=[3, 3, 9, 3],
384-
dims=[48, 96, 192, 288],
385-
**kwargs)
386-
model.default_cfg = default_cfgs['mambaout_femto']
387-
if pretrained:
388-
state_dict = torch.hub.load_state_dict_from_url(
389-
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
390-
model.load_state_dict(state_dict)
391-
return model
392-
413+
model_args = dict(depths=(3, 3, 9, 3), dims=(48, 96, 192, 288))
414+
return _create_mambaout('mambaout_femto', pretrained=pretrained, **dict(model_args, **kwargs))
393415

394416
# Kobe Memorial Version with 24 Gated CNN blocks
395417
@register_model
396418
def mambaout_kobe(pretrained=False, **kwargs):
397-
model = MambaOut(
398-
depths=[3, 3, 15, 3],
399-
dims=[48, 96, 192, 288],
400-
**kwargs)
401-
model.default_cfg = default_cfgs['mambaout_kobe']
402-
if pretrained:
403-
state_dict = torch.hub.load_state_dict_from_url(
404-
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
405-
model.load_state_dict(state_dict)
406-
return model
407-
419+
model_args = dict(depths=[3, 3, 15, 3], dims=[48, 96, 192, 288])
420+
return _create_mambaout('mambaout_kobe', pretrained=pretrained, **dict(model_args, **kwargs))
408421

409422
@register_model
410423
def mambaout_tiny(pretrained=False, **kwargs):
411-
model = MambaOut(
412-
depths=[3, 3, 9, 3],
413-
dims=[96, 192, 384, 576],
414-
**kwargs)
415-
model.default_cfg = default_cfgs['mambaout_tiny']
416-
if pretrained:
417-
state_dict = torch.hub.load_state_dict_from_url(
418-
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
419-
model.load_state_dict(state_dict)
420-
return model
424+
model_args = dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 576])
425+
return _create_mambaout('mambaout_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
421426

422427

423428
@register_model
424429
def mambaout_small(pretrained=False, **kwargs):
425-
model = MambaOut(
426-
depths=[3, 4, 27, 3],
427-
dims=[96, 192, 384, 576],
428-
**kwargs)
429-
model.default_cfg = default_cfgs['mambaout_small']
430-
if pretrained:
431-
state_dict = torch.hub.load_state_dict_from_url(
432-
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
433-
model.load_state_dict(state_dict)
434-
return model
430+
model_args = dict(depths=[3, 4, 27, 3], dims=[96, 192, 384, 576])
431+
return _create_mambaout('mambaout_small', pretrained=pretrained, **dict(model_args, **kwargs))
435432

436433

437434
@register_model
438435
def mambaout_base(pretrained=False, **kwargs):
439-
model = MambaOut(
440-
depths=[3, 4, 27, 3],
441-
dims=[128, 256, 512, 768],
442-
**kwargs)
443-
model.default_cfg = default_cfgs['mambaout_base']
444-
if pretrained:
445-
state_dict = torch.hub.load_state_dict_from_url(
446-
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
447-
model.load_state_dict(state_dict)
448-
return model
436+
model_args = dict(depths=[3, 4, 27, 3], dims=[128, 256, 512, 768])
437+
return _create_mambaout('mambaout_base', pretrained=pretrained, **dict(model_args, **kwargs))
449438

450439

451440
@register_model
452441
def mambaout_small_rw(pretrained=False, **kwargs):
453-
model = MambaOut(
442+
model_args = dict(
454443
depths=[3, 4, 27, 3],
455444
dims=[96, 192, 384, 576],
445+
stem_mid_norm=False,
456446
ls_init_value=1e-6,
457-
**kwargs,
458447
)
459-
model.default_cfg = default_cfgs['mambaout_small']
460-
if pretrained:
461-
state_dict = torch.hub.load_state_dict_from_url(
462-
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
463-
model.load_state_dict(state_dict)
464-
return model
448+
return _create_mambaout('mambaout_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))
465449

466450

467451
@register_model
468452
def mambaout_base_rw(pretrained=False, **kwargs):
469-
model = MambaOut(
453+
model_args = dict(
470454
depths=(3, 4, 27, 3),
471455
dims=(128, 256, 512, 768),
456+
stem_mid_norm=False,
472457
ls_init_value=1e-6,
473-
**kwargs
474458
)
475-
model.default_cfg = default_cfgs['mambaout_base']
476-
if pretrained:
477-
state_dict = torch.hub.load_state_dict_from_url(
478-
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
479-
model.load_state_dict(state_dict)
480-
return model
459+
return _create_mambaout('mambaout_base_rw', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)