Guide

Multi-Head Latent Attention (MLA) explained

Harbor Support's on-prem 70B assistant already ran Group Query Attention (GQA) with eight shared K/V head groups. That unlocked 32k-token RAG sessions on two H100 80GB GPUs. Product asked for 64k context so agents could paste full incident timelines without chunking. GQA alone did not fit: cache still stored full dhead key and value vectors for every past token across 80 layers. Multi-Head Latent Attention (MLA) — the mechanism behind DeepSeek-V2 and DeepSeek-V3 — solved the next bottleneck. Instead of caching separate K and V tensors per head, MLA projects hidden states into a low-rank latent vector per token, caches only that latent, and up-projects to full keys and values inside the attention kernel when scores are computed. Harbor's migration to an MLA layout cut KV bytes per token by ~93% versus their prior GQA checkpoint and doubled sustainable concurrent 64k sessions without new hardware. MLA is orthogonal to Flash Attention (compute tiling) and prefill-decode disaggregation (pool splitting): it changes what you store in the KV cache, not how you schedule GPUs. This guide covers the down/up projection math, RoPE decoupling, MLA vs MHA vs GQA vs MQA, the Harbor 64k RAG refactor, an architecture decision table, pitfalls, and a production checklist.

Why head sharing is not always enough

Standard multi-head attention caches two tensors per token per layer: keys and values, each with shape [num_kv_heads, dhead]. GQA reduces num_kv_heads from h to g, shrinking cache by factor g/h. That is often an 8× win on Llama-class models — but the per-head dimension dhead stays large (128–256 on 70B stacks). At sequence length 65536, even eight heads per layer still allocate gigabytes per concurrent user.

MLA attacks the other axis: representation width. A bottleneck latent ct with dimension dc ≪ h × dhead is what actually gets written to cache. Full keys and values are reconstructed on the fly via learned up-projection matrices. Attention quality depends on whether the low-rank subspace captures the information K and V normally carry — DeepSeek's results suggest it can, when trained end-to-end with MLA from the start or converted with careful continued pretraining.

MLA vs quantization and sliding windows

INT8/FP8 KV cache quantization also shrinks bytes but can destabilize long-context recall if scales drift. Sliding-window attention caps cache length by dropping distant tokens. MLA preserves full context while compressing stored state — you still pay compute to up-project latents during attention, but memory bandwidth (often the decode bottleneck) drops sharply.

How MLA works: down-project, cache, up-project

At token position t, let ht be the hidden state entering an attention layer. MLA introduces a joint down-projection:

ct = ht WDKV   with   ct ∈ ℝdc

Only ct is appended to the KV cache during decode. When computing attention for the current query qt, latents for all past positions are up-projected:

kt′ = ct′ WUK,   vt′ = ct′ WUV

Queries may use a separate projection qt = ht WQ split across heads as in standard MHA. Attention scores then follow the usual scaled dot-product recipe over reconstructed K and V. The critical serving insight: cache footprint scales with dc, not with 2 × g × dhead.

Decoupled RoPE (DeepSeek pattern)

Rotary position embeddings (RoPE) normally rotate key (and sometimes query) vectors before the dot product. DeepSeek-V2 decouples positional and content components: a slice of the key carries RoPE on a dedicated low-dimensional subspace while the latent carries semantic content. That avoids applying RoPE to compressed latents in ways that break length extrapolation. If you port MLA into an existing RoPE model, match the paper's split or re-validate perplexity every 4k context doubling.

Training vs retrofitting

MLA is easiest when baked into pretraining — DeepSeek-V2/V3 trained with MLA from scratch. Retrofitting a released MHA checkpoint requires initializing WDKV from SVD or PCA of stacked K/V weights, then continued pretrain. Harbor used 120B tokens of mixed code, math, and support transcripts for three epochs after PCA init; MMLU dropped 0.7 points before recovering within 0.2 of baseline.

KV cache arithmetic: MLA vs GQA vs MHA

Per layer, per token, FP16 bytes cached during autoregressive decode:

  • MHA: 4 × h × dh (K and V, 2 bytes each)
  • GQA: 4 × g × dh
  • MLA: 2 × dc (single latent; factor 2 for FP16)

Example: h = 64, g = 8, dh = 128, dc = 512, L = 80, n = 65536:

  • MHA cache ≈ 34 GB per sequence
  • GQA cache ≈ 4.3 GB per sequence
  • MLA cache ≈ 5.2 GB × (512 / 2048) ≈ 1.3 GB if latent replaces equivalent 8-head 128-dim KV

