Guide
Federated learning explained
A mobile keyboard wants to predict your next word from how you actually type — but uploading every keystroke to a central server violates privacy law, burns bandwidth, and creates a honeypot for attackers. Federated learning flips the data flow: instead of sending raw examples to the cloud, the cloud sends a model to each device, the device trains locally on its own data, and only model updates (weight deltas) travel back. A coordinator aggregates those updates into a better global model without ever seeing your messages, photos, or health readings. This guide explains cross-device vs cross-silo federated learning, the FedAvg algorithm, client selection and stragglers, non-IID data challenges, secure aggregation, differential privacy, and how federated learning fits alongside machine learning fundamentals, on-device inference, and MLOps.
What federated learning is — and what it is not
Federated learning (FL) is a distributed training paradigm where many clients (phones, browsers, hospital EHR systems, factory sensors) collaborate to improve a shared model while keeping training data decentralized. Google popularized cross-device FL for Gboard; healthcare and finance use cross-silo FL when a handful of institutions refuse to pool patient or transaction records.
It is not:
- Perfect privacy by default — gradient updates can leak training examples through membership inference or reconstruction attacks unless you add secure aggregation and differential privacy.
- A replacement for centralized training when you already own the data — if you can legally and cheaply move data to one warehouse, centralized training is simpler, faster to debug, and easier to monitor for drift.
- Edge inference alone — running a frozen model on-device is inference; FL additionally updates that model from local data without centralizing it.
FL sits at the intersection of distributed systems, optimization theory, and privacy engineering. The hard problems are statistical (data is not independent and identically distributed across clients) and operational (phones go offline mid-round).
Cross-device vs cross-silo federated learning
Two deployment shapes dominate production:
Cross-device FL
Millions of consumer devices — smartphones, wearables, browsers — participate intermittently. Each device holds a tiny, private dataset (your photos, typing history, voice commands). Clients are unreliable: batteries die, networks drop, OS kills background jobs. Training rounds must tolerate massive dropout; only a fraction of eligible devices complete each round.
Cross-silo FL
A small number of powerful nodes — hospitals, banks, telecom carriers — each hold large siloed datasets. Clients are fewer but more capable: GPUs, stable networks, SLAs. Coordination often uses secure enclaves or homomorphic encryption because gradient leakage across institutions carries regulatory risk even when raw rows never move.
The algorithmic core is similar; the systems engineering diverges. Cross-device optimizes for battery, uplink cost, and fair participation; cross-silo optimizes for throughput, audit trails, and contractual data-governance clauses.
FedAvg: the canonical training loop
Federated Averaging (FedAvg), introduced by McMahan et al., is the baseline most teams start from. One round looks like this:
- Server broadcasts the current global model weights
w_tto a sampled subset of clients. - Each selected client downloads
w_t, runs several epochs of SGD on local data, and computes an updateΔw_k = w_k - w_t(or sendsw_kdirectly). - Server aggregates client updates, typically a weighted average
by local sample count:
w_{t+1} = w_t + Σ (n_k / N) · Δw_k. - Repeat for hundreds or thousands of rounds until validation loss plateaus.
Key hyperparameters beyond the usual learning rate:
- Clients per round (C) — fraction of total population sampled. Too few increases variance; too many strains the coordinator.
- Local epochs (E) — how many passes each client makes before uploading. Higher E reduces communication but risks client drift when data is non-IID.
- Batch size (B) — on-device memory limits often cap this at 16–64.
FedAvg assumes clients are representative when weighted by n_k. That
assumption breaks when power users dominate updates or when only charging,
Wi-Fi-connected devices participate — a participation bias that
silently skews the global model.
Non-IID data: the central statistical challenge
Classical ML theory assumes training examples are drawn independently from the same distribution. Federated clients violate both assumptions:
- Label skew — one hospital sees mostly pediatric cases; another sees geriatric. Local models optimize for different label priors.
- Feature skew — keyboards in different languages have different character frequencies; cameras in different regions see different lighting and objects.
- Quantity skew — a few clients hold millions of examples; most hold dozens. Naive averaging lets whales dominate.
- Temporal drift — user behavior shifts seasonally; stale clients upload outdated gradients.
Mitigations include:
- FedProx — adds a proximal term penalizing deviation from the global model, limiting local overfitting.
- SCAFFOLD — corrects for client drift using control variates.
- Personalization layers — shared backbone + per-client head (similar in spirit to transfer learning).
- Clustered FL — train separate sub-models for client cohorts discovered via embedding similarity.
Always evaluate on a held-out centralized test set and per-client slices. A global accuracy of 92% can hide a cohort stuck at 60%.
Client selection, stragglers, and systems constraints
Production FL is as much distributed systems as statistics. The coordinator must:
- Sample clients fairly — round-robin by device ID hash avoids always picking the same fast phones. Cap per-client contribution per day to prevent dominance.
- Set round timeouts — wait for the first K completions, not all N invitations. Stragglers on 3G should not block the round.
- Compress updates — quantize floats to 8-bit, sparsify small gradients, or use sketching. Mobile uplink is the bottleneck, not server GPU.
- Respect device policy — train only on Wi-Fi, while charging, and with explicit user consent toggles. iOS and Android background limits kill naive implementations.
- Version models atomically — clients mid-download during a weight swap produce corrupt updates. Use version tags and reject stale uploads.
Observability differs from centralized training: you cannot inspect a unified training set. Log per-round participation rate, upload size distribution, local loss histograms, and anomaly scores on updates (poisoned gradients spike norm).
Privacy: secure aggregation and differential privacy
Raw data stays local, but model updates are not automatically private. Research shows gradients can reconstruct inputs, especially on small batch sizes and early training rounds. Production systems layer defenses:
Secure aggregation
Clients mask their updates with pairwise random seeds that cancel in sum. The
server learns only the aggregate, not any individual Δw_k. Requires
cryptographic setup per round and breaks if too many clients drop out — plan
for dropout thresholds.
Differential privacy (DP)
Clip per-client update norms, add calibrated Gaussian noise before aggregation, and track a privacy budget (ε, δ) across rounds. Stronger privacy usually costs accuracy; tune ε on a validation slice with legal/compliance input, not only ML metrics.
Trusted execution environments
Cross-silo setups sometimes aggregate inside ARM TrustZone or cloud confidential VMs so the coordinator operator never sees plaintext sums.
Privacy guarantees are composable with governance: document what is collected, retention periods, and opt-out behavior. A user who disables FL should receive the last global model, not a broken feature.
When to use federated learning
Good fits:
- Personalization on private user data (keyboard, camera, health, finance).
- Regulatory barriers to data pooling (HIPAA, GDPR data-minimization).
- Edge bandwidth too expensive to upload raw media at scale.
- Continual adaptation where centralized retraining lags real-world drift.
Poor fits:
- You already own a clean centralized dataset and need fast iteration.
- Model is huge relative to device RAM (multi-billion-parameter LLMs without aggressive compression).
- Labels require expert annotation unavailable on-device.
- Debugging complexity outweighs privacy benefit for a small user base.
Many products combine modes: pretrain centrally, fine-tune federated on-device, serve with on-device inference, and monitor centrally with federated evaluation metrics only.
Common anti-patterns
- Claiming FL is "fully private" without secure aggregation or DP — legal and security teams will correctly object.
- Ignoring participation bias — models trained only on flagship phones mis-serve budget devices.
- Unbounded local epochs — clients overfit their tiny datasets and upload destructive gradients.
- No update validation — accepting NaN weights or 100× norm spikes poisons the global model.
- Shipping identical rounds to all app versions — architecture mismatches corrupt aggregation.
- Evaluating only global metrics — fairness regressions across demographics stay invisible.
- Blocking the UI thread — on-device training must run in background workers with thermal throttling.
Production checklist
- Define threat model: honest-but-curious server vs malicious clients vs both.
- Prototype FedAvg in simulation with synthetic non-IID splits before device tests.
- Measure uplink bytes per round; set compression targets for cellular users.
- Implement client sampling, round timeouts, and per-client daily caps.
- Add secure aggregation or DP before any production rollout with sensitive data.
- Validate updates (finite norms, shape match, version tag) before aggregation.
- Log participation rate, straggler fraction, and per-cohort eval metrics.
- Provide user-visible consent, opt-out, and clear privacy policy language.
- Plan rollback: keep last-known-good global weights if a round degrades eval.
- Integrate with MLOps for model registry, staged rollout, and drift alerts.
Key takeaways
- Federated learning trains a shared model across decentralized data by aggregating local updates, not raw examples.
- FedAvg is the baseline loop — tune clients per round, local epochs, and weighting to balance communication vs convergence.
- Non-IID data is the norm; use FedProx, personalization, or clustering when local drift hurts global quality.
- Privacy requires engineering — secure aggregation and differential privacy, not just keeping data on-device.
- Combine FL with centralized pretraining, on-device inference, and federated evaluation for most real products.
Related reading
- Machine learning fundamentals explained — supervised learning, loss functions, bias-variance, and train/test splits
- Edge AI and on-device inference explained — model compression, NPU runtimes, latency budgets, and offline serving
- MLOps explained — experiment tracking, model registries, deployment, and production monitoring
- Model drift and concept drift explained — detecting when production data diverges from training assumptions