Guide
Layer normalization explained
Convolutional networks normalize across the batch; transformers cannot. When batch size is one token stream or sequences vary in length, statistics computed over the batch dimension are meaningless or unstable. Layer normalization (LayerNorm) fixes this by normalizing within each sample — computing mean and variance across the feature dimension of every token independently. It is the normalization layer inside GPT, BERT, Llama, and virtually every modern language model, and it appears in RNN stacks where temporal steps cannot share batch statistics either. This guide explains the LayerNorm formula, learnable affine parameters, the pre-norm vs post-norm debate that shapes how deep transformers train, the lighter RMSNorm variant used in Llama-class models, and how layer norm compares to batch normalization — with links to transformer blocks and deep learning fundamentals.
What layer normalization computes
Given a feature vector x of dimension d (one token's
hidden state, or one time step in an RNN), layer norm computes:
- Mean:
μ = (1/d) Σ xᵢ - Variance:
σ² = (1/d) Σ (xᵢ − μ)² - Normalized:
x̂ᵢ = (xᵢ − μ) / √(σ² + ε) - Output:
yᵢ = γ x̂ᵢ + β
The small constant ε (typically 1e-5) prevents division by zero.
γ (scale) and β (shift) are learnable
parameters per feature — after whitening, the network can recover any
mean and variance it needs. Without γ and β, forcing zero mean and unit
variance would overly constrain representational capacity.
Crucially, statistics are computed over the feature dimension for a single example. A batch of 32 tokens produces 32 independent normalizations. This makes layer norm invariant to batch size — batch size 1 at inference works identically to batch size 512 at training.
Why transformers use layer norm, not batch norm
Transformers process sequences where each position is a token embedding. Batch norm would normalize across the batch dimension at each position — but sequence lengths differ between samples, padding masks distort statistics, and autoregressive inference runs one token at a time with no batch dimension to aggregate over.
Layer norm also avoids the train/eval split that complicates batch norm.
Batch norm uses running population statistics at inference; layer norm applies
the same per-sample formula in both modes. There is no
model.eval() behavioral change specific to LayerNorm — one less
footgun when exporting to ONNX or serving in production.
The original "Attention Is All You Need" transformer used post-norm: attention and feed-forward sublayers, then layer norm. Modern large language models overwhelmingly use pre-norm (norm before each sublayer), which we cover below. Either way, the normalization primitive is layer norm, not batch norm.
Pre-norm vs post-norm transformer blocks
Post-norm (original Transformer, BERT)
Structure: Sublayer(x) → Dropout → Add(x) → LayerNorm. The
residual connection adds the sublayer output to the input, then normalizes.
Post-norm stacks can be harder to train at extreme depth without careful
learning-rate warmup — gradients through many layer norm operations can
destabilize early training.
Pre-norm (GPT-2/3, Llama, most modern LLMs)
Structure: x + Sublayer(LayerNorm(x)). Normalization happens
before attention and feed-forward, so each sublayer receives
whitened inputs. Pre-norm enables training much deeper transformers without
elaborate warmup schedules and is the default in new architecture design.
Practical implication
When porting weights between architectures, pre-norm vs post-norm placement is not interchangeable — you cannot drop BERT weights into a pre-norm stack without retraining. When fine-tuning, match the normalization placement of the base checkpoint exactly. See LLM fine-tuning for how architecture details affect transfer.
RMSNorm: the modern lightweight variant
Root Mean Square Layer Normalization (RMSNorm) drops mean centering and re-scales by the root-mean-square of activations:
yᵢ = γ · xᵢ / √( (1/d) Σ xⱼ² + ε )
Skipping the mean subtraction saves computation and empirically performs as
well as full LayerNorm in large language models. Llama, Mistral, Gemma, and
many open-weight LLMs use RMSNorm instead of classic LayerNorm. The learnable
scale γ remains; the bias β is typically omitted.
When reading model cards or Hugging Face configs, rms_norm_eps
is the ε equivalent. Quantization and distillation pipelines must preserve
RMSNorm epsilon and gamma values — small numerical differences compound across
dozens of layers.
Layer norm in RNNs and sequence models
Before transformers dominated NLP, layer norm stabilized LSTM and GRU training. Applied per time step across hidden units, it reduces internal covariate shift along the temporal dimension without requiring batch statistics across sequences of different lengths.
Variants include LayerNorm applied to LSTM gates (normalize each gate's pre-activation) and LayerNorm on the hidden state output. The pattern mirrors transformers: normalize before the nonlinear gate computations so gradients flow more evenly through long unrolled sequences during backpropagation through time.
Comparison table: normalization layers at a glance
| Method | Axes normalized | Batch-size dependent? | Train/eval difference? | Primary use |
|---|---|---|---|---|
| LayerNorm | Features per sample | No | No | Transformers, RNNs, LLMs |
| RMSNorm | RMS per sample (no mean) | No | No | Modern LLMs (Llama, Mistral) |
| BatchNorm | Batch + spatial per channel | Yes | Yes (running stats) | CNNs, large-batch vision |
| GroupNorm | Channel groups per sample | No | No | Small-batch detection/segmentation |
| InstanceNorm | Spatial per channel per sample | No | No | Style transfer, GANs |
Full treatment of batch norm mechanics, running statistics, and SyncBN is in the dedicated batch normalization guide. The key decision: if your model processes sequences or runs at batch size 1, layer norm (or RMSNorm) is the correct default.
Interaction with activations and optimizers
Normalization layers sit adjacent to activation functions — in pre-norm transformers, LayerNorm precedes attention and the GELU feed-forward; the residual path bypasses both. This ordering means activations inside sublayers operate on normalized inputs with roughly unit variance, which keeps dot-product attention scores from exploding.
Layer norm changes effective gradient scale, so learning rates tuned for unnormalized networks will not transfer. AdamW with warmup remains standard for transformer training; see neural network optimizers for schedule details. When fine-tuning with LoRA or adapters, base model layer norm weights usually stay frozen — only attention and FFN adapter weights update.
Mixed-precision training (BF16/FP16) typically keeps layer norm in FP32 for numerical stability, similar to batch norm. Frameworks like PyTorch handle this when using automatic mixed precision; custom kernels should not cast norm weights to half precision blindly.
Common mistakes and debugging tips
- Wrong axis — implementing layer norm over the sequence dimension instead of the hidden dimension destroys token independence.
- Confusing pre-norm and post-norm — loading weights into the wrong block layout produces garbage outputs despite matching parameter counts.
- Forgetting ε in custom implementations — rare but catastrophic when a feature dimension has zero variance.
- Applying batch norm to transformers — works in toy setups, fails at inference batch size 1 and with variable sequence lengths.
- Not matching RMSNorm vs LayerNorm — swapping formulas when porting checkpoints causes silent quality collapse.
- Over-normalizing — stacking multiple norm layers without residual paths over-constrains the representation.
If loss diverges early in transformer training, verify: pre-norm placement, learning rate warmup, and that layer norm epsilon matches the reference implementation (1e-5 for LayerNorm, 1e-6 for some RMSNorm configs).
Decision table: which normalization to use
| Architecture | Recommended norm | Notes |
|---|---|---|
| Transformer / LLM | LayerNorm or RMSNorm | Pre-norm placement; match base checkpoint |
| CNN (batch ≥ 16) | BatchNorm | Conv → BN → ReLU default stack |
| CNN (batch < 8) | GroupNorm | Per-sample; no running stats needed |
| LSTM / GRU | LayerNorm per step | Stabilizes BPTT through long sequences |
| Style transfer / GAN | InstanceNorm | Removes instance-specific contrast |
| Edge inference (batch 1) | LayerNorm / RMSNorm / GroupNorm | Avoid batch-dependent norms |
Practitioner checklist
- Confirm normalization axis: features per token/sample, not batch or sequence.
- Match pre-norm vs post-norm to your base architecture when fine-tuning.
- Use RMSNorm only when the checkpoint or paper specifies it — do not swap freely.
- Keep layer norm in FP32 during mixed-precision training unless framework handles it.
- Verify ε matches reference (1e-5 LayerNorm, model-specific for RMSNorm).
- Do not use batch norm in transformer blocks processing variable-length sequences.
- When exporting to ONNX/TorchScript, confirm norm layers trace correctly at batch 1.
- Document norm type and placement in model cards for downstream quantisation.
- Retune learning rate when adding or removing normalization from an existing stack.
- Compare against a known-good implementation before writing custom LayerNorm kernels.
Key takeaways
- LayerNorm normalizes across features within each sample — independent of batch size.
- Transformers and RNNs rely on layer norm because batch statistics are unavailable or misleading.
- Pre-norm (norm before sublayer) is the modern default for deep transformer training.
- RMSNorm simplifies LayerNorm by dropping mean centering; dominant in Llama-class LLMs.
- Normalization choice is architectural — match the reference model exactly when fine-tuning or distilling.
Related reading
- Transformer architecture explained — attention blocks where pre-norm layer norm lives
- Batch normalization explained — batch statistics, running means, and when CNNs need BatchNorm instead
- Activation functions explained — GELU and ReLU placement relative to normalization
- Neural network optimizers explained — AdamW, warmup, and learning rates for normalized stacks