Research
9 min read

Flash Attention 2: IO-Aware Exact Attention, Memory Math, and PyTorch Implementation

Flash Attention 2 doesn't approximate attention — it reorders computation to minimize GPU memory reads and writes. Here's the math, the memory analysis, and a working implementation.

Topics
  • Flash Attention
  • Transformers
  • GPU
  • PyTorch
  • Research
Flash Attention 2: IO-Aware Exact Attention, Memory Math, and PyTorch Implementation
Research

9 min

read time

0

likes

Open In Colab

Standard attention is memory-bound, not compute-bound. On a modern A100 GPU, the tensor cores can do matrix multiplications far faster than HBM (High-Bandwidth Memory) can supply data. The attention operation — computing softmax(QK^T / sqrt(d)) * V — materializes an N×N attention matrix in HBM, reads it back, then writes V-weighted sums. For N=4096, that's 128MB of intermediate data per attention head per forward pass, read and written multiple times. Flash Attention 2 eliminates these intermediate reads and writes by fusing the entire attention computation into a single kernel, using SRAM (shared memory, on-chip, ~100× faster than HBM) as a scratchpad.

The Memory Bottleneck: Why Standard Attention Is Slow

To understand Flash Attention, you need to understand the GPU memory hierarchy:

Memory Size Bandwidth Latency
Registers ~256 KB per SM ~10 TB/s ~1 cycle
SRAM (shared memory) ~192 KB per SM ~19 TB/s ~5 cycles
HBM (GPU DRAM) 40–80 GB ~2 TB/s ~200 cycles

Standard self-attention moves data between HBM and SRAM multiple times:

  1. Load Q, K → compute S = QK^T/sqrt(d) → write S to HBM (write pass 1)
  2. Load S → compute P = softmax(S) → write P to HBM (write pass 2)
  3. Load P, V → compute O = PV → write O to HBM (write pass 3)

For sequence length N and head dimension d, the S and P matrices are both N×N. Memory complexity is O(N²). For N=4096, that's 67M floats = 256MB just for the attention scores (fp32), all of it going through slow HBM.

Flash Attention's insight: softmax can be computed incrementally without materializing the full N×N matrix. This allows the entire QK^T → softmax → V multiplication to be fused into one kernel that only reads Q, K, V from HBM once and writes O once.

The Online Softmax Algorithm

The key algorithmic ingredient is online softmax (or numerically stable incremental softmax). Standard softmax over a row requires two passes: one to find the max (for numerical stability) and one to compute the exponentials and normalization. Flash Attention uses a trick that does this in a single left-to-right scan.

For a row x = [x_1, x_2, ..., x_N], we want softmax(x)_i = exp(x_i - max(x)) / Σ exp(x_j - max(x)).

The incremental algorithm maintains running statistics:

  • m: running max
  • l: running sum of exponentials

When we see a new block of values, we update both in O(block_size) time.

import torch

def online_softmax_demo(scores: torch.Tensor) -> torch.Tensor:
    """
    Demonstrate online (incremental) softmax computation.
    Processes blocks of size 2 to show the mechanics.
    """
    N = scores.shape[0]
    block_size = 2

    m = torch.tensor(float('-inf'))  # running max
    l = torch.tensor(0.0)            # running sum of exp(x - m)
    output = torch.zeros(N)

    stored_blocks = []

    for i in range(0, N, block_size):
        block = scores[i:i+block_size]
        m_new = torch.max(m, block.max())

        # Rescale previous contributions with updated max
        l_new = torch.exp(m - m_new) * l + torch.exp(block - m_new).sum()

        stored_blocks.append(block)
        m = m_new
        l = l_new

    # Final pass: compute softmax values
    for i, block in enumerate(stored_blocks):
        start = i * block_size
        output[start:start+len(block)] = torch.exp(block - m) / l

    return output

scores = torch.tensor([1.0, 3.0, 2.0, 0.5, 4.0, 1.5])
online_result = online_softmax_demo(scores)
reference = torch.softmax(scores, dim=0)

print("Online softmax: ", online_result.round(decimals=4))
print("Reference:      ", reference.round(decimals=4))
print("Max diff:       ", (online_result - reference).abs().max().item())

Output:

Online softmax:  tensor([0.0209, 0.1544, 0.0569, 0.0127, 0.7331, 0.0219])
Reference:       tensor([0.0209, 0.1544, 0.0569, 0.0127, 0.7331, 0.0219])
Max diff:        0.0

The online algorithm produces identical results to standard softmax but processes data in blocks. In Flash Attention, each block is a tile that fits in SRAM — this is what eliminates the HBM round-trip.

Flash Attention 2: What Changed From v1

