Guide
Gradient checkpointing explained
Harbor Analytics needed to fine-tune a 13B-parameter document classifier on
8k-token sequences. One A100 80GB should have been enough — weights in
BF16 plus Adam states still fit — but the forward pass OOM'd during
the first training step. The culprit was not parameters; it was
activation memory: every intermediate tensor from 40 transformer
layers, kept alive for
backpropagation.
Enabling gradient checkpointing (also called
activation checkpointing) dropped peak activation VRAM by 62%
and let training proceed at batch size 2 with only a 33% longer backward pass.
The technique is standard in LLM fine-tuning stacks (Hugging Face
gradient_checkpointing_enable(), DeepSpeed, FSDP) but poorly
understood outside framework toggles. This guide explains what activations cost,
how checkpointing trades recomputation for memory, selective vs full strategies
for transformers, PyTorch's torch.utils.checkpoint API,
pairing with
mixed precision
and
Flash Attention,
the Harbor fine-tune refactor, a technique decision table, pitfalls, and a
production checklist alongside our
transformer architecture guide.
Where training memory actually goes
GPU memory during training has four main buckets: model weights, optimizer states (Adam stores momentum and variance — often 2× parameter bytes in FP32), gradients (one buffer per parameter), and activations (intermediate outputs of every layer saved for the backward pass). For inference you mostly care about weights and KV cache; for training, activations frequently dominate once sequence length or batch size grows.
In a vanilla
decoder-only transformer,
activation memory scales roughly with
batch × seq_len × hidden_size × num_layers,
with multipliers for attention maps (quadratic in sequence length during
training unless you use memory-efficient kernels). A 13B model at batch 2
and sequence 8192 can accumulate tens of gigabytes of activations before
the optimizer even runs. That is why you can load weights fine but still OOM
on the first backward step.
Why backprop needs activations
Backpropagation
applies the chain rule layer by layer. To compute gradients for layer
L, you need the input activations to that layer from the forward
pass — the values before the nonlinearity or attention softmax. Standard
autograd frameworks store these tensors automatically. Deeper networks and
longer sequences mean more stored tensors and higher peak memory between
forward and backward.
How gradient checkpointing works
Gradient checkpointing breaks the default “store everything” contract. During the forward pass, you mark certain segments (often whole transformer blocks) as checkpoints. The framework saves only the checkpoint segment's input and discards intermediate activations inside the segment. During backward, it re-runs the forward computation for that segment from the saved input to regenerate the discarded activations, then computes gradients normally.
The tradeoff is explicit: less memory, more compute. If you
checkpoint every layer in an N-layer network, you recompute
roughly one extra forward pass worth of work during backward — often
cited as ~33% total training time overhead for transformers, though the exact
number depends on which ops dominate and whether attention is already fused.
Memory savings can approach 1/N of activation storage when
checkpointing each layer individually (Chen et al., “Training Deep Nets
with Sublinear Memory Cost”).
Checkpoint granularity
- Per-layer checkpointing — one checkpoint boundary per transformer block; maximum memory savings, highest recompute cost.
- Every-k-layers — checkpoint every 2 or 4 blocks; balances memory and speed; common in production fine-tunes.
- Selective op checkpointing — checkpoint only attention or FFN submodules; used when one sub-block dominates activation footprint.
- Full-model checkpointing — treat the entire forward as one segment; minimal memory win, rarely used alone.
Checkpointing is orthogonal to
mixed precision:
recomputation runs under the same autocast policy as the original
forward. It is also complementary to
Flash Attention,
which reduces attention activation memory during the forward itself;
checkpointing addresses what remains stored across layers.
PyTorch and Hugging Face integration
PyTorch exposes checkpointing via
torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=False).
Wrap a forward segment; inputs that require grad are saved, internals are not.
Set use_reentrant=False (default in recent PyTorch) for better
compatibility with autograd and distributed training.
Hugging Face Transformers adds a one-line enable for most LLM classes:
model.gradient_checkpointing_enable()
This inserts checkpoint boundaries at each decoder layer. Combine with
model.config.use_cache = False during training — KV cache
is for inference only and wastes memory if left on. For PEFT/LoRA fine-tunes,
checkpointing the frozen base model layers is especially valuable because you
cannot reduce trainable parameter count but you can slash activation storage.
Distributed training notes
FSDP (Fully Sharded Data Parallel) shards parameters across GPUs; activation checkpointing still helps because activations are not sharded by default. DeepSpeed ZeRO stages reduce optimizer and gradient memory; activation checkpointing targets the remaining peak. Use both when fine-tuning large models on modest clusters. Profile before assuming ZeRO alone fixes OOM — long-context fine-tunes often need checkpointing even under ZeRO-3.
Harbor Analytics 13B fine-tune refactor (worked example)
Problem. Harbor Analytics fine-tuned a 13B Mistral-class model for multi-label document routing (legal, finance, ops, spam). Target: 8k context, batch size 2, LoRA rank 64 on attention projections, BF16 weights, AdamW on one A100 80GB. Forward completed; backward OOM at ~74 GB peak with standard autograd storage.
Change. Enabled Hugging Face gradient checkpointing on all
40 decoder layers. Disabled use_cache. Kept Flash Attention 2
for the attention forward (already configured). No change to LoRA targets or
learning rate.
Results. Peak activation memory fell from ~48 GB to ~18 GB; total peak VRAM dropped from 74 GB (OOM) to 52 GB (stable). Training step time increased from 4.1s to 5.5s (+34%). After 3 epochs on 120k labeled docs, macro-F1 reached 0.887 vs 0.889 on a smaller-context baseline without checkpointing (within noise). The team shipped the 8k-context model instead of truncating documents to 4k.
Lesson. When sequence length is the bottleneck, checkpointing is often cheaper than buying a second GPU or slashing context. Measure step time vs memory before scaling hardware.
Technique decision table
| Technique | What it saves | Compute cost | When to choose |
|---|---|---|---|
| Gradient checkpointing | Activation memory | +20–40% step time | Long sequences, deep models, single-GPU fine-tunes |
| Mixed precision (BF16/FP16) | Weight + activation bytes | Often faster (Tensor Cores) | Always on modern NVIDIA GPUs; pair with checkpointing |
| Flash Attention | Attention activation memory | Usually faster attention | Training and inference with long context |
| Micro-batching (grad accumulation) | Per-step activation memory | More steps per effective batch | When batch size drives OOM, not sequence length |
| ZeRO / FSDP sharding | Optimizer + gradient + weight memory | Communication overhead | Multi-GPU; does not replace checkpointing for long ctx |
| Smaller model / LoRA | Weights + optimizer states | Capacity tradeoff | When full fine-tune is unnecessary |
| CPU offloading | GPU parameter/optimizer bytes | Severe slowdown | Last resort; avoid for production training loops |
Common pitfalls
- Expecting checkpointing to shrink optimizer memory — Adam states still scale with trainable parameters; checkpointing only affects activations.
- Leaving
use_cache=Trueduring training — allocates inference KV tensors you do not need; always disable for fine-tune. - Checkpointing inside dropout without fixed seeds —
recomputed forward must match original stochastic ops; modern
use_reentrant=Falsehandles this, but custom layers need care. - Assuming linear speed penalty — attention-heavy layers cost more to recompute; profile per-layer, consider every-2-layer checkpointing if 33% overhead hurts SLA.
- Combining with torch.compile edge cases — some compile modes re-fuse graphs in ways that interact badly with checkpoint boundaries; test compiled + checkpointed together.
- Ignoring activation memory in eval — large-batch
validation without
torch.no_grad()can still OOM; checkpointing is a training technique, not an inference fix.
Production checklist
- Profile peak VRAM breakdown (weights, optimizer, activations) before tuning.
- Enable gradient checkpointing when activation memory exceeds ~40% of budget.
- Set
use_cache=Falsefor all training and fine-tune runs. - Pair checkpointing with BF16/FP16 and Flash Attention when available.
- Benchmark step time with and without checkpointing; document overhead %.
- For LoRA/QLoRA, checkpoint the frozen base model layers by default.
- Validate loss curves match non-checkpointed baseline on a small shard.
- In distributed jobs, confirm checkpoint boundaries align across ranks.
- Disable checkpointing for short-sequence experiments where it adds cost without benefit.
- Document checkpoint policy in training configs for reproducibility.
Key takeaways
- Gradient checkpointing discards intermediate activations and recomputes them during backward, trading compute for memory.
- Activation storage often dominates LLM fine-tuning OOMs, not parameter count.
- Per-layer checkpointing in transformers is the default lever; expect ~20–40% slower steps for large memory wins.
- Stack with mixed precision and Flash Attention; use ZeRO/FSDP for optimizer sharding, not as a substitute.
- Hugging Face
gradient_checkpointing_enable()plususe_cache=Falseis the standard fine-tune pairing.
Related reading
- Backpropagation explained — why activations must be stored or recomputed
- Mixed precision training explained — BF16/FP16 memory savings that complement checkpointing
- Flash Attention explained — memory-efficient attention kernels during forward
- Transformer architecture explained — where checkpoint boundaries typically land per layer