Skip to content

Commit 9262c57

Browse files
committed
flash attention
1 parent 4752644 commit 9262c57

File tree

8 files changed

+1202
-927
lines changed

8 files changed

+1202
-927
lines changed

docs/index.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ <h3><strong><a href="https://nn.labml.ai/ja/">Japanese (translated)</a></strong>
8080
<h2>Paper Implementations</h2>
8181
<h4><a href="transformers/index.html">Transformers</a></h4>
8282
<ul><li><a href="transformers/mha.html">Multi-headed attention</a> </li>
83+
<li><a href="transformers/flash/index.html">Triton Flash Attention</a> </li>
8384
<li><a href="transformers/models.html">Transformer building blocks</a> </li>
8485
<li><a href="transformers/xl/index.html">Transformer XL</a> </li>
8586
<li><a href="transformers/xl/relative_mha.html">Relative multi-headed attention</a> </li>

docs/transformers/flash/index.html

Lines changed: 830 additions & 724 deletions
Large diffs are not rendered by default.

docs/transformers/flash/test.html

Lines changed: 220 additions & 155 deletions
Large diffs are not rendered by default.

labml_nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#### ✨ [Transformers](transformers/index.html)
2626
2727
* [Multi-headed attention](transformers/mha.html)
28+
* [Triton Flash Attention](transformers/flash/index.html)
2829
* [Transformer building blocks](transformers/models.html)
2930
* [Transformer XL](transformers/xl/index.html)
3031
* [Relative multi-headed attention](transformers/xl/relative_mha.html)

labml_nn/transformers/flash/__init__.py

Lines changed: 110 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,36 @@
11
"""
2+
---
3+
title: Flash Attention
4+
summary: >
5+
This is a PyTorch/Triton implementation of Flash Attention 2
6+
with explanations.
7+
---
8+
29
# Flash Attention
310
11+
Flash attention speeds up transformer attention mechanism by reducing the number of
12+
memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM.
13+
14+
It's introduced in paper
15+
[FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135)
16+
and further optimized in paper
17+
[FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691).
18+
Official CUDA implementation can be found at [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention).
19+
20+
Our implementation is based on the
21+
[Triton's example implementation](https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html).
22+
23+
*Note: You can click on the mathematical symbols or identifiers to highlight them*.
24+
25+
You can run [test.py](./test.html) to see correctness and measure performance of this implementation.
26+
427
## Forward pass
528
29+
Here's the attention forward pass. The formulas represent a single attention head.
30+
$Q_i$ is query vector (row vector) at position $i$
31+
and $K_j$ and $V_j$ are the key and value row vectors at position $j$.
32+
$O_i$ is the output vector at position $i$.
33+
634
\begin{align}
735
S_{ij} &= \sigma Q_i K_j^T
836
\\
@@ -15,6 +43,12 @@
1543
&= \frac{1}{L_i} \sum_j e^{S_{ij}} V_j
1644
\end{align}
1745
46+
$S_{ij}$ is the attention score matrix before softmax,
47+
$L_i$ is the softmax denominator,
48+
and $P_{ij}$ is the attention matrix after softmax.
49+
50+
#### Flash Attention Optimization
51+
1852
You can compute $O_i$, instead of doing the full softmax,
1953
by computing the sum of exponents $l_i$ and the unnormalized output $\tilde{O}_i$
2054
while iterating over keys:
@@ -57,8 +91,14 @@
5791
5892
$$O_i = \frac{\tilde{O}_i}{l_i}$$
5993
94+
This reduces the memory usage since we don't have to compute full $S_{ij}$ matrix or $P_{ij}$ matrix.
95+
It also speeds up since we don't have to load these large matrices.
96+
Instead it only loads blocks of $K$ and $V$ as it iterates over them.
97+
6098
## Backward pass
6199
100+
Here's the standard backward pass. $dO_i$ is the gradient vector on the output $O_i$
101+
62102
\begin{align}
63103
dV_j &= \sum_i P_{ij} dO_i
64104
\\
@@ -95,7 +135,14 @@
95135
dS_{ij} = P_{ij} dP_{ij} - D_i P_{ij}
96136
\end{align}
97137
98-
*Note: $Q_i$, $K_j$, $dQ_i$, etc are row vectors.*
138+
Flash attention saves $L_i$ from the forward pass since it doesn't take much memory.
139+
So during the backward pass it doesn't have to keep computing $l_i$ or $m_i$.
140+
141+
It first computes $D_i$.
142+
Then it iterates over the queries and compute (accumulate) $dK_j$ and $dV_j$.
143+
Finally it iterates over the keys and compute (accumulate) $dQ_i$.
144+
145+
In both forward and backward pass we calculate logarithms and exponentials of $2$ instead of $e$ for performance.
99146
"""
100147