Exact ratios depend on how aggressively dc is chosen. DeepSeek-V2 reports up to ~93% KV reduction vs standard MHA on comparable configs. You trade extra FLOPs on up-projection each decode step — usually cheaper than reading a fat cache from HBM.

Harbor Support 64k RAG refactor (worked example)

Problem. Harbor Support served a GQA 70B model (TP=2 on H100 80GB) with max context 32768. Product required 65536 tokens for consolidated incident exports. vLLM preemption kicked in above ~22k average tokens with batch 6; p99 time-to-first-token exceeded 8s.

Change. Evaluated DeepSeek-V2 MLA recipes; adopted dc = 512 with decoupled RoPE, PCA init from existing GQA K/V weights, 72-hour continued pretrain. Serving: same TP=2, FP16 weights, PagedAttention block size 16, max_model_len 65536, speculative decoding with a 7B draft model unchanged.

Results. KV bytes per token fell from ~0.13 MB/layer-equiv (GQA) to ~0.04 MB. Sustainable concurrent 64k sessions rose from 4 to 11 at p95 TTFT < 2.5s. Per-token decode latency improved ~22% from reduced HBM traffic despite up-project FLOPs. Internal RAG exact-match: 90.8% vs 91.1% at 32k (within noise); at 64k, MLA held 89.4% vs 88.9% for truncated-chunk baseline.

Lesson. When GQA is insufficient for the next context tier, MLA compresses width rather than buying more KV heads or GPUs. Profile bytes and bandwidth before scaling replicas.

Architecture decision table

Pattern What is cached Typical cache vs MHA Compute cost When to choose
MHA Full K, V per head 1× baseline Lowest per attend Research, short context, maximum quality
GQA / MQA Reduced head-count K, V g/h to 1/h Low Llama/Mistral-class serving, moderate context
MLA Low-rank latent c dc / (2 g dh) Higher (up-project each step) 64k+ context, memory-bound decode, DeepSeek-style stacks
KV quantization Quantized K, V or latent 2×–4× smaller Dequant overhead Add-on after GQA/MLA; validate long-context recall
Sliding window + GQA Bounded window of K, V Fixed cap Low Streaming chat where old tokens are disposable

Common pitfalls

  • Applying RoPE to raw latents — breaks position extrapolation; use decoupled RoPE or re-derive keys before rotation.
  • Assuming MLA is a drop-in for any MHA checkpoint — without continued pretrain, perplexity and tool-calling accuracy can crater.
  • Ignoring up-project FLOPs in capacity planning — MLA saves HBM but adds matmuls; on compute-saturated small batches, net speedup may be modest.
  • Confusing MLA with MoE — MoE sparsifies feed-forward layers; MLA compresses attention cache. DeepSeek-V3 uses both.
  • Mismatched serving kernels — verify vLLM/SGLang version supports MLA layout and decoupled RoPE before production cutover.
  • Over-aggressive dc — shrinking latent too far hurts multi-hop reasoning in long RAG threads; sweep dc against your eval suite.
  • Skipping 64k evals — quality loss often appears only past 16k tokens where cache pressure used to truncate context silently.

Production checklist

  • Measure KV bytes per token at target context for MHA, GQA, and MLA candidates.
  • Confirm inference engine MLA support (kernel version, RoPE split layout).
  • If retrofitting, budget continued pretrain and hold quality gates on MMLU plus domain RAG.
  • Benchmark p95/p99 TTFT and tokens/sec at 32k and 64k, not single-stream only.
  • Profile HBM read bandwidth before and after MLA on representative decode batches.
  • Pair with PagedAttention and continuous batching; MLA reduces block bytes per slot.
  • Validate speculative decoding acceptance rates unchanged after MLA migration.
  • Document dc and RoPE decoupling in model cards for fine-tuners.
  • Run long-context needle-in-haystack tests after any latent-dimension change.
  • Monitor preemption and cache fragmentation under multi-tenant 64k load.

Key takeaways

  • MLA caches a low-rank latent per token instead of full K and V tensors, shrinking KV memory along the representation-width axis.
  • GQA reduces head count; MLA compresses each token's stored state — they solve different bottlenecks and can be combined conceptually in new architectures.
  • DeepSeek-V2/V3 popularized MLA with decoupled RoPE; retrofitting older checkpoints needs init plus continued pretrain.
  • Decode gains come from lower HBM traffic; budget extra up-project FLOPs when sizing GPUs.
  • For 64k+ RAG and multi-tenant serving, MLA can be the difference between fitting context and silently truncating prompts.

Related reading