Guide

Group Query Attention (GQA) explained

Harbor Support's on-prem 70B assistant served twenty concurrent RAG sessions with 32k-token context windows. Prefill fit in VRAM, but decode crashed once accumulated KV cache tensors exceeded GPU memory — not because weights grew, but because every attention head stored its own key and value vectors per token. The team could not afford to halve context or buy another H100 row. Group Query Attention (GQA) fixed the bottleneck: instead of 64 independent K/V head pairs (multi-head attention), the model uses 64 query heads grouped into eight shared K/V heads. Cache bytes per token dropped by roughly 4× with modest quality loss after a short continued-pretrain conversion. GQA sits between full multi-head attention (MHA) and extreme multi-query attention (MQA) — the design pattern behind Llama 2/3, Mistral, and many production inference stacks. This guide covers how shared K/V heads work, cache arithmetic, uptraining from MHA, pairing with Flash Attention and RoPE, the Harbor decode refactor, an architecture decision table, pitfalls, and a production checklist alongside our transformer architecture guide.

Multi-head, multi-query, and group-query attention

Standard scaled dot-product attention projects hidden states into query Q, key K, and value V tensors. Multi-head attention (MHA) runs h independent attention operations in parallel — each with its own WQ, WK, and WV projections. During autoregressive decode, every past token's keys and values for every head are cached. Memory scales as 2 × h × dhead × seq_len per layer (times batch and layer count).

Multi-query attention (MQA) shares a single K/V head across all query heads. Cache memory drops by a factor of h because only one K and one V vector exist per token per layer. Throughput on long contexts improves dramatically, but some models lose fluency or factual recall when pushed to a single shared K/V pair.

Group Query Attention (GQA) is the compromise: query heads are partitioned into g groups (often 8), and each group shares one K/V head pair. With h = 64 and g = 8, you get eight K/V heads instead of 64 — an 8× cache reduction vs MHA, but 8× more expressive than MQA. Llama 2 70B uses GQA with 8 groups; Mistral 7B uses GQA with 8 KV heads and 32 query heads. The attention math per head is unchanged; only the projection layout and cache layout differ.

Broadcasting K/V across query heads

Implementation-wise, GQA computes Q with shape [batch, seq, h, d] as usual. K and V are computed with only g heads, then broadcast (or repeated via reshape) so each group of h/g query heads attends against the same K/V slice. Frameworks like Hugging Face Transformers expose this as num_key_value_heads vs num_attention_heads. Inference engines (vLLM, TensorRT-LLM) pack the smaller K/V tensors into paged cache blocks — the win is physical bytes stored, not FLOPs on the query side.

KV cache arithmetic (why GQA matters at decode)

Model weights are fixed at load time; the KV cache grows linearly with concurrent users, context length, and batch size. For a decoder-only stack with L layers, hidden size d, h heads, head dim dh = d/h, sequence length n, and FP16 (2 bytes):

KV_bytes ≈ 2 × L × n × h × dh × 2 for MHA (factor 2 for K and V).

Replace h with g (number of KV heads) for GQA:

KV_bytes_GQA ≈ KV_bytes_MHA × (g / h)

Example: 80 layers, n = 32768, h = 64, dh = 128, FP16 MHA cache ≈ 80 × 32768 × 64 × 128 × 4 bytes ≈ 34 GB per sequence. GQA with g = 8 cuts that to ≈ 4.3 GB — the difference between OOM and twenty parallel sessions on one GPU. Prefill still computes full attention over the prompt; savings accrue on every decode step as cached K/V accumulate. This is why serving benchmarks quote “tokens/sec at 32k context” separately from weight memory.

GQA does not replace Flash Attention

Flash Attention optimizes how attention is computed (tiling, SRAM, fused softmax). GQA optimizes what you store between steps. They stack: FlashAttention kernels accept grouped K/V layouts; paged KV managers allocate fewer bytes per slot. Do not expect GQA to fix quadratic prefill cost — long prompts still benefit from chunked prefill, sliding windows, or context pruning.

Converting MHA checkpoints to GQA

