Skip to content

Commit bc1085b

Browse files
committed
Implement diff attention in own layer file. Add support to vision_transformer.py and naflexivt.py ViTs
1 parent 3091df4 commit bc1085b

File tree

5 files changed

+255
-13
lines changed

5 files changed

+255
-13
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .create_conv2d import create_conv2d
4343
from .create_norm import get_norm_layer, create_norm_layer
4444
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
45+
from .diff_attention import DiffAttention
4546
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path, calculate_drop_path_rates
4647
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
4748
from .evo_norm import (

timm/layers/diff_attention.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
"""Differential Attention
2+
3+
Paper: 'Differential Transformer' - https://arxiv.org/abs/2410.05258
4+
5+
Reference impl: https://github.com/microsoft/unilm/tree/master/Diff-Transformer
6+
7+
Hacked together by / Copyright 2024, Ross Wightman
8+
"""
9+
import math
10+
from typing import Optional, Type
11+
12+
import torch
13+
import torch.nn as nn
14+
import torch.nn.functional as F
15+
16+
from .attention import maybe_add_mask
17+
from .config import use_fused_attn
18+
from .norm import RmsNorm
19+
20+
21+
class DiffAttention(nn.Module):
22+
"""Differential Attention module.
23+
24+
Computes attention as the difference between two softmax attention maps, which helps
25+
cancel out noise and promotes sparse attention patterns. The module splits Q and K
26+
into two groups, computes separate attention maps, and subtracts one from the other
27+
scaled by a learnable lambda parameter.
28+
29+
The attention output is computed as:
30+
Attn = softmax(Q1 @ K1^T) - lambda * softmax(Q2 @ K2^T)
31+
Output = Attn @ V
32+
33+
Supports both fused (scaled_dot_product_attention) and manual implementations.
34+
"""
35+
fused_attn: torch.jit.Final[bool]
36+
37+
def __init__(
38+
self,
39+
dim: int,
40+
num_heads: int = 8,
41+
qkv_bias: bool = False,
42+
qk_norm: bool = False,
43+
scale_norm: bool = False,
44+
proj_bias: bool = True,
45+
attn_drop: float = 0.,
46+
proj_drop: float = 0.,
47+
norm_layer: Optional[Type[nn.Module]] = None,
48+
depth: int = 0,
49+
dual_lambda: bool = False,
50+
device=None,
51+
dtype=None,
52+
) -> None:
53+
"""Initialize the DiffAttention module.
54+
55+
Args:
56+
dim: Input dimension of the token embeddings.
57+
num_heads: Number of attention heads.
58+
qkv_bias: Whether to use bias in the query, key, value projections.
59+
qk_norm: Whether to apply normalization to query and key vectors.
60+
scale_norm: Whether to apply normalization before the output projection.
61+
proj_bias: Whether to use bias in the output projection.
62+
attn_drop: Dropout rate applied to the attention weights.
63+
proj_drop: Dropout rate applied after the output projection.
64+
norm_layer: Normalization layer constructor (defaults to RmsNorm).
65+
depth: Block depth index, used to compute depth-dependent lambda_init.
66+
dual_lambda: If True, use simplified dual scalar lambda parameterization
67+
(2 params). If False, use the paper's original formulation with
68+
lambda_q/k vectors (4 * head_dim params).
69+
"""
70+
super().__init__()
71+
dd = {'device': device, 'dtype': dtype}
72+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
73+
if norm_layer is None:
74+
norm_layer = RmsNorm
75+
self.num_heads = num_heads
76+
self.head_dim = dim // num_heads // 2
77+
self.scale = self.head_dim ** -0.5
78+
self.fused_attn = use_fused_attn()
79+
80+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
81+
self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
82+
self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
83+
self.attn_drop = nn.Dropout(attn_drop)
84+
self.attn_drop_p = attn_drop
85+
self.norm = norm_layer(dim, **dd) if scale_norm else nn.Identity()
86+
self.proj = nn.Linear(dim, dim, bias=proj_bias, **dd)
87+
self.proj_drop = nn.Dropout(proj_drop)
88+
89+
self.dual_lambda = dual_lambda
90+
if dual_lambda:
91+
self.lambda_a = nn.Parameter(torch.empty((), dtype=torch.float32, device=device))
92+
self.lambda_b = nn.Parameter(torch.empty((), dtype=torch.float32, device=device))
93+
self.lambda_q1 = self.lambda_k1 = self.lambda_q2 = self.lambda_k2 = None
94+
else:
95+
self.lambda_a = self.lambda_b = None
96+
self.lambda_q1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
97+
self.lambda_k1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
98+
self.lambda_q2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
99+
self.lambda_k2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
100+
101+
self.sub_norm = RmsNorm(2 * self.head_dim, eps=1e-5, **dd)
102+
103+
self.lambda_init = 0.8
104+
self.set_lambda_init(depth)
105+
self.reset_parameters()
106+
107+
def set_lambda_init(self, depth: int):
108+
self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth)
109+
110+
def reset_parameters(self):
111+
if self.dual_lambda:
112+
nn.init.zeros_(self.lambda_a)
113+
nn.init.zeros_(self.lambda_b)
114+
else:
115+
nn.init.normal_(self.lambda_q1, mean=0, std=0.1)
116+
nn.init.normal_(self.lambda_k1, mean=0, std=0.1)
117+
nn.init.normal_(self.lambda_q2, mean=0, std=0.1)
118+
nn.init.normal_(self.lambda_k2, mean=0, std=0.1)
119+
120+
def _compute_lambda(self) -> torch.Tensor:
121+
if self.lambda_a is not None:
122+
lambda_1 = torch.exp(self.lambda_a)
123+
lambda_2 = torch.exp(self.lambda_b)
124+
else:
125+
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float())
126+
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float())
127+
return lambda_1 - lambda_2 + self.lambda_init
128+
129+
def forward(
130+
self,
131+
x: torch.Tensor,
132+
attn_mask: Optional[torch.Tensor] = None,
133+
) -> torch.Tensor:
134+
B, N, C = x.shape
135+
136+
q, k, v = self.qkv(x).chunk(3, dim=2)
137+
q = q.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
138+
k = k.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
139+
v = v.reshape(B, N, self.num_heads, 2 * self.head_dim).transpose(1, 2)
140+
141+
q, k = self.q_norm(q), self.k_norm(k)
142+
143+
lambda_full = self._compute_lambda().type_as(q)
144+
145+
if self.fused_attn:
146+
q = q.reshape(B, self.num_heads, 2, N, self.head_dim)
147+
k = k.reshape(B, self.num_heads, 2, N, self.head_dim)
148+
q1, q2 = q.unbind(2)
149+
k1, k2 = k.unbind(2)
150+
151+
dropout_p = self.attn_drop_p if self.training else 0.0
152+
attn1 = F.scaled_dot_product_attention(q1, k1, v, attn_mask=attn_mask, dropout_p=dropout_p)
153+
attn2 = F.scaled_dot_product_attention(q2, k2, v, attn_mask=attn_mask, dropout_p=dropout_p)
154+
155+
x = attn1 - lambda_full * attn2
156+
else:
157+
q = q * self.scale
158+
attn = q @ k.transpose(-2, -1)
159+
attn = maybe_add_mask(attn, attn_mask)
160+
attn = attn.softmax(dim=-1)
161+
attn = self.attn_drop(attn)
162+
163+
attn = attn.view(B, self.num_heads, 2, N, N)
164+
attn = attn[:, :, 0] - lambda_full * attn[:, :, 1]
165+
x = attn @ v
166+
167+
x = self.sub_norm(x)
168+
x = x * (1 - self.lambda_init)
169+
x = x.transpose(1, 2).reshape(B, N, C)
170+
171+
x = self.norm(x)
172+
x = self.proj(x)
173+
x = self.proj_drop(x)
174+
175+
return x

