@@ -362,7 +362,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
362362 # Precalculate $\frac{\sigma}{\log 2}$.
363363 #
364364 # We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log 2$ instead.
365- sm_scale = sm_scale * 1.44269504
365+ sm_scale_log2 = sm_scale * 1.44269504
366366
367367 # Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
368368 # the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
@@ -381,7 +381,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
381381 # Inner loop upto the diagonal block
382382 b_o , b_l , b_m = _attn_fwd_inner (b_o , b_l , b_m , b_q ,
383383 p_kT , p_v ,
384- sm_scale ,
384+ sm_scale_log2 ,
385385 BLOCK_Q , d_head , BLOCK_K ,
386386 offs_i , offs_j ,
387387 j = tl .full ([], 0 , tl .int32 ), # type: ignore
@@ -392,7 +392,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
392392 )
393393 # Diagonal block with masking within it
394394 b_o , b_l , b_m = _attn_fwd_inner (b_o , b_l , b_m , b_q , p_kT , p_v ,
395- sm_scale ,
395+ sm_scale_log2 ,
396396 BLOCK_Q , d_head , BLOCK_K ,
397397 offs_i , offs_j ,
398398 j = i * BLOCK_Q ,
@@ -404,7 +404,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
404404 else :
405405 # Iterate through all $K_j$
406406 b_o , b_l , b_m = _attn_fwd_inner (b_o , b_l , b_m , b_q , p_kT , p_v ,
407- sm_scale ,
407+ sm_scale_log2 ,
408408 BLOCK_Q , d_head , BLOCK_K ,
409409 offs_i , offs_j ,
410410 j = tl .full ([], 0 , tl .int32 ), # type: ignore
@@ -423,7 +423,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
423423@triton .jit
424424def _attn_fwd_inner (b_o , b_l , b_m , b_q ,
425425 p_kT , p_v ,
426- scale ,
426+ sm_scale_log2 ,
427427 BLOCK_Q : tl .constexpr ,
428428 d_head : tl .constexpr ,
429429 BLOCK_K : tl .constexpr ,
@@ -446,7 +446,7 @@ def _attn_fwd_inner(b_o, b_l, b_m, b_q,
446446 b_kT = tl .load (p_kT , boundary_check = (1 ,), padding_option = "zero" )
447447 # Compute $(\log 2) S_ij = (\log 2) \sigma Q_i K_j^T$
448448 b_s = tl .dot (b_q , b_kT , out_dtype = HI_PRES_TL )
449- b_s = b_s * scale
449+ b_s = b_s * sm_scale_log2
450450
451451 # Apply causal mask
452452 if MASK :
0 commit comments