Guide

Gradient checkpointing explained

Harbor Analytics needed to fine-tune a 13B-parameter document classifier on 8k-token sequences. One A100 80GB should have been enough — weights in BF16 plus Adam states still fit — but the forward pass OOM'd during the first training step. The culprit was not parameters; it was activation memory: every intermediate tensor from 40 transformer layers, kept alive for backpropagation. Enabling gradient checkpointing (also called activation checkpointing) dropped peak activation VRAM by 62% and let training proceed at batch size 2 with only a 33% longer backward pass. The technique is standard in LLM fine-tuning stacks (Hugging Face gradient_checkpointing_enable(), DeepSpeed, FSDP) but poorly understood outside framework toggles. This guide explains what activations cost, how checkpointing trades recomputation for memory, selective vs full strategies for transformers, PyTorch's torch.utils.checkpoint API, pairing with mixed precision and Flash Attention, the Harbor fine-tune refactor, a technique decision table, pitfalls, and a production checklist alongside our transformer architecture guide.

Where training memory actually goes

GPU memory during training has four main buckets: model weights, optimizer states (Adam stores momentum and variance — often 2× parameter bytes in FP32), gradients (one buffer per parameter), and activations (intermediate outputs of every layer saved for the backward pass). For inference you mostly care about weights and KV cache; for training, activations frequently dominate once sequence length or batch size grows.

In a vanilla decoder-only transformer, activation memory scales roughly with batch × seq_len × hidden_size × num_layers, with multipliers for attention maps (quadratic in sequence length during training unless you use memory-efficient kernels). A 13B model at batch 2 and sequence 8192 can accumulate tens of gigabytes of activations before the optimizer even runs. That is why you can load weights fine but still OOM on the first backward step.

Why backprop needs activations

Backpropagation applies the chain rule layer by layer. To compute gradients for layer L, you need the input activations to that layer from the forward pass — the values before the nonlinearity or attention softmax. Standard autograd frameworks store these tensors automatically. Deeper networks and longer sequences mean more stored tensors and higher peak memory between forward and backward.

How gradient checkpointing works

Gradient checkpointing breaks the default “store everything” contract. During the forward pass, you mark certain segments (often whole transformer blocks) as checkpoints. The framework saves only the checkpoint segment's input and discards intermediate activations inside the segment. During backward, it re-runs the forward computation for that segment from the saved input to regenerate the discarded activations, then computes gradients normally.

The tradeoff is explicit: less memory, more compute. If you checkpoint every layer in an N-layer network, you recompute roughly one extra forward pass worth of work during backward — often cited as ~33% total training time overhead for transformers, though the exact number depends on which ops dominate and whether attention is already fused. Memory savings can approach 1/N of activation storage when checkpointing each layer individually (Chen et al., “Training Deep Nets with Sublinear Memory Cost”).

Checkpoint granularity

  • Per-layer checkpointing — one checkpoint boundary per transformer block; maximum memory savings, highest recompute cost.
  • Every-k-layers — checkpoint every 2 or 4 blocks; balances memory and speed; common in production fine-tunes.
  • Selective op checkpointing — checkpoint only attention or FFN submodules; used when one sub-block dominates activation footprint.
  • Full-model checkpointing — treat the entire forward as one segment; minimal memory win, rarely used alone.

Checkpointing is orthogonal to mixed precision: recomputation runs under the same autocast policy as the original forward. It is also complementary to Flash Attention, which reduces attention activation memory during the forward itself; checkpointing addresses what remains stored across layers.

PyTorch and Hugging Face integration

PyTorch exposes checkpointing via torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=False). Wrap a forward segment; inputs that require grad are saved, internals are not. Set use_reentrant=False (default in recent PyTorch) for better compatibility with autograd and distributed training.

Hugging Face Transformers adds a one-line enable for most LLM classes:

model.gradient_checkpointing_enable()

This inserts checkpoint boundaries at each decoder layer. Combine with model.config.use_cache = False during training — KV cache is for inference only and wastes memory if left on. For PEFT/LoRA fine-tunes, checkpointing the frozen base model layers is especially valuable because you cannot reduce trainable parameter count but you can slash activation storage.

Distributed training notes

