From f5d70c2a1a410dd2fc3e8f48f70fde87a5dddab7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 3 Dec 2025 13:41:55 -0800 Subject: [PATCH 1/4] Switcth to 2.9.1 w/ python 3.13 for upper versions in unit tests --- .github/workflows/tests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 393e862b7a..7dd1ded48e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,11 +16,11 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python: ['3.10', '3.12'] - torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.5.1', vision: '0.20.1'}] + python: ['3.10', '3.13'] + torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.9.1', vision: '0.24.1'}] testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward'] exclude: - - python: '3.12' + - python: '3.13' torch: {base: '1.13.0', vision: '0.14.0'} runs-on: ${{ matrix.os }} From 31f665187dc5071318d17ff4ce3563e3f437d26b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 3 Dec 2025 13:46:52 -0800 Subject: [PATCH 2/4] Add some addition attention impl, inspired by PR #2048. Add some basic unit tests to cover broader timm pooling options. Make ROPE pooling more flexible with latest RotaryEmbedding impl. --- tests/test_layers_pool.py | 435 ++++++++++++++++++++++++++++++++ timm/layers/__init__.py | 4 +- timm/layers/attention_pool2d.py | 24 +- timm/layers/other_pool.py | 290 +++++++++++++++++++++ timm/layers/pos_embed_sincos.py | 6 + timm/layers/slot_pool.py | 254 +++++++++++++++++++ 6 files changed, 1007 insertions(+), 6 deletions(-) create mode 100644 tests/test_layers_pool.py create mode 100644 timm/layers/other_pool.py create mode 100644 timm/layers/slot_pool.py diff --git a/tests/test_layers_pool.py b/tests/test_layers_pool.py new file mode 100644 index 0000000000..b8f5908454 --- /dev/null +++ b/tests/test_layers_pool.py @@ -0,0 +1,435 @@ +"""Tests for timm pooling layers.""" +import pytest +import torch +import torch.nn as nn + +import importlib +import os + +torch_backend = os.environ.get('TORCH_BACKEND') +if torch_backend is not None: + importlib.import_module(torch_backend) +torch_device = os.environ.get('TORCH_DEVICE', 'cpu') + + +# Adaptive Avg/Max Pooling Tests + +class TestAdaptiveAvgMaxPool: + """Test adaptive_avgmax_pool module.""" + + def test_adaptive_avgmax_pool2d(self): + from timm.layers import adaptive_avgmax_pool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + out = adaptive_avgmax_pool2d(x, 1) + assert out.shape == (2, 64, 1, 1) + # Should be average of avg and max + expected = 0.5 * (x.mean(dim=(2, 3), keepdim=True) + x.amax(dim=(2, 3), keepdim=True)) + assert torch.allclose(out, expected) + + def test_select_adaptive_pool2d(self): + from timm.layers import select_adaptive_pool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + + out_avg = select_adaptive_pool2d(x, pool_type='avg', output_size=1) + assert out_avg.shape == (2, 64, 1, 1) + assert torch.allclose(out_avg, x.mean(dim=(2, 3), keepdim=True)) + + out_max = select_adaptive_pool2d(x, pool_type='max', output_size=1) + assert out_max.shape == (2, 64, 1, 1) + assert torch.allclose(out_max, x.amax(dim=(2, 3), keepdim=True)) + + def test_adaptive_avgmax_pool2d_module(self): + from timm.layers import AdaptiveAvgMaxPool2d + x = torch.randn(2, 64, 14, 14, device=torch_device) + pool = AdaptiveAvgMaxPool2d(output_size=1).to(torch_device) + out = pool(x) + assert out.shape == (2, 64, 1, 1) + + def test_select_adaptive_pool2d_module(self): + from timm.layers import SelectAdaptivePool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + + for pool_type in ['avg', 'max', 'avgmax', 'catavgmax']: + pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True).to(torch_device) + out = pool(x) + if pool_type == 'catavgmax': + assert out.shape == (2, 128) # concatenated + else: + assert out.shape == (2, 64) + + def test_select_adaptive_pool2d_fast(self): + from timm.layers import SelectAdaptivePool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + + for pool_type in ['fast', 'fastavg', 'fastmax', 'fastavgmax', 'fastcatavgmax']: + pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True).to(torch_device) + out = pool(x) + if 'cat' in pool_type: + assert out.shape == (2, 128) + else: + assert out.shape == (2, 64) + + +# Attention Pool Tests + +class TestAttentionPool: + """Test attention-based pooling layers.""" + + def test_attention_pool_latent_basic(self): + from timm.layers import AttentionPoolLatent + x = torch.randn(2, 49, 64, device=torch_device) + pool = AttentionPoolLatent(in_features=64, num_heads=4).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_attention_pool_latent_multi_latent(self): + from timm.layers import AttentionPoolLatent + x = torch.randn(2, 49, 64, device=torch_device) + pool = AttentionPoolLatent( + in_features=64, + num_heads=4, + latent_len=4, + pool_type='avg', + ).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_attention_pool2d_basic(self): + from timm.layers import AttentionPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = AttentionPool2d(in_features=64, feat_size=7).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_attention_pool2d_different_feat_size(self): + from timm.layers import AttentionPool2d + # Test with different spatial sizes (requires pos_embed interpolation) + pool = AttentionPool2d(in_features=64, feat_size=7).to(torch_device) + for size in [7, 14]: + x = torch.randn(2, 64, size, size, device=torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_rot_attention_pool2d_basic(self): + from timm.layers import RotAttentionPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = RotAttentionPool2d(in_features=64, ref_feat_size=7).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_rot_attention_pool2d_different_sizes(self): + from timm.layers import RotAttentionPool2d + pool = RotAttentionPool2d(in_features=64, ref_feat_size=7).to(torch_device) + for size in [7, 14, 10]: + x = torch.randn(2, 64, size, size, device=torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_rot_attention_pool2d_rope_types(self): + from timm.layers import RotAttentionPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + for rope_type in ['base', 'cat', 'dinov3']: + pool = RotAttentionPool2d( + in_features=64, + ref_feat_size=7, + rope_type=rope_type, + ).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + +# LSE Pool Tests + +class TestLsePool: + """Test LogSumExp pooling layers.""" + + def test_lse_plus_2d_basic(self): + from timm.layers import LsePlus2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = LsePlus2d().to(torch_device) + out = pool(x) + assert out.shape == (2, 64, 1, 1) + + def test_lse_plus_2d_flatten(self): + from timm.layers import LsePlus2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = LsePlus2d(flatten=True).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_lse_plus_1d_basic(self): + from timm.layers import LsePlus1d + x = torch.randn(2, 49, 64, device=torch_device) + pool = LsePlus1d().to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_lse_high_r_approximates_max(self): + from timm.layers import LsePlus2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = LsePlus2d(r=100.0, r_learnable=False).to(torch_device) + out = pool(x) + out_max = x.amax(dim=(2, 3), keepdim=True) + assert torch.allclose(out, out_max, atol=0.1) + + def test_lse_low_r_approximates_avg(self): + from timm.layers import LsePlus2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = LsePlus2d(r=0.01, r_learnable=False).to(torch_device) + out = pool(x) + out_avg = x.mean(dim=(2, 3), keepdim=True) + assert torch.allclose(out, out_avg, atol=0.1) + + def test_lse_learnable_r_gradient(self): + from timm.layers import LsePlus2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = LsePlus2d(r=10.0, r_learnable=True).to(torch_device) + out = pool(x).sum() + out.backward() + assert pool.r.grad is not None + assert pool.r.grad.abs() > 0 + + +# SimPool Tests + +class TestSimPool: + """Test SimPool attention-based pooling layers.""" + + def test_simpool_2d_basic(self): + from timm.layers import SimPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = SimPool2d(dim=64).to(torch_device) + out = pool(x) + assert out.shape == (2, 1, 64) + + def test_simpool_2d_flatten(self): + from timm.layers import SimPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = SimPool2d(dim=64, flatten=True).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_simpool_1d_basic(self): + from timm.layers import SimPool1d + x = torch.randn(2, 49, 64, device=torch_device) + pool = SimPool1d(dim=64).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_simpool_multi_head(self): + from timm.layers import SimPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + for num_heads in [1, 2, 4, 8]: + pool = SimPool2d(dim=64, num_heads=num_heads, flatten=True).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_simpool_with_gamma(self): + from timm.layers import SimPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = SimPool2d(dim=64, gamma=2.0, flatten=True).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + assert not torch.isnan(out).any() + + def test_simpool_qk_norm(self): + from timm.layers import SimPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = SimPool2d(dim=64, qk_norm=True, flatten=True).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + +# Slot Pool Tests + +class TestSlotPool: + """Test Slot Attention pooling layers.""" + + def test_slot_pool_basic(self): + from timm.layers import SlotPool + x = torch.randn(2, 49, 64, device=torch_device) + pool = SlotPool(dim=64).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_slot_pool_2d_basic(self): + from timm.layers import SlotPool2d + x = torch.randn(2, 64, 7, 7, device=torch_device) + pool = SlotPool2d(dim=64).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_slot_pool_multi_slot(self): + from timm.layers import SlotPool + x = torch.randn(2, 49, 64, device=torch_device) + for num_slots in [1, 2, 4, 8]: + pool = SlotPool(dim=64, num_slots=num_slots).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_slot_pool_iterations(self): + from timm.layers import SlotPool + x = torch.randn(2, 49, 64, device=torch_device) + for iters in [1, 2, 3, 5]: + pool = SlotPool(dim=64, iters=iters).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_slot_pool_pool_types(self): + from timm.layers import SlotPool + x = torch.randn(2, 49, 64, device=torch_device) + for pool_type in ['max', 'avg', 'first']: + pool = SlotPool(dim=64, num_slots=4, pool_type=pool_type).to(torch_device) + out = pool(x) + assert out.shape == (2, 64) + + def test_slot_pool_stochastic_train_mode(self): + from timm.layers import SlotPool + x = torch.randn(2, 49, 64, device=torch_device) + pool = SlotPool(dim=64, stochastic_init=True).to(torch_device) + pool.train() + out1 = pool(x) + out2 = pool(x) + # Should differ in train mode with stochastic init + assert not torch.allclose(out1, out2) + + def test_slot_pool_stochastic_eval_mode(self): + from timm.layers import SlotPool + x = torch.randn(2, 49, 64, device=torch_device) + pool = SlotPool(dim=64, stochastic_init=True).to(torch_device) + pool.eval() + out1 = pool(x) + out2 = pool(x) + # Should be deterministic in eval mode + assert torch.allclose(out1, out2) + + +# Common Tests (Gradient, JIT, dtype) + +class TestPoolingCommon: + """Common tests across all pooling layers.""" + + @pytest.mark.parametrize('pool_cls,kwargs,input_shape', [ + ('LsePlus2d', {}, (2, 64, 7, 7)), + ('LsePlus1d', {}, (2, 49, 64)), + ('SimPool2d', {'dim': 64}, (2, 64, 7, 7)), + ('SimPool1d', {'dim': 64}, (2, 49, 64)), + ('SlotPool', {'dim': 64}, (2, 49, 64)), + ('SlotPool2d', {'dim': 64}, (2, 64, 7, 7)), + ('SelectAdaptivePool2d', {'pool_type': 'avg', 'flatten': True}, (2, 64, 7, 7)), + ('AttentionPoolLatent', {'in_features': 64, 'num_heads': 4}, (2, 49, 64)), + ('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)), + ('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)), + ]) + def test_gradient_flow(self, pool_cls, kwargs, input_shape): + import timm.layers as layers + x = torch.randn(*input_shape, device=torch_device, requires_grad=True) + pool = getattr(layers, pool_cls)(**kwargs).to(torch_device) + out = pool(x) + loss = out.sum() + loss.backward() + assert x.grad is not None + assert x.grad.abs().sum() > 0 + + @pytest.mark.parametrize('pool_cls,kwargs,input_shape', [ + ('LsePlus2d', {}, (2, 64, 7, 7)), + ('LsePlus1d', {}, (2, 49, 64)), + ('SimPool2d', {'dim': 64}, (2, 64, 7, 7)), + ('SimPool1d', {'dim': 64}, (2, 49, 64)), + ('SlotPool', {'dim': 64, 'iters': 2}, (2, 49, 64)), + ('SlotPool2d', {'dim': 64, 'iters': 2}, (2, 64, 7, 7)), + ('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)), + ('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)), + ]) + def test_torchscript(self, pool_cls, kwargs, input_shape): + import timm.layers as layers + x = torch.randn(*input_shape, device=torch_device) + pool = getattr(layers, pool_cls)(**kwargs).to(torch_device) + pool.eval() + scripted = torch.jit.script(pool) + out_orig = pool(x) + out_script = scripted(x) + assert torch.allclose(out_orig, out_script, atol=1e-5) + + @pytest.mark.parametrize('pool_cls,kwargs,input_shape', [ + ('LsePlus2d', {'flatten': True}, (2, 64, 7, 7)), + ('LsePlus1d', {}, (2, 49, 64)), + ('SimPool2d', {'dim': 64, 'flatten': True}, (2, 64, 7, 7)), + ('SimPool1d', {'dim': 64}, (2, 49, 64)), + ('SlotPool', {'dim': 64}, (2, 49, 64)), + ('SlotPool2d', {'dim': 64}, (2, 64, 7, 7)), + ('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)), + ('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)), + ]) + def test_eval_deterministic(self, pool_cls, kwargs, input_shape): + import timm.layers as layers + x = torch.randn(*input_shape, device=torch_device) + pool = getattr(layers, pool_cls)(**kwargs).to(torch_device) + pool.eval() + with torch.no_grad(): + out1 = pool(x) + out2 = pool(x) + assert torch.allclose(out1, out2) + + @pytest.mark.parametrize('pool_cls,kwargs,input_shape', [ + ('LsePlus2d', {'flatten': True}, (2, 64, 7, 7)), + ('SimPool2d', {'dim': 64, 'flatten': True}, (2, 64, 7, 7)), + ('SlotPool2d', {'dim': 64}, (2, 64, 7, 7)), + ('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)), + ]) + def test_different_spatial_sizes(self, pool_cls, kwargs, input_shape): + import timm.layers as layers + B, C, _, _ = input_shape + pool = getattr(layers, pool_cls)(**kwargs).to(torch_device) + for H, W in [(7, 7), (14, 14), (1, 1), (3, 5)]: + x = torch.randn(B, C, H, W, device=torch_device) + out = pool(x) + assert out.shape[0] == B + assert out.shape[-1] == C + + +# BlurPool Tests + +class TestBlurPool: + """Test BlurPool anti-aliasing layer.""" + + def test_blur_pool_2d_basic(self): + from timm.layers import BlurPool2d + x = torch.randn(2, 64, 14, 14, device=torch_device) + pool = BlurPool2d(channels=64).to(torch_device) + out = pool(x) + assert out.shape == (2, 64, 7, 7) + + def test_blur_pool_2d_stride(self): + from timm.layers import BlurPool2d + x = torch.randn(2, 64, 28, 28, device=torch_device) + pool = BlurPool2d(channels=64, stride=4).to(torch_device) + out = pool(x) + assert out.shape == (2, 64, 8, 8) + + +# Pool1d Tests + +class TestPool1d: + """Test 1D pooling utilities.""" + + def test_global_pool_nlc(self): + from timm.layers import global_pool_nlc + x = torch.randn(2, 49, 64, device=torch_device) + + # By default, avg/max excludes first token (num_prefix_tokens=1) + out_avg = global_pool_nlc(x, pool_type='avg') + assert out_avg.shape == (2, 64) + assert torch.allclose(out_avg, x[:, 1:].mean(dim=1)) + + out_max = global_pool_nlc(x, pool_type='max') + assert out_max.shape == (2, 64) + assert torch.allclose(out_max, x[:, 1:].amax(dim=1)) + + out_first = global_pool_nlc(x, pool_type='token') + assert out_first.shape == (2, 64) + assert torch.allclose(out_first, x[:, 0]) + + # Test with reduce_include_prefix=True + out_avg_all = global_pool_nlc(x, pool_type='avg', reduce_include_prefix=True) + assert torch.allclose(out_avg_all, x.mean(dim=1)) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index f40a1f77cb..e01abaadd3 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -18,7 +18,7 @@ from .attention import Attention, AttentionRope, maybe_add_mask from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2 from .attention_pool import AttentionPoolLatent -from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding +from .attention_pool2d import AttentionPool2d, RotAttentionPool2d from .blur_pool import BlurPool2d, create_aa from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead from .cond_conv2d import CondConv2d, get_condconv_initializer @@ -107,7 +107,9 @@ from .patch_dropout import PatchDropout, PatchDropoutWithIndices, patch_dropout_forward from .patch_embed import PatchEmbed, PatchEmbedWithSize, PatchEmbedInterpolator, resample_patch_embed from .pool1d import global_pool_nlc +from .other_pool import LsePlus2d, LsePlus1d, SimPool2d, SimPool1d from .pool2d_same import AvgPool2dSame, create_pool2d +from .slot_pool import SlotPool, SlotPool2d from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc from .pos_embed_rel import ( RelPosMlp, diff --git a/timm/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py index cc26aecdf4..6a813630d3 100644 --- a/timm/layers/attention_pool2d.py +++ b/timm/layers/attention_pool2d.py @@ -15,7 +15,7 @@ from .config import use_fused_attn from .helpers import to_2tuple from .pos_embed import resample_abs_pos_embed -from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding +from .pos_embed_sincos import apply_rot_embed_cat, create_rope_embed from .weight_init import trunc_normal_ @@ -44,6 +44,7 @@ def __init__( pool_type: str = 'token', class_token: bool = False, drop_rate: float = 0., + rope_type: str = 'cat', device=None, dtype=None, ): @@ -65,6 +66,7 @@ def __init__( self.pool_type = pool_type.lower() self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() + self.rope_type = rope_type if class_token: self.cls_token = nn.Parameter(torch.zeros(1, embed_dim, **dd)) @@ -80,7 +82,16 @@ def __init__( self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias, **dd) self.drop = nn.Dropout(drop_rate) self.proj = nn.Linear(embed_dim, self.out_features, **dd) - self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size, **dd) + + self.pos_embed = create_rope_embed( + rope_type=rope_type, + dim=embed_dim, + num_heads=num_heads, + in_pixels=False, + ref_feat_shape=ref_feat_size, + rotate_half=False, + **dd, + ) def init_weights(self, zero_init_last: bool = False): if self.qkv is None: @@ -129,9 +140,12 @@ def forward(self, x, pre_logits: bool = False): x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = x.unbind(0) - rse, rce = self.pos_embed.get_embed((H, W)) - q = torch.cat([q[:, :, :1, :], apply_rot_embed(q[:, :, 1:, :], rse, rce)], dim=2).type_as(v) - k = torch.cat([k[:, :, :1, :], apply_rot_embed(k[:, :, 1:, :], rse, rce)], dim=2).type_as(v) + rope = self.pos_embed.get_embed((H, W)) + if isinstance(rope, tuple): + # RotaryEmbedding returns (sin, cos) tuple - concatenate for apply_rot_embed_cat + rope = torch.cat(rope, dim=-1) + q = torch.cat([q[:, :, :1, :], apply_rot_embed_cat(q[:, :, 1:, :], rope)], dim=2).type_as(v) + k = torch.cat([k[:, :, :1, :], apply_rot_embed_cat(k[:, :, 1:, :], rope)], dim=2).type_as(v) if self.fused_attn: x = nn.functional.scaled_dot_product_attention(q, k, v) diff --git a/timm/layers/other_pool.py b/timm/layers/other_pool.py new file mode 100644 index 0000000000..716f4ca638 --- /dev/null +++ b/timm/layers/other_pool.py @@ -0,0 +1,290 @@ +""" Non-Local Attention Pooling Layers + +A collection of global pooling layers that go beyond simple avg/max pooling. + +LSEPool - LogSumExp pooling, a smooth approximation between avg and max pooling +SimPool - Attention-based pooling from 'Keep It SimPool' (ICCV 2023) + +Based on implementations from: +* LSE Pooling: custom implementation by Bill Psomas +* SimPool: https://arxiv.org/abs/2309.06891 - 'Keep It SimPool: Who Said Supervised Transformers + Suffer from Attention Deficit?' by Bill Psomas et al. + +Hacked together by / Copyright 2024 Ross Wightman, original code by Bill Psomas +""" +from typing import Optional, Type, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .config import use_fused_attn + + +class LsePlus2d(nn.Module): + """LogSumExp (LSE) Pooling for 2D inputs. + + A smooth approximation to max pooling that provides a learnable interpolation between + average and max pooling. When r is large, LSE approaches max pooling; when r is small, + it approaches average pooling. + + Implements: (1/r) * log((1/n) * sum(exp(r * (x - x_max)))) + x_max + + The x_max subtraction provides numerical stability. + """ + + def __init__( + self, + r: float = 10.0, + r_learnable: bool = True, + flatten: bool = False, + device=None, + dtype=None, + ): + """ + Args: + r: Initial value of the pooling parameter. Higher = closer to max pooling. + r_learnable: If True, r is a learnable parameter. + flatten: If True, flatten spatial dims in output. + """ + super().__init__() + if r_learnable: + self.r = nn.Parameter(torch.tensor(r, device=device, dtype=dtype)) + else: + self.register_buffer('r', torch.tensor(r, device=device, dtype=dtype)) + self.flatten = flatten + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_max = F.adaptive_max_pool2d(x, 1) + exp_x = torch.exp(self.r * (x - x_max)) + sum_exp = exp_x.mean(dim=(2, 3), keepdim=True) + out = x_max + (1.0 / self.r) * torch.log(sum_exp) + if self.flatten: + out = out.flatten(1) + return out + + +class LsePlus1d(nn.Module): + """LogSumExp (LSE) Pooling for sequence (NLC) inputs. + + A smooth approximation to max pooling that provides a learnable interpolation between + average and max pooling. When r is large, LSE approaches max pooling; when r is small, + it approaches average pooling. + """ + + def __init__( + self, + r: float = 10.0, + r_learnable: bool = True, + device=None, + dtype=None, + ): + """ + Args: + r: Initial value of the pooling parameter. Higher = closer to max pooling. + r_learnable: If True, r is a learnable parameter. + """ + super().__init__() + if r_learnable: + self.r = nn.Parameter(torch.tensor(r, device=device, dtype=dtype)) + else: + self.register_buffer('r', torch.tensor(r, device=device, dtype=dtype)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, N, C) + x_max = x.max(dim=1, keepdim=True).values + exp_x = torch.exp(self.r * (x - x_max)) + sum_exp = exp_x.mean(dim=1, keepdim=True) + out = x_max + (1.0 / self.r) * torch.log(sum_exp) + return out.squeeze(1) # (B, C) + + +class SimPool2d(nn.Module): + """SimPool: Simple Attention-Based Pooling for 2D (NCHW) inputs. + + From 'Keep It SimPool: Who Said Supervised Transformers Suffer from Attention Deficit?' + https://arxiv.org/abs/2309.06891 + + Uses GAP as query initialization and applies cross-attention between the GAP query + and spatial features to produce a weighted pooled representation. + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 1, + qkv_bias: bool = False, + qk_norm: bool = False, + gamma: Optional[float] = None, + norm_layer: Optional[Type[nn.Module]] = None, + flatten: bool = False, + device=None, + dtype=None, + ): + """ + Args: + dim: Input feature dimension (number of channels). + num_heads: Number of attention heads. + qkv_bias: If True, add bias to query and key projections. + qk_norm: If True, apply normalization to queries and keys. + gamma: If provided, apply power normalization to values with this exponent. + norm_layer: Normalization layer for patches and optionally qk_norm. + flatten: If True, flatten output to (B, C). + """ + super().__init__() + dd = {'device': device, 'dtype': dtype} + assert dim % num_heads == 0, 'dim must be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.gamma = gamma + self.flatten = flatten + self.fused_attn = use_fused_attn() + + norm_layer = norm_layer or nn.LayerNorm + self.norm = norm_layer(dim, **dd) + self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd) + self.k = nn.Linear(dim, dim, bias=qkv_bias, **dd) + if qk_norm: + self.q_norm = norm_layer(self.head_dim, **dd) + self.k_norm = norm_layer(self.head_dim, **dd) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + N = H * W + + # Reshape to (B, N, C) for attention + x = x.flatten(2).transpose(1, 2) # (B, N, C) + + # GAP as query initialization + q = x.mean(dim=1, keepdim=True) # (B, 1, C) + + # Normalize patches for keys and values + x_norm = self.norm(x) + + # Project query and keys + q = self.q(q).reshape(B, 1, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k(x_norm).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + v = x_norm.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + + q, k = self.q_norm(q), self.k_norm(k) + + if self.gamma is not None: + # Power normalization on values + v_min = v.amin(dim=-2, keepdim=True) + v_shifted = v - v_min + 1e-6 + if self.fused_attn: + attn_out = F.scaled_dot_product_attention(q, k, v_shifted.pow(self.gamma)) + else: + attn = (q * self.scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn_out = attn @ v_shifted.pow(self.gamma) + out = attn_out.pow(1.0 / self.gamma) + else: + if self.fused_attn: + out = F.scaled_dot_product_attention(q, k, v) + else: + attn = (q * self.scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + out = attn @ v + + # (B, num_heads, 1, head_dim) -> (B, C) or (B, 1, C) + out = out.transpose(1, 2).reshape(B, 1, C) + if self.flatten: + out = out.squeeze(1) + return out + + +class SimPool1d(nn.Module): + """SimPool: Simple Attention-Based Pooling for sequence (NLC) inputs. + + From 'Keep It SimPool: Who Said Supervised Transformers Suffer from Attention Deficit?' + https://arxiv.org/abs/2309.06891 + + Uses GAP as query initialization and applies cross-attention between the GAP query + and sequence tokens to produce a weighted pooled representation. + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 1, + qkv_bias: bool = False, + qk_norm: bool = False, + gamma: Optional[float] = None, + norm_layer: Optional[Type[nn.Module]] = None, + device=None, + dtype=None, + ): + """ + Args: + dim: Input feature dimension. + num_heads: Number of attention heads. + qkv_bias: If True, add bias to query and key projections. + qk_norm: If True, apply normalization to queries and keys. + gamma: If provided, apply power normalization to values with this exponent. + norm_layer: Normalization layer for tokens and optionally qk_norm. + """ + super().__init__() + dd = {'device': device, 'dtype': dtype} + assert dim % num_heads == 0, 'dim must be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.gamma = gamma + self.fused_attn = use_fused_attn() + + norm_layer = norm_layer or nn.LayerNorm + self.norm = norm_layer(dim, **dd) + self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd) + self.k = nn.Linear(dim, dim, bias=qkv_bias, **dd) + if qk_norm: + self.q_norm = norm_layer(self.head_dim, **dd) + self.k_norm = norm_layer(self.head_dim, **dd) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + + # GAP as query initialization + q = x.mean(dim=1, keepdim=True) # (B, 1, C) + + # Normalize tokens for keys and values + x_norm = self.norm(x) + + # Project query and keys + q = self.q(q).reshape(B, 1, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k(x_norm).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + v = x_norm.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + + q, k = self.q_norm(q), self.k_norm(k) + + if self.gamma is not None: + # Power normalization on values + v_min = v.amin(dim=-2, keepdim=True) + v_shifted = v - v_min + 1e-6 + if self.fused_attn: + attn_out = F.scaled_dot_product_attention(q, k, v_shifted.pow(self.gamma)) + else: + attn = (q * self.scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn_out = attn @ v_shifted.pow(self.gamma) + out = attn_out.pow(1.0 / self.gamma) + else: + if self.fused_attn: + out = F.scaled_dot_product_attention(q, k, v) + else: + attn = (q * self.scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + out = attn @ v + + # (B, num_heads, 1, head_dim) -> (B, C) + out = out.transpose(1, 2).reshape(B, C) + return out diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 9d314e3a26..dde0d5e1a6 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -1164,13 +1164,19 @@ def create_rope_embed( Rotary embedding module """ if rope_type == 'base': + kwargs.pop('rotate_half', None) # doesn't support return RotaryEmbedding(dim=dim // num_heads, **kwargs) elif rope_type == 'cat': + kwargs.pop('rotate_half', None) # doesn't support return RotaryEmbeddingCat(dim=dim // num_heads, **kwargs) elif rope_type == 'mixed': # Mixed requires depth parameter, generates differing embeddings per layer and head + kwargs.pop('in_pixels', None) # doesn't support + kwargs.pop('ref_feat_shape', None) # doesn't support return RotaryEmbeddingMixed(dim=dim, num_heads=num_heads, **kwargs) elif rope_type == 'dinov3': + kwargs.pop('in_pixels', None) # doesn't support + kwargs.pop('ref_feat_shape', None) # doesn't support return RotaryEmbeddingDinoV3(dim=dim // num_heads, **kwargs) else: raise ValueError(f"Unknown RoPE type: {rope_type}") diff --git a/timm/layers/slot_pool.py b/timm/layers/slot_pool.py new file mode 100644 index 0000000000..ddb30c5fe5 --- /dev/null +++ b/timm/layers/slot_pool.py @@ -0,0 +1,254 @@ +""" Slot Attention Pooling + +Slot Attention mechanism adapted for use as a pooling layer. + +Based on 'Object-Centric Learning with Slot Attention' by Locatello et al. +https://arxiv.org/abs/2006.15055 + +Original implementation reference: +https://github.com/lucidrains/slot-attention (MIT License) + +Adapted for timm by Ross Wightman, original PR code by Bill Psomas +""" +from typing import Optional, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .config import use_fused_attn +from .mlp import Mlp + + +class SlotPool(nn.Module): + """Slot Attention pooling module. + + Unlike standard attention pooling, Slot Attention uses iterative refinement + with competition between slots. The softmax is applied over slots (not keys), + causing slots to compete for explaining input locations. + + This creates a soft clustering effect where each slot specializes to different + parts of the input, useful for object-centric representations. + + For standard pooling use cases, set num_slots=1 and iters=1 to get behavior + closer to AttentionPoolLatent but with the slot attention update mechanism. + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim: int, + num_slots: int = 1, + iters: int = 3, + hidden_dim: Optional[int] = None, + mlp_ratio: float = 2.0, + qkv_bias: bool = True, + stochastic_init: bool = False, + pool_type: str = 'max', + eps: float = 1e-8, + norm_layer: Optional[Type[nn.Module]] = None, + act_layer: Optional[Type[nn.Module]] = nn.GELU, + device=None, + dtype=None, + ): + """ + Args: + dim: Input feature dimension. + num_slots: Number of slot vectors. For pooling, 1 is typical. + iters: Number of iterative refinement steps. + hidden_dim: Hidden dimension for slot MLP. Defaults to dim * mlp_ratio. + mlp_ratio: Ratio for hidden dim if hidden_dim not specified. + qkv_bias: If True, add bias to q, k, v projections. + stochastic_init: If True, initialize slots with learned mu + learned sigma * noise. + If False, slots are deterministically initialized from learned parameters. + pool_type: How to aggregate multiple slots - 'max', 'avg', or 'first'. + eps: Small constant for numerical stability in attention normalization. + norm_layer: Normalization layer class. + act_layer: Activation layer class for MLP. + """ + super().__init__() + dd = {'device': device, 'dtype': dtype} + self.num_slots = num_slots + self.iters = iters + self.eps = eps + self.scale = dim ** -0.5 + self.stochastic_init = stochastic_init + self.pool_type = pool_type + self.fused_attn = use_fused_attn() + + norm_layer = norm_layer or nn.LayerNorm + + # Slot initialization parameters + self.slots_mu = nn.Parameter(torch.zeros(1, 1, dim, **dd)) + self.slots_log_sigma = nn.Parameter(torch.zeros(1, 1, dim, **dd)) + + # Projections - separate q, k, v (no fused qkv since q comes from slots, kv from input) + self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd) + self.k = nn.Linear(dim, dim, bias=qkv_bias, **dd) + self.v = nn.Linear(dim, dim, bias=qkv_bias, **dd) + + # GRU for slot updates + self.gru = nn.GRUCell(dim, dim, **dd) + + # MLP after GRU update + hidden_dim = hidden_dim or int(dim * mlp_ratio) + self.norm_mlp = norm_layer(dim, **dd) + self.mlp = Mlp(dim, hidden_dim, act_layer=act_layer, **dd) + + # Input normalization + self.norm_input = norm_layer(dim, **dd) + self.norm_slots = norm_layer(dim, **dd) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.slots_mu) + if self.stochastic_init: + nn.init.xavier_uniform_(self.slots_log_sigma) + + def _init_slots(self, B: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + """Initialize slot vectors.""" + mu = self.slots_mu.expand(B, self.num_slots, -1) + if self.stochastic_init and self.training: + sigma = self.slots_log_sigma.exp().expand(B, self.num_slots, -1) + slots = mu + sigma * torch.randn_like(mu) + else: + # Deterministic: just use mu repeated for each slot + # Add small learned perturbation per slot to break symmetry + slots = mu + return slots + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (B, N, C) where N is sequence length. + + Returns: + Pooled output of shape (B, C). + """ + B, N, C = x.shape + device, dtype = x.device, x.dtype + + # Initialize slots + slots = self._init_slots(B, device, dtype) + + # Normalize input and compute k, v (constant across iterations) + x = self.norm_input(x) + k = self.k(x) # (B, N, C) + v = self.v(x) # (B, N, C) + + # Iterative refinement + for _ in range(self.iters): + slots_prev = slots + + # Normalize slots and compute queries + slots = self.norm_slots(slots) + q = self.q(slots) # (B, num_slots, C) + + # Compute attention: (B, num_slots, N) + # Note: we do NOT use F.sdpa here because we need softmax over slots (dim=1), + # not over keys (dim=-1) as standard attention does + attn = (q @ k.transpose(-2, -1)) * self.scale # (B, num_slots, N) + + # Softmax over SLOTS (not keys) - this is the key difference from standard attention + # Each input location decides which slot to route to + attn = attn.softmax(dim=1) # normalize over slots + attn = attn + self.eps + + # Weighted mean over slots (normalize so weights sum to 1 per slot) + attn = attn / attn.sum(dim=-1, keepdim=True) # (B, num_slots, N) + + # Aggregate values into slots + updates = attn @ v # (B, num_slots, C) + + # GRU update + slots = self.gru( + updates.reshape(B * self.num_slots, C), + slots_prev.reshape(B * self.num_slots, C), + ) + slots = slots.reshape(B, self.num_slots, C) + + # MLP residual + slots = slots + self.mlp(self.norm_mlp(slots)) + + # Pool slots to single vector + if self.pool_type == 'max': + out = slots.max(dim=1).values + elif self.pool_type == 'avg': + out = slots.mean(dim=1) + elif self.pool_type == 'first': + out = slots[:, 0] + else: + raise ValueError(f"Unknown pool_type: {self.pool_type}") + + return out + + +class SlotPool2d(nn.Module): + """Slot Attention pooling for 2D (NCHW) inputs. + + Convenience wrapper that handles NCHW -> NLC conversion. + """ + def __init__( + self, + dim: int, + num_slots: int = 1, + iters: int = 3, + hidden_dim: Optional[int] = None, + mlp_ratio: float = 2.0, + qkv_bias: bool = True, + stochastic_init: bool = False, + pool_type: str = 'max', + eps: float = 1e-8, + norm_layer: Optional[Type[nn.Module]] = None, + act_layer: Optional[Type[nn.Module]] = nn.GELU, + flatten: bool = True, + device=None, + dtype=None, + ): + """ + Args: + dim: Input feature dimension (channels). + num_slots: Number of slot vectors. + iters: Number of iterative refinement steps. + hidden_dim: Hidden dimension for slot MLP. + mlp_ratio: Ratio for hidden dim if hidden_dim not specified. + qkv_bias: If True, add bias to q, k, v projections. + stochastic_init: If True, use stochastic slot initialization during training. + pool_type: How to aggregate multiple slots - 'max', 'avg', or 'first'. + eps: Small constant for numerical stability. + norm_layer: Normalization layer class. + act_layer: Activation layer class for MLP. + flatten: If True, output shape is (B, C). If False, (B, 1, C). + """ + super().__init__() + self.flatten = flatten + self.pool = SlotPool( + dim=dim, + num_slots=num_slots, + iters=iters, + hidden_dim=hidden_dim, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + stochastic_init=stochastic_init, + pool_type=pool_type, + eps=eps, + norm_layer=norm_layer, + act_layer=act_layer, + device=device, + dtype=dtype, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Input tensor of shape (B, C, H, W). + + Returns: + Pooled output of shape (B, C) if flatten=True, else (B, 1, C). + """ + x = x.flatten(2).transpose(1, 2) # (B, H*W, C) + out = self.pool(x) # (B, C) + if not self.flatten: + out = out.unsqueeze(1) + return out From 480a2b29c62a0a584119fd1d8141698547466071 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 3 Dec 2025 19:13:37 -0800 Subject: [PATCH 3/4] Remove SlotPool, was expensive and hard to work with. Tweak flatten for simpool2d and lse --- tests/test_layers_pool.py | 105 +++------------- timm/layers/__init__.py | 1 - timm/layers/other_pool.py | 10 +- timm/layers/slot_pool.py | 254 -------------------------------------- 4 files changed, 17 insertions(+), 353 deletions(-) delete mode 100644 timm/layers/slot_pool.py diff --git a/tests/test_layers_pool.py b/tests/test_layers_pool.py index b8f5908454..a282d74340 100644 --- a/tests/test_layers_pool.py +++ b/tests/test_layers_pool.py @@ -148,14 +148,15 @@ def test_lse_plus_2d_basic(self): x = torch.randn(2, 64, 7, 7, device=torch_device) pool = LsePlus2d().to(torch_device) out = pool(x) - assert out.shape == (2, 64, 1, 1) + # Default is flatten=True + assert out.shape == (2, 64) - def test_lse_plus_2d_flatten(self): + def test_lse_plus_2d_no_flatten(self): from timm.layers import LsePlus2d x = torch.randn(2, 64, 7, 7, device=torch_device) - pool = LsePlus2d(flatten=True).to(torch_device) + pool = LsePlus2d(flatten=False).to(torch_device) out = pool(x) - assert out.shape == (2, 64) + assert out.shape == (2, 64, 1, 1) def test_lse_plus_1d_basic(self): from timm.layers import LsePlus1d @@ -169,7 +170,7 @@ def test_lse_high_r_approximates_max(self): x = torch.randn(2, 64, 7, 7, device=torch_device) pool = LsePlus2d(r=100.0, r_learnable=False).to(torch_device) out = pool(x) - out_max = x.amax(dim=(2, 3), keepdim=True) + out_max = x.amax(dim=(2, 3)) assert torch.allclose(out, out_max, atol=0.1) def test_lse_low_r_approximates_avg(self): @@ -177,7 +178,7 @@ def test_lse_low_r_approximates_avg(self): x = torch.randn(2, 64, 7, 7, device=torch_device) pool = LsePlus2d(r=0.01, r_learnable=False).to(torch_device) out = pool(x) - out_avg = x.mean(dim=(2, 3), keepdim=True) + out_avg = x.mean(dim=(2, 3)) assert torch.allclose(out, out_avg, atol=0.1) def test_lse_learnable_r_gradient(self): @@ -200,13 +201,6 @@ def test_simpool_2d_basic(self): x = torch.randn(2, 64, 7, 7, device=torch_device) pool = SimPool2d(dim=64).to(torch_device) out = pool(x) - assert out.shape == (2, 1, 64) - - def test_simpool_2d_flatten(self): - from timm.layers import SimPool2d - x = torch.randn(2, 64, 7, 7, device=torch_device) - pool = SimPool2d(dim=64, flatten=True).to(torch_device) - out = pool(x) assert out.shape == (2, 64) def test_simpool_1d_basic(self): @@ -220,14 +214,14 @@ def test_simpool_multi_head(self): from timm.layers import SimPool2d x = torch.randn(2, 64, 7, 7, device=torch_device) for num_heads in [1, 2, 4, 8]: - pool = SimPool2d(dim=64, num_heads=num_heads, flatten=True).to(torch_device) + pool = SimPool2d(dim=64, num_heads=num_heads).to(torch_device) out = pool(x) assert out.shape == (2, 64) def test_simpool_with_gamma(self): from timm.layers import SimPool2d x = torch.randn(2, 64, 7, 7, device=torch_device) - pool = SimPool2d(dim=64, gamma=2.0, flatten=True).to(torch_device) + pool = SimPool2d(dim=64, gamma=2.0).to(torch_device) out = pool(x) assert out.shape == (2, 64) assert not torch.isnan(out).any() @@ -235,74 +229,10 @@ def test_simpool_with_gamma(self): def test_simpool_qk_norm(self): from timm.layers import SimPool2d x = torch.randn(2, 64, 7, 7, device=torch_device) - pool = SimPool2d(dim=64, qk_norm=True, flatten=True).to(torch_device) - out = pool(x) - assert out.shape == (2, 64) - - -# Slot Pool Tests - -class TestSlotPool: - """Test Slot Attention pooling layers.""" - - def test_slot_pool_basic(self): - from timm.layers import SlotPool - x = torch.randn(2, 49, 64, device=torch_device) - pool = SlotPool(dim=64).to(torch_device) - out = pool(x) - assert out.shape == (2, 64) - - def test_slot_pool_2d_basic(self): - from timm.layers import SlotPool2d - x = torch.randn(2, 64, 7, 7, device=torch_device) - pool = SlotPool2d(dim=64).to(torch_device) + pool = SimPool2d(dim=64, qk_norm=True).to(torch_device) out = pool(x) assert out.shape == (2, 64) - def test_slot_pool_multi_slot(self): - from timm.layers import SlotPool - x = torch.randn(2, 49, 64, device=torch_device) - for num_slots in [1, 2, 4, 8]: - pool = SlotPool(dim=64, num_slots=num_slots).to(torch_device) - out = pool(x) - assert out.shape == (2, 64) - - def test_slot_pool_iterations(self): - from timm.layers import SlotPool - x = torch.randn(2, 49, 64, device=torch_device) - for iters in [1, 2, 3, 5]: - pool = SlotPool(dim=64, iters=iters).to(torch_device) - out = pool(x) - assert out.shape == (2, 64) - - def test_slot_pool_pool_types(self): - from timm.layers import SlotPool - x = torch.randn(2, 49, 64, device=torch_device) - for pool_type in ['max', 'avg', 'first']: - pool = SlotPool(dim=64, num_slots=4, pool_type=pool_type).to(torch_device) - out = pool(x) - assert out.shape == (2, 64) - - def test_slot_pool_stochastic_train_mode(self): - from timm.layers import SlotPool - x = torch.randn(2, 49, 64, device=torch_device) - pool = SlotPool(dim=64, stochastic_init=True).to(torch_device) - pool.train() - out1 = pool(x) - out2 = pool(x) - # Should differ in train mode with stochastic init - assert not torch.allclose(out1, out2) - - def test_slot_pool_stochastic_eval_mode(self): - from timm.layers import SlotPool - x = torch.randn(2, 49, 64, device=torch_device) - pool = SlotPool(dim=64, stochastic_init=True).to(torch_device) - pool.eval() - out1 = pool(x) - out2 = pool(x) - # Should be deterministic in eval mode - assert torch.allclose(out1, out2) - # Common Tests (Gradient, JIT, dtype) @@ -314,8 +244,6 @@ class TestPoolingCommon: ('LsePlus1d', {}, (2, 49, 64)), ('SimPool2d', {'dim': 64}, (2, 64, 7, 7)), ('SimPool1d', {'dim': 64}, (2, 49, 64)), - ('SlotPool', {'dim': 64}, (2, 49, 64)), - ('SlotPool2d', {'dim': 64}, (2, 64, 7, 7)), ('SelectAdaptivePool2d', {'pool_type': 'avg', 'flatten': True}, (2, 64, 7, 7)), ('AttentionPoolLatent', {'in_features': 64, 'num_heads': 4}, (2, 49, 64)), ('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)), @@ -336,8 +264,6 @@ def test_gradient_flow(self, pool_cls, kwargs, input_shape): ('LsePlus1d', {}, (2, 49, 64)), ('SimPool2d', {'dim': 64}, (2, 64, 7, 7)), ('SimPool1d', {'dim': 64}, (2, 49, 64)), - ('SlotPool', {'dim': 64, 'iters': 2}, (2, 49, 64)), - ('SlotPool2d', {'dim': 64, 'iters': 2}, (2, 64, 7, 7)), ('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)), ('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)), ]) @@ -352,12 +278,10 @@ def test_torchscript(self, pool_cls, kwargs, input_shape): assert torch.allclose(out_orig, out_script, atol=1e-5) @pytest.mark.parametrize('pool_cls,kwargs,input_shape', [ - ('LsePlus2d', {'flatten': True}, (2, 64, 7, 7)), + ('LsePlus2d', {}, (2, 64, 7, 7)), ('LsePlus1d', {}, (2, 49, 64)), - ('SimPool2d', {'dim': 64, 'flatten': True}, (2, 64, 7, 7)), + ('SimPool2d', {'dim': 64}, (2, 64, 7, 7)), ('SimPool1d', {'dim': 64}, (2, 49, 64)), - ('SlotPool', {'dim': 64}, (2, 49, 64)), - ('SlotPool2d', {'dim': 64}, (2, 64, 7, 7)), ('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)), ('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)), ]) @@ -372,9 +296,8 @@ def test_eval_deterministic(self, pool_cls, kwargs, input_shape): assert torch.allclose(out1, out2) @pytest.mark.parametrize('pool_cls,kwargs,input_shape', [ - ('LsePlus2d', {'flatten': True}, (2, 64, 7, 7)), - ('SimPool2d', {'dim': 64, 'flatten': True}, (2, 64, 7, 7)), - ('SlotPool2d', {'dim': 64}, (2, 64, 7, 7)), + ('LsePlus2d', {}, (2, 64, 7, 7)), + ('SimPool2d', {'dim': 64}, (2, 64, 7, 7)), ('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)), ]) def test_different_spatial_sizes(self, pool_cls, kwargs, input_shape): diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index e01abaadd3..f9148db665 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -109,7 +109,6 @@ from .pool1d import global_pool_nlc from .other_pool import LsePlus2d, LsePlus1d, SimPool2d, SimPool1d from .pool2d_same import AvgPool2dSame, create_pool2d -from .slot_pool import SlotPool, SlotPool2d from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc from .pos_embed_rel import ( RelPosMlp, diff --git a/timm/layers/other_pool.py b/timm/layers/other_pool.py index 716f4ca638..cced920fbf 100644 --- a/timm/layers/other_pool.py +++ b/timm/layers/other_pool.py @@ -37,7 +37,7 @@ def __init__( self, r: float = 10.0, r_learnable: bool = True, - flatten: bool = False, + flatten: bool = True, device=None, dtype=None, ): @@ -118,7 +118,6 @@ def __init__( qk_norm: bool = False, gamma: Optional[float] = None, norm_layer: Optional[Type[nn.Module]] = None, - flatten: bool = False, device=None, dtype=None, ): @@ -139,7 +138,6 @@ def __init__( self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.gamma = gamma - self.flatten = flatten self.fused_attn = use_fused_attn() norm_layer = norm_layer or nn.LayerNorm @@ -192,10 +190,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: attn = attn.softmax(dim=-1) out = attn @ v - # (B, num_heads, 1, head_dim) -> (B, C) or (B, 1, C) - out = out.transpose(1, 2).reshape(B, 1, C) - if self.flatten: - out = out.squeeze(1) + # (B, num_heads, 1, head_dim) -> (B, C) or (B, C) + out = out.transpose(1, 2).reshape(B, C) return out diff --git a/timm/layers/slot_pool.py b/timm/layers/slot_pool.py deleted file mode 100644 index ddb30c5fe5..0000000000 --- a/timm/layers/slot_pool.py +++ /dev/null @@ -1,254 +0,0 @@ -""" Slot Attention Pooling - -Slot Attention mechanism adapted for use as a pooling layer. - -Based on 'Object-Centric Learning with Slot Attention' by Locatello et al. -https://arxiv.org/abs/2006.15055 - -Original implementation reference: -https://github.com/lucidrains/slot-attention (MIT License) - -Adapted for timm by Ross Wightman, original PR code by Bill Psomas -""" -from typing import Optional, Type - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .config import use_fused_attn -from .mlp import Mlp - - -class SlotPool(nn.Module): - """Slot Attention pooling module. - - Unlike standard attention pooling, Slot Attention uses iterative refinement - with competition between slots. The softmax is applied over slots (not keys), - causing slots to compete for explaining input locations. - - This creates a soft clustering effect where each slot specializes to different - parts of the input, useful for object-centric representations. - - For standard pooling use cases, set num_slots=1 and iters=1 to get behavior - closer to AttentionPoolLatent but with the slot attention update mechanism. - """ - fused_attn: torch.jit.Final[bool] - - def __init__( - self, - dim: int, - num_slots: int = 1, - iters: int = 3, - hidden_dim: Optional[int] = None, - mlp_ratio: float = 2.0, - qkv_bias: bool = True, - stochastic_init: bool = False, - pool_type: str = 'max', - eps: float = 1e-8, - norm_layer: Optional[Type[nn.Module]] = None, - act_layer: Optional[Type[nn.Module]] = nn.GELU, - device=None, - dtype=None, - ): - """ - Args: - dim: Input feature dimension. - num_slots: Number of slot vectors. For pooling, 1 is typical. - iters: Number of iterative refinement steps. - hidden_dim: Hidden dimension for slot MLP. Defaults to dim * mlp_ratio. - mlp_ratio: Ratio for hidden dim if hidden_dim not specified. - qkv_bias: If True, add bias to q, k, v projections. - stochastic_init: If True, initialize slots with learned mu + learned sigma * noise. - If False, slots are deterministically initialized from learned parameters. - pool_type: How to aggregate multiple slots - 'max', 'avg', or 'first'. - eps: Small constant for numerical stability in attention normalization. - norm_layer: Normalization layer class. - act_layer: Activation layer class for MLP. - """ - super().__init__() - dd = {'device': device, 'dtype': dtype} - self.num_slots = num_slots - self.iters = iters - self.eps = eps - self.scale = dim ** -0.5 - self.stochastic_init = stochastic_init - self.pool_type = pool_type - self.fused_attn = use_fused_attn() - - norm_layer = norm_layer or nn.LayerNorm - - # Slot initialization parameters - self.slots_mu = nn.Parameter(torch.zeros(1, 1, dim, **dd)) - self.slots_log_sigma = nn.Parameter(torch.zeros(1, 1, dim, **dd)) - - # Projections - separate q, k, v (no fused qkv since q comes from slots, kv from input) - self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd) - self.k = nn.Linear(dim, dim, bias=qkv_bias, **dd) - self.v = nn.Linear(dim, dim, bias=qkv_bias, **dd) - - # GRU for slot updates - self.gru = nn.GRUCell(dim, dim, **dd) - - # MLP after GRU update - hidden_dim = hidden_dim or int(dim * mlp_ratio) - self.norm_mlp = norm_layer(dim, **dd) - self.mlp = Mlp(dim, hidden_dim, act_layer=act_layer, **dd) - - # Input normalization - self.norm_input = norm_layer(dim, **dd) - self.norm_slots = norm_layer(dim, **dd) - - self._init_weights() - - def _init_weights(self): - nn.init.xavier_uniform_(self.slots_mu) - if self.stochastic_init: - nn.init.xavier_uniform_(self.slots_log_sigma) - - def _init_slots(self, B: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: - """Initialize slot vectors.""" - mu = self.slots_mu.expand(B, self.num_slots, -1) - if self.stochastic_init and self.training: - sigma = self.slots_log_sigma.exp().expand(B, self.num_slots, -1) - slots = mu + sigma * torch.randn_like(mu) - else: - # Deterministic: just use mu repeated for each slot - # Add small learned perturbation per slot to break symmetry - slots = mu - return slots - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: Input tensor of shape (B, N, C) where N is sequence length. - - Returns: - Pooled output of shape (B, C). - """ - B, N, C = x.shape - device, dtype = x.device, x.dtype - - # Initialize slots - slots = self._init_slots(B, device, dtype) - - # Normalize input and compute k, v (constant across iterations) - x = self.norm_input(x) - k = self.k(x) # (B, N, C) - v = self.v(x) # (B, N, C) - - # Iterative refinement - for _ in range(self.iters): - slots_prev = slots - - # Normalize slots and compute queries - slots = self.norm_slots(slots) - q = self.q(slots) # (B, num_slots, C) - - # Compute attention: (B, num_slots, N) - # Note: we do NOT use F.sdpa here because we need softmax over slots (dim=1), - # not over keys (dim=-1) as standard attention does - attn = (q @ k.transpose(-2, -1)) * self.scale # (B, num_slots, N) - - # Softmax over SLOTS (not keys) - this is the key difference from standard attention - # Each input location decides which slot to route to - attn = attn.softmax(dim=1) # normalize over slots - attn = attn + self.eps - - # Weighted mean over slots (normalize so weights sum to 1 per slot) - attn = attn / attn.sum(dim=-1, keepdim=True) # (B, num_slots, N) - - # Aggregate values into slots - updates = attn @ v # (B, num_slots, C) - - # GRU update - slots = self.gru( - updates.reshape(B * self.num_slots, C), - slots_prev.reshape(B * self.num_slots, C), - ) - slots = slots.reshape(B, self.num_slots, C) - - # MLP residual - slots = slots + self.mlp(self.norm_mlp(slots)) - - # Pool slots to single vector - if self.pool_type == 'max': - out = slots.max(dim=1).values - elif self.pool_type == 'avg': - out = slots.mean(dim=1) - elif self.pool_type == 'first': - out = slots[:, 0] - else: - raise ValueError(f"Unknown pool_type: {self.pool_type}") - - return out - - -class SlotPool2d(nn.Module): - """Slot Attention pooling for 2D (NCHW) inputs. - - Convenience wrapper that handles NCHW -> NLC conversion. - """ - def __init__( - self, - dim: int, - num_slots: int = 1, - iters: int = 3, - hidden_dim: Optional[int] = None, - mlp_ratio: float = 2.0, - qkv_bias: bool = True, - stochastic_init: bool = False, - pool_type: str = 'max', - eps: float = 1e-8, - norm_layer: Optional[Type[nn.Module]] = None, - act_layer: Optional[Type[nn.Module]] = nn.GELU, - flatten: bool = True, - device=None, - dtype=None, - ): - """ - Args: - dim: Input feature dimension (channels). - num_slots: Number of slot vectors. - iters: Number of iterative refinement steps. - hidden_dim: Hidden dimension for slot MLP. - mlp_ratio: Ratio for hidden dim if hidden_dim not specified. - qkv_bias: If True, add bias to q, k, v projections. - stochastic_init: If True, use stochastic slot initialization during training. - pool_type: How to aggregate multiple slots - 'max', 'avg', or 'first'. - eps: Small constant for numerical stability. - norm_layer: Normalization layer class. - act_layer: Activation layer class for MLP. - flatten: If True, output shape is (B, C). If False, (B, 1, C). - """ - super().__init__() - self.flatten = flatten - self.pool = SlotPool( - dim=dim, - num_slots=num_slots, - iters=iters, - hidden_dim=hidden_dim, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - stochastic_init=stochastic_init, - pool_type=pool_type, - eps=eps, - norm_layer=norm_layer, - act_layer=act_layer, - device=device, - dtype=dtype, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: Input tensor of shape (B, C, H, W). - - Returns: - Pooled output of shape (B, C) if flatten=True, else (B, 1, C). - """ - x = x.flatten(2).transpose(1, 2) # (B, H*W, C) - out = self.pool(x) # (B, C) - if not self.flatten: - out = out.unsqueeze(1) - return out From 3f0be6853fcd35e7cfe5769a70f2937c1c9c7a3c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 4 Dec 2025 07:34:58 -0800 Subject: [PATCH 4/4] Add drop block support to ByobNet --- timm/models/byobnet.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 0e33aabcfd..75bbde3a2f 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -42,6 +42,7 @@ NormMlpClassifierHead, ConvNormAct, BatchNormAct2d, + DropBlock2d, EvoNorm2dS0a, AttentionPool2d, RotAttentionPool2d, @@ -1339,11 +1340,42 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo block_kwargs.update(override_kwargs(block_cfg.block_kwargs, model_cfg.block_kwargs)) +def drop_blocks( + drop_prob: float = 0., + block_size: int = 3, + num_stages: int = 4, +) -> List[Optional[partial]]: + """Create DropBlock layer partials for each stage. + + DropBlock is applied to the last two stages only, following common practice. + The block_size specifies the size for the final stage; the second-to-last + stage uses a larger block size scaled to account for 2x larger feature maps. + + Args: + drop_prob: Drop probability for DropBlock. + block_size: Block size for the final stage. Second-to-last stage + uses `block_size * 2 - 1` to scale with feature map size. + num_stages: Number of stages in the model. + + Returns: + List of DropBlock partial instances or None for each stage. + """ + assert num_stages >= 2 + dbs = [None] * num_stages + if drop_prob: + # Scale block size for second-to-last stage (2x larger feature maps) + dbs[-2] = partial(DropBlock2d, drop_prob=drop_prob, block_size=block_size * 2 - 1, gamma_scale=0.25) + dbs[-1] = partial(DropBlock2d, drop_prob=drop_prob, block_size=block_size, gamma_scale=1.00) + return dbs + + def create_byob_stages( cfg: ByoModelCfg, drop_path_rate: float, output_stride: int, stem_feat: Dict[str, Any], + drop_block_rate: float = 0., + drop_block_size: int = 3, feat_size: Optional[int] = None, layers: Optional[LayerFn] = None, block_kwargs_fn: Optional[Callable] = update_block_kwargs, @@ -1353,8 +1385,10 @@ def create_byob_stages( layers = layers or LayerFn() feature_info = [] block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks] + num_stages = len(block_cfgs) depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs] dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True) + dbs = drop_blocks(drop_block_rate, drop_block_size, num_stages) dilation = 1 net_stride = stem_feat['reduction'] prev_chs = stem_feat['num_chs'] @@ -1384,6 +1418,7 @@ def create_byob_stages( group_size=group_size, bottle_ratio=block_cfg.br, downsample=cfg.downsample, + drop_block=dbs[stage_idx], drop_path_rate=dpr[stage_idx][block_idx], layers=layers, device=device, @@ -1437,6 +1472,8 @@ def __init__( output_stride: int = 32, img_size: Optional[Union[int, Tuple[int, int]]] = None, drop_rate: float = 0., + drop_block_rate: float = 0., + drop_block_size: int = 3, drop_path_rate: float = 0., zero_init_last: bool = True, device=None, @@ -1452,6 +1489,8 @@ def __init__( output_stride: Output stride of network, one of (8, 16, 32). img_size: Image size for fixed image size models (i.e. self-attn). drop_rate: Classifier dropout rate. + drop_block_rate: DropBlock drop rate. + drop_block_size: DropBlock block size for final stage (scales up for earlier stages). drop_path_rate: Stochastic depth drop-path rate. zero_init_last: Zero-init last weight of residual path. **kwargs: Extra kwargs overlayed onto cfg. @@ -1490,6 +1529,8 @@ def __init__( drop_path_rate, output_stride, stem_feat[-1], + drop_block_rate=drop_block_rate, + drop_block_size=drop_block_size, layers=stage_layers, feat_size=feat_size, **dd,