timm/models/eva.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def __init__(
389389
attn_head_dim: Optional[int] = None,
390390
device=None,
391391
dtype=None,
392+
**kwargs,
392393
):
393394
""" Initialize the post-norm EVA transformer block.
394395

timm/models/naflexvit.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ class NaFlexVitCfg:
132132
act_layer: Optional[str] = None # Activation layer for MLP blocks
133133
block_fn: Optional[str] = None # Transformer block implementation class name
134134
mlp_layer: Optional[str] = None # MLP implementation class name
135+
attn_layer: Optional[str] = None # Attention layer implementation (e.g., 'attn', 'diff')
135136

136137
# EVA-specific parameters
137138
attn_type: str = 'standard' # Attention type: 'standard', 'eva', 'rope'
@@ -289,13 +290,15 @@ def get_block_fn(cfg: NaFlexVitCfg) -> Callable:
289290
else:
290291
# Standard ViT block
291292
block_fn = cfg.block_fn or Block
293+
block_kwargs = {}
292294
if cfg.scale_mlp_norm or cfg.scale_attn_inner_norm:
293295
# param names differ between EVA vs non-EVA block types
294-
block_fn = partial(
295-
block_fn,
296-
scale_mlp_norm=cfg.scale_mlp_norm,
297-
scale_attn_norm=cfg.scale_attn_inner_norm
298-
)
296+
block_kwargs['scale_mlp_norm'] = cfg.scale_mlp_norm
297+
block_kwargs['scale_attn_norm'] = cfg.scale_attn_inner_norm
298+
if cfg.attn_layer:
299+
block_kwargs['attn_layer'] = cfg.attn_layer
300+
if block_kwargs:
301+
block_fn = partial(block_fn, **block_kwargs)
299302
return block_fn
300303

301304

@@ -1214,6 +1217,7 @@ def __init__(
12141217
norm_layer=norm_layer,
12151218
act_layer=act_layer,
12161219
mlp_layer=mlp_layer,
1220+
depth=i,
12171221
**dd,
12181222
)
12191223
for i in range(cfg.depth)

timm/models/vision_transformer.py

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
)
4848
from timm.layers import (
4949
Attention,
50+
DiffAttention,
5051
AttentionPoolLatent,
5152
PatchEmbed,
5253
Mlp,
@@ -79,6 +80,49 @@
7980
_logger = logging.getLogger(__name__)
8081

8182

83+
ATTN_LAYERS = {
84+
'': Attention,
85+
'attn': Attention,
86+
'diff': DiffAttention,
87+
}
88+
89+
90+
def _create_attn(
91+
attn_layer: LayerType,
92+
dim: int,
93+
num_heads: int,
94+
qkv_bias: bool = False,
95+
qk_norm: bool = False,
96+
scale_norm: bool = False,
97+
proj_bias: bool = True,
98+
attn_drop: float = 0.,
99+
proj_drop: float = 0.,
100+
norm_layer: Optional[Type[nn.Module]] = None,
101+
depth: int = 0,
102+
**kwargs,
103+
) -> nn.Module:
104+
if isinstance(attn_layer, str):
105+
attn_layer = ATTN_LAYERS.get(attn_layer, None)
106+
assert attn_layer is not None, f'Unknown attn_layer: {attn_layer}'
107+
108+
# Only pass depth to attention layers that use it
109+
if issubclass(attn_layer, DiffAttention):
110+
kwargs['depth'] = depth
111+
112+
return attn_layer(
113+
dim,
114+
num_heads=num_heads,
115+
qkv_bias=qkv_bias,
116+
qk_norm=qk_norm,
117+
scale_norm=scale_norm,
118+
proj_bias=proj_bias,
119+
attn_drop=attn_drop,
120+
proj_drop=proj_drop,
121+
norm_layer=norm_layer,
122+
**kwargs,
123+
)
124+
125+
82126
class Block(nn.Module):
83127
"""Transformer block with pre-normalization."""
84128

@@ -99,6 +143,8 @@ def __init__(
99143
act_layer: Type[nn.Module] = nn.GELU,
100144
norm_layer: Type[nn.Module] = LayerNorm,
101145
mlp_layer: Type[nn.Module] = Mlp,
146+
attn_layer: LayerType = Attention,
147+
depth: int = 0,
102148
device=None,
103149
dtype=None,
104150
) -> None:
@@ -118,12 +164,15 @@ def __init__(
118164
act_layer: Activation layer.
119165
norm_layer: Normalization layer.
120166
mlp_layer: MLP layer.
167+
attn_layer: Attention layer type (class or string).
168+
depth: Block index, passed to attention layer for depth-dependent init.
121169
"""
122170
super().__init__()
123171
dd = {'device': device, 'dtype': dtype}
124172

