Skip to content

Commit c9c973b

Browse files
committed
Experimenting with differential attention
1 parent f689c85 commit c9c973b

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

timm/models/vision_transformer.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,82 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
107107
return x
108108

109109

110+
class DiffAttention(nn.Module):
111+
fused_attn: Final[bool]
112+
113+
def __init__(
114+
self,
115+
dim: int,
116+
num_heads: int = 8,
117+
qkv_bias: bool = False,
118+
qk_norm: bool = False,
119+
attn_drop: float = 0.,
120+
proj_drop: float = 0.,
121+
norm_layer: nn.Module = RmsNorm,
122+
) -> None:
123+
super().__init__()
124+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
125+
self.num_heads = num_heads
126+
self.head_dim = dim // num_heads // 2
127+
self.scale = self.head_dim ** -0.5
128+
self.fused_attn = use_fused_attn()
129+
130+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
131+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
132+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
133+
self.attn_drop = nn.Dropout(attn_drop)
134+
self.proj = nn.Linear(dim, dim)
135+
self.proj_drop = nn.Dropout(proj_drop)
136+
137+
self.lambda_init = 0.8
138+
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
139+
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
140+
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
141+
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
142+
143+
self.sub_norm = RmsNorm(2 * self.head_dim, eps=1e-5)
144+
145+
def _set_lambda_init(self, depth: int):
146+
self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth)
147+
148+
def forward(self, x: torch.Tensor) -> torch.Tensor:
149+
B, N, C = x.shape
150+
q, k, v = self.qkv(x).chunk(3, dim=2)
151+
q = q.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
152+
k = k.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
153+
v = v.reshape(B, N, self.num_heads, 2 * self.head_dim).transpose(1, 2)
154+
q, k = self.q_norm(q), self.k_norm(k)
155+
156+
if self.fused_attn:
157+
q = q.reshape(B, self.num_heads, 2, N, self.head_dim)
158+
k = k.reshape(B, self.num_heads, 2, N, self.head_dim)
159+
q1, q2 = q.unbind(2)
160+
k1, k2 = k.unbind(2)
161+
attn1 = F.scaled_dot_product_attention(q1, k1, v)
162+
attn2 = F.scaled_dot_product_attention(q2, k2, v)
163+
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
164+
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
165+
lambda_full = lambda_1 - lambda_2 + self.lambda_init
166+
x = attn1 - lambda_full * attn2
167+
else:
168+
q = q * self.scale
169+
attn = q @ k.transpose(-2, -1)
170+
attn = attn.softmax(dim=-1)
171+
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
172+
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
173+
lambda_full = lambda_1 - lambda_2 + self.lambda_init
174+
attn = attn.view(B, self.num_heads, 2, N, N)
175+
attn = attn[:, :, 0] - lambda_full * attn[:, :, 1]
176+
x = attn @ v
177+
178+
x = self.sub_norm(x)
179+
x = x * (1 - self.lambda_init)
180+
x = x.transpose(1, 2).reshape(B, N, C)
181+
x = self.proj(x)
182+
x = self.proj_drop(x)
183+
return x
184+
185+
110186
class LayerScale(nn.Module):
111187
def __init__(
112188
self,

0 commit comments

Comments
 (0)