101148
from typing import Any, Tuple
@@ -110,9 +157,12 @@
110157

111158
class AttentionFunc(torch.autograd.Function):
112159
@staticmethod
113-
def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
160+
def forward(ctx: Any,
161+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
114162
causal: bool, sm_scale: float) -> torch.Tensor:
115163
"""
164+
### Forward pass
165+
116166
Group query attention forward pass. Returns the output in shape `[batch_size, n_heads, q_seq_len, d_head]`.
117167
118168
:param ctx: is the context for torch gradient descent
@@ -121,7 +171,7 @@ def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
121171
:param k: has shape `[batch_size, k_heads, kv_seq_len, d_head]`
122172
:param v: has shape `[batch_size, k_heads, kv_seq_len, d_head]`
123173
:param causal: whether to apply causal attention mask
124-
:param sm_scale: softmax scale factor
174+
:param sm_scale: softmax scale factor $\sigma$
125175
"""
126176
batch_size, n_heads, q_seq_len, d_head = q.shape
127177
_, k_heads, kv_seq_len, _ = k.shape
@@ -171,6 +221,8 @@ def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
171221
@staticmethod
172222
def backward(ctx: Any, do: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
173223
"""
224+
### Backward pass
225+
174226
The backward pass computes the gradients of the input tensors.
175227
176228
:param ctx: is the context for torch gradient descent
@@ -264,22 +316,27 @@ def _get_autotune_configs(inner_loop: str) -> list:
264316
"""
265317

266318
configs = []
267-
# List possible BLOCK_Q and BLOCK_K that satisfy BLOCK_Q divisible by BLOCK_K
268-
# and also try to cover a wide range
269-
for bm in [64, 128, 256]:
270-
# We'll try bn in [16, 32, 64, 128] that are divisors and <= bm
271-
for bn in [64, 128, 256]:
272-
if inner_loop == 'key' and bm % bn != 0:
319+
320+
# Possible options for `BLOCK_Q`
321+
for bq in [64, 128, 256]:
322+
# Possible options for `BLOCK_K`
323+
for bk in [64, 128, 256]:
324+
# If the inner loop is along keys the `BLOCK_Q` must be a multiple of `BLOCK_K` for causal masking
325+
if inner_loop == 'key' and bq % bk != 0:
273326
continue
274-
if inner_loop == 'query' and bn % bm != 0:
327+
# Similarly when the inner loop is along queries
328+
if inner_loop == 'query' and bk % bq != 0:
275329
continue
330+
331+
# Number of stages and warps
276332
for s in [2, 3, 4]:
277333
for w in [4, 8]:
278-
if bm * bn < 128 * 128 and w == 8:
334+
if bq * bk < 128 * 128 and w == 8:
279335
continue
280336

281-
configs.append(triton.Config({'BLOCK_Q': bm, 'BLOCK_K': bn}, num_stages=s, num_warps=w))
337+
configs.append(triton.Config({'BLOCK_Q': bq, 'BLOCK_K': bk}, num_stages=s, num_warps=w))
282338

339+
# **Use `return configs` to autotune. Trying all combinations is slow for testing.**
283340
return configs[:1]
284341

285342

@@ -292,34 +349,37 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
292349
kv_seq_len: tl.constexpr,
293350
d_head: tl.constexpr,
294351
is_causal: tl.constexpr,
295-
BLOCK_Q: tl.constexpr, # q seq len block
296-
BLOCK_K: tl.constexpr, # k seq len block
352+
BLOCK_Q: tl.constexpr,
353+
BLOCK_K: tl.constexpr,
297354
):
298355
"""
299-
:param t_q: query
300-
:param t_k: keys
301-
:param t_v: values
302-
:param sm_scale: softmax scale
356+
### Triton kernel for Flash attention forward pass
357+
358+
:param t_q: queries $Q_i$
359+
:param t_k: keys $K_j$
360+
:param t_v: values $V_j$
361+
:param sm_scale_log2e: $\sigma \log_2 e$ softmax scale multiplied by $\log_2 e$
303362
:param t_lse: $\log_2 \sum_j e^{S_{ij}}$ (out)
304-
:param t_o: output (out)
305-
:param n_groups: number of groups
363+
:param t_o: $O_i$ output
364+
:param n_groups: number of groups in GQA
306365
:param q_seq_len: query sequence length
307366
:param kv_seq_len: key/value sequence length
308-
:param d_head: size of a head
309-
:param BLOCK_Q: block size for query sequence length
310-
:param BLOCK_K: block size for key sequence length
367+
:param d_head: number of dimensions in a head
368+
:param BLOCK_Q: block size for query sequence length
369+
:param BLOCK_K: block size for key sequence length
311370
:param is_causal: whether causal attention
312371
313372
Strides `z`, `h`, `m` and `d` denote the stride of the corresponding dimensions
314-
(`batch_size`, `n_heads`, `seq_len`, `d_head`) in the query.
315-
Stride `n` denote the stride on `seq_len` of key.
373+
(`batch_size`, `n_heads`, `q_seq_len`, `d_head`) in the query.
374+
Stride `n` denote the stride on `kv_seq_len` of key.
316375
"""
317376