125173
self.norm1 = norm_layer(dim, **dd)
126-
self.attn = Attention(
174+
self.attn = _create_attn(
175+
attn_layer,
127176
dim,
128177
num_heads=num_heads,
129178
qkv_bias=qkv_bias,
@@ -133,7 +182,8 @@ def __init__(
133182
attn_drop=attn_drop,
134183
proj_drop=proj_drop,
135184
norm_layer=norm_layer,
136-
**dd
185+
depth=depth,
186+
**dd,
137187
)
138188
self.ls1 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()
139189
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@@ -175,14 +225,17 @@ def __init__(
175225
act_layer: Type[nn.Module] = nn.GELU,
176226
norm_layer: Type[nn.Module] = LayerNorm,
177227
mlp_layer: Type[nn.Module] = Mlp,
178-
device = None,
179-
dtype = None,
228+
attn_layer: LayerType = Attention,
229+
depth: int = 0,
230+
device=None,
231+
dtype=None,
180232
) -> None:
181233
super().__init__()
182234
dd = {'device': device, 'dtype': dtype}
183235
self.init_values = init_values
184236

185-
self.attn = Attention(
237+
self.attn = _create_attn(
238+
attn_layer,
186239
dim,
187240
num_heads=num_heads,
188241
qkv_bias=qkv_bias,
@@ -192,6 +245,7 @@ def __init__(
192245
attn_drop=attn_drop,
193246
proj_drop=proj_drop,
194247
norm_layer=norm_layer,
248+
depth=depth,
195249
**dd,
196250
)
197251
self.norm1 = norm_layer(dim, **dd)
@@ -351,8 +405,10 @@ def __init__(
351405
act_layer: Type[nn.Module] = nn.GELU,
352406
norm_layer: Type[nn.Module] = LayerNorm,
353407
mlp_layer: Type[nn.Module] = Mlp,
354-
device = None,
355-
dtype = None
408+
attn_layer: LayerType = Attention,
409+
depth: int = 0,
410+
device=None,
411+
dtype=None,
356412
) -> None:
357413
dd = {'device': device, 'dtype': dtype}
358414
super().__init__()
@@ -362,7 +418,8 @@ def __init__(
362418
for _ in range(num_parallel):
363419
self.attns.append(nn.Sequential(OrderedDict([
364420
('norm', norm_layer(dim, **dd)),
365-
('attn', Attention(
421+
('attn', _create_attn(
422+
attn_layer,
366423
dim,
367424
num_heads=num_heads,
368425
qkv_bias=qkv_bias,
@@ -372,6 +429,7 @@ def __init__(
372429
attn_drop=attn_drop,
373430
proj_drop=proj_drop,
374431
norm_layer=norm_layer,
432+
depth=depth,
375433
**dd,
376434
)),
377435
('ls', LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()),
@@ -482,6 +540,7 @@ def __init__(
482540
act_layer: Optional[LayerType] = None,
483541
block_fn: Type[nn.Module] = Block,
484542
mlp_layer: Type[nn.Module] = Mlp,
543+
attn_layer: LayerType = Attention,
485544
device=None,
486545
dtype=None,
487546
) -> None:
@@ -592,6 +651,8 @@ def __init__(
592651
norm_layer=norm_layer,
593652
act_layer=act_layer,
594653
mlp_layer=mlp_layer,
654+
attn_layer=attn_layer,
655+
depth=i,
595656
**dd,
596657
)
597658
for i in range(depth)])

0 commit comments

Comments
 (0)