@@ -107,6 +107,82 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
107107 return x
108108
109109
110+ class DiffAttention (nn .Module ):
111+ fused_attn : Final [bool ]
112+
113+ def __init__ (
114+ self ,
115+ dim : int ,
116+ num_heads : int = 8 ,
117+ qkv_bias : bool = False ,
118+ qk_norm : bool = False ,
119+ attn_drop : float = 0. ,
120+ proj_drop : float = 0. ,
121+ norm_layer : nn .Module = RmsNorm ,
122+ ) -> None :
123+ super ().__init__ ()
124+ assert dim % num_heads == 0 , 'dim should be divisible by num_heads'
125+ self .num_heads = num_heads
126+ self .head_dim = dim // num_heads // 2
127+ self .scale = self .head_dim ** - 0.5
128+ self .fused_attn = use_fused_attn ()
129+
130+ self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
131+ self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
132+ self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
133+ self .attn_drop = nn .Dropout (attn_drop )
134+ self .proj = nn .Linear (dim , dim )
135+ self .proj_drop = nn .Dropout (proj_drop )
136+
137+ self .lambda_init = 0.8
138+ self .lambda_q1 = nn .Parameter (torch .zeros (self .head_dim , dtype = torch .float32 ).normal_ (mean = 0 , std = 0.1 ))
139+ self .lambda_k1 = nn .Parameter (torch .zeros (self .head_dim , dtype = torch .float32 ).normal_ (mean = 0 , std = 0.1 ))
140+ self .lambda_q2 = nn .Parameter (torch .zeros (self .head_dim , dtype = torch .float32 ).normal_ (mean = 0 , std = 0.1 ))
141+ self .lambda_k2 = nn .Parameter (torch .zeros (self .head_dim , dtype = torch .float32 ).normal_ (mean = 0 , std = 0.1 ))
142+
143+ self .sub_norm = RmsNorm (2 * self .head_dim , eps = 1e-5 )
144+
145+ def _set_lambda_init (self , depth : int ):
146+ self .lambda_init = 0.8 - 0.6 * math .exp (- 0.3 * depth )
147+
148+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
149+ B , N , C = x .shape
150+ q , k , v = self .qkv (x ).chunk (3 , dim = 2 )
151+ q = q .reshape (B , N , 2 * self .num_heads , self .head_dim ).transpose (1 , 2 )
152+ k = k .reshape (B , N , 2 * self .num_heads , self .head_dim ).transpose (1 , 2 )
153+ v = v .reshape (B , N , self .num_heads , 2 * self .head_dim ).transpose (1 , 2 )
154+ q , k = self .q_norm (q ), self .k_norm (k )
155+
156+ if self .fused_attn :
157+ q = q .reshape (B , self .num_heads , 2 , N , self .head_dim )
158+ k = k .reshape (B , self .num_heads , 2 , N , self .head_dim )
159+ q1 , q2 = q .unbind (2 )
160+ k1 , k2 = k .unbind (2 )
161+ attn1 = F .scaled_dot_product_attention (q1 , k1 , v )
162+ attn2 = F .scaled_dot_product_attention (q2 , k2 , v )
163+ lambda_1 = torch .exp (torch .sum (self .lambda_q1 * self .lambda_k1 , dim = - 1 ).float ()).type_as (q )
164+ lambda_2 = torch .exp (torch .sum (self .lambda_q2 * self .lambda_k2 , dim = - 1 ).float ()).type_as (q )
165+ lambda_full = lambda_1 - lambda_2 + self .lambda_init
166+ x = attn1 - lambda_full * attn2
167+ else :
168+ q = q * self .scale
169+ attn = q @ k .transpose (- 2 , - 1 )
170+ attn = attn .softmax (dim = - 1 )
171+ lambda_1 = torch .exp (torch .sum (self .lambda_q1 * self .lambda_k1 , dim = - 1 ).float ()).type_as (q )
172+ lambda_2 = torch .exp (torch .sum (self .lambda_q2 * self .lambda_k2 , dim = - 1 ).float ()).type_as (q )
173+ lambda_full = lambda_1 - lambda_2 + self .lambda_init
174+ attn = attn .view (B , self .num_heads , 2 , N , N )
175+ attn = attn [:, :, 0 ] - lambda_full * attn [:, :, 1 ]
176+ x = attn @ v
177+
178+ x = self .sub_norm (x )
179+ x = x * (1 - self .lambda_init )
180+ x = x .transpose (1 , 2 ).reshape (B , N , C )
181+ x = self .proj (x )
182+ x = self .proj_drop (x )
183+ return x
184+
185+
110186class LayerScale (nn .Module ):
111187 def __init__ (
112188 self ,
0 commit comments