377+
# We are computing the attention for $O_i$ for `i` ... `i + BLOCK_Q' in batch/head combination $z$.
318378
i = tl.program_id(0)
319379
z = tl.program_id(1) // n_groups
320-
g = tl.program_id(1) % n_groups # TODO
380+
g = tl.program_id(1) % n_groups
321381

322-
# Create block pointers
382+
# #### Create block pointers
323383
p_q = tl.make_block_ptr(t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head,
324384
(q_seq_len, d_head),
325385
(d_head, 1),
@@ -354,6 +414,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
354414
# Initialize offsets
355415
offs_i = i * BLOCK_Q + tl.arange(0, BLOCK_Q)
356416
offs_j = tl.arange(0, BLOCK_K)
417+
357418
# Mask for $Q$ for the last block
358419
i_mask = offs_i < q_seq_len
359420

@@ -427,6 +488,12 @@ def _attn_fwd_inner(b_o, b_l, b_m, b_q,
427488
q_seq_len: tl.constexpr,
428489
kv_seq_len: tl.constexpr
429490
):
491+
"""
492+
#### Inner loop to calculate $O_i$
493+
494+
This iterates through keys and values starting from `j` for `steps` number of steps.
495+
In each step it processes `BLOCK_K` entries of keys/values.
496+
"""
430497
tl.static_assert(BLOCK_Q % BLOCK_K == 0)
431498

432499
# Move $K_j$ and $V_j$ pointers
@@ -492,6 +559,9 @@ def _attn_bwd_d(t_o, t_do,
492559
q_seq_len: tl.constexpr,
493560
n_groups: tl.constexpr,
494561
):
562+
"""
563+
#### Triton kernel to compute $D_i$
564+
"""
495565
i = tl.program_id(0) * BLOCK_Q
496566
z = tl.program_id(1)
497567

@@ -539,9 +609,10 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
539609
BLOCK_K: tl.constexpr,
540610
):
541611
"""
542-
Compute $dK_j$ and $dV_j$ for $j1 \dots j2$ by iterating over $Q_i$
612+
#### Triton kernel to compute $dK_j$ and $dV_j$
543613
"""
544614

615+
# Compute $dK_j$ and $dV_j$ for `j` ... `j + BLOCK_K` by iterating over $Q_i$
545616
j = tl.program_id(0) * BLOCK_K
546617
z = tl.program_id(1)
547618

@@ -623,7 +694,7 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
623694
kv_seq_len=kv_seq_len,
624695
)
625696

626-
# Innerloop on queries after the diagonal
697+
# Inner loop on queries after the diagonal
627698
b_dk, b_dv = _attn_bwd_dkdv_inner(
628699
b_dk, b_dv,
629700
p_qT, b_k, b_v, p_do,
@@ -671,7 +742,9 @@ def _attn_bwd_dkdv_inner(b_dk, b_dv,
671742
MASK: tl.constexpr,
672743
q_seq_len: tl.constexpr,
673744
kv_seq_len: tl.constexpr):
674-
"""Inner loop along query"""
745+
"""
746+
#### Inner loop to calculate $dK_j$, $dV_j$
747+
"""
675748

676749
# To apply the mask
677750
tl.static_assert(BLOCK_K % BLOCK_Q == 0)
@@ -755,6 +828,10 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
755828
BLOCK_Q: tl.constexpr,
756829
BLOCK_K: tl.constexpr,
757830
):
831+
"""
832+
#### Triton kernel to compute $dQ_i$
833+
"""
834+
758835
i = tl.program_id(0) * BLOCK_Q
759836
z = tl.program_id(1) // n_groups
760837
g = tl.program_id(1) % n_groups # TODO
@@ -863,7 +940,9 @@ def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
863940
MASK: tl.constexpr,
864941
q_seq_len: tl.constexpr,
865942
kv_seq_len: tl.constexpr):
866-
"""Inner loop over key"""
943+
"""
944+
#### Inner loop to calculate $dQ_i$
945+
"""
867946

868947
# Offsets
869948
offs_i = i + tl.arange(0, BLOCK_Q)

0 commit comments

Comments
 (0)