Guide

Flash Attention explained

Your fine-tuning job dies at sequence length 8,192 with an out-of-memory error — even though the model weights fit comfortably on an A100. The culprit is not parameter count; it is the attention matrix. Standard scaled dot-product attention materializes an n × n score tensor for every head and layer, flooding GPU high-bandwidth memory (HBM) with reads and writes. Flash Attention (Dao et al., 2022) is an IO-aware CUDA kernel that never stores the full attention matrix. It tiles queries, keys, and values through fast on-chip SRAM, fuses softmax with the matrix multiply, and recomputes values during the backward pass instead of checkpointing gigabytes of intermediates. The result: 2–4× wall-clock speedups and dramatically lower memory use at long contexts — without changing the mathematical output of attention. This guide explains why attention is memory-bound, how tiling and online softmax work, FlashAttention-2 parallelism improvements, integration in PyTorch and serving engines, a Harbor LLM deployment worked example, a decision table, pitfalls, and a production checklist. For attention intuition, see attention mechanism explained; for where attention sits in the full stack, see transformer architecture explained.

Why standard attention hurts

Given query matrix Q, key matrix K, and value matrix V (each shaped [sequence_length, head_dim] per head), vanilla attention computes:

Attention(Q, K, V) = softmax(QKT / √d) · V

The intermediate QKT tensor has shape [n, n] per head. For a 32-layer, 32-head model at n = 8192 in FP16, storing every attention score map alone can exceed tens of gigabytes across layers — before gradients, optimizer states, or the KV cache during inference. FLOPs scale as O(n²), but on modern GPUs the bottleneck is often HBM bandwidth: repeatedly reading and writing those huge tensors to slow off-chip memory while fast SRAM sits underutilized.

Training vs inference bottlenecks

During training, you need both forward activations (or recomputation strategy) and backward gradients through attention. Materializing QKT for backprop is expensive. During inference prefill, you compute attention over the entire prompt at once — again quadratic in prompt length. Decode steps are cheaper per token (one new query against cached keys), but prefill for long documents still benefits from fused attention kernels. Flash Attention addresses the prefill and training regimes where full quadratic attention maps would otherwise dominate memory and wall time.

IO-aware tiling and online softmax

Flash Attention’s core idea is fusion with tiling. Instead of computing the full QKT matrix in HBM, the kernel loads blocks of Q, K, and V into SRAM (shared memory / L1 on the GPU), computes partial attention outputs for that tile, and accumulates results incrementally. The trick is online softmax: softmax normally needs the full row of logits to subtract the max for numerical stability. Flash Attention tracks running row maxima and normalization sums as tiles arrive, updating the partial output without ever seeing the complete row at once. The final per-head output is mathematically identical to standard attention — this is not an approximation like low-rank attention; it is an exact reordering of the same operations.

Recomputation in the backward pass

To save memory during training, Flash Attention does not store the full attention matrix for backpropagation. Instead, it recomputes the forward attention blocks during the backward pass — trading extra compute for vastly less HBM traffic. On bandwidth-bound workloads this trade wins: recomputing on fast SRAM is cheaper than reading/writing giant tensors from HBM. This is the same memory/compute exchange principle behind gradient checkpointing, but applied at the kernel level inside attention specifically.

FlashAttention-2

The second revision improves work partitioning across warps and thread blocks, reduces non-matmul FLOPs in the inner loop, and sequences operations so more time is spent in Tensor Core matrix multiplies. Reported speedups over FlashAttention-1 are often another ~2× on A100-class hardware. FlashAttention-3 targets Hopper (H100) with FP8 and asynchronous memory copies. In practice, check which version your framework and GPU generation actually ship — the naming moves faster than blog posts.

Memory and speed: what changes in practice

Standard attention memory for the score matrix scales O(n²) per layer per head (ignoring batch for a moment). Flash Attention reduces memory to O(n) in sequence length for attention intermediates, because only tile-sized buffers live in HBM at once. That is what unlocks longer training sequences and larger batch sizes on fixed GPU VRAM.

Wall-clock speedups come from fewer bytes moved across the PCIe/HBM boundary and better occupancy on Tensor Cores — not from asymptotically fewer FLOPs. Short sequences (e.g. n < 512) may see modest gains; very long contexts (4k, 8k, 32k+) are where Flash Attention routinely turns an OOM into a finished job. Always benchmark your actual sequence-length distribution; micro-benchmarks at n = 128 mislead.

RegimeStandard attention painFlash Attention benefit
Long-context fine-tuningOOM on attention mapsFits longer n or larger batch
Prefill on 8k+ promptsSlow, memory-spikyLower latency first token
Short chat turns (<1k)Already fastMarginal; other ops dominate
Multi-head + many layersMemory multiplies per layerCompounds savings across depth

Where you get Flash Attention today

