Guide

JAX fundamentals explained

Your research notebook runs fine on a laptop CPU, but scaling the same NumPy code to eight TPU cores requires a rewrite in another framework — unless you start in JAX. Google Research built JAX as a minimal layer over XLA: write plain array code, then compose functional transformations for gradients, compilation, batching, and multi-device parallelism. It powers AlphaFold, PaLM pretraining experiments, and a growing slice of the Hugging Face ecosystem via Flax. This guide covers JAX arrays and devices, the four core transforms (grad, jit, vmap, pmap), a Flax + Optax training loop, a Harbor Fleet demand forecaster worked example, a framework decision table, common pitfalls, and a practitioner checklist alongside our PyTorch fundamentals guide, TensorFlow fundamentals guide, and gradient descent guide.

What JAX is (and is not)

JAX is not a full deep-learning framework like PyTorch or TensorFlow. It is a numerical computing library with a NumPy-compatible API (jax.numpy) plus composable function transformations. You typically pair JAX with Flax (neural network modules), Optax (optimizers), and Orbax (checkpointing) for end-to-end training — similar to how most PyTorch users depend on Hugging Face Transformers rather than raw torch.nn alone.

Three properties define JAX's niche:

  • Pure functions — transforms assume no hidden state or side effects; randomness flows through explicit PRNG keys.
  • XLA compilationjit lowers Python to optimized accelerator kernels, often matching hand-tuned CUDA.
  • Composable transformsgrad(jit(vmap(f))) is valid; order matters and is part of the design.

JAX arrays are immutable. In-place mutation (x[0] = 1) either fails or creates silent copies depending on context — a cultural shock for PyTorch users accustomed to mutating tensors in place.

Arrays, dtypes, and devices

Import NumPy operations from jax.numpy as jnp. Device placement is explicit at array creation:

x = jnp.ones((4, 4), dtype=jnp.float32) defaults to the default backend (GPU if available, else CPU). Use jax.device_put(x, jax.devices('gpu')[0]) to move data. jax.devices() lists all visible accelerators; jax.process_count() reports nodes in multi-host TPU pods.

Random numbers

JAX has no global RNG state. Every random call takes a key and returns a new key plus samples:

key, subkey = jax.random.split(key); noise = jax.random.normal(subkey, shape)

Splitting keys inside jit-compiled code requires jax.random.fold_in or passing keys as function arguments. Forgetting this produces identical "random" batches every epoch — a common first-week bug.

Dtypes and precision

JAX defaults to float32 on GPU/TPU. Enable mixed precision with jax.config.update('jax_default_matmul_precision', 'bfloat16') or explicit jnp.bfloat16 casts. TPU hardware favors bfloat16; NVIDIA GPUs often prefer float16 with loss scaling (handled in Optax via optax.scale_by_adam wrappers or third-party recipes).

The four core transformations

grad — automatic differentiation

grad(f) returns a function that computes the gradient of scalar output f with respect to its first argument. For vector outputs, use jax.vjp (reverse-mode) or jax.jvp (forward-mode). Higher-order derivatives compose naturally: grad(grad(f)) for Hessian-vector products in meta-learning.

Unlike PyTorch's dynamic autograd graph, JAX traces the function once (inside jit) and differentiates the trace. Control flow must be jax.lax.cond, scan, or fori_loop — plain Python if on array values fails under jit because both branches get traced.

jit — XLA compilation

@jax.jit compiles a function to an XLA executable on first call (the "warmup" step). Subsequent calls with the same shapes reuse the cached binary. Shape changes trigger recompilation — batch-size sweeps during hyperparameter search can be slow unless you pad to fixed shapes.

Use jax.block_until_ready(out) when benchmarking; JAX dispatches asynchronously and naive time.time() loops measure queue depth, not kernel time.

vmap — automatic batching

vmap(f) maps f over a leading batch axis without writing explicit loops. A single-example forward pass becomes a batched forward pass — the same function serves inference and training. Combine with grad for per-example gradients (useful in differential privacy and some meta-learning setups).

pmap — data parallelism

pmap(f) replicates f across local devices, shards inputs along the leading axis, and runs SPMD (single program, multiple data). Gradients are automatically all-reduced across devices. For multi-host TPU slices, pair pmap with jax.experimental.multihost_utils or migrate to jax.sharding (GSPMD) for finer control in JAX 0.4+.

Flax modules and an Optax training loop

Flax models subclass nn.Module and define parameters in setup() or inline via nn.Dense, nn.LayerNorm, etc. Initialization is explicit:

params = model.init(rng, dummy_input)['params']

A minimal training step in pure JAX style:

  1. Define loss_fn(params, batch) returning a scalar.
  2. Wrap with grad_loss = jax.grad(loss_fn).
  3. Create an Optax optimizer: tx = optax.adamw(learning_rate); opt_state = tx.init(params).
  4. Inside @jax.jit, compute grads = grad_loss(params, batch), then updates, opt_state = tx.update(grads, opt_state, params), then params = optax.apply_updates(params, updates).

Flax's TrainState dataclass bundles params, optimizer state, and step counter — the idiomatic pattern in official examples. Checkpoint with Orbax: orbax.checkpoint.PyTreeCheckpointer saves the full pytree (nested dict of arrays) portably.

For transformers, use flax.nnx (Flax NNX, 2024+) for PyTorch-like mutable modules, or Hugging Face FlaxAutoModel for pretrained weights. The ecosystem is smaller than PyTorch but sufficient for research reproduction and TPU-scale pretraining.

