Guide
Distributed LLM training explained
Harbor Analytics needed to fine-tune a 13B-parameter decoder-only model on 2.4M merchant-support transcripts with a 4096-token context window. A single A100 80GB could not hold weights, optimizer states, activations, and gradients simultaneously — even with gradient checkpointing and BF16 mixed precision. Their training team deployed distributed LLM training across eight GPUs: PyTorch FSDP with ZeRO-3 sharding for parameters and optimizer states, plus selective Flash Attention kernels. Peak per-GPU memory fell from 78 GB (OOM on one card) to 41 GB, global batch size reached 64, and wall-clock fine-tune time dropped from an estimated 19 days on one GPU to 2.8 days on the cluster. This guide explains the parallelism axes — data, tensor, and pipeline — how FSDP and DeepSpeed ZeRO partition memory, when to combine them into 3D parallelism, the Harbor Analytics refactor, a technique decision table, pitfalls, and a production checklist alongside our LoRA fine-tuning guide.
Why one GPU is not enough for LLM training
Training memory has four major buckets: model weights (parameters), optimizer states (Adam stores first and second moment estimates per parameter — often 2× parameter bytes in FP32), gradients (one buffer per parameter during backpropagation), and activations (intermediate tensors saved for the backward pass). A 13B model in BF16 weights alone is ~26 GB; Adam states add ~52 GB in FP32; activations at 4096 sequence length with full attention can exceed 30 GB per layer block without checkpointing.
Distributed training splits these tensors across devices and coordinates with collective communication (all-reduce, all-gather, reduce-scatter) over NVLink or InfiniBand. The goal is not only to fit the model but to increase effective batch size and throughput while keeping GPUs utilization high. Inference serving (tensor parallelism in vLLM) solves a different problem — this guide focuses on the training forward/backward loop.
Data parallelism (DDP)
Data parallel training replicates the full model on every GPU.
Each rank receives a different micro-batch; ranks compute gradients locally,
then all_reduce averages gradients across the cluster before the
optimizer step. PyTorch
DistributedDataParallel (DDP) is the standard implementation.
Strengths and limits
DDP is simple, scales well when the model fits in one GPU's memory, and near-linearly increases throughput up to communication bandwidth limits. It does not reduce per-GPU memory — every card still holds full weights, optimizer states, and activations. For LLMs above ~7B parameters with long contexts, DDP alone OOMs. DDP remains the outer loop in most multi-GPU setups: FSDP or tensor parallelism fits the model; DDP scales batch size across data-parallel groups.
Tensor parallelism (TP)
Tensor parallelism shards individual layers across GPUs.
In a linear layer Y = XA, column-parallel splits A
along output features so each rank computes a partial output; row-parallel
splits along input features and requires an all-reduce to combine. Attention
heads are a natural shard boundary: each GPU owns a subset of Q/K/V heads in
multi-head attention blocks.
TP keeps per-GPU activation memory roughly proportional to
1 / tp_size for sharded layers, but introduces frequent
all-reduce or all-to-all communication within every layer. It works
best inside a single node with fast NVLink (typical tp_size 2, 4,
or 8). Cross-node TP is rare because latency kills throughput. Megatron-LM
popularized TP for transformer training; Hugging Face
device_map="auto" uses a related idea for inference.
Pipeline parallelism (PP)
Pipeline parallelism assigns contiguous layer groups to different GPUs: GPU 0 runs layers 0–7, GPU 1 runs layers 8–15, and so on. Micro-batches flow through the pipeline like an assembly line. Without optimization, GPUs idle waiting for predecessors (the pipeline bubble). GPipe and PipeDream schedules (1F1B, interleaved virtual stages) overlap forward and backward passes to shrink the bubble.
PP reduces per-GPU memory to roughly 1 / pp_size of layer
activations and parameters, but adds latency proportional to pipeline depth.
It shines when a model is too tall for one GPU even with TP, or when crossing
nodes where intra-node TP and inter-node PP combine. Tuning micro-batch count
vs pipeline stages is essential — too few micro-batches and GPUs starve.
FSDP and DeepSpeed ZeRO
Fully Sharded Data Parallel (FSDP, PyTorch native) and
ZeRO (DeepSpeed, stages 1–3) shard optimizer states,
gradients, and/or parameters across data-parallel ranks instead of replicating
them. Each GPU stores only 1 / world_size of the sharded tensors
and all_gathers full weights just-in-time for the forward pass,
then reduce_scatters gradients on the backward pass.
ZeRO stages (mental model)
- Stage 1 — shard optimizer states only; modest savings.
- Stage 2 — shard optimizer states + gradients; meaningful for 7B–13B.
- Stage 3 — shard parameters too; largest memory win, most communication.
FSDP's ShardingStrategy.FULL_SHARD maps to ZeRO-3.
SHARD_GRAD_OP maps roughly to ZeRO-2. Pair FSDP with
activation_checkpointing and
per-layer checkpointing
for multiplicative savings. DeepSpeed adds ZeRO-Offload (CPU/NVMe) for teams
without enough GPU RAM — slower but can train 70B on modest clusters.
3D parallelism and expert parallelism
Large foundation-model training (100B+) often combines all three axes: 3D parallelism = data parallel × tensor parallel × pipeline parallel. A typical layout: TP=8 within a node, PP=4 across nodes, DP=16 for global batch. World size = 8 × 4 × 16 = 512 GPUs. Process groups must be configured so collectives hit the right ranks — bugs here cause silent wrong gradients.
Expert parallelism (for Mixture-of-Experts models) routes tokens to different expert MLPs on different GPUs. It is orthogonal to TP/PP/DP but adds all-to-all token dispatch communication. Load imbalance (some experts overloaded) degrades throughput; auxiliary load-balancing loss terms during training mitigate hot experts.
Harbor Analytics 13B fine-tune refactor (worked example)
Problem. Fine-tune Llama-class 13B on 2.4M transcripts, 4096 context, full fine-tune (not LoRA) for domain vocabulary. Single A100: OOM at batch size 1 even with BF16 + checkpointing. Eight A100 80GB node available; naive DDP still OOM per GPU.
Change. Wrapped model in FSDP FULL_SHARD with
auto_wrap_policy on transformer blocks. Enabled
gradient_checkpointing_enable(). Used
torch.cuda.amp.autocast(dtype=bfloat16). Global batch 64 =
micro-batch 1 × 8 GPUs × gradient accumulation 8. Learning rate
scaled linearly with global batch (base LR 2e-5 × 64/32). NCCL backend
over NVLink; limit_all_gathers=True to reduce peak memory spikes.
Results. Peak memory per GPU: 41 GB (down from 78 GB attempted DDP). Throughput: 1,240 tokens/sec cluster-wide vs estimated 180 tokens/sec single-GPU. Fine-tune completed in 2.8 days. Validation perplexity matched single-GPU reference run within 0.3% — confirming correct gradient synchronization. Lesson: FSDP ZeRO-3 is the first knob for multi-GPU fine-tune when LoRA is not an option; add TP only if FSDP still OOMs on long contexts.
Technique decision table
| Technique | What is sharded | Per-GPU memory | When to choose |
|---|---|---|---|
| Single GPU + checkpointing | Activations (recompute) | Full model on one card | ≤7B models, short context, prototyping |
| DDP | Nothing (replicated) | Full model per GPU | Model fits one GPU; scale batch/throughput |
| FSDP / ZeRO-2 | Optimizer + gradients | Moderate reduction | 7B–30B fine-tune, default multi-GPU starting point |
| FSDP / ZeRO-3 | Params + optimizer + grads | Largest DP-style savings | 13B+ full fine-tune, long context |
| Tensor parallelism | Layer matrices / heads | ~1/tp_size per layer | Model too wide; intra-node NVLink; pairs with FSDP |
| Pipeline parallelism | Layer groups | ~1/pp_size layers | Very deep models; cross-node; combine with TP |
| LoRA / QLoRA | Low-rank adapters only | Train small adapter matrices | When full fine-tune is unnecessary; see LoRA guide |
| DeepSpeed Offload | States to CPU/NVMe | GPU holds minimal tensors | Budget clusters; accept slower steps |
Common pitfalls
- Using DDP alone on models that do not fit — every rank OOMs identically; switch to FSDP or LoRA first.
- Wrong learning rate after global batch change — linear or square-root scaling rules need tuning; monitor loss spikes.
- FSDP wrap policy too coarse — wrapping the entire model in one unit defeats per-layer gather overlap; wrap transformer blocks.
- Ignoring gradient accumulation math — global batch = micro_batch × num_gpus × accum_steps; document all three.
- Tensor parallel across slow links — TP on Ethernet clusters often loses to FSDP + PP layout.
- Pipeline bubble with too few micro-batches — rule of thumb: micro-batches ≥ 2 × pp_stages for 1F1B schedules.
- Checkpoint corruption on rank 0 only — use FSDP
state_dict_type=FULL_STATE_DICTor sharded checkpoint APIs; rank-0-only saves miss shards. - Mixing precision dtype mismatches — keep master weights FP32 in optimizer while forward uses BF16; do not cast optimizer states blindly.
Production checklist
- Estimate memory: weights + optimizer + gradients + activations at target seq length.
- Try single-GPU + checkpointing + BF16 before distributing.
- Start with FSDP FULL_SHARD (ZeRO-3) for multi-GPU fine-tune above 7B.
- Set
auto_wrap_policyon transformer block modules. - Configure global batch, micro-batch, accumulation, and LR scaling explicitly.
- Enable
gradient_checkpointingwhen activations dominate memory. - Verify NCCL / InfiniBand health with a small all-reduce benchmark before long runs.
- Log per-GPU memory and throughput; watch for stragglers and thermal throttle.
- Add TP only if FSDP OOMs; keep TP inside NVLink domains.
- Use sharded checkpoint save/load APIs compatible with your parallelism layout.
- Regression-test loss curve against a known single-GPU or smaller run.
Key takeaways
- LLM training memory is weights + optimizer + gradients + activations — single-GPU training hits walls quickly above 7B parameters.
- DDP replicates the model and scales batch size; FSDP/ZeRO shards states and parameters to cut per-GPU memory.
- Tensor parallelism shards layers within a node; pipeline parallelism splits depth across stages — combine as 3D parallelism at scale.
- FSDP ZeRO-3 plus gradient checkpointing is the default multi-GPU fine-tune stack before reaching for TP/PP.
- LoRA and quantization reduce training scope when full fine-tune is not required — distributed training is not always the first answer.
Related reading
- Gradient checkpointing explained — activation memory tradeoffs paired with FSDP
- Mixed precision training explained — BF16/FP16 forward with FP32 optimizer states
- LoRA fine-tuning explained — when adapter training beats full sharded fine-tune
- Transformer architecture explained — what TP shards inside each block