Skip to content

Commit eb5c004

Browse files
committed
sm scale log2
1 parent 5a8182d commit eb5c004

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

labml_nn/transformers/flash/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
362362
# Precalculate $\frac{\sigma}{\log 2}$.
363363
#
364364
# We will be use this when calculating $S_{ij}$ so `S` will store $S_{ij} \log 2$ instead.
365-
sm_scale = sm_scale * 1.44269504
365+
sm_scale_log2 = sm_scale * 1.44269504
366366

367367
# Initialize $m_i$ and $l_i$. $m_i$ is initialized to $-\inf$ and $l_i$ to $1$. So in the first update,
368368
# the effect of initial $l_i$ is $e^{m_i - m_{i}^{\text{new}}} l_i = 0$.
@@ -381,7 +381,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
381381
# Inner loop upto the diagonal block
382382
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q,
383383
p_kT, p_v,
384-
sm_scale,
384+
sm_scale_log2,
385385
BLOCK_Q, d_head, BLOCK_K,
386386
offs_i, offs_j,
387387
j=tl.full([], 0, tl.int32), # type: ignore
@@ -392,7 +392,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
392392
)
393393
# Diagonal block with masking within it
394394
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
395-
sm_scale,
395+
sm_scale_log2,
396396
BLOCK_Q, d_head, BLOCK_K,
397397
offs_i, offs_j,
398398
j=i * BLOCK_Q,
@@ -404,7 +404,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
404404
else:
405405
# Iterate through all $K_j$
406406
b_o, b_l, b_m = _attn_fwd_inner(b_o, b_l, b_m, b_q, p_kT, p_v,
407-
sm_scale,
407+
sm_scale_log2,
408408
BLOCK_Q, d_head, BLOCK_K,
409409
offs_i, offs_j,
410410
j=tl.full([], 0, tl.int32), # type: ignore
@@ -423,7 +423,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale, t_lse, t_o,
423423
@triton.jit
424424
def _attn_fwd_inner(b_o, b_l, b_m, b_q,
425425
p_kT, p_v,
426-
scale,
426+
sm_scale_log2,
427427
BLOCK_Q: tl.constexpr,
428428
d_head: tl.constexpr,
429429
BLOCK_K: tl.constexpr,
@@ -446,7 +446,7 @@ def _attn_fwd_inner(b_o, b_l, b_m, b_q,
446446
b_kT = tl.load(p_kT, boundary_check=(1,), padding_option="zero")
447447
# Compute $(\log 2) S_ij = (\log 2) \sigma Q_i K_j^T$
448448
b_s = tl.dot(b_q, b_kT, out_dtype=HI_PRES_TL)
449-
b_s = b_s * scale
449+
b_s = b_s * sm_scale_log2
450450

451451
# Apply causal mask
452452
if MASK:

0 commit comments

Comments
 (0)