11"""
2+ ---
3+ title: Flash Attention
4+ summary: >
5+ This is a PyTorch/Triton implementation of Flash Attention 2
6+ with explanations.
7+ ---
8+
29# Flash Attention
310
11+ Flash attention speeds up transformer attention mechanism by reducing the number of
12+ memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM.
13+
14+ It's introduced in paper
15+ [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135)
16+ and further optimized in paper
17+ [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691).
18+ Official CUDA implementation can be found at [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention).
19+
20+ Our implementation is based on the
21+ [Triton's example implementation](https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html).
22+
23+ *Note: You can click on the mathematical symbols or identifiers to highlight them*.
24+
25+ You can run [test.py](./test.html) to see correctness and measure performance of this implementation.
26+
427## Forward pass
528
29+ Here's the attention forward pass. The formulas represent a single attention head.
30+ $Q_i$ is query vector (row vector) at position $i$
31+ and $K_j$ and $V_j$ are the key and value row vectors at position $j$.
32+ $O_i$ is the output vector at position $i$.
33+
634\b egin{align}
735S_{ij} &= \sigma Q_i K_j^T
836\\
1543&= \f rac{1}{L_i} \sum_j e^{S_{ij}} V_j
1644\end{align}
1745
46+ $S_{ij}$ is the attention score matrix before softmax,
47+ $L_i$ is the softmax denominator,
48+ and $P_{ij}$ is the attention matrix after softmax.
49+
50+ #### Flash Attention Optimization
51+
1852You can compute $O_i$, instead of doing the full softmax,
1953by computing the sum of exponents $l_i$ and the unnormalized output $\t ilde{O}_i$
2054while iterating over keys:
5791
5892$$O_i = \f rac{\t ilde{O}_i}{l_i}$$
5993
94+ This reduces the memory usage since we don't have to compute full $S_{ij}$ matrix or $P_{ij}$ matrix.
95+ It also speeds up since we don't have to load these large matrices.
96+ Instead it only loads blocks of $K$ and $V$ as it iterates over them.
97+
6098## Backward pass
6199
100+ Here's the standard backward pass. $dO_i$ is the gradient vector on the output $O_i$
101+
62102\b egin{align}
63103dV_j &= \sum_i P_{ij} dO_i
64104\\
95135dS_{ij} = P_{ij} dP_{ij} - D_i P_{ij}
96136\end{align}
97137
98- *Note: $Q_i$, $K_j$, $dQ_i$, etc are row vectors.*
138+ Flash attention saves $L_i$ from the forward pass since it doesn't take much memory.
139+ So during the backward pass it doesn't have to keep computing $l_i$ or $m_i$.
140+
141+ It first computes $D_i$.
142+ Then it iterates over the queries and compute (accumulate) $dK_j$ and $dV_j$.
143+ Finally it iterates over the keys and compute (accumulate) $dQ_i$.
144+
145+ In both forward and backward pass we calculate logarithms and exponentials of $2$ instead of $e$ for performance.
99146"""
100147
101148from typing import Any , Tuple
110157
111158class AttentionFunc (torch .autograd .Function ):
112159 @staticmethod
113- def forward (ctx : Any , q : torch .Tensor , k : torch .Tensor , v : torch .Tensor ,
160+ def forward (ctx : Any ,
161+ q : torch .Tensor , k : torch .Tensor , v : torch .Tensor ,
114162 causal : bool , sm_scale : float ) -> torch .Tensor :
115163 """
164+ ### Forward pass
165+
116166 Group query attention forward pass. Returns the output in shape `[batch_size, n_heads, q_seq_len, d_head]`.
117167
118168 :param ctx: is the context for torch gradient descent
@@ -121,7 +171,7 @@ def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
121171 :param k: has shape `[batch_size, k_heads, kv_seq_len, d_head]`
122172 :param v: has shape `[batch_size, k_heads, kv_seq_len, d_head]`
123173 :param causal: whether to apply causal attention mask
124- :param sm_scale: softmax scale factor
174+ :param sm_scale: softmax scale factor $\sigma$
125175 """
126176 batch_size , n_heads , q_seq_len , d_head = q .shape
127177 _ , k_heads , kv_seq_len , _ = k .shape
@@ -171,6 +221,8 @@ def forward(ctx: Any, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
171221 @staticmethod
172222 def backward (ctx : Any , do : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , None , None ]:
173223 """
224+ ### Backward pass
225+
174226 The backward pass computes the gradients of the input tensors.
175227
176228 :param ctx: is the context for torch gradient descent
@@ -264,22 +316,27 @@ def _get_autotune_configs(inner_loop: str) -> list:
264316 """
265317
266318 configs = []
267- # List possible BLOCK_Q and BLOCK_K that satisfy BLOCK_Q divisible by BLOCK_K
268- # and also try to cover a wide range
269- for bm in [64 , 128 , 256 ]:
270- # We'll try bn in [16, 32, 64, 128] that are divisors and <= bm
271- for bn in [64 , 128 , 256 ]:
272- if inner_loop == 'key' and bm % bn != 0 :
319+
320+ # Possible options for `BLOCK_Q`
321+ for bq in [64 , 128 , 256 ]:
322+ # Possible options for `BLOCK_K`
323+ for bk in [64 , 128 , 256 ]:
324+ # If the inner loop is along keys the `BLOCK_Q` must be a multiple of `BLOCK_K` for causal masking
325+ if inner_loop == 'key' and bq % bk != 0 :
273326 continue
274- if inner_loop == 'query' and bn % bm != 0 :
327+ # Similarly when the inner loop is along queries
328+ if inner_loop == 'query' and bk % bq != 0 :
275329 continue
330+
331+ # Number of stages and warps
276332 for s in [2 , 3 , 4 ]:
277333 for w in [4 , 8 ]:
278- if bm * bn < 128 * 128 and w == 8 :
334+ if bq * bk < 128 * 128 and w == 8 :
279335 continue
280336
281- configs .append (triton .Config ({'BLOCK_Q' : bm , 'BLOCK_K' : bn }, num_stages = s , num_warps = w ))
337+ configs .append (triton .Config ({'BLOCK_Q' : bq , 'BLOCK_K' : bk }, num_stages = s , num_warps = w ))
282338
339+ # **Use `return configs` to autotune. Trying all combinations is slow for testing.**
283340 return configs [:1 ]
284341
285342
@@ -292,34 +349,37 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
292349 kv_seq_len : tl .constexpr ,
293350 d_head : tl .constexpr ,
294351 is_causal : tl .constexpr ,
295- BLOCK_Q : tl .constexpr , # q seq len block
296- BLOCK_K : tl .constexpr , # k seq len block
352+ BLOCK_Q : tl .constexpr ,
353+ BLOCK_K : tl .constexpr ,
297354 ):
298355 """
299- :param t_q: query
300- :param t_k: keys
301- :param t_v: values
302- :param sm_scale: softmax scale
356+ ### Triton kernel for Flash attention forward pass
357+
358+ :param t_q: queries $Q_i$
359+ :param t_k: keys $K_j$
360+ :param t_v: values $V_j$
361+ :param sm_scale_log2e: $\sigma \log_2 e$ softmax scale multiplied by $\log_2 e$
303362 :param t_lse: $\log_2 \sum_j e^{S_{ij}}$ (out)
304- :param t_o: output (out)
305- :param n_groups: number of groups
363+ :param t_o: $O_i$ output
364+ :param n_groups: number of groups in GQA
306365 :param q_seq_len: query sequence length
307366 :param kv_seq_len: key/value sequence length
308- :param d_head: size of a head
309- :param BLOCK_Q: block size for query sequence length
310- :param BLOCK_K: block size for key sequence length
367+ :param d_head: number of dimensions in a head
368+ :param BLOCK_Q: block size for query sequence length
369+ :param BLOCK_K: block size for key sequence length
311370 :param is_causal: whether causal attention
312371
313372 Strides `z`, `h`, `m` and `d` denote the stride of the corresponding dimensions
314- (`batch_size`, `n_heads`, `seq_len `, `d_head`) in the query.
315- Stride `n` denote the stride on `seq_len ` of key.
373+ (`batch_size`, `n_heads`, `q_seq_len `, `d_head`) in the query.
374+ Stride `n` denote the stride on `kv_seq_len ` of key.
316375 """
317376
377+ # We are computing the attention for $O_i$ for `i` ... `i + BLOCK_Q' in batch/head combination $z$.
318378 i = tl .program_id (0 )
319379 z = tl .program_id (1 ) // n_groups
320- g = tl .program_id (1 ) % n_groups # TODO
380+ g = tl .program_id (1 ) % n_groups
321381
322- # Create block pointers
382+ # #### Create block pointers
323383 p_q = tl .make_block_ptr (t_q + z * n_groups * q_seq_len * d_head + g * q_seq_len * d_head ,
324384 (q_seq_len , d_head ),
325385 (d_head , 1 ),
@@ -354,6 +414,7 @@ def _attn_fwd(t_q, t_k, t_v, sm_scale_log2e, t_lse, t_o,
354414 # Initialize offsets
355415 offs_i = i * BLOCK_Q + tl .arange (0 , BLOCK_Q )
356416 offs_j = tl .arange (0 , BLOCK_K )
417+
357418 # Mask for $Q$ for the last block
358419 i_mask = offs_i < q_seq_len
359420
@@ -427,6 +488,12 @@ def _attn_fwd_inner(b_o, b_l, b_m, b_q,
427488 q_seq_len : tl .constexpr ,
428489 kv_seq_len : tl .constexpr
429490 ):
491+ """
492+ #### Inner loop to calculate $O_i$
493+
494+ This iterates through keys and values starting from `j` for `steps` number of steps.
495+ In each step it processes `BLOCK_K` entries of keys/values.
496+ """
430497 tl .static_assert (BLOCK_Q % BLOCK_K == 0 )
431498
432499 # Move $K_j$ and $V_j$ pointers
@@ -492,6 +559,9 @@ def _attn_bwd_d(t_o, t_do,
492559 q_seq_len : tl .constexpr ,
493560 n_groups : tl .constexpr ,
494561 ):
562+ """
563+ #### Triton kernel to compute $D_i$
564+ """
495565 i = tl .program_id (0 ) * BLOCK_Q
496566 z = tl .program_id (1 )
497567
@@ -539,9 +609,10 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
539609 BLOCK_K : tl .constexpr ,
540610 ):
541611 """
542- Compute $dK_j$ and $dV_j$ for $j1 \dots j2$ by iterating over $Q_i $
612+ #### Triton kernel to compute $dK_j$ and $dV_j $
543613 """
544614
615+ # Compute $dK_j$ and $dV_j$ for `j` ... `j + BLOCK_K` by iterating over $Q_i$
545616 j = tl .program_id (0 ) * BLOCK_K
546617 z = tl .program_id (1 )
547618
@@ -623,7 +694,7 @@ def _attn_bwd_dkdv(t_q, t_k, t_v, sm_scale,
623694 kv_seq_len = kv_seq_len ,
624695 )
625696
626- # Innerloop on queries after the diagonal
697+ # Inner loop on queries after the diagonal
627698 b_dk , b_dv = _attn_bwd_dkdv_inner (
628699 b_dk , b_dv ,
629700 p_qT , b_k , b_v , p_do ,
@@ -671,7 +742,9 @@ def _attn_bwd_dkdv_inner(b_dk, b_dv,
671742 MASK : tl .constexpr ,
672743 q_seq_len : tl .constexpr ,
673744 kv_seq_len : tl .constexpr ):
674- """Inner loop along query"""
745+ """
746+ #### Inner loop to calculate $dK_j$, $dV_j$
747+ """
675748
676749 # To apply the mask
677750 tl .static_assert (BLOCK_K % BLOCK_Q == 0 )
@@ -755,6 +828,10 @@ def _attn_bwd_dq(t_q, t_k, t_v, t_do,
755828 BLOCK_Q : tl .constexpr ,
756829 BLOCK_K : tl .constexpr ,
757830 ):
831+ """
832+ #### Triton kernel to compute $dQ_i$
833+ """
834+
758835 i = tl .program_id (0 ) * BLOCK_Q
759836 z = tl .program_id (1 ) // n_groups
760837 g = tl .program_id (1 ) % n_groups # TODO
@@ -863,7 +940,9 @@ def _attn_bwd_dq_inner(b_dq, b_q, p_kT, p_vT,
863940 MASK : tl .constexpr ,
864941 q_seq_len : tl .constexpr ,
865942 kv_seq_len : tl .constexpr ):
866- """Inner loop over key"""
943+ """
944+ #### Inner loop to calculate $dQ_i$
945+ """
867946
868947 # Offsets
869948 offs_i = i + tl .arange (0 , BLOCK_Q )
0 commit comments