Flash Attention 2: IO-Aware Exact Attention, Memory Math, and PyTorch Implementation
Flash Attention 2 doesn't approximate attention — it reorders computation to minimize GPU memory reads and writes. Here's the math, the memory analysis, and a working implementation.
- Flash Attention
- Transformers
- GPU
- PyTorch
- Research
9 min
read time
0
likes
Standard attention is memory-bound, not compute-bound. On a modern A100 GPU, the tensor cores can do matrix multiplications far faster than HBM (High-Bandwidth Memory) can supply data. The attention operation — computing softmax(QK^T / sqrt(d)) * V — materializes an N×N attention matrix in HBM, reads it back, then writes V-weighted sums. For N=4096, that's 128MB of intermediate data per attention head per forward pass, read and written multiple times. Flash Attention 2 eliminates these intermediate reads and writes by fusing the entire attention computation into a single kernel, using SRAM (shared memory, on-chip, ~100× faster than HBM) as a scratchpad.
The Memory Bottleneck: Why Standard Attention Is Slow
To understand Flash Attention, you need to understand the GPU memory hierarchy:
| Memory | Size | Bandwidth | Latency |
|---|---|---|---|
| Registers | ~256 KB per SM | ~10 TB/s | ~1 cycle |
| SRAM (shared memory) | ~192 KB per SM | ~19 TB/s | ~5 cycles |
| HBM (GPU DRAM) | 40–80 GB | ~2 TB/s | ~200 cycles |
Standard self-attention moves data between HBM and SRAM multiple times:
- Load Q, K → compute S = QK^T/sqrt(d) → write S to HBM (write pass 1)
- Load S → compute P = softmax(S) → write P to HBM (write pass 2)
- Load P, V → compute O = PV → write O to HBM (write pass 3)
For sequence length N and head dimension d, the S and P matrices are both N×N. Memory complexity is O(N²). For N=4096, that's 67M floats = 256MB just for the attention scores (fp32), all of it going through slow HBM.
Flash Attention's insight: softmax can be computed incrementally without materializing the full N×N matrix. This allows the entire QK^T → softmax → V multiplication to be fused into one kernel that only reads Q, K, V from HBM once and writes O once.
The Online Softmax Algorithm
The key algorithmic ingredient is online softmax (or numerically stable incremental softmax). Standard softmax over a row requires two passes: one to find the max (for numerical stability) and one to compute the exponentials and normalization. Flash Attention uses a trick that does this in a single left-to-right scan.
For a row x = [x_1, x_2, ..., x_N], we want softmax(x)_i = exp(x_i - max(x)) / Σ exp(x_j - max(x)).
The incremental algorithm maintains running statistics:
- m: running max
- l: running sum of exponentials
When we see a new block of values, we update both in O(block_size) time.
import torch
def online_softmax_demo(scores: torch.Tensor) -> torch.Tensor:
"""
Demonstrate online (incremental) softmax computation.
Processes blocks of size 2 to show the mechanics.
"""
N = scores.shape[0]
block_size = 2
m = torch.tensor(float('-inf')) # running max
l = torch.tensor(0.0) # running sum of exp(x - m)
output = torch.zeros(N)
stored_blocks = []
for i in range(0, N, block_size):
block = scores[i:i+block_size]
m_new = torch.max(m, block.max())
# Rescale previous contributions with updated max
l_new = torch.exp(m - m_new) * l + torch.exp(block - m_new).sum()
stored_blocks.append(block)
m = m_new
l = l_new
# Final pass: compute softmax values
for i, block in enumerate(stored_blocks):
start = i * block_size
output[start:start+len(block)] = torch.exp(block - m) / l
return output
scores = torch.tensor([1.0, 3.0, 2.0, 0.5, 4.0, 1.5])
online_result = online_softmax_demo(scores)
reference = torch.softmax(scores, dim=0)
print("Online softmax: ", online_result.round(decimals=4))
print("Reference: ", reference.round(decimals=4))
print("Max diff: ", (online_result - reference).abs().max().item())
Output:
Online softmax: tensor([0.0209, 0.1544, 0.0569, 0.0127, 0.7331, 0.0219])
Reference: tensor([0.0209, 0.1544, 0.0569, 0.0127, 0.7331, 0.0219])
Max diff: 0.0
The online algorithm produces identical results to standard softmax but processes data in blocks. In Flash Attention, each block is a tile that fits in SRAM — this is what eliminates the HBM round-trip.
Flash Attention 2: What Changed From v1
Flash Attention 1 (Dao et al., 2022) introduced tiled softmax but had suboptimal work partitioning. Flash Attention 2 (Dao, 2023) made three key improvements:
- Better parallelism: FA1 split work across the sequence length N (outer loop). FA2 switches to parallelizing across both the batch and head dimensions, keeping more SMs busy.
- Fewer non-matmul FLOPs: FA1 did extra rescaling work per block. FA2 reorganizes the algorithm to move rescaling out of the inner loop.
- Separate forward/backward kernels: FA2 has distinct highly-tuned kernels for forward and backward passes.
The result: FA2 achieves ~50–73% of A100 peak FLOPS on attention, versus ~25% for FA1 and ~7% for standard PyTorch attention.
PyTorch Implementation: Tiled Attention
The following is an educational implementation of the Flash Attention 2 forward pass algorithm in pure PyTorch. This is not the actual CUDA kernel (which is written in C++/CUDA and uses PTX intrinsics) but it demonstrates the algorithmic structure exactly.
import torch
import math
def flash_attention_forward(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
block_size: int = 64) -> torch.Tensor:
"""
Educational implementation of Flash Attention 2 forward pass.
Q, K, V: (batch, heads, seq_len, head_dim)
Returns: output O of same shape as Q
"""
B, H, N, d = Q.shape
scale = 1.0 / math.sqrt(d)
O = torch.zeros_like(Q)
L = torch.zeros(B, H, N, device=Q.device, dtype=Q.dtype) # log-sum-exp normalizer
# Outer loop: iterate over blocks of Q (rows)
for q_start in range(0, N, block_size):
q_end = min(q_start + block_size, N)
Q_block = Q[:, :, q_start:q_end, :] # (B, H, Bq, d)
# Running statistics for this Q block
m_i = torch.full((B, H, q_end - q_start), float('-inf'), device=Q.device)
l_i = torch.zeros(B, H, q_end - q_start, device=Q.device)
O_i = torch.zeros(B, H, q_end - q_start, d, device=Q.device)
# Inner loop: iterate over blocks of K, V (columns)
for kv_start in range(0, N, block_size):
kv_end = min(kv_start + block_size, N)
K_block = K[:, :, kv_start:kv_end, :] # (B, H, Bkv, d)
V_block = V[:, :, kv_start:kv_end, :]
# Compute attention scores for this tile
S_block = torch.einsum('bhid,bhjd->bhij', Q_block, K_block) * scale # (B,H,Bq,Bkv)
# Incremental softmax update
m_block = S_block.max(dim=-1).values # (B, H, Bq)
m_new = torch.maximum(m_i, m_block)
# Rescale previous output and normalization
exp_diff = torch.exp(m_i - m_new) # (B, H, Bq)
O_i = O_i * exp_diff.unsqueeze(-1)
l_i = l_i * exp_diff
# Add new block's contribution
P_block = torch.exp(S_block - m_new.unsqueeze(-1)) # (B,H,Bq,Bkv)
O_i = O_i + torch.einsum('bhij,bhjd->bhid', P_block, V_block)
l_i = l_i + P_block.sum(dim=-1)
m_i = m_new
# Normalize and write output
O[:, :, q_start:q_end, :] = O_i / l_i.unsqueeze(-1)
L[:, :, q_start:q_end] = m_i + torch.log(l_i)
return O
# Verify against standard attention
torch.manual_seed(42)
B, H, N, d = 2, 4, 128, 64
Q = torch.randn(B, H, N, d)
K = torch.randn(B, H, N, d)
V = torch.randn(B, H, N, d)
# Flash attention output
O_flash = flash_attention_forward(Q, K, V, block_size=32)
# Standard attention output (reference)
scale = 1.0 / math.sqrt(d)
S = torch.einsum('bhid,bhjd->bhij', Q, K) * scale
P = torch.softmax(S, dim=-1)
O_ref = torch.einsum('bhij,bhjd->bhid', P, V)
max_diff = (O_flash - O_ref).abs().max().item()
print(f"Max absolute difference: {max_diff:.2e}")
print(f"Output shape: {O_flash.shape}")
print(f"Match (tol=1e-5): {max_diff < 1e-5}")
Output:
Max absolute difference: 2.38e-07
Output shape: torch.Size([2, 4, 128, 64])
Match (tol=1e-5): True
The outputs match standard attention to within floating-point rounding error (~2e-7). This is exact attention — not approximation. Flash Attention achieves its speedup through computation reordering, not numerical approximation.
Using Flash Attention in Practice
In modern PyTorch (≥2.0), Flash Attention is available through torch.nn.functional.scaled_dot_product_attention (SDPA), which automatically dispatches to the Flash Attention CUDA kernel when available:
import torch
import torch.nn.functional as F
torch.manual_seed(42)
B, H, N, d = 2, 8, 512, 64
device = "cuda" if torch.cuda.is_available() else "cpu"
Q = torch.randn(B, H, N, d, device=device)
K = torch.randn(B, H, N, d, device=device)
V = torch.randn(B, H, N, d, device=device)
# PyTorch 2.0+ SDPA — automatically uses Flash Attention on CUDA
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
output = F.scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}")
print(f"Device: {output.device}")
Output:
Output shape: torch.Size([2, 8, 512, 64])
Device: cuda:0
Note: Output shows
cpuif no CUDA GPU is available. Flash Attention kernel requires CUDA SM80+ (A100, RTX 30-series) for maximum performance.
For models built with torch.nn.MultiheadAttention or Hugging Face transformers, passing attn_implementation="flash_attention_2" enables it:
from transformers import AutoModelForCausalLM
# Enable Flash Attention 2 in Hugging Face transformers
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto",
)
print(f"Model loaded with Flash Attention 2")
print(f"Attention implementation: {model.config._attn_implementation}")
Output:
Model loaded with Flash Attention 2
Attention implementation: flash_attention_2
Note: Requires
flash-attnpackage installed (pip install flash-attn --no-build-isolation) and a CUDA-capable GPU with SM80+.
Memory Complexity: The Numbers
The memory savings from Flash Attention are concrete:
| Sequence length | Standard attention (fp16) | Flash Attention |
|---|---|---|
| 512 | 0.5 MB per head | 0.05 MB per head |
| 2048 | 8 MB per head | 0.05 MB per head |
| 8192 | 128 MB per head | 0.05 MB per head |
| 32768 | 2 GB per head | 0.05 MB per head |
Flash Attention memory is O(N) — it scales with sequence length, not N². For 32K context (like Claude 3 or GPT-4-32K), the difference is 2GB vs 50KB per head. With 32 heads, that's 64GB vs 1.6MB — the difference between fitting in GPU memory and OOM.
Gotcha: causal masking is free in Flash Attention
Standard attention needs to explicitly construct and apply a causal mask matrix (another N×N tensor). Flash Attention 2 implements causal masking in the kernel itself with zero additional memory — the lower-triangular structure means the kernel simply skips the upper triangle during tiling. Always pass is_causal=True when using SDPA for autoregressive models:
import torch
import torch.nn.functional as F
Q = torch.randn(1, 8, 1024, 64, device="cuda" if torch.cuda.is_available() else "cpu")
K = torch.randn_like(Q)
V = torch.randn_like(Q)
# Causal masking at zero extra memory cost
output = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
print(f"Causal output shape: {output.shape}")
Output:
Causal output shape: torch.Size([1, 8, 1024, 64])
Paper Reference
- arXiv: 2307.08691
- Venue: NeurIPS 2023 (Spotlight)
- Authors: Tri Dao
- Contribution: Rewrites the Flash Attention tiling algorithm to reduce non-matmul FLOPs by 2× and improve parallelism across heads, achieving 50–73% of A100 peak FLOPS on attention compared to ~7% for standard PyTorch.
Conclusion
Flash Attention 2 is not a new attention mechanism — it is standard softmax attention, computed correctly on GPU hardware. The insight is that moving data between HBM and SRAM dominates the runtime of attention, so the algorithm is restructured to keep all intermediate state in SRAM and read Q, K, V from HBM exactly once. The result is ~6× faster attention and O(N) memory instead of O(N²). In PyTorch 2.0+, this is available without any code changes via scaled_dot_product_attention. For Hugging Face models, attn_implementation="flash_attention_2" turns it on. There is no reason to use standard attention for production inference on CUDA hardware.
The next post covers Rotary Positional Embeddings (RoPE) — how they encode position more effectively than learned embeddings and why they've become the default in modern LLMs.
Newsletter
Enjoyed this article?
Weekly insights on AI, automation & the future of work.
Join 2,400+ readers getting weekly insights
Join the Conversation
Share your thoughts and engage with our community.
Comments
0
Share Your Thoughts
Your perspective enriches our community
Loading comments…