Guide
Group Query Attention (GQA) explained
Harbor Support's on-prem 70B assistant served twenty concurrent RAG sessions with 32k-token context windows. Prefill fit in VRAM, but decode crashed once accumulated KV cache tensors exceeded GPU memory — not because weights grew, but because every attention head stored its own key and value vectors per token. The team could not afford to halve context or buy another H100 row. Group Query Attention (GQA) fixed the bottleneck: instead of 64 independent K/V head pairs (multi-head attention), the model uses 64 query heads grouped into eight shared K/V heads. Cache bytes per token dropped by roughly 4× with modest quality loss after a short continued-pretrain conversion. GQA sits between full multi-head attention (MHA) and extreme multi-query attention (MQA) — the design pattern behind Llama 2/3, Mistral, and many production inference stacks. This guide covers how shared K/V heads work, cache arithmetic, uptraining from MHA, pairing with Flash Attention and RoPE, the Harbor decode refactor, an architecture decision table, pitfalls, and a production checklist alongside our transformer architecture guide.
Multi-head, multi-query, and group-query attention
Standard
scaled dot-product attention
projects hidden states into query Q, key K, and value
V tensors. Multi-head attention (MHA) runs
h independent attention operations in parallel — each with
its own WQ, WK, and
WV projections. During autoregressive decode, every
past token's keys and values for every head are cached. Memory
scales as 2 × h × dhead × seq_len
per layer (times batch and layer count).
Multi-query attention (MQA) shares a single K/V head across
all query heads. Cache memory drops by a factor of h because only
one K and one V vector exist per token per layer. Throughput on long contexts
improves dramatically, but some models lose fluency or factual recall when
pushed to a single shared K/V pair.
Group Query Attention (GQA) is the compromise: query heads are
partitioned into g groups (often 8), and each group shares one K/V
head pair. With h = 64 and g = 8, you get eight K/V
heads instead of 64 — an 8× cache reduction vs MHA, but 8×
more expressive than MQA. Llama 2 70B uses GQA with 8 groups; Mistral 7B uses
GQA with 8 KV heads and 32 query heads. The attention math per head is unchanged;
only the projection layout and cache layout differ.
Broadcasting K/V across query heads
Implementation-wise, GQA computes Q with shape
[batch, seq, h, d] as usual. K and V are
computed with only g heads, then broadcast (or
repeated via reshape) so each group of h/g query heads attends
against the same K/V slice. Frameworks like Hugging Face Transformers expose
this as num_key_value_heads vs num_attention_heads.
Inference engines (vLLM, TensorRT-LLM) pack the smaller K/V tensors into
paged cache blocks — the win is physical bytes stored, not FLOPs on the
query side.
KV cache arithmetic (why GQA matters at decode)
Model weights are fixed at load time; the
KV cache
grows linearly with concurrent users, context length, and batch size. For a
decoder-only stack with L layers, hidden size d,
h heads, head dim dh = d/h, sequence length
n, and FP16 (2 bytes):
KV_bytes ≈ 2 × L × n × h × dh × 2
for MHA (factor 2 for K and V).
Replace h with g (number of KV heads) for GQA:
KV_bytes_GQA ≈ KV_bytes_MHA × (g / h)
Example: 80 layers, n = 32768, h = 64,
dh = 128, FP16 MHA cache ≈ 80 × 32768
× 64 × 128 × 4 bytes ≈ 34 GB per sequence. GQA with
g = 8 cuts that to ≈ 4.3 GB — the difference between
OOM and twenty parallel sessions on one GPU. Prefill still computes full
attention over the prompt; savings accrue on every decode step as cached K/V
accumulate. This is why serving benchmarks quote “tokens/sec at 32k
context” separately from weight memory.
GQA does not replace Flash Attention
Flash Attention optimizes how attention is computed (tiling, SRAM, fused softmax). GQA optimizes what you store between steps. They stack: FlashAttention kernels accept grouped K/V layouts; paged KV managers allocate fewer bytes per slot. Do not expect GQA to fix quadratic prefill cost — long prompts still benefit from chunked prefill, sliding windows, or context pruning.
Converting MHA checkpoints to GQA
You rarely train GQA from scratch unless the base architecture already specifies
num_key_value_heads. Common path: start from an MHA checkpoint and
uptrain (continued pretrain or distillation) after merging
K/V head weights.
- Mean merge — average K/V weights of heads in each group as initialization; fastest, may need more recovery steps.
- Strided selection — keep every
h/g-th head's K/V weights; simple for power-of-two group counts. - Distillation — teacher MHA model supervises student GQA on a small high-quality corpus; best quality, higher cost.
Query projections stay at full head count; only WK and
WV shrink. After conversion, run perplexity and
downstream evals (MMLU slices, your RAG hit rate) before shipping. Harbor
used 200B tokens of domain support logs plus public code/text for two epochs
after mean-merge init — MMLU dropped 0.4 points, but internal ticket
resolution accuracy recovered within noise.
Harbor Support LLM decode refactor (worked example)
Problem. Harbor Support served a 70B MHA model on four A100 80GB GPUs with tensor parallelism. At 32k context and batch 4, vLLM reported KV cache usage exceeding available blocks after ~18k tokens average per session — tail latency spiked from eviction and preemption.
Change. Converted to GQA with g = 8 (matching
Llama-2-70B layout), mean-merged K/V init, 48 hours continued pretrain on
Harbor's curated corpus. Serving config: same TP=4, FP16 weights,
pagedAttention block size 16, max model len 32768.
Results. KV bytes per token fell from ~1.05 MB to ~0.13 MB per layer-equivalent (8× on K/V side). Sustainable concurrent sessions rose from ~12 to ~28 at p95 TTFT under 2s. Per-token decode latency improved ~15% from reduced memory bandwidth reading smaller caches. Quality: internal RAG exact-match on gold answers 91.2% vs 91.6% pre-conversion (within eval variance).
Lesson. GQA is a serving architecture lever, not a training novelty. Measure cache bytes and concurrent capacity before buying hardware.
Architecture decision table
| Pattern | KV heads | Cache vs MHA | Typical quality | When to choose |
|---|---|---|---|---|
| MHA | h |
1× (baseline) | Highest | Research, short context, quality-first fine-tunes |
| GQA | g (e.g. 8) |
g/h |
Near-MHA after uptrain | Production LLMs, long context, multi-tenant serving |
| MQA | 1 | 1/h |
Can degrade on hard tasks | Extreme throughput, edge devices, smaller models |
| Low-rank KV compression | Varies | Learned | Task-dependent | Experimental; when you cannot change arch |
| Sliding-window attention | Full or GQA | Bounded window | Local context only | Very long streams where distant past is disposable |
Common pitfalls
- Assuming GQA fixes prefill OOM — quadratic attention during prefill still scales with sequence length; GQA mainly helps decode cache growth.
- Skipping uptrain after head merge — naive weight averaging without recovery training often hurts reasoning benchmarks.
- Mismatched config in serving —
num_key_value_headsmust match the checkpoint; silent broadcast bugs produce garbage logits. - Ignoring query-head count in FLOPs estimates — Q
projections and output projections still scale with
h; GQA is not a full 8× speedup on every op. - Confusing GQA with MoE — MoE sparsifies FFN; GQA sparsifies K/V storage. Orthogonal tricks.
- Wrong group count —
hmust divide evenly byg; odd splits require architecture changes.
Production checklist
- Profile KV cache bytes per token at target context before scaling replicas.
- Confirm
num_attention_headsandnum_key_value_headsin config match weights. - Validate inference engine supports GQA layout (vLLM, TGI, TensorRT-LLM version matrix).
- If converting MHA to GQA, plan uptrain budget and quality eval gates.
- Benchmark concurrent sessions at p95/p99 latency, not single-stream tokens/sec.
- Pair GQA with FP8/INT8 KV cache quantization only after accuracy tests.
- Keep RoPE and position scaling settings unchanged unless context extension requires NTK/YaRN.
- Monitor cache fragmentation and preemption under bursty multi-tenant load.
- Document group count in model cards for downstream fine-tuners.
- Re-run RAG retrieval evals — attention pattern changes can shift long-context recall.
Key takeaways
- GQA shares K/V heads across groups of query heads, shrinking KV cache by factor
h/gvs full MHA. - MQA is the extreme case (
g = 1); GQA trades a little memory for better quality than MQA. - Decode and multi-tenant serving benefit most; prefill still needs Flash Attention or context limits.
- Converting MHA checkpoints requires merge init plus short continued pretrain or distillation.
- Llama, Mistral, and many open weights already ship GQA — configure serving to exploit it.
Related reading
- LLM KV cache explained — what gets stored per token during autoregressive decode
- Flash Attention explained — IO-aware attention kernels that complement smaller caches
- Transformer architecture explained — where multi-head attention sits in the full stack
- Rotary position embeddings (RoPE) explained — positional encoding paired with modern GQA models