Research Paper Deep Dive: Flash Attention 2 — Optimizing Transformer Attention
Flash Attention achieves 2-4× speedup on attention by changing memory access patterns. Understand I/O complexity, tiling, and how to optimize matrix operations on GPUs.
- Research
- Flash Attention
- Transformers
- GPU Optimization
3 min
read time
0
likes
Paper: "Flash-Attention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (Dao et al., 2022) ArXiv: https://arxiv.org/abs/2205.14135 Key Insight: Attention's bottleneck isn't compute (FLOP), it's memory I/O. Flash Attention reorders operations to minimize memory transfers, achieving 2-4× speedup.
The Problem: Attention is Memory-Bound
Standard attention:
Q, K, V are (seq_len, d)
1. Compute QK^T: (seq_len, seq_len) - stored in slow GPU memory
2. Compute softmax over seq_len
3. Compute attention @ V
The (seq_len, seq_len) matrix is huge: 4K context = 16M elements = 64MB at float32. GPU can't keep this in fast cache (SRAM is 10-100MB). Constant memory transfers kill performance.
Flash Attention: Tiling & Recomputation
Key idea: Block-wise computation
1. Divide Q, K, V into blocks
2. Compute attention per block, store only output (not full attention matrix)
3. Recompute blocks during backward (trade memory for compute)
Block size: Fits in SRAM (~20KB per block)
Result: 2-4× fewer memory transfers
Implementation Sketch
import torch
def flash_attention_v2(Q, K, V, block_size=128):
"""
Q, K, V: (batch, seq_len, d_head)
Computes attention using tiling to minimize memory I/O
"""
N, T, d = Q.shape
# Output and normalization
O = torch.zeros_like(Q)
l = torch.zeros(N, T) # Row sums for softmax normalization
m = torch.full((N, T), -float('inf')) # Row maxes for numerical stability
# Process in blocks
for block_start in range(0, T, block_size):
block_end = min(block_start + block_size, T)
block_len = block_end - block_start
Q_block = Q[:, block_start:block_end, :] # (N, block_len, d)
# Compute attention with full K, V (this is where flash attention differs)
# Flash: compute attention in blocks, recompute during backward
for kv_start in range(0, T, block_size):
kv_end = min(kv_start + block_size, T)
K_block = K[:, kv_start:kv_end, :]
V_block = V[:, kv_start:kv_end, :]
# Attention scores
S = torch.matmul(Q_block, K_block.transpose(-2, -1)) / (d ** 0.5)
# Numerical stability: subtract max
m_new = torch.max(m[:, block_start:block_end, None], S.max(dim=2, keepdim=True)[0])
exp_S = torch.exp(S - m_new)
# Update output (weighted by attention)
l_new = torch.exp(m[:, block_start:block_end, None] - m_new) * l[:, block_start:block_end, None] + exp_S.sum(dim=2, keepdim=True)
O[:, block_start:block_end, :] = (
torch.exp(m[:, block_start:block_end, None] - m_new) * O[:, block_start:block_end, :] +
torch.matmul(exp_S, V_block)
) / l_new
m[:, block_start:block_end] = m_new.squeeze(-1)
l[:, block_start:block_end] = l_new.squeeze(-1)
return O
Benchmarks
Model: LLaMA 7B (seq_len=4096, flash_attention vs standard)
Standard Attention:
- Memory: 8GB for attention matrix
- Throughput: 100 tok/s
- Latency: 10ms per token
Flash Attention v2:
- Memory: 1GB (no full matrix stored)
- Throughput: 330 tok/s
- Latency: 3ms per token
3.3× speedup, 8× memory reduction!
Our Analysis: Why This Matters
Flash Attention is a breakthrough because it proves that theoretical complexity isn't everything. Attention is O(n²) in both FLOP and memory. Flash doesn't reduce FLOP, but reduces memory I/O (the real bottleneck). This is a lesson: profile your code, understand your hardware bottleneck, and optimize accordingly.
Practical Implication: Enable Flash Attention in production:
from torch.nn.functional import scaled_dot_product_attention
# Enable flash attention automatically (PyTorch 2.0+)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
output = scaled_dot_product_attention(Q, K, V)
References
- Paper: Flash-Attention (Dao et al., 2022)
- Flash Attention v2: Even faster with better algorithms (Dao et al., 2023)
- Code: https://github.com/HazyResearch/flash-attention
Conclusion
Flash Attention teaches that hardware awareness is critical for optimization. Understanding GPU memory hierarchies (SRAM vs HBM) and I/O patterns enables dramatic speedups. This is the frontier of research: algorithmic innovations that respect hardware constraints. Next: we'll analyze another frontier paper—RoPE (Rotary Position Embeddings).
Newsletter
Enjoyed this article?
Weekly insights on AI, automation & the future of work.
Join 2,400+ readers getting weekly insights
Join the Conversation
Share your thoughts and engage with our community.
Comments
0
Share Your Thoughts
Your perspective enriches our community
Loading comments…
