|
| 1 | +"""Differential Attention |
| 2 | +
|
| 3 | +Paper: 'Differential Transformer' - https://arxiv.org/abs/2410.05258 |
| 4 | +
|
| 5 | +Reference impl: https://github.com/microsoft/unilm/tree/master/Diff-Transformer |
| 6 | +
|
| 7 | +Hacked together by / Copyright 2024, Ross Wightman |
| 8 | +""" |
| 9 | +import math |
| 10 | +from typing import Optional, Type |
| 11 | + |
| 12 | +import torch |
| 13 | +import torch.nn as nn |
| 14 | +import torch.nn.functional as F |
| 15 | + |
| 16 | +from .attention import maybe_add_mask |
| 17 | +from .config import use_fused_attn |
| 18 | +from .norm import RmsNorm |
| 19 | + |
| 20 | + |
| 21 | +class DiffAttention(nn.Module): |
| 22 | + """Differential Attention module. |
| 23 | +
|
| 24 | + Computes attention as the difference between two softmax attention maps, which helps |
| 25 | + cancel out noise and promotes sparse attention patterns. The module splits Q and K |
| 26 | + into two groups, computes separate attention maps, and subtracts one from the other |
| 27 | + scaled by a learnable lambda parameter. |
| 28 | +
|
| 29 | + The attention output is computed as: |
| 30 | + Attn = softmax(Q1 @ K1^T) - lambda * softmax(Q2 @ K2^T) |
| 31 | + Output = Attn @ V |
| 32 | +
|
| 33 | + Supports both fused (scaled_dot_product_attention) and manual implementations. |
| 34 | + """ |
| 35 | + fused_attn: torch.jit.Final[bool] |
| 36 | + |
| 37 | + def __init__( |
| 38 | + self, |
| 39 | + dim: int, |
| 40 | + num_heads: int = 8, |
| 41 | + qkv_bias: bool = False, |
| 42 | + qk_norm: bool = False, |
| 43 | + scale_norm: bool = False, |
| 44 | + proj_bias: bool = True, |
| 45 | + attn_drop: float = 0., |
| 46 | + proj_drop: float = 0., |
| 47 | + norm_layer: Optional[Type[nn.Module]] = None, |
| 48 | + depth: int = 0, |
| 49 | + dual_lambda: bool = False, |
| 50 | + device=None, |
| 51 | + dtype=None, |
| 52 | + ) -> None: |
| 53 | + """Initialize the DiffAttention module. |
| 54 | +
|
| 55 | + Args: |
| 56 | + dim: Input dimension of the token embeddings. |
| 57 | + num_heads: Number of attention heads. |
| 58 | + qkv_bias: Whether to use bias in the query, key, value projections. |
| 59 | + qk_norm: Whether to apply normalization to query and key vectors. |
| 60 | + scale_norm: Whether to apply normalization before the output projection. |
| 61 | + proj_bias: Whether to use bias in the output projection. |
| 62 | + attn_drop: Dropout rate applied to the attention weights. |
| 63 | + proj_drop: Dropout rate applied after the output projection. |
| 64 | + norm_layer: Normalization layer constructor (defaults to RmsNorm). |
| 65 | + depth: Block depth index, used to compute depth-dependent lambda_init. |
| 66 | + dual_lambda: If True, use simplified dual scalar lambda parameterization |
| 67 | + (2 params). If False, use the paper's original formulation with |
| 68 | + lambda_q/k vectors (4 * head_dim params). |
| 69 | + """ |
| 70 | + super().__init__() |
| 71 | + dd = {'device': device, 'dtype': dtype} |
| 72 | + assert dim % num_heads == 0, 'dim should be divisible by num_heads' |
| 73 | + if norm_layer is None: |
| 74 | + norm_layer = RmsNorm |
| 75 | + self.num_heads = num_heads |
| 76 | + self.head_dim = dim // num_heads // 2 |
| 77 | + self.scale = self.head_dim ** -0.5 |
| 78 | + self.fused_attn = use_fused_attn() |
| 79 | + |
| 80 | + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd) |
| 81 | + self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() |
| 82 | + self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() |
| 83 | + self.attn_drop = nn.Dropout(attn_drop) |
| 84 | + self.attn_drop_p = attn_drop |
| 85 | + self.norm = norm_layer(dim, **dd) if scale_norm else nn.Identity() |
| 86 | + self.proj = nn.Linear(dim, dim, bias=proj_bias, **dd) |
| 87 | + self.proj_drop = nn.Dropout(proj_drop) |
| 88 | + |
| 89 | + self.dual_lambda = dual_lambda |
| 90 | + if dual_lambda: |
| 91 | + self.lambda_a = nn.Parameter(torch.empty((), dtype=torch.float32, device=device)) |
| 92 | + self.lambda_b = nn.Parameter(torch.empty((), dtype=torch.float32, device=device)) |
| 93 | + self.lambda_q1 = self.lambda_k1 = self.lambda_q2 = self.lambda_k2 = None |
| 94 | + else: |
| 95 | + self.lambda_a = self.lambda_b = None |
| 96 | + self.lambda_q1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device)) |
| 97 | + self.lambda_k1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device)) |
| 98 | + self.lambda_q2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device)) |
| 99 | + self.lambda_k2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device)) |
| 100 | + |
| 101 | + self.sub_norm = RmsNorm(2 * self.head_dim, eps=1e-5, **dd) |
| 102 | + |
| 103 | + self.lambda_init = 0.8 |
| 104 | + self.set_lambda_init(depth) |
| 105 | + self.reset_parameters() |
| 106 | + |
| 107 | + def set_lambda_init(self, depth: int): |
| 108 | + self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth) |
| 109 | + |
| 110 | + def reset_parameters(self): |
| 111 | + if self.dual_lambda: |
| 112 | + nn.init.zeros_(self.lambda_a) |
| 113 | + nn.init.zeros_(self.lambda_b) |
| 114 | + else: |
| 115 | + nn.init.normal_(self.lambda_q1, mean=0, std=0.1) |
| 116 | + nn.init.normal_(self.lambda_k1, mean=0, std=0.1) |
| 117 | + nn.init.normal_(self.lambda_q2, mean=0, std=0.1) |
| 118 | + nn.init.normal_(self.lambda_k2, mean=0, std=0.1) |
| 119 | + |
| 120 | + def _compute_lambda(self) -> torch.Tensor: |
| 121 | + if self.lambda_a is not None: |
| 122 | + lambda_1 = torch.exp(self.lambda_a) |
| 123 | + lambda_2 = torch.exp(self.lambda_b) |
| 124 | + else: |
| 125 | + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()) |
| 126 | + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()) |
| 127 | + return lambda_1 - lambda_2 + self.lambda_init |
| 128 | + |
| 129 | + def forward( |
| 130 | + self, |
| 131 | + x: torch.Tensor, |
| 132 | + attn_mask: Optional[torch.Tensor] = None, |
| 133 | + ) -> torch.Tensor: |
| 134 | + B, N, C = x.shape |
| 135 | + |
| 136 | + q, k, v = self.qkv(x).chunk(3, dim=2) |
| 137 | + q = q.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2) |
| 138 | + k = k.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2) |
| 139 | + v = v.reshape(B, N, self.num_heads, 2 * self.head_dim).transpose(1, 2) |
| 140 | + |
| 141 | + q, k = self.q_norm(q), self.k_norm(k) |
| 142 | + |
| 143 | + lambda_full = self._compute_lambda().type_as(q) |
| 144 | + |
| 145 | + if self.fused_attn: |
| 146 | + q = q.reshape(B, self.num_heads, 2, N, self.head_dim) |
| 147 | + k = k.reshape(B, self.num_heads, 2, N, self.head_dim) |
| 148 | + q1, q2 = q.unbind(2) |
| 149 | + k1, k2 = k.unbind(2) |
| 150 | + |
| 151 | + dropout_p = self.attn_drop_p if self.training else 0.0 |
| 152 | + attn1 = F.scaled_dot_product_attention(q1, k1, v, attn_mask=attn_mask, dropout_p=dropout_p) |
| 153 | + attn2 = F.scaled_dot_product_attention(q2, k2, v, attn_mask=attn_mask, dropout_p=dropout_p) |
| 154 | + |
| 155 | + x = attn1 - lambda_full * attn2 |
| 156 | + else: |
| 157 | + q = q * self.scale |
| 158 | + attn = q @ k.transpose(-2, -1) |
| 159 | + attn = maybe_add_mask(attn, attn_mask) |
| 160 | + attn = attn.softmax(dim=-1) |
| 161 | + attn = self.attn_drop(attn) |
| 162 | + |
| 163 | + attn = attn.view(B, self.num_heads, 2, N, N) |
| 164 | + attn = attn[:, :, 0] - lambda_full * attn[:, :, 1] |
| 165 | + x = attn @ v |
| 166 | + |
| 167 | + x = self.sub_norm(x) |
| 168 | + x = x * (1 - self.lambda_init) |
| 169 | + x = x.transpose(1, 2).reshape(B, N, C) |
| 170 | + |
| 171 | + x = self.norm(x) |
| 172 | + x = self.proj(x) |
| 173 | + x = self.proj_drop(x) |
| 174 | + |
| 175 | + return x |
0 commit comments