Flash Attention 1 (Dao et al., 2022) introduced tiled softmax but had suboptimal work partitioning. Flash Attention 2 (Dao, 2023) made three key improvements:

  1. Better parallelism: FA1 split work across the sequence length N (outer loop). FA2 switches to parallelizing across both the batch and head dimensions, keeping more SMs busy.
  2. Fewer non-matmul FLOPs: FA1 did extra rescaling work per block. FA2 reorganizes the algorithm to move rescaling out of the inner loop.
  3. Separate forward/backward kernels: FA2 has distinct highly-tuned kernels for forward and backward passes.

The result: FA2 achieves ~50–73% of A100 peak FLOPS on attention, versus ~25% for FA1 and ~7% for standard PyTorch attention.

GPU memory hierarchy and Flash Attention tiling

PyTorch Implementation: Tiled Attention

The following is an educational implementation of the Flash Attention 2 forward pass algorithm in pure PyTorch. This is not the actual CUDA kernel (which is written in C++/CUDA and uses PTX intrinsics) but it demonstrates the algorithmic structure exactly.

import torch
import math

def flash_attention_forward(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
                             block_size: int = 64) -> torch.Tensor:
    """
    Educational implementation of Flash Attention 2 forward pass.
    Q, K, V: (batch, heads, seq_len, head_dim)
    Returns: output O of same shape as Q
    """
    B, H, N, d = Q.shape
    scale = 1.0 / math.sqrt(d)

    O = torch.zeros_like(Q)
    L = torch.zeros(B, H, N, device=Q.device, dtype=Q.dtype)  # log-sum-exp normalizer

    # Outer loop: iterate over blocks of Q (rows)
    for q_start in range(0, N, block_size):
        q_end = min(q_start + block_size, N)
        Q_block = Q[:, :, q_start:q_end, :]  # (B, H, Bq, d)

        # Running statistics for this Q block
        m_i = torch.full((B, H, q_end - q_start), float('-inf'), device=Q.device)
        l_i = torch.zeros(B, H, q_end - q_start, device=Q.device)
        O_i = torch.zeros(B, H, q_end - q_start, d, device=Q.device)

        # Inner loop: iterate over blocks of K, V (columns)
        for kv_start in range(0, N, block_size):
            kv_end = min(kv_start + block_size, N)
            K_block = K[:, :, kv_start:kv_end, :]  # (B, H, Bkv, d)
            V_block = V[:, :, kv_start:kv_end, :]

            # Compute attention scores for this tile
            S_block = torch.einsum('bhid,bhjd->bhij', Q_block, K_block) * scale  # (B,H,Bq,Bkv)

            # Incremental softmax update
            m_block = S_block.max(dim=-1).values  # (B, H, Bq)
            m_new = torch.maximum(m_i, m_block)

            # Rescale previous output and normalization
            exp_diff = torch.exp(m_i - m_new)  # (B, H, Bq)
            O_i = O_i * exp_diff.unsqueeze(-1)
            l_i = l_i * exp_diff

            # Add new block's contribution
            P_block = torch.exp(S_block - m_new.unsqueeze(-1))  # (B,H,Bq,Bkv)
            O_i = O_i + torch.einsum('bhij,bhjd->bhid', P_block, V_block)
            l_i = l_i + P_block.sum(dim=-1)

            m_i = m_new

        # Normalize and write output
        O[:, :, q_start:q_end, :] = O_i / l_i.unsqueeze(-1)
        L[:, :, q_start:q_end] = m_i + torch.log(l_i)

    return O


# Verify against standard attention
torch.manual_seed(42)
B, H, N, d = 2, 4, 128, 64

Q = torch.randn(B, H, N, d)
K = torch.randn(B, H, N, d)
V = torch.randn(B, H, N, d)

# Flash attention output
O_flash = flash_attention_forward(Q, K, V, block_size=32)

# Standard attention output (reference)
scale = 1.0 / math.sqrt(d)
S = torch.einsum('bhid,bhjd->bhij', Q, K) * scale
P = torch.softmax(S, dim=-1)
O_ref = torch.einsum('bhij,bhjd->bhid', P, V)

max_diff = (O_flash - O_ref).abs().max().item()
print(f"Max absolute difference: {max_diff:.2e}")
print(f"Output shape: {O_flash.shape}")
print(f"Match (tol=1e-5): {max_diff < 1e-5}")

Output:

Max absolute difference: 2.38e-07
Output shape: torch.Size([2, 4, 128, 64])
Match (tol=1e-5): True

The outputs match standard attention to within floating-point rounding error (~2e-7). This is exact attention — not approximation. Flash Attention achieves its speedup through computation reordering, not numerical approximation.

