Guide
Flash Attention explained
Your fine-tuning job dies at sequence length 8,192 with an out-of-memory error — even
though the model weights fit comfortably on an A100. The culprit is not parameter count;
it is the attention matrix. Standard scaled dot-product attention
materializes an n × n score tensor for every head and layer, flooding
GPU high-bandwidth memory (HBM) with reads and writes. Flash Attention
(Dao et al., 2022) is an IO-aware CUDA kernel that never stores the full
attention matrix. It tiles queries, keys, and values through fast on-chip SRAM, fuses
softmax with the matrix multiply, and recomputes values during the backward pass instead
of checkpointing gigabytes of intermediates. The result: 2–4× wall-clock
speedups and dramatically lower memory use at long contexts — without changing the
mathematical output of attention. This guide explains why attention is memory-bound,
how tiling and online softmax work, FlashAttention-2 parallelism improvements,
integration in PyTorch and serving engines, a Harbor LLM deployment worked example,
a decision table, pitfalls, and a production checklist. For attention intuition, see
attention mechanism explained;
for where attention sits in the full stack, see
transformer architecture explained.
Why standard attention hurts
Given query matrix Q, key matrix K, and value matrix
V (each shaped [sequence_length, head_dim] per head), vanilla
attention computes:
Attention(Q, K, V) = softmax(QKT / √d) · V
The intermediate QKT tensor has shape
[n, n] per head. For a 32-layer, 32-head model at n = 8192
in FP16, storing every attention score map alone can exceed tens of gigabytes across
layers — before gradients, optimizer states, or the
KV cache
during inference. FLOPs scale as O(n²), but on modern GPUs the
bottleneck is often HBM bandwidth: repeatedly reading and writing
those huge tensors to slow off-chip memory while fast SRAM sits underutilized.
Training vs inference bottlenecks
During training, you need both forward activations (or recomputation
strategy) and backward gradients through attention. Materializing QKT
for backprop is expensive. During inference prefill, you compute
attention over the entire prompt at once — again quadratic in prompt length. Decode
steps are cheaper per token (one new query against cached keys), but prefill for long
documents still benefits from fused attention kernels. Flash Attention addresses the
prefill and training regimes where full quadratic attention maps would otherwise
dominate memory and wall time.
IO-aware tiling and online softmax
Flash Attention’s core idea is fusion with tiling. Instead of
computing the full QKT matrix in HBM, the kernel loads
blocks of Q, K, and V into
SRAM (shared memory / L1 on the GPU), computes partial attention outputs for that
tile, and accumulates results incrementally. The trick is online
softmax: softmax normally needs the full row of logits to subtract the max
for numerical stability. Flash Attention tracks running row maxima and normalization
sums as tiles arrive, updating the partial output without ever seeing the complete row
at once. The final per-head output is mathematically identical to
standard attention — this is not an approximation like low-rank attention; it is an
exact reordering of the same operations.
Recomputation in the backward pass
To save memory during training, Flash Attention does not store the full attention matrix for backpropagation. Instead, it recomputes the forward attention blocks during the backward pass — trading extra compute for vastly less HBM traffic. On bandwidth-bound workloads this trade wins: recomputing on fast SRAM is cheaper than reading/writing giant tensors from HBM. This is the same memory/compute exchange principle behind gradient checkpointing, but applied at the kernel level inside attention specifically.
FlashAttention-2
The second revision improves work partitioning across warps and thread blocks, reduces non-matmul FLOPs in the inner loop, and sequences operations so more time is spent in Tensor Core matrix multiplies. Reported speedups over FlashAttention-1 are often another ~2× on A100-class hardware. FlashAttention-3 targets Hopper (H100) with FP8 and asynchronous memory copies. In practice, check which version your framework and GPU generation actually ship — the naming moves faster than blog posts.
Memory and speed: what changes in practice
Standard attention memory for the score matrix scales
O(n²) per layer per head (ignoring batch for a moment). Flash
Attention reduces memory to O(n) in sequence length for
attention intermediates, because only tile-sized buffers live in HBM at once. That
is what unlocks longer training sequences and larger batch sizes on fixed GPU VRAM.
Wall-clock speedups come from fewer bytes moved across the PCIe/HBM
boundary and better occupancy on Tensor Cores — not from asymptotically fewer FLOPs.
Short sequences (e.g. n < 512) may see modest gains; very long
contexts (4k, 8k, 32k+) are where Flash Attention routinely turns an OOM into a
finished job. Always benchmark your actual sequence-length distribution; micro-benchmarks
at n = 128 mislead.
| Regime | Standard attention pain | Flash Attention benefit |
|---|---|---|
| Long-context fine-tuning | OOM on attention maps | Fits longer n or larger batch |
| Prefill on 8k+ prompts | Slow, memory-spiky | Lower latency first token |
| Short chat turns (<1k) | Already fast | Marginal; other ops dominate |
| Multi-head + many layers | Memory multiplies per layer | Compounds savings across depth |
Where you get Flash Attention today
You rarely call Flash Attention directly. It ships inside framework and serving integrations:
- PyTorch 2.x —
torch.nn.functional.scaled_dot_product_attention(SDPA) dispatches to Flash, memory-efficient, or math backends based on dtype, head dimension, GPU capability, and mask type. Settorch.backends.cuda.sdp_kernelflags to prefer Flash when debugging. - Hugging Face Transformers —
attn_implementation="sdpa"or"flash_attention_2"on supported models (Llama, Mistral, etc.). - vLLM / TensorRT-LLM / TGI — production model serving stacks bundle fused attention in their CUDA graphs for prefill and decode.
- Standalone
flash-attnpackage — pip-installable kernels for custom models; requires compatible NVIDIA GPU, correct CUDA/PyTorch builds, and often head dims divisible by 8.
Compatibility constraints
Flash kernels impose practical limits: FP16/BF16 are standard; FP32 may fall back to slower paths. Head dimensions above 128 or odd sizes can disable fusion. Causal masking, dropout, and attention bias support improved across versions but still vary — verify against your model card. AMD ROCm builds and Apple Metal have separate fused-attention paths (not always called “Flash Attention” but solving the same problem).
Worked example: Harbor Support LLM prefill tuning
Harbor Support runs a 7B instruct model on one A10G (24 GB) to draft ticket replies. Median prompts are 900 tokens, but 5% of enterprise tickets attach 6–12k-token log dumps. Before optimization, prefill on a 10k-token attachment spiked VRAM past 22 GB and triggered batching collapses in the serving queue.
- Baseline profiling — PyTorch profiler showed 68% of prefill time
in
aten::scaled_dot_product_attentionwith the math backend (Flash disabled due to an older Transformers pin). - Enable SDPA Flash — Upgraded to Transformers 4.40+, set
attn_implementation="sdpa", confirmedtorch.backends.cuda.flash_sdp_enabled()is True on A10G. - Chunk long attachments — RAG still retrieves relevant log sections; full 10k prefill dropped to ~3.2k effective tokens for 95% of SLA-bound replies (product change, not kernel — but pairs with headroom Flash opened).
- Batch size recovery — With attention memory down, continuous batching raised from 4 to 7 concurrent prefills without OOM; P95 prefill latency fell 41% on >4k-token inputs.
- Regression eval — Same held-out ticket set; ROUGE-L and human spot-checks unchanged — confirming exact attention, not approximate kernels.
Pair with LLM cost optimization (smaller batches were costing GPU idle time) and prompt caching for repeated system prefixes.
Decision table: which attention path to use
| Scenario | Recommended path | Why |
|---|---|---|
| Training/fine-tuning LLMs on NVIDIA | FlashAttention-2 via SDPA or flash_attn | Memory-linear attention; longer n |
| Production vLLM/TGI serving | Engine default fused kernels | Already integrated in CUDA graphs |
| Custom attention (alibi, exotic masks) | Verify mask support; may need math fallback | Not every mask ships in fused kernels |
| CPU / edge inference | Standard or xFormers CPU paths | Flash is GPU-specific |
| Need attention weights exported | Math attention (no fusion) | Fused kernels skip materializing weights |
| Very short sequences | Either; profile first | Overhead may dominate |
Common pitfalls
- Assuming Flash is on — SDPA silently falls back to the math backend when dtype, head size, or mask constraints fail. Log which backend ran.
- Chasing FLOPs instead of bytes — Recomputation adds compute; Flash wins when HBM traffic was the bottleneck, not always on compute-bound shapes.
- Ignoring prefill vs decode — KV-cache decode is a different kernel path; Flash Attention headlines apply most to prefill/training.
- Version skew — CUDA, PyTorch, and
flash-attnwheels must match; pip install errors are common on mismatched driver stacks. - Needing attention maps for interpretability — Fused paths do not return full softmax weights; use hooks or a separate interpretability pass.
- Benchmarking at wrong
n— Gains scale with sequence length; test your production percentile, not toy lengths.
Production checklist
- Profile whether attention is memory-bound (HBM near limit during prefill).
- Confirm GPU architecture supports Flash / SDPA fused path for your dtype and head dim.
- Enable SDPA or
flash_attention_2in training and serving configs. - Log attention backend selection in staging (Flash vs mem_efficient vs math).
- Benchmark P50/P95 prefill latency at 50th and 95th percentile sequence lengths.
- Run quality regression after enabling — output should match exact attention.
- Document fallback if exotic masks or head dims block fusion.
- Align CUDA, PyTorch, and optional
flash-attnversions in lockstep.
Key takeaways
- Attention is often memory-bandwidth bound; the
n × nscore matrix is the usual VRAM killer. - Flash Attention tiles Q/K/V through SRAM, uses online softmax, and recomputes in backward — exact, not approximate.
- Memory scales linearly in
nfor attention intermediates; speedups grow with context length. - PyTorch SDPA, Transformers, and serving engines expose Flash without writing custom CUDA.
- Verify the fused kernel actually runs; silent fallback to math attention erases expected gains.
Related reading
- Attention mechanism explained — Q/K/V intuition, scaled dot-product, multi-head attention
- Transformer architecture explained — encoder-decoder stacks, positional encoding, FFN blocks
- LLM KV cache explained — prefill vs decode memory and PagedAttention
- Model serving explained — batching, GPU runtimes, and production latency