Guide
Multi-task learning explained
Harbor Support routed incoming tickets through three separate BERT classifiers: issue category (12 labels), urgency tier (4 labels), and team queue (8 labels). Each model needed its own fine-tuned copy of the encoder, tripling GPU memory at inference and drifting out of sync when one head was retrained without the others. They merged into a single multi-task learning (MTL) model: one shared transformer backbone with three lightweight task heads trained jointly on every labeled ticket. Category F1 rose from 0.84 to 0.87, urgency accuracy held steady, and inference latency dropped by roughly 60%. Multi-task learning trains one model on multiple related objectives simultaneously, forcing a shared representation that captures structure useful across tasks. Done well, MTL improves data efficiency, reduces serving cost, and regularizes small tasks with signal from larger ones. Done poorly, tasks compete for capacity and negative transfer makes every head worse. This guide explains hard vs soft parameter sharing, loss weighting strategies, when MTL beats separate models, a Harbor Support ticket router worked example, an approach decision table, common pitfalls, and a production checklist.
What multi-task learning is
In single-task training you minimize one loss
Ltask(θ) over parameters θ.
In MTL you minimize a combined objective
L = Σi wi Li(θshared, θi)
where θshared is a backbone (encoder layers,
embedding table) and θi are task-specific heads.
All tasks see the same input (or inputs derived from a common source) during
training; at inference you may use one head, several, or all.
MTL differs from transfer learning in sequencing: transfer learning trains on a source task, then fine-tunes on a target. MTL optimizes tasks together from the start (or after a brief shared pretrain). It also differs from ensembling separate models — one forward pass, one set of shared weights, one deployment artifact.
The inductive bias: related tasks share low-level structure. Sentiment, topic, and language ID on the same text all benefit from syntax and entity features. Depth estimation and surface normals in vision share edge detectors. When that assumption holds, the shared trunk learns more robust features than any single task alone with limited labels.
Hard vs soft parameter sharing
Hard parameter sharing
The default in modern
deep learning:
all tasks use identical lower layers; only the final layers (classification
heads, regression outputs) differ. BERT with multiple linear heads on top of
[CLS] is hard sharing. Benefits: few extra parameters per task,
strong regularization, simple to implement in PyTorch or Hugging Face
(num_labels dict or custom forward returning a loss
dict).
Soft parameter sharing
Each task has its own network; cross-task penalties encourage weights to stay similar (L2 distance between corresponding layers) or tasks attend to each other (cross-stitch units, sluice networks). More flexible when tasks need different capacities, but more parameters and tuning overhead. Use when hard sharing shows negative transfer on at least one head.
Task-specific adapters
A middle ground popular in LLM stacks: frozen or lightly tuned backbone plus small per-task adapter modules (LoRA ranks, bottleneck layers). Shares most compute while giving each task a private adjustment channel — related to LoRA fine-tuning but trained jointly rather than sequentially.
Loss weighting and training dynamics
Naively summing losses treats a regression task with MSE scale 0.01 the same as a classification task with cross-entropy 2.3. Dominant tasks hijack gradients; rare tasks never converge. Practical strategies:
- Manual weights — set
wiby business priority or inverse gradient norm. Fast to ship; brittle when data distributions shift. - Uncertainty weighting — learn homoscedastic task
uncertainties
σisoL ∝ Σ (1/2σi²) Li + log σi. Tasks with high noise get down-weighted automatically. - Gradient normalization — GradNorm, PCGrad, or project conflicting gradients so one task cannot reverse another’s update direction. Helps when tasks are only loosely related.
- Dynamic sampling — each batch mixes examples from all tasks; oversample scarce tasks or use proportional sampling by dataset size.
Monitor per-task learning curves separately. A flat urgency head while category improves is a sign of imbalance or negative transfer, not proof MTL failed globally.
Worked example: Harbor Support ticket router
Harbor’s support queue receives ~4,000 tickets per week. Labels exist for:
- Category (12 classes) — billing, technical, account, feature request, etc. ~35k labeled examples.
- Urgency (4 tiers) — P0–P3 from SLA policy. ~28k labels; noisy because agents disagree on borderline cases.
- Queue (8 teams) — which on-call rotation owns the ticket. ~22k labels; overlaps with category but not deterministically.
Architecture: DistilBERT encoder (66M params) frozen for the
first epoch, then full fine-tune. Three linear heads on [CLS];
losses are weighted cross-entropy with class weights for imbalance. Combined loss
uses learned uncertainty weights after a 500-step warmup with equal weights.
Training: batches of 32 with 12 category, 10 urgency, 10 queue examples (proportional oversampling for queue). Early stopping on macro-averaged validation F1 across tasks. Total training: 4 epochs on one A10 GPU (~45 min).
Results vs three single-task DistilBERT models:
- Category macro-F1: 0.84 → 0.87 (main win — largest task regularized by urgency signal).
- Urgency accuracy: 0.91 → 0.91 (unchanged; noisy labels cap ceiling).
- Queue macro-F1: 0.79 → 0.81 (smallest dataset benefited most from shared features).
- Inference: one forward pass (~18 ms) vs three (~48 ms) on the same hardware.
They kept a single-task category model as a shadow for two weeks; disagreement rate was 4.2% with MTL winning on ambiguous billing/technical boundary tickets.
When MTL helps vs hurts
MTL tends to work when tasks are related, at least one task has abundant labels, and you need multiple predictions per input at serving time. It tends to fail when tasks need incompatible features (RGB object detection + grayscale OCR on different inputs), one task is much noisier without down-weighting, or a single task dominates the product metric and others are throwaway auxiliaries.
Negative transfer is the failure mode where joint training underperforms single-task baselines. Diagnostic steps: train single-task models first; add tasks one at a time; try gradient surgery or separate adapters for the conflicting head. Sometimes the right answer is transfer learning (pretrain on the big task, fine-tune per small task) rather than true MTL.
Approach decision table
| Scenario | Recommended approach | Why |
|---|---|---|
| Multiple labels per input, related semantics | Hard-sharing MTL | One encoder, multiple heads; best latency and data efficiency |
| One large task + tiny auxiliary tasks | MTL with uncertainty or low auxiliary weights | Auxiliaries regularize without dominating gradients |
| Tasks need different input modalities | Separate encoders + fusion layer, or separate models | Hard sharing on incompatible inputs forces bad compromises |
| Only one task needed at inference | Transfer learning or single-task + distillation | MTL serving benefit disappears; sequential training may be simpler |
| Conflicting gradient directions between tasks | PCGrad / GradNorm / per-task adapters | Reduces negative transfer without abandoning shared trunk |
| LLM with many downstream skills | Instruction tuning or adapter MTL | Joint instruction mix approximates MTL at scale |
| Tabular targets with shared features | Multi-output regression or gradient boosting multi-target | Tree models handle heterogeneous targets with less tuning than deep MTL |
Common pitfalls
- Unevaluated single-task baselines — MTL must beat independent models on each metric that matters; joint training is not free.
- Silent label leakage — queue labels derived deterministically from category make MTL look brilliant but add no generalization; audit label pipelines like any leakage check.
- Fixed loss weights forever — class distribution and task difficulty shift; re-tune weights or use learned uncertainty.
- Missing examples per task — not every ticket has all three labels; use masked loss (ignore unlabeled heads) rather than dropping rows.
- Over-sharing capacity — twelve tasks on a tiny MLP trunk; upgrade backbone or move rare tasks to soft sharing.
- Deployment drift — retraining one head in isolation after MTL ships breaks the shared encoder; version and deploy the full multi-head artifact atomically.
- Ignoring calibration — urgency tiers drive paging; check per-task calibration even when accuracy looks fine.
Production checklist
- List all tasks, label counts, noise level, and which predictions are required at inference.
- Train strong single-task baselines for every head before enabling MTL.
- Choose hard vs soft sharing based on ablation, not architecture fashion.
- Implement masked multi-task loss for partially labeled data.
- Log per-task metrics every epoch; alert on negative transfer early.
- Tune loss weights (manual, uncertainty, or gradient methods) on validation.
- Export one ONNX/TorchScript bundle with all heads for consistent serving.
- Shadow-deploy against legacy single-task models before cutover.
- Document which tasks are primary vs auxiliary for future retrains.
- Re-evaluate MTL vs separate models when a new task is added or data mix shifts.
Key takeaways
- MTL shares a backbone across related tasks, improving sample efficiency and cutting inference cost.
- Hard parameter sharing is the default; soft sharing and adapters address negative transfer.
- Loss weighting is not optional — dominant tasks will starve the rest without explicit balancing.
- Always benchmark single-task models; MTL is a hypothesis, not a default.
- Production MTL is one deployable unit — partial retrains break the shared representation.
Related reading
- Transfer learning explained — sequential pretrain-then-finetune when tasks are not jointly labeled
- Text classification explained — single-task NLP baselines and evaluation metrics
- Class imbalance explained — per-task weighting when label frequencies diverge
- Knowledge distillation explained — compress a multi-head teacher into a smaller student