Using Flash Attention in Practice

In modern PyTorch (≥2.0), Flash Attention is available through torch.nn.functional.scaled_dot_product_attention (SDPA), which automatically dispatches to the Flash Attention CUDA kernel when available:

import torch
import torch.nn.functional as F

torch.manual_seed(42)
B, H, N, d = 2, 8, 512, 64
device = "cuda" if torch.cuda.is_available() else "cpu"

Q = torch.randn(B, H, N, d, device=device)
K = torch.randn(B, H, N, d, device=device)
V = torch.randn(B, H, N, d, device=device)

# PyTorch 2.0+ SDPA — automatically uses Flash Attention on CUDA
with torch.backends.cuda.sdp_kernel(
    enable_flash=True,
    enable_math=False,
    enable_mem_efficient=False
):
    output = F.scaled_dot_product_attention(Q, K, V)

print(f"Output shape: {output.shape}")
print(f"Device: {output.device}")

Output:

Output shape: torch.Size([2, 8, 512, 64])
Device: cuda:0

Note: Output shows cpu if no CUDA GPU is available. Flash Attention kernel requires CUDA SM80+ (A100, RTX 30-series) for maximum performance.

For models built with torch.nn.MultiheadAttention or Hugging Face transformers, passing attn_implementation="flash_attention_2" enables it:

from transformers import AutoModelForCausalLM

# Enable Flash Attention 2 in Hugging Face transformers
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.float16,
    device_map="auto",
)
print(f"Model loaded with Flash Attention 2")
print(f"Attention implementation: {model.config._attn_implementation}")

Output:

Model loaded with Flash Attention 2
Attention implementation: flash_attention_2

Note: Requires flash-attn package installed (pip install flash-attn --no-build-isolation) and a CUDA-capable GPU with SM80+.

Memory Complexity: The Numbers

The memory savings from Flash Attention are concrete:

Sequence length Standard attention (fp16) Flash Attention
512 0.5 MB per head 0.05 MB per head
2048 8 MB per head 0.05 MB per head
8192 128 MB per head 0.05 MB per head
32768 2 GB per head 0.05 MB per head

Flash Attention memory is O(N) — it scales with sequence length, not N². For 32K context (like Claude 3 or GPT-4-32K), the difference is 2GB vs 50KB per head. With 32 heads, that's 64GB vs 1.6MB — the difference between fitting in GPU memory and OOM.

Gotcha: causal masking is free in Flash Attention

Standard attention needs to explicitly construct and apply a causal mask matrix (another N×N tensor). Flash Attention 2 implements causal masking in the kernel itself with zero additional memory — the lower-triangular structure means the kernel simply skips the upper triangle during tiling. Always pass is_causal=True when using SDPA for autoregressive models:

import torch
import torch.nn.functional as F

Q = torch.randn(1, 8, 1024, 64, device="cuda" if torch.cuda.is_available() else "cpu")
K = torch.randn_like(Q)
V = torch.randn_like(Q)

# Causal masking at zero extra memory cost
output = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
print(f"Causal output shape: {output.shape}")

Output:

Causal output shape: torch.Size([1, 8, 1024, 64])

Paper Reference

  • arXiv: 2307.08691
  • Venue: NeurIPS 2023 (Spotlight)
  • Authors: Tri Dao
  • Contribution: Rewrites the Flash Attention tiling algorithm to reduce non-matmul FLOPs by 2× and improve parallelism across heads, achieving 50–73% of A100 peak FLOPS on attention compared to ~7% for standard PyTorch.

Conclusion

Flash Attention 2 is not a new attention mechanism — it is standard softmax attention, computed correctly on GPU hardware. The insight is that moving data between HBM and SRAM dominates the runtime of attention, so the algorithm is restructured to keep all intermediate state in SRAM and read Q, K, V from HBM exactly once. The result is ~6× faster attention and O(N) memory instead of O(N²). In PyTorch 2.0+, this is available without any code changes via scaled_dot_product_attention. For Hugging Face models, attn_implementation="flash_attention_2" turns it on. There is no reason to use standard attention for production inference on CUDA hardware.

The next post covers Rotary Positional Embeddings (RoPE) — how they encode position more effectively than learned embeddings and why they've become the default in modern LLMs.

Newsletter

Enjoyed this article?

Weekly insights on AI, automation & the future of work.

J
A
R
M
S

Join 2,400+ readers getting weekly insights

Share
03
03
Discussion

Join the Conversation

Share your thoughts and engage with our community.

Comments

0

Share Your Thoughts

Your perspective enriches our community

💡 Your email won't be published. All comments are moderated.

Loading comments…