FSDP (Fully Sharded Data Parallel) shards parameters across GPUs; activation checkpointing still helps because activations are not sharded by default. DeepSpeed ZeRO stages reduce optimizer and gradient memory; activation checkpointing targets the remaining peak. Use both when fine-tuning large models on modest clusters. Profile before assuming ZeRO alone fixes OOM — long-context fine-tunes often need checkpointing even under ZeRO-3.

Harbor Analytics 13B fine-tune refactor (worked example)

Problem. Harbor Analytics fine-tuned a 13B Mistral-class model for multi-label document routing (legal, finance, ops, spam). Target: 8k context, batch size 2, LoRA rank 64 on attention projections, BF16 weights, AdamW on one A100 80GB. Forward completed; backward OOM at ~74 GB peak with standard autograd storage.

Change. Enabled Hugging Face gradient checkpointing on all 40 decoder layers. Disabled use_cache. Kept Flash Attention 2 for the attention forward (already configured). No change to LoRA targets or learning rate.

Results. Peak activation memory fell from ~48 GB to ~18 GB; total peak VRAM dropped from 74 GB (OOM) to 52 GB (stable). Training step time increased from 4.1s to 5.5s (+34%). After 3 epochs on 120k labeled docs, macro-F1 reached 0.887 vs 0.889 on a smaller-context baseline without checkpointing (within noise). The team shipped the 8k-context model instead of truncating documents to 4k.

Lesson. When sequence length is the bottleneck, checkpointing is often cheaper than buying a second GPU or slashing context. Measure step time vs memory before scaling hardware.

Technique decision table

Technique What it saves Compute cost When to choose
Gradient checkpointing Activation memory +20–40% step time Long sequences, deep models, single-GPU fine-tunes
Mixed precision (BF16/FP16) Weight + activation bytes Often faster (Tensor Cores) Always on modern NVIDIA GPUs; pair with checkpointing
Flash Attention Attention activation memory Usually faster attention Training and inference with long context
Micro-batching (grad accumulation) Per-step activation memory More steps per effective batch When batch size drives OOM, not sequence length
ZeRO / FSDP sharding Optimizer + gradient + weight memory Communication overhead Multi-GPU; does not replace checkpointing for long ctx
Smaller model / LoRA Weights + optimizer states Capacity tradeoff When full fine-tune is unnecessary
CPU offloading GPU parameter/optimizer bytes Severe slowdown Last resort; avoid for production training loops

Common pitfalls

  • Expecting checkpointing to shrink optimizer memory — Adam states still scale with trainable parameters; checkpointing only affects activations.
  • Leaving use_cache=True during training — allocates inference KV tensors you do not need; always disable for fine-tune.
  • Checkpointing inside dropout without fixed seeds — recomputed forward must match original stochastic ops; modern use_reentrant=False handles this, but custom layers need care.
  • Assuming linear speed penalty — attention-heavy layers cost more to recompute; profile per-layer, consider every-2-layer checkpointing if 33% overhead hurts SLA.
  • Combining with torch.compile edge cases — some compile modes re-fuse graphs in ways that interact badly with checkpoint boundaries; test compiled + checkpointed together.
  • Ignoring activation memory in eval — large-batch validation without torch.no_grad() can still OOM; checkpointing is a training technique, not an inference fix.

Production checklist

  • Profile peak VRAM breakdown (weights, optimizer, activations) before tuning.
  • Enable gradient checkpointing when activation memory exceeds ~40% of budget.
  • Set use_cache=False for all training and fine-tune runs.
  • Pair checkpointing with BF16/FP16 and Flash Attention when available.
  • Benchmark step time with and without checkpointing; document overhead %.
  • For LoRA/QLoRA, checkpoint the frozen base model layers by default.
  • Validate loss curves match non-checkpointed baseline on a small shard.
  • In distributed jobs, confirm checkpoint boundaries align across ranks.
  • Disable checkpointing for short-sequence experiments where it adds cost without benefit.
  • Document checkpoint policy in training configs for reproducibility.

Key takeaways

  • Gradient checkpointing discards intermediate activations and recomputes them during backward, trading compute for memory.
  • Activation storage often dominates LLM fine-tuning OOMs, not parameter count.
  • Per-layer checkpointing in transformers is the default lever; expect ~20–40% slower steps for large memory wins.
  • Stack with mixed precision and Flash Attention; use ZeRO/FSDP for optimizer sharding, not as a substitute.
  • Hugging Face gradient_checkpointing_enable() plus use_cache=False is the standard fine-tune pairing.

Related reading