Guide
Focal loss explained
Harbor Analytics trained a ResNet-based classifier to flag chargebacks before settlement. Only 0.4% of transactions were positive — classic class imbalance. Standard cross-entropy with inverse-frequency class weights produced 99.1% accuracy and validation ROC-AUC of 0.91, but at the 0.5 probability threshold recall on held-out fraud was 0.38: the model learned to be cautious on rare positives while confidently labeling millions of easy negatives. Analysts still reviewed thousands of false alarms daily. Switching to focal loss (Lin et al., 2017) with γ=2 and α=0.75 down-weighted well-classified negatives, forced gradients toward hard misclassified examples, and lifted recall to 0.71 at the same precision operating point — without resampling the training set. Focal loss is now the default objective in imbalanced vision, object detection (RetinaNet), and many tabular neural heads. This guide covers the math, how γ and α interact, PyTorch and Keras patterns, when to prefer focal loss over class weights or SMOTE, a Harbor Analytics worked example, a method decision table, common pitfalls, and a production checklist.
What focal loss fixes
In balanced classification, cross-entropy assigns similar gradient magnitude to every misclassified example. On imbalanced data, easy negatives dominate: a model that outputs p(fraud)=0.01 on 99.6% of rows contributes tiny per-example loss but enormous aggregate gradient because there are so many of them. The network spends most of each epoch learning to push already-easy negatives even lower instead of fixing the few hard positives and borderline cases that actually matter.
Focal loss adds a modulating factor (1 − pt)γ that shrinks the loss contribution of well-classified examples. When the model is confident and correct (pt near 1), the factor approaches zero and the example barely updates weights. When the model is wrong or uncertain (pt low), the factor stays large and the example drives learning. γ controls how aggressively easy examples are suppressed; γ=0 recovers weighted cross-entropy.
The focal loss formula
For binary classification, define pt as the model's estimated probability for the true class:
p_t = p if y = 1
p_t = 1 − p if y = 0
Focal loss is:
FL(p_t) = −α_t · (1 − p_t)^γ · log(p_t)
αt is a class-balancing weight (typically α for the positive class and 1−α for the negative). γ ≥ 0 is the focusing parameter. The paper's defaults for RetinaNet object detection were γ=2 and α=0.25 on the rare foreground class — but tabular and fraud use cases often need higher α on the minority class (0.5–0.75) because prevalence is even more extreme than in COCO detection.
Multiclass focal loss applies the same modulating factor per class logit, usually with softmax probabilities. The key intuition is unchanged: reduce the influence of examples the model already gets right.
How γ and α interact
γ (gamma) controls hard-example emphasis. At γ=0, focal loss equals α-weighted cross-entropy. At γ=2 (the most common starting point), an example with pt=0.9 contributes 100× less loss than under plain cross-entropy; an example with pt=0.5 contributes 4× less. Higher γ (3–5) can help when easy negatives absolutely swamp training, but values above 5 often destabilize optimization and produce collapsed logits.
α (alpha) reweights class priors, similar to
class_weight in scikit-learn or
weight in PyTorch cross-entropy. It addresses imbalance at
the class level; γ addresses imbalance at the example difficulty
level. You usually need both on severely skewed data: α lifts the minority
class baseline contribution while γ prevents the majority class from
drowning gradients anyway.
Tune γ and α on a validation set using precision-recall metrics, not accuracy. Plot PR curves at several (γ, α) pairs; the operating point your business cares about (e.g. recall ≥ 0.65 at precision ≥ 0.40) should drive selection.
PyTorch and TensorFlow patterns
PyTorch does not ship focal loss in torch.nn, but
torchvision.ops.sigmoid_focal_loss implements the binary
version used in detection. For general classification:
import torch.nn.functional as F
bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
p = torch.sigmoid(logits)
p_t = p * targets + (1 - p) * (1 - targets)
focal = alpha * (1 - p_t) ** gamma * bce
loss = focal.mean()
Apply αt per-class by multiplying the positive and negative
branches separately. In Keras, community implementations wrap
binary_crossentropy with the same modulating factor; TensorFlow
Addons historically shipped tfa.losses.SigmoidFocalCrossEntropy.
Always verify reduction mode: detection training often sums
over all anchors then normalizes by the number of positives; classification
usually means over the batch.
When using mixed precision, compute focal loss in float32 even if forward pass is float16 — (1 − pt)γ with γ=2 amplifies numerical error in low-precision logits.
RetinaNet and object detection
Focal loss was introduced to fix extreme foreground-background imbalance in one-stage object detectors. A single 800×600 image can yield 100,000 anchor boxes with only a handful of positive (object) matches. Without focal loss, cross-entropy training collapses: the detector learns to predict background everywhere. RetinaNet paired focal loss with a feature pyramid network and matched two-stage detector accuracy at faster inference.
Detection uses sigmoid focal loss per class (multi-label style) rather than softmax, because objects can overlap in class membership handling. The same principle applies to semantic segmentation with heavy class skew (rare defect pixels in industrial vision) and to CNN heads on medical imaging where lesion pixels are a tiny fraction of the image.
Harbor Analytics chargeback classifier
Harbor's dataset: 2.4M card transactions, 9,600 chargebacks (0.4% positive). Features: merchant category, amount z-score, device fingerprint cluster, velocity counters. Baseline: ResNet-style MLP with weighted cross-entropy (weights 1 and 249). Training accuracy 99.1%, PR-AUC 0.44, recall@precision=0.5 was 0.38.
After switching to focal loss (γ=2, α=0.75 on positives): training accuracy dropped to 97.2% (expected — the model stopped chasing easy negatives), PR-AUC rose to 0.58, recall@precision=0.5 hit 0.71. They combined focal loss with label smoothing ε=0.05 on the positive class only, which trimmed overconfident false positives on borderline merchants. Production threshold was moved to 0.35 based on validation cost matrix (false negative 8× more expensive than false positive). No SMOTE or oversampling was used — focal loss alone was sufficient for this architecture.
Focal loss vs alternatives
| Technique | What it does | Best when |
|---|---|---|
| Class-weighted CE | Scales loss by inverse frequency | Mild imbalance (10:1), linear models, quick baseline |
| SMOTE / oversampling | Synthetic minority examples | Tabular + tree models; risky on high-dim vision without care |
| Focal loss | Down-weights easy examples dynamically | Deep nets, detection, severe imbalance, hard-example mining |
| Label smoothing | Softens targets | Overconfidence / calibration; combine with focal at low ε |
| Threshold tuning | Post-hoc decision boundary | Always — no loss replaces picking the right cutoff |
See the broader loss functions guide for Huber, hinge, and ranking objectives. Focal loss is not a substitute for clean labels, stratified splits, or proper cross-validation.
Common pitfalls
- γ too high. Values above 5 often zero out most gradients and stall learning; start at 2.
- Ignoring α. γ alone does not fix class prior skew; set α from prevalence or tune on validation.
- Optimizing accuracy. Focal loss will lower accuracy while improving PR-AUC — that is often correct.
- Double-counting imbalance. Using focal loss + heavy SMOTE + extreme class weights can over-correct.
- Wrong reduction in detection. Normalizing by batch size instead of positive count skews anchor training.
- FP16 without loss scaling. Small pt values underflow; keep loss computation in float32.
- Skipping threshold tuning. Focal loss changes score distributions; re-derive the production cutoff.
- Applying to regression. Focal loss is for classification; use Huber or quantile loss for continuous targets.
Production checklist
- Establish baseline PR-AUC and recall@business-precision with weighted cross-entropy.
- Grid-search γ ∈ {0, 1, 2, 3} and α ∈ {0.25, 0.5, 0.75} on stratified validation only.
- Plot PR curves for top three (γ, α) pairs; pick by cost matrix, not AUC alone.
- Verify focal implementation against a reference (torchvision unit test or paper appendix).
- Compute loss in float32 when using automatic mixed precision.
- Log per-epoch minority-class recall and majority-class specificity separately.
- Re-tune decision threshold after switching loss; do not reuse the old 0.5 cutoff.
- Ablation: focal only vs focal + label smoothing vs focal + light oversampling.
- Monitor score distribution drift in production; refit threshold quarterly.
- Document γ, α, and threshold in the model card for reproducibility.
Key takeaways
- Focal loss down-weights easy examples via (1 − pt)γ, focusing gradients on hard misclassifications.
- γ=2 and tuned α are strong defaults for imbalanced deep classifiers; detection often uses α=0.25, fraud may need α=0.75.
- Complements class weights rather than replacing them — α handles class frequency, γ handles example difficulty.
- RetinaNet proved it at scale for object detection; the same math applies to fraud, medical screening, and defect vision.
- Always re-tune thresholds and evaluate with precision-recall, not accuracy.
Related reading
- Cross-entropy explained — the base loss focal loss modulates
- Class imbalance explained — SMOTE, weights, and evaluation metrics
- Loss functions explained — choosing objectives across task types
- Label smoothing explained — complementary calibration technique