You rarely train GQA from scratch unless the base architecture already specifies num_key_value_heads. Common path: start from an MHA checkpoint and uptrain (continued pretrain or distillation) after merging K/V head weights.

  • Mean merge — average K/V weights of heads in each group as initialization; fastest, may need more recovery steps.
  • Strided selection — keep every h/g-th head's K/V weights; simple for power-of-two group counts.
  • Distillation — teacher MHA model supervises student GQA on a small high-quality corpus; best quality, higher cost.

Query projections stay at full head count; only WK and WV shrink. After conversion, run perplexity and downstream evals (MMLU slices, your RAG hit rate) before shipping. Harbor used 200B tokens of domain support logs plus public code/text for two epochs after mean-merge init — MMLU dropped 0.4 points, but internal ticket resolution accuracy recovered within noise.

Harbor Support LLM decode refactor (worked example)

Problem. Harbor Support served a 70B MHA model on four A100 80GB GPUs with tensor parallelism. At 32k context and batch 4, vLLM reported KV cache usage exceeding available blocks after ~18k tokens average per session — tail latency spiked from eviction and preemption.

Change. Converted to GQA with g = 8 (matching Llama-2-70B layout), mean-merged K/V init, 48 hours continued pretrain on Harbor's curated corpus. Serving config: same TP=4, FP16 weights, pagedAttention block size 16, max model len 32768.

Results. KV bytes per token fell from ~1.05 MB to ~0.13 MB per layer-equivalent (8× on K/V side). Sustainable concurrent sessions rose from ~12 to ~28 at p95 TTFT under 2s. Per-token decode latency improved ~15% from reduced memory bandwidth reading smaller caches. Quality: internal RAG exact-match on gold answers 91.2% vs 91.6% pre-conversion (within eval variance).

Lesson. GQA is a serving architecture lever, not a training novelty. Measure cache bytes and concurrent capacity before buying hardware.

Architecture decision table

Pattern KV heads Cache vs MHA Typical quality When to choose
MHA h 1× (baseline) Highest Research, short context, quality-first fine-tunes
GQA g (e.g. 8) g/h Near-MHA after uptrain Production LLMs, long context, multi-tenant serving
MQA 1 1/h Can degrade on hard tasks Extreme throughput, edge devices, smaller models
Low-rank KV compression Varies Learned Task-dependent Experimental; when you cannot change arch
Sliding-window attention Full or GQA Bounded window Local context only Very long streams where distant past is disposable

Common pitfalls

  • Assuming GQA fixes prefill OOM — quadratic attention during prefill still scales with sequence length; GQA mainly helps decode cache growth.
  • Skipping uptrain after head merge — naive weight averaging without recovery training often hurts reasoning benchmarks.
  • Mismatched config in servingnum_key_value_heads must match the checkpoint; silent broadcast bugs produce garbage logits.
  • Ignoring query-head count in FLOPs estimates — Q projections and output projections still scale with h; GQA is not a full 8× speedup on every op.
  • Confusing GQA with MoE — MoE sparsifies FFN; GQA sparsifies K/V storage. Orthogonal tricks.
  • Wrong group counth must divide evenly by g; odd splits require architecture changes.

Production checklist

  • Profile KV cache bytes per token at target context before scaling replicas.
  • Confirm num_attention_heads and num_key_value_heads in config match weights.
  • Validate inference engine supports GQA layout (vLLM, TGI, TensorRT-LLM version matrix).
  • If converting MHA to GQA, plan uptrain budget and quality eval gates.
  • Benchmark concurrent sessions at p95/p99 latency, not single-stream tokens/sec.
  • Pair GQA with FP8/INT8 KV cache quantization only after accuracy tests.
  • Keep RoPE and position scaling settings unchanged unless context extension requires NTK/YaRN.
  • Monitor cache fragmentation and preemption under bursty multi-tenant load.
  • Document group count in model cards for downstream fine-tuners.
  • Re-run RAG retrieval evals — attention pattern changes can shift long-context recall.

Key takeaways

  • GQA shares K/V heads across groups of query heads, shrinking KV cache by factor h/g vs full MHA.
  • MQA is the extreme case (g = 1); GQA trades a little memory for better quality than MQA.
  • Decode and multi-tenant serving benefit most; prefill still needs Flash Attention or context limits.
  • Converting MHA checkpoints requires merge init plus short continued pretrain or distillation.
  • Llama, Mistral, and many open weights already ship GQA — configure serving to exploit it.

Related reading