From 20aea82a310efbb6613055e9f15cd1a05ad1fa63 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 26 Nov 2025 10:48:25 -0800 Subject: [PATCH] Add coord attn and some variants that I had lying around from old experiments. --- timm/layers/coord_attn.py | 341 +++++++++++++++++++++++++++++++++++++ timm/layers/create_attn.py | 9 + 2 files changed, 350 insertions(+) create mode 100644 timm/layers/coord_attn.py diff --git a/timm/layers/coord_attn.py b/timm/layers/coord_attn.py new file mode 100644 index 0000000000..1bd149fa7b --- /dev/null +++ b/timm/layers/coord_attn.py @@ -0,0 +1,341 @@ +""" Coordinate Attention and Variants + +Coordinate Attention decomposes channel attention into two 1D feature encoding processes +to capture long-range dependencies with precise positional information. This module includes +the original implementation along with simplified and other variants. + +Papers / References: +- Coordinate Attention: `Coordinate Attention for Efficient Mobile Network Design` - https://arxiv.org/abs/2103.02907 +- Efficient Local Attention: `Rethinking Local Perception in Lightweight Vision Transformer` - https://arxiv.org/abs/2403.01123 + +Hacked together by / Copyright 2025 Ross Wightman +""" +from typing import Optional, Type, Union + +import torch +from torch import nn + +from .create_act import create_act_layer +from .helpers import make_divisible +from .norm import GroupNorm1 + + +class CoordAttn(nn.Module): + def __init__( + self, + channels: int, + rd_ratio: float = 1. / 16, + rd_channels: Optional[int] = None, + rd_divisor: int = 8, + se_factor: float = 2/3, + bias: bool = False, + act_layer: Type[nn.Module] = nn.Hardswish, + norm_layer: Optional[Type[nn.Module]] = nn.BatchNorm2d, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + has_skip: bool = False, + device=None, + dtype=None, + ): + """Coordinate Attention module for spatial feature recalibration. + + Introduced in "Coordinate Attention for Efficient Mobile Network Design" (CVPR 2021). + Decomposes channel attention into two 1D feature encoding processes along the height and + width axes to capture long-range dependencies with precise positional information. + + Args: + channels: Number of input channels. + rd_ratio: Reduction ratio for bottleneck channel calculation. + rd_channels: Explicit number of bottleneck channels, overrides rd_ratio if set. + rd_divisor: Divisor for making bottleneck channels divisible. + se_factor: Applied to rd_ratio for final channel count (keeps params similar to SE). + bias: Whether to use bias in convolution layers. + act_layer: Activation module class for bottleneck. + norm_layer: Normalization module class, None for no normalization. + gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class. + has_skip: Whether to add residual skip connection to output. + device: Device to place tensors on. + dtype: Data type for tensors. + """ + + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.has_skip = has_skip + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio * se_factor, rd_divisor, round_limit=0.) + + self.conv1 = nn.Conv2d(channels, rd_channels, kernel_size=1, stride=1, padding=0, bias=bias, **dd) + self.bn1 = norm_layer(rd_channels, **dd) if norm_layer is not None else nn.Identity() + self.act = act_layer() + + self.conv_h = nn.Conv2d(rd_channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, **dd) + self.conv_w = nn.Conv2d(rd_channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, **dd) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + identity = x + + N, C, H, W = x.size() + + # Strip pooling + x_h = x.mean(3, keepdim=True) + x_w = x.mean(2, keepdim=True) + + x_w = x_w.transpose(-1, -2) + y = torch.cat([x_h, x_w], dim=2) + y = self.conv1(y) + y = self.bn1(y) + y = self.act(y) + x_h, x_w = torch.split(y, [H, W], dim=2) + x_w = x_w.transpose(-1, -2) + + a_h = self.gate(self.conv_h(x_h)) + a_w = self.gate(self.conv_w(x_w)) + + out = identity * a_w * a_h + if self.has_skip: + out = out + identity + + return out + + +class SimpleCoordAttn(nn.Module): + """Simplified Coordinate Attention variant. + + Uses + * linear layers instead of convolutions + * no norm + * additive pre-gating re-combination + for reduced complexity while maintaining the core coordinate attention mechanism + of separate height and width attention. + """ + + def __init__( + self, + channels: int, + rd_ratio: float = 0.25, + rd_channels: Optional[int] = None, + rd_divisor: int = 8, + se_factor: float = 2 / 3, + bias: bool = True, + act_layer: Type[nn.Module] = nn.SiLU, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + has_skip: bool = False, + device=None, + dtype=None, + ): + """ + Args: + channels: Number of input channels. + rd_ratio: Reduction ratio for bottleneck channel calculation. + rd_channels: Explicit number of bottleneck channels, overrides rd_ratio if set. + rd_divisor: Divisor for making bottleneck channels divisible. + se_factor: Applied to rd_ratio for final channel count (keeps param similar to SE) + bias: Whether to use bias in linear layers. + act_layer: Activation module class for bottleneck. + gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class. + has_skip: Whether to add residual skip connection to output. + device: Device to place tensors on. + dtype: Data type for tensors. + """ + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.has_skip = has_skip + + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio * se_factor, rd_divisor, round_limit=0.) + + self.fc1 = nn.Linear(channels, rd_channels, bias=bias, **dd) + self.act = act_layer() + self.fc_h = nn.Linear(rd_channels, channels, bias=bias, **dd) + self.fc_w = nn.Linear(rd_channels, channels, bias=bias, **dd) + + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + identity = x + + # Strip pooling + x_h = x.mean(dim=3) # (N, C, H) + x_w = x.mean(dim=2) # (N, C, W) + + # Shared bottleneck projection + x_h = self.act(self.fc1(x_h.transpose(1, 2))) # (N, H, rd_c) + x_w = self.act(self.fc1(x_w.transpose(1, 2))) # (N, W, rd_c) + + # Separate attention heads + a_h = self.fc_h(x_h).transpose(1, 2).unsqueeze(-1) # (N, C, H, 1) + a_w = self.fc_w(x_w).transpose(1, 2).unsqueeze(-2) # (N, C, 1, W) + + out = identity * self.gate(a_h + a_w) + if self.has_skip: + out = out + identity + + return out + + +class EfficientLocalAttn(nn.Module): + """Efficient Local Attention. + + Lightweight alternative to Coordinate Attention that preserves spatial + information without channel reduction. Uses 1D depthwise convolutions + and GroupNorm for better generalization. + + Paper: https://arxiv.org/abs/2403.01123 + """ + + def __init__( + self, + channels: int, + kernel_size: int = 7, + bias: bool = False, + act_layer: Type[nn.Module] = nn.SiLU, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + norm_layer: Optional[Type[nn.Module]] = GroupNorm1, + has_skip: bool = False, + device=None, + dtype=None, + ): + """ + Args: + channels: Number of input channels. + kernel_size: Kernel size for 1D depthwise convolutions. + bias: Whether to use bias in convolution layers. + act_layer: Activation module class applied after normalization. + gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class. + norm_layer: Normalization module class, None for no normalization. + has_skip: Whether to add residual skip connection to output. + device: Device to place tensors on. + dtype: Data type for tensors. + """ + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.has_skip = has_skip + + self.conv_h = nn.Conv2d( + channels, channels, + kernel_size=(kernel_size, 1), + stride=1, + padding=(kernel_size // 2, 0), + groups=channels, + bias=bias, + **dd + ) + self.conv_w = nn.Conv2d( + channels, channels, + kernel_size=(1, kernel_size), + stride=1, + padding=(0, kernel_size // 2), + groups=channels, + bias=bias, + **dd + ) + if norm_layer is not None: + self.norm_h = norm_layer(channels, **dd) + self.norm_w = norm_layer(channels, **dd) + else: + self.norm_h = nn.Identity() + self.norm_w = nn.Identity() + self.act = act_layer() + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + identity = x + + # Strip pooling: (N, C, H, W) -> (N, C, H) and (N, C, W) + x_h = x.mean(dim=3, keepdim=True) + x_w = x.mean(dim=2, keepdim=True) + + # 1D conv + norm + act + x_h = self.act(self.norm_h(self.conv_h(x_h))) # (N, C, H, 1) + x_w = self.act(self.norm_w(self.conv_w(x_w))) # (N, C, 1, W) + + # Generate attention maps + a_h = self.gate(x_h) # (N, C, H, 1) + a_w = self.gate(x_w) # (N, C, 1, W) + + out = identity * a_h * a_w + if self.has_skip: + out = out + identity + + return out + + +class StripAttn(nn.Module): + """Minimal Strip Attention. + + Lightweight spatial attention using strip pooling with optional learned refinement. + """ + + def __init__( + self, + channels: int, + use_conv: bool = True, + kernel_size: int = 3, + bias: bool = False, + gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', + has_skip: bool = False, + device=None, + dtype=None, + **_, + ): + """ + Args: + channels: Number of input channels. + use_conv: Whether to apply depthwise convolutions for learned spatial refinement. + kernel_size: Kernel size for 1D depthwise convolutions when use_conv is True. + bias: Whether to use bias in convolution layers. + gate_layer: Gate activation, either 'sigmoid', 'hardsigmoid', or a module class. + has_skip: Whether to add residual skip connection to output. + device: Device to place tensors on. + dtype: Data type for tensors. + """ + dd = {'device': device, 'dtype': dtype} + super().__init__() + self.has_skip = has_skip + self.use_conv = use_conv + + if use_conv: + self.conv_h = nn.Conv2d( + channels, channels, + kernel_size=(kernel_size, 1), + stride=1, + padding=(kernel_size // 2, 0), + groups=channels, + bias=bias, + **dd + ) + self.conv_w = nn.Conv2d( + channels, channels, + kernel_size=(1, kernel_size), + stride=1, + padding=(0, kernel_size // 2), + groups=channels, + bias=bias, + **dd + ) + else: + self.conv_h = nn.Identity() + self.conv_w = nn.Identity() + + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + identity = x + + # Strip pooling + x_h = x.mean(dim=3, keepdim=True) # (N, C, H, 1) + x_w = x.mean(dim=2, keepdim=True) # (N, C, 1, W) + + # Optional learned refinement + x_h = self.conv_h(x_h) + x_w = self.conv_w(x_w) + + # Combine and gate + a_hw = self.gate(x_h + x_w) # broadcasts to (N, C, H, W) + + out = identity * a_hw + if self.has_skip: + out = out + identity + + return out + diff --git a/timm/layers/create_attn.py b/timm/layers/create_attn.py index cc7e91ea9a..b1cbb36664 100644 --- a/timm/layers/create_attn.py +++ b/timm/layers/create_attn.py @@ -7,6 +7,7 @@ from .bottleneck_attn import BottleneckAttn from .cbam import CbamModule, LightCbamModule +from .coord_attn import CoordAttn, EfficientLocalAttn, StripAttn, SimpleCoordAttn from .eca import EcaModule, CecaModule from .gather_excite import GatherExcite from .global_context import GlobalContext @@ -47,6 +48,14 @@ def get_attn(attn_type): module_cls = CbamModule elif attn_type == 'lcbam': module_cls = LightCbamModule + elif attn_type == 'coord': + module_cls = CoordAttn + elif attn_type == 'scoord': + module_cls = SimpleCoordAttn + elif attn_type == 'ela': + module_cls = EfficientLocalAttn + elif attn_type == 'strip': + module_cls = StripAttn # Attention / attention-like modules w/ significant params # Typically replace some of the existing workhorse convs in a network architecture.