GPU, TPU, and sharding

On a single NVIDIA GPU, JAX behaves like accelerated NumPy once functions are jit-compiled. Multi-GPU on one host uses pmap over jax.local_devices() — typically 2–8 GPUs per machine.

TPU pods are JAX's home turf. Google Colab TPU runtimes expose a TPU_NAME mesh; pmap across 8 cores is the default tutorial pattern. For large models, sharding via jax.sharding.NamedSharding and jax.experimental.mesh_utils partitions parameters across devices (tensor parallelism) instead of only batching inputs (data parallelism).

JAX does not ship a production serving stack comparable to TorchServe or TF Serving. Most teams export to ONNX or deploy Python workers behind batching proxies. Research and training are the sweet spot; inference at scale often converts checkpoints to another runtime.

Worked example: Harbor Fleet demand forecaster

Harbor Fleet operates 120 delivery vans across three cities. Operations wants a 7-day ahead demand forecast per depot to right-size staffing. A data scientist builds a small temporal model in JAX + Flax:

  1. Data: daily order counts per depot, 18 months history; features include day-of-week, holidays, and lag-7 values. Stored as a jnp.ndarray of shape (depots, days, features).
  2. Model: a 2-layer GRU implemented in Flax (nn.GRUCell unrolled with jax.lax.scan for jit compatibility), outputting scalar demand per depot-day.
  3. Loss: Huber loss on log-transformed counts — robust to spike days (Black Friday) without exploding gradients.
  4. Training: batch size 64 depots via vmap over the depot axis inside a jit-compiled step; AdamW lr=1e-3, 200 epochs, PRNG key split per epoch for dropout masks.
  5. Evaluation: MAPE on a held-out last 30 days; 8.2% average vs 11.4% for their legacy Prophet baseline.
  6. Deployment: nightly cron on a TPU v4-8 pod retrains in 4 minutes; exported NumPy weights feed a lightweight CPU inference script for the ops dashboard.

The team chose JAX because their Google Cloud contract includes TPU quota and the same code runs on a local GPU for debugging. The functional style forced explicit RNG handling — annoying at first, but it eliminated the "why did validation loss jump on rerun?" bugs they had with implicit global seeds.

Framework decision table

Need JAX PyTorch TensorFlow
Research flexibility, imperative style Functional; steeper learning curve Best fit — eager by default Keras 3 eager mode
TPU training at scale First-class pmap / sharding Supported via XLA bridge Native TPU integration
LLM / transformer ecosystem Flax + Hugging Face Flax models Hugging Face, vLLM — largest hub Smaller hub
Custom gradients / meta-learning Composable grad, vjp, custom_vjp torch.autograd.Function tf.custom_gradient
Production serving Export to ONNX or wrap in Python TorchServe, TensorRT, Triton TF Serving, TFLite
Mobile / edge Limited native tooling ExecuTorch, ONNX TFLite mature on Android

Choose JAX when you need TPU-scale training, composable transforms for research (differentiable physics, meta-learning, Bayesian methods), or your team already writes NumPy and wants accelerators without learning a new tensor API. Choose PyTorch for the broadest ecosystem and production path.

Common pitfalls

  • Python control flow on traced values — use jax.lax.cond or scan inside jit; plain if x > 0 on arrays fails or traces both branches incorrectly.
  • Reusing PRNG keys — produces identical noise every call; always split keys.
  • Measuring time without block_until_ready — async dispatch makes naive timers lie.
  • Shape-changing recompilation — variable batch sizes in production cause constant re-jit; pad to fixed shapes.
  • In-place mutationx.at[0].set(1) returns a new array; the old reference is unchanged.
  • Silent float64 promotion — mixing float64 NumPy arrays with float32 JAX arrays slows TPU kernels; cast explicitly.
  • Debugging inside jitprint shows tracers, not values; use jax.debug.print or disable jit while debugging.
  • Host-device sync in tight loops — calling np.array(jax_array) every step kills throughput.
  • Assuming pmap = multi-node — single-host multi-GPU only unless you configure multi-host initialization.
  • Skipping warmup — first jit call compiles; benchmark the second call onward.

Practitioner checklist

  • Pin jax, jaxlib, and CUDA/TPU driver versions — wheels are hardware-specific.
  • Structure code as pure functions taking params and batch; keep side effects outside jit.
  • Thread PRNG keys explicitly through every stochastic operation.
  • Run a CPU-only smoke test in CI with JAX_PLATFORMS=cpu.
  • Profile with jax.profiler or TensorBoard after warmup, not on first call.
  • Use flax.training.checkpoints or Orbax for reproducible checkpoint saves.
  • Validate exported predictions against a non-jit reference on one batch before scaling.
  • Document sharding strategy when moving from single-GPU debug to TPU pod.
  • Plan an export path (ONNX, NumPy weights, or SavedModel via jax2tf) before production deadlines.
  • Read the official "JAX sharp bits" wiki page before shipping — it lists every footgun in one place.

Key takeaways

  • JAX is NumPy on accelerators plus composable transforms — not a batteries-included DL framework.
  • grad, jit, vmap, and pmap are the four primitives to master before Flax.
  • Functional purity (explicit RNG, no mutation) enables optimization that imperative frameworks struggle to match.
  • Pair JAX with Flax + Optax for neural network training; use Orbax for checkpoints.
  • TPU training is JAX's strongest advantage; production serving usually requires export to another runtime.

Related reading