You rarely call Flash Attention directly. It ships inside framework and serving integrations:

  • PyTorch 2.xtorch.nn.functional.scaled_dot_product_attention (SDPA) dispatches to Flash, memory-efficient, or math backends based on dtype, head dimension, GPU capability, and mask type. Set torch.backends.cuda.sdp_kernel flags to prefer Flash when debugging.
  • Hugging Face Transformersattn_implementation="sdpa" or "flash_attention_2" on supported models (Llama, Mistral, etc.).
  • vLLM / TensorRT-LLM / TGI — production model serving stacks bundle fused attention in their CUDA graphs for prefill and decode.
  • Standalone flash-attn package — pip-installable kernels for custom models; requires compatible NVIDIA GPU, correct CUDA/PyTorch builds, and often head dims divisible by 8.

Compatibility constraints

Flash kernels impose practical limits: FP16/BF16 are standard; FP32 may fall back to slower paths. Head dimensions above 128 or odd sizes can disable fusion. Causal masking, dropout, and attention bias support improved across versions but still vary — verify against your model card. AMD ROCm builds and Apple Metal have separate fused-attention paths (not always called “Flash Attention” but solving the same problem).

Worked example: Harbor Support LLM prefill tuning

Harbor Support runs a 7B instruct model on one A10G (24 GB) to draft ticket replies. Median prompts are 900 tokens, but 5% of enterprise tickets attach 6–12k-token log dumps. Before optimization, prefill on a 10k-token attachment spiked VRAM past 22 GB and triggered batching collapses in the serving queue.

  1. Baseline profiling — PyTorch profiler showed 68% of prefill time in aten::scaled_dot_product_attention with the math backend (Flash disabled due to an older Transformers pin).
  2. Enable SDPA Flash — Upgraded to Transformers 4.40+, set attn_implementation="sdpa", confirmed torch.backends.cuda.flash_sdp_enabled() is True on A10G.
  3. Chunk long attachments — RAG still retrieves relevant log sections; full 10k prefill dropped to ~3.2k effective tokens for 95% of SLA-bound replies (product change, not kernel — but pairs with headroom Flash opened).
  4. Batch size recovery — With attention memory down, continuous batching raised from 4 to 7 concurrent prefills without OOM; P95 prefill latency fell 41% on >4k-token inputs.
  5. Regression eval — Same held-out ticket set; ROUGE-L and human spot-checks unchanged — confirming exact attention, not approximate kernels.

Pair with LLM cost optimization (smaller batches were costing GPU idle time) and prompt caching for repeated system prefixes.

Decision table: which attention path to use

ScenarioRecommended pathWhy
Training/fine-tuning LLMs on NVIDIAFlashAttention-2 via SDPA or flash_attnMemory-linear attention; longer n
Production vLLM/TGI servingEngine default fused kernelsAlready integrated in CUDA graphs
Custom attention (alibi, exotic masks)Verify mask support; may need math fallbackNot every mask ships in fused kernels
CPU / edge inferenceStandard or xFormers CPU pathsFlash is GPU-specific
Need attention weights exportedMath attention (no fusion)Fused kernels skip materializing weights
Very short sequencesEither; profile firstOverhead may dominate

Common pitfalls

  • Assuming Flash is on — SDPA silently falls back to the math backend when dtype, head size, or mask constraints fail. Log which backend ran.
  • Chasing FLOPs instead of bytes — Recomputation adds compute; Flash wins when HBM traffic was the bottleneck, not always on compute-bound shapes.
  • Ignoring prefill vs decode — KV-cache decode is a different kernel path; Flash Attention headlines apply most to prefill/training.
  • Version skew — CUDA, PyTorch, and flash-attn wheels must match; pip install errors are common on mismatched driver stacks.
  • Needing attention maps for interpretability — Fused paths do not return full softmax weights; use hooks or a separate interpretability pass.
  • Benchmarking at wrong n — Gains scale with sequence length; test your production percentile, not toy lengths.

Production checklist

  • Profile whether attention is memory-bound (HBM near limit during prefill).
  • Confirm GPU architecture supports Flash / SDPA fused path for your dtype and head dim.
  • Enable SDPA or flash_attention_2 in training and serving configs.
  • Log attention backend selection in staging (Flash vs mem_efficient vs math).
  • Benchmark P50/P95 prefill latency at 50th and 95th percentile sequence lengths.
  • Run quality regression after enabling — output should match exact attention.
  • Document fallback if exotic masks or head dims block fusion.
  • Align CUDA, PyTorch, and optional flash-attn versions in lockstep.

Key takeaways

  • Attention is often memory-bandwidth bound; the n × n score matrix is the usual VRAM killer.
  • Flash Attention tiles Q/K/V through SRAM, uses online softmax, and recomputes in backward — exact, not approximate.
  • Memory scales linearly in n for attention intermediates; speedups grow with context length.
  • PyTorch SDPA, Transformers, and serving engines expose Flash without writing custom CUDA.
  • Verify the fused kernel actually runs; silent fallback to math attention erases expected gains.

Related reading