Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python: ['3.10', '3.12']
torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.5.1', vision: '0.20.1'}]
python: ['3.10', '3.13']
torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.9.1', vision: '0.24.1'}]
testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward']
exclude:
- python: '3.12'
- python: '3.13'
torch: {base: '1.13.0', vision: '0.14.0'}
runs-on: ${{ matrix.os }}

Expand Down
358 changes: 358 additions & 0 deletions tests/test_layers_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,358 @@
"""Tests for timm pooling layers."""
import pytest
import torch
import torch.nn as nn

import importlib
import os

torch_backend = os.environ.get('TORCH_BACKEND')
if torch_backend is not None:
importlib.import_module(torch_backend)
torch_device = os.environ.get('TORCH_DEVICE', 'cpu')


# Adaptive Avg/Max Pooling Tests

class TestAdaptiveAvgMaxPool:
"""Test adaptive_avgmax_pool module."""

def test_adaptive_avgmax_pool2d(self):
from timm.layers import adaptive_avgmax_pool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
out = adaptive_avgmax_pool2d(x, 1)
assert out.shape == (2, 64, 1, 1)
# Should be average of avg and max
expected = 0.5 * (x.mean(dim=(2, 3), keepdim=True) + x.amax(dim=(2, 3), keepdim=True))
assert torch.allclose(out, expected)

def test_select_adaptive_pool2d(self):
from timm.layers import select_adaptive_pool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)

out_avg = select_adaptive_pool2d(x, pool_type='avg', output_size=1)
assert out_avg.shape == (2, 64, 1, 1)
assert torch.allclose(out_avg, x.mean(dim=(2, 3), keepdim=True))

out_max = select_adaptive_pool2d(x, pool_type='max', output_size=1)
assert out_max.shape == (2, 64, 1, 1)
assert torch.allclose(out_max, x.amax(dim=(2, 3), keepdim=True))

def test_adaptive_avgmax_pool2d_module(self):
from timm.layers import AdaptiveAvgMaxPool2d
x = torch.randn(2, 64, 14, 14, device=torch_device)
pool = AdaptiveAvgMaxPool2d(output_size=1).to(torch_device)
out = pool(x)
assert out.shape == (2, 64, 1, 1)

def test_select_adaptive_pool2d_module(self):
from timm.layers import SelectAdaptivePool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)

for pool_type in ['avg', 'max', 'avgmax', 'catavgmax']:
pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True).to(torch_device)
out = pool(x)
if pool_type == 'catavgmax':
assert out.shape == (2, 128) # concatenated
else:
assert out.shape == (2, 64)

def test_select_adaptive_pool2d_fast(self):
from timm.layers import SelectAdaptivePool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)

for pool_type in ['fast', 'fastavg', 'fastmax', 'fastavgmax', 'fastcatavgmax']:
pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True).to(torch_device)
out = pool(x)
if 'cat' in pool_type:
assert out.shape == (2, 128)
else:
assert out.shape == (2, 64)


# Attention Pool Tests

class TestAttentionPool:
"""Test attention-based pooling layers."""

def test_attention_pool_latent_basic(self):
from timm.layers import AttentionPoolLatent
x = torch.randn(2, 49, 64, device=torch_device)
pool = AttentionPoolLatent(in_features=64, num_heads=4).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)

def test_attention_pool_latent_multi_latent(self):
from timm.layers import AttentionPoolLatent
x = torch.randn(2, 49, 64, device=torch_device)
pool = AttentionPoolLatent(
in_features=64,
num_heads=4,
latent_len=4,
pool_type='avg',
).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)

def test_attention_pool2d_basic(self):
from timm.layers import AttentionPool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = AttentionPool2d(in_features=64, feat_size=7).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)

def test_attention_pool2d_different_feat_size(self):
from timm.layers import AttentionPool2d
# Test with different spatial sizes (requires pos_embed interpolation)
pool = AttentionPool2d(in_features=64, feat_size=7).to(torch_device)
for size in [7, 14]:
x = torch.randn(2, 64, size, size, device=torch_device)
out = pool(x)
assert out.shape == (2, 64)

def test_rot_attention_pool2d_basic(self):
from timm.layers import RotAttentionPool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = RotAttentionPool2d(in_features=64, ref_feat_size=7).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)

