Guide
JAX fundamentals explained
Your research notebook runs fine on a laptop CPU, but scaling the same NumPy
code to eight TPU cores requires a rewrite in another framework — unless you
start in JAX. Google Research built JAX as a minimal layer over
XLA: write plain array code, then compose functional transformations for
gradients, compilation, batching, and multi-device parallelism. It powers
AlphaFold, PaLM pretraining experiments, and a growing slice of the Hugging
Face ecosystem via Flax. This guide covers JAX arrays and devices, the four
core transforms (grad, jit, vmap,
pmap), a Flax + Optax training loop, a Harbor Fleet demand
forecaster worked example, a framework decision table, common pitfalls, and a
practitioner checklist alongside our
PyTorch fundamentals guide,
TensorFlow fundamentals guide,
and
gradient descent guide.
What JAX is (and is not)
JAX is not a full deep-learning framework like PyTorch or TensorFlow.
It is a numerical computing library with a NumPy-compatible API
(jax.numpy) plus composable function transformations. You
typically pair JAX with Flax (neural network modules),
Optax (optimizers), and Orbax (checkpointing)
for end-to-end training — similar to how most PyTorch users depend on
Hugging Face Transformers rather than raw torch.nn alone.
Three properties define JAX's niche:
- Pure functions — transforms assume no hidden state or side effects; randomness flows through explicit PRNG keys.
- XLA compilation —
jitlowers Python to optimized accelerator kernels, often matching hand-tuned CUDA. - Composable transforms —
grad(jit(vmap(f)))is valid; order matters and is part of the design.
JAX arrays are immutable. In-place mutation (x[0] = 1) either
fails or creates silent copies depending on context — a cultural shock for
PyTorch users accustomed to mutating tensors in place.
Arrays, dtypes, and devices
Import NumPy operations from jax.numpy as jnp. Device placement
is explicit at array creation:
x = jnp.ones((4, 4), dtype=jnp.float32) defaults to the default
backend (GPU if available, else CPU). Use
jax.device_put(x, jax.devices('gpu')[0]) to move data.
jax.devices() lists all visible accelerators;
jax.process_count() reports nodes in multi-host TPU pods.
Random numbers
JAX has no global RNG state. Every random call takes a key and returns a new key plus samples:
key, subkey = jax.random.split(key); noise = jax.random.normal(subkey, shape)
Splitting keys inside jit-compiled code requires
jax.random.fold_in or passing keys as function arguments.
Forgetting this produces identical "random" batches every epoch — a common
first-week bug.
Dtypes and precision
JAX defaults to float32 on GPU/TPU. Enable mixed precision with
jax.config.update('jax_default_matmul_precision', 'bfloat16')
or explicit jnp.bfloat16 casts. TPU hardware favors bfloat16;
NVIDIA GPUs often prefer float16 with loss scaling (handled in Optax via
optax.scale_by_adam wrappers or third-party recipes).
The four core transformations
grad — automatic differentiation
grad(f) returns a function that computes the gradient of scalar
output f with respect to its first argument. For vector outputs,
use jax.vjp (reverse-mode) or jax.jvp (forward-mode).
Higher-order derivatives compose naturally:
grad(grad(f)) for Hessian-vector products in meta-learning.
Unlike PyTorch's dynamic autograd graph, JAX traces the function once (inside
jit) and differentiates the trace. Control flow must be
jax.lax.cond, scan, or fori_loop — plain
Python if on array values fails under jit because
both branches get traced.
jit — XLA compilation
@jax.jit compiles a function to an XLA executable on first call
(the "warmup" step). Subsequent calls with the same shapes reuse the cached
binary. Shape changes trigger recompilation — batch-size sweeps during
hyperparameter search can be slow unless you pad to fixed shapes.
Use jax.block_until_ready(out) when benchmarking; JAX dispatches
asynchronously and naive time.time() loops measure queue depth,
not kernel time.
vmap — automatic batching
vmap(f) maps f over a leading batch axis without
writing explicit loops. A single-example forward pass becomes a batched
forward pass — the same function serves inference and training. Combine with
grad for per-example gradients (useful in differential privacy
and some meta-learning setups).
pmap — data parallelism
pmap(f) replicates f across local devices, shards
inputs along the leading axis, and runs SPMD (single program, multiple data).
Gradients are automatically all-reduced across devices. For multi-host TPU
slices, pair pmap with jax.experimental.multihost_utils
or migrate to jax.sharding (GSPMD) for finer control in JAX 0.4+.
Flax modules and an Optax training loop
Flax models subclass nn.Module and define parameters in
setup() or inline via nn.Dense, nn.LayerNorm,
etc. Initialization is explicit:
params = model.init(rng, dummy_input)['params']
A minimal training step in pure JAX style:
- Define
loss_fn(params, batch)returning a scalar. - Wrap with
grad_loss = jax.grad(loss_fn). - Create an Optax optimizer:
tx = optax.adamw(learning_rate);opt_state = tx.init(params). - Inside
@jax.jit, computegrads = grad_loss(params, batch), thenupdates, opt_state = tx.update(grads, opt_state, params), thenparams = optax.apply_updates(params, updates).
Flax's TrainState dataclass bundles params, optimizer state, and
step counter — the idiomatic pattern in official examples. Checkpoint with
Orbax: orbax.checkpoint.PyTreeCheckpointer saves the full pytree
(nested dict of arrays) portably.
For transformers, use flax.nnx (Flax NNX, 2024+) for PyTorch-like
mutable modules, or Hugging Face FlaxAutoModel for pretrained
weights. The ecosystem is smaller than PyTorch but sufficient for research
reproduction and TPU-scale pretraining.
GPU, TPU, and sharding
On a single NVIDIA GPU, JAX behaves like accelerated NumPy once functions are
jit-compiled. Multi-GPU on one host uses pmap over
jax.local_devices() — typically 2–8 GPUs per machine.
TPU pods are JAX's home turf. Google Colab TPU runtimes expose a
TPU_NAME mesh; pmap across 8 cores is the default
tutorial pattern. For large models, sharding via
jax.sharding.NamedSharding and jax.experimental.mesh_utils
partitions parameters across devices (tensor parallelism) instead of only
batching inputs (data parallelism).
JAX does not ship a production serving stack comparable to TorchServe or TF Serving. Most teams export to ONNX or deploy Python workers behind batching proxies. Research and training are the sweet spot; inference at scale often converts checkpoints to another runtime.
Worked example: Harbor Fleet demand forecaster
Harbor Fleet operates 120 delivery vans across three cities. Operations wants a 7-day ahead demand forecast per depot to right-size staffing. A data scientist builds a small temporal model in JAX + Flax:
- Data: daily order counts per depot, 18 months history; features include day-of-week, holidays, and lag-7 values. Stored as a
jnp.ndarrayof shape(depots, days, features). - Model: a 2-layer GRU implemented in Flax (
nn.GRUCellunrolled withjax.lax.scanforjitcompatibility), outputting scalar demand per depot-day. - Loss: Huber loss on log-transformed counts — robust to spike days (Black Friday) without exploding gradients.
- Training: batch size 64 depots via
vmapover the depot axis inside ajit-compiled step; AdamW lr=1e-3, 200 epochs, PRNG key split per epoch for dropout masks. - Evaluation: MAPE on a held-out last 30 days; 8.2% average vs 11.4% for their legacy Prophet baseline.
- Deployment: nightly cron on a TPU v4-8 pod retrains in 4 minutes; exported NumPy weights feed a lightweight CPU inference script for the ops dashboard.
The team chose JAX because their Google Cloud contract includes TPU quota and the same code runs on a local GPU for debugging. The functional style forced explicit RNG handling — annoying at first, but it eliminated the "why did validation loss jump on rerun?" bugs they had with implicit global seeds.
Framework decision table
| Need | JAX | PyTorch | TensorFlow |
|---|---|---|---|
| Research flexibility, imperative style | Functional; steeper learning curve | Best fit — eager by default | Keras 3 eager mode |
| TPU training at scale | First-class pmap / sharding |
Supported via XLA bridge | Native TPU integration |
| LLM / transformer ecosystem | Flax + Hugging Face Flax models | Hugging Face, vLLM — largest hub | Smaller hub |
| Custom gradients / meta-learning | Composable grad, vjp, custom_vjp |
torch.autograd.Function |
tf.custom_gradient |
| Production serving | Export to ONNX or wrap in Python | TorchServe, TensorRT, Triton | TF Serving, TFLite |
| Mobile / edge | Limited native tooling | ExecuTorch, ONNX | TFLite mature on Android |
Choose JAX when you need TPU-scale training, composable transforms for research (differentiable physics, meta-learning, Bayesian methods), or your team already writes NumPy and wants accelerators without learning a new tensor API. Choose PyTorch for the broadest ecosystem and production path.
Common pitfalls
- Python control flow on traced values — use
jax.lax.condorscaninsidejit; plainif x > 0on arrays fails or traces both branches incorrectly. - Reusing PRNG keys — produces identical noise every call; always
splitkeys. - Measuring time without
block_until_ready— async dispatch makes naive timers lie. - Shape-changing recompilation — variable batch sizes in production cause constant re-jit; pad to fixed shapes.
- In-place mutation —
x.at[0].set(1)returns a new array; the old reference is unchanged. - Silent float64 promotion — mixing float64 NumPy arrays with float32 JAX arrays slows TPU kernels; cast explicitly.
- Debugging inside
jit—printshows tracers, not values; usejax.debug.printor disablejitwhile debugging. - Host-device sync in tight loops — calling
np.array(jax_array)every step kills throughput. - Assuming pmap = multi-node — single-host multi-GPU only unless you configure multi-host initialization.
- Skipping warmup — first
jitcall compiles; benchmark the second call onward.
Practitioner checklist
- Pin
jax,jaxlib, and CUDA/TPU driver versions — wheels are hardware-specific. - Structure code as pure functions taking params and batch; keep side effects outside
jit. - Thread PRNG keys explicitly through every stochastic operation.
- Run a CPU-only smoke test in CI with
JAX_PLATFORMS=cpu. - Profile with
jax.profileror TensorBoard after warmup, not on first call. - Use
flax.training.checkpointsor Orbax for reproducible checkpoint saves. - Validate exported predictions against a non-jit reference on one batch before scaling.
- Document sharding strategy when moving from single-GPU debug to TPU pod.
- Plan an export path (ONNX, NumPy weights, or SavedModel via jax2tf) before production deadlines.
- Read the official "JAX sharp bits" wiki page before shipping — it lists every footgun in one place.
Key takeaways
- JAX is NumPy on accelerators plus composable transforms — not a batteries-included DL framework.
grad,jit,vmap, andpmapare the four primitives to master before Flax.- Functional purity (explicit RNG, no mutation) enables optimization that imperative frameworks struggle to match.
- Pair JAX with Flax + Optax for neural network training; use Orbax for checkpoints.
- TPU training is JAX's strongest advantage; production serving usually requires export to another runtime.
Related reading
- PyTorch fundamentals explained — imperative alternative with the largest ecosystem
- TensorFlow fundamentals explained — Keras API and TFLite for mobile
- Deep learning explained — neural network concepts underlying all three frameworks
- Gradient descent explained — how optimizers use gradients that JAX computes