def test_rot_attention_pool2d_different_sizes(self):
from timm.layers import RotAttentionPool2d
pool = RotAttentionPool2d(in_features=64, ref_feat_size=7).to(torch_device)
for size in [7, 14, 10]:
x = torch.randn(2, 64, size, size, device=torch_device)
out = pool(x)
assert out.shape == (2, 64)

def test_rot_attention_pool2d_rope_types(self):
from timm.layers import RotAttentionPool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
for rope_type in ['base', 'cat', 'dinov3']:
pool = RotAttentionPool2d(
in_features=64,
ref_feat_size=7,
rope_type=rope_type,
).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)


# LSE Pool Tests

class TestLsePool:
"""Test LogSumExp pooling layers."""

def test_lse_plus_2d_basic(self):
from timm.layers import LsePlus2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = LsePlus2d().to(torch_device)
out = pool(x)
# Default is flatten=True
assert out.shape == (2, 64)

def test_lse_plus_2d_no_flatten(self):
from timm.layers import LsePlus2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = LsePlus2d(flatten=False).to(torch_device)
out = pool(x)
assert out.shape == (2, 64, 1, 1)

def test_lse_plus_1d_basic(self):
from timm.layers import LsePlus1d
x = torch.randn(2, 49, 64, device=torch_device)
pool = LsePlus1d().to(torch_device)
out = pool(x)
assert out.shape == (2, 64)

def test_lse_high_r_approximates_max(self):
from timm.layers import LsePlus2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = LsePlus2d(r=100.0, r_learnable=False).to(torch_device)
out = pool(x)
out_max = x.amax(dim=(2, 3))
assert torch.allclose(out, out_max, atol=0.1)

def test_lse_low_r_approximates_avg(self):
from timm.layers import LsePlus2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = LsePlus2d(r=0.01, r_learnable=False).to(torch_device)
out = pool(x)
out_avg = x.mean(dim=(2, 3))
assert torch.allclose(out, out_avg, atol=0.1)

def test_lse_learnable_r_gradient(self):
from timm.layers import LsePlus2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = LsePlus2d(r=10.0, r_learnable=True).to(torch_device)
out = pool(x).sum()
out.backward()
assert pool.r.grad is not None
assert pool.r.grad.abs() > 0


# SimPool Tests

class TestSimPool:
"""Test SimPool attention-based pooling layers."""

def test_simpool_2d_basic(self):
from timm.layers import SimPool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = SimPool2d(dim=64).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)

def test_simpool_1d_basic(self):
from timm.layers import SimPool1d
x = torch.randn(2, 49, 64, device=torch_device)
pool = SimPool1d(dim=64).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)

def test_simpool_multi_head(self):
from timm.layers import SimPool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
for num_heads in [1, 2, 4, 8]:
pool = SimPool2d(dim=64, num_heads=num_heads).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)

def test_simpool_with_gamma(self):
from timm.layers import SimPool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = SimPool2d(dim=64, gamma=2.0).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)
assert not torch.isnan(out).any()

def test_simpool_qk_norm(self):
from timm.layers import SimPool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = SimPool2d(dim=64, qk_norm=True).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)


# Common Tests (Gradient, JIT, dtype)

class TestPoolingCommon:
"""Common tests across all pooling layers."""

@pytest.mark.parametrize('pool_cls,kwargs,input_shape', [
('LsePlus2d', {}, (2, 64, 7, 7)),
('LsePlus1d', {}, (2, 49, 64)),
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
('SimPool1d', {'dim': 64}, (2, 49, 64)),
('SelectAdaptivePool2d', {'pool_type': 'avg', 'flatten': True}, (2, 64, 7, 7)),
('AttentionPoolLatent', {'in_features': 64, 'num_heads': 4}, (2, 49, 64)),
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
])
def test_gradient_flow(self, pool_cls, kwargs, input_shape):
import timm.layers as layers
x = torch.randn(*input_shape, device=torch_device, requires_grad=True)
pool = getattr(layers, pool_cls)(**kwargs).to(torch_device)
out = pool(x)
loss = out.sum()
loss.backward()
assert x.grad is not None
assert x.grad.abs().sum() > 0

@pytest.mark.parametrize('pool_cls,kwargs,input_shape', [
('LsePlus2d', {}, (2, 64, 7, 7)),
('LsePlus1d', {}, (2, 49, 64)),
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
('SimPool1d', {'dim': 64}, (2, 49, 64)),
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
])
def test_torchscript(self, pool_cls, kwargs, input_shape):
import timm.layers as layers
x = torch.randn(*input_shape, device=torch_device)
pool = getattr(layers, pool_cls)(**kwargs).to(torch_device)
pool.eval()
scripted = torch.jit.script(pool)
out_orig = pool(x)
out_script = scripted(x)
assert torch.allclose(out_orig, out_script, atol=1e-5)

@pytest.mark.parametrize('pool_cls,kwargs,input_shape', [
('LsePlus2d', {}, (2, 64, 7, 7)),
('LsePlus1d', {}, (2, 49, 64)),
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
('SimPool1d', {'dim': 64}, (2, 49, 64)),
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
])
def test_eval_deterministic(self, pool_cls, kwargs, input_shape):
import timm.layers as layers
x = torch.randn(*input_shape, device=torch_device)
pool = getattr(layers, pool_cls)(**kwargs).to(torch_device)
pool.eval()
with torch.no_grad():
out1 = pool(x)
out2 = pool(x)
assert torch.allclose(out1, out2)

@pytest.mark.parametrize('pool_cls,kwargs,input_shape', [
('LsePlus2d', {}, (2, 64, 7, 7)),
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
])
def test_different_spatial_sizes(self, pool_cls, kwargs, input_shape):
import timm.layers as layers
B, C, _, _ = input_shape
pool = getattr(layers, pool_cls)(**kwargs).to(torch_device)
for H, W in [(7, 7), (14, 14), (1, 1), (3, 5)]:
x = torch.randn(B, C, H, W, device=torch_device)
out = pool(x)
assert out.shape[0] == B
assert out.shape[-1] == C


# BlurPool Tests

class TestBlurPool:
"""Test BlurPool anti-aliasing layer."""

def test_blur_pool_2d_basic(self):
from timm.layers import BlurPool2d
x = torch.randn(2, 64, 14, 14, device=torch_device)
pool = BlurPool2d(channels=64).to(torch_device)
out = pool(x)
assert out.shape == (2, 64, 7, 7)

def test_blur_pool_2d_stride(self):
from timm.layers import BlurPool2d
x = torch.randn(2, 64, 28, 28, device=torch_device)
pool = BlurPool2d(channels=64, stride=4).to(torch_device)
out = pool(x)
assert out.shape == (2, 64, 8, 8)


# Pool1d Tests

class TestPool1d:
"""Test 1D pooling utilities."""

def test_global_pool_nlc(self):
from timm.layers import global_pool_nlc
x = torch.randn(2, 49, 64, device=torch_device)

# By default, avg/max excludes first token (num_prefix_tokens=1)
out_avg = global_pool_nlc(x, pool_type='avg')
assert out_avg.shape == (2, 64)
assert torch.allclose(out_avg, x[:, 1:].mean(dim=1))

out_max = global_pool_nlc(x, pool_type='max')
assert out_max.shape == (2, 64)
assert torch.allclose(out_max, x[:, 1:].amax(dim=1))

out_first = global_pool_nlc(x, pool_type='token')
assert out_first.shape == (2, 64)
assert torch.allclose(out_first, x[:, 0])

# Test with reduce_include_prefix=True
out_avg_all = global_pool_nlc(x, pool_type='avg', reduce_include_prefix=True)
assert torch.allclose(out_avg_all, x.mean(dim=1))
3 changes: 2 additions & 1 deletion timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .attention import Attention, AttentionRope, maybe_add_mask
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
from .attention_pool import AttentionPoolLatent
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d
from .blur_pool import BlurPool2d, create_aa
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
from .cond_conv2d import CondConv2d, get_condconv_initializer
Expand Down Expand Up @@ -107,6 +107,7 @@
from .patch_dropout import PatchDropout, PatchDropoutWithIndices, patch_dropout_forward
from .patch_embed import PatchEmbed, PatchEmbedWithSize, PatchEmbedInterpolator, resample_patch_embed
from .pool1d import global_pool_nlc
from .other_pool import LsePlus2d, LsePlus1d, SimPool2d, SimPool1d
from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
from .pos_embed_rel import (
Expand Down
Loading