Guide
TensorFlow fundamentals explained
Warehouse pickers need a defect classifier that runs on Android handhelds with
no network latency. Training in Python is one problem; shipping a 4 MB model that
scores a photo in 30 ms on a Snapdragon chip is another.
TensorFlow is Google's open-source machine learning stack built
for exactly that path — from research notebook to
TensorFlow Lite on edge devices, with first-class
TPU support and the mature Keras high-level API
that most production teams touch daily. This guide covers TensorFlow 2's eager
execution model, building models with Keras layers, feeding data through
tf.data, the compile / fit training loop,
GPU and mixed-precision basics, exporting SavedModel and TFLite artifacts, a
Harbor Supply shelf-audit classifier worked example, a framework comparison table,
common pitfalls, and a practitioner checklist alongside our
PyTorch fundamentals guide,
deep learning primer,
and
Python fundamentals explainer.
What TensorFlow 2 provides
TensorFlow is a numerical computation library with automatic differentiation, optimized CPU/GPU/TPU kernels, and a production toolchain (TFX, TF Serving, TFLite). Since TensorFlow 2.0, eager execution is the default: operations run immediately like normal Python, which makes debugging with print statements and breakpoints straightforward. Keras — now the official high-level API — sits on top for layer composition, training loops, and callbacks.
The stack splits into layers you will encounter in order of abstraction:
Core modules you will touch daily
tensorflow— tensors, math ops,GradientTapefor custom trainingkeras—Model,Sequential, layers, losses, optimizers, metricstf.data—Datasetpipelines with map, batch, prefetch, cachetf.function— trace Python into optimized graphs for speedtf.lite— convert and quantize models for mobile and embedded
Unlike the graph-first TensorFlow 1.x era, you rarely build static graphs by
hand today. When you need speed, wrap hot paths in @tf.function and
TensorFlow traces them into XLA-compiled graphs automatically.
Tensors: shape, dtype, and placement
A tensor is an n-dimensional array with a fixed
shape and dtype. A batch of 64 grayscale
defect images at 128×128 is (64, 128, 128, 1) in NHWC
order (batch, height, width, channels) — TensorFlow's default channel ordering,
opposite of PyTorch's NCHW convention. Always verify layout when porting
weights between frameworks.
Create tensors with tf.constant, tf.zeros, or convert
from NumPy via tf.convert_to_tensor. Place computation on GPU with
tf.device('/GPU:0') or let TensorFlow auto-place. Mixed precision
(tf.keras.mixed_precision) runs matmul in float16 while keeping
sensitive ops in float32 — often a free 1.5–2× speedup on modern NVIDIA
GPUs with minimal accuracy loss.
Automatic differentiation with GradientTape
Keras handles gradients inside model.fit, but custom training loops
use tf.GradientTape: record forward ops inside a with
block, then call tape.gradient(loss, model.trainable_variables).
This is TensorFlow's equivalent of PyTorch autograd — explicit, flexible, and
necessary when you need GANs, reinforcement learning, or non-standard loss
combinations that Keras cannot express in one line.
Building models with Keras
The two entry points are Sequential (linear stack) and the Functional API (DAG with multiple inputs/outputs). For a simple CNN classifier:
model = keras.Sequential([
keras.layers.Input(shape=(128, 128, 1)),
keras.layers.Conv2D(32, 3, activation='relu'),
keras.layers.MaxPooling2D(),
keras.layers.Conv2D(64, 3, activation='relu'),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(3, activation='softmax')
])
Each layer is a callable that transforms tensor shape. Inspect output shapes
with model.summary() before training — a mismatch at the first
Dense layer is the most common startup failure. For transfer learning, swap the
head of a pretrained keras.applications model (MobileNetV3,
EfficientNet) and freeze base layers with layer.trainable = False
during initial fine-tuning.
Subclasses and custom layers
Subclassing keras.Model gives full control over call()
— useful for research prototypes. Production code usually prefers Functional API
or Sequential because SavedModel export and TFLite conversion work more reliably
on declarative graphs than on Python subclass methods with dynamic control flow.
tf.data: efficient input pipelines
GPUs starve when the CPU cannot load and preprocess fast enough.
tf.data.Dataset solves this with a composable pipeline:
- Source —
from_tensor_slices,list_files, orTFRecordDataset - Map — decode JPEG, resize, normalize (
dataset.map(fn, num_parallel_calls=AUTOTUNE)) - Batch — group examples; use
drop_remainder=Truefor TPU - Prefetch — overlap CPU preprocessing with GPU training (
prefetch(AUTOTUNE))
Cache small datasets in RAM with .cache() after the expensive decode
step. For large image corpora, write TFRecords once and read
shards in parallel — random access from thousands of loose JPEG files on disk
becomes the bottleneck long before the model does.
Pass the finished Dataset directly to model.fit(train_ds, validation_data=val_ds).
Keras consumes batches indefinitely; set steps_per_epoch when the
dataset size is not inferable.
The compile / fit training loop
Keras collapses the manual PyTorch seven-step loop into three calls:
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=20,
callbacks=[early_stop, checkpoint, tensorboard]
)
compile wires the optimizer, loss function, and metrics.
fit runs epochs, calling train_on_batch internally,
updating weights, and logging to history.history. Use
callbacks for production hygiene:
EarlyStopping— halt when validation loss plateausModelCheckpoint— save best weights by monitored metricReduceLROnPlateau— decay learning rate on stagnationTensorBoard— scalar and histogram logging
For fine-grained control (GANs, gradient accumulation), write a custom loop with
GradientTape and optimizer.apply_gradients — the
pattern mirrors PyTorch's explicit loop but stays in TensorFlow ops for XLA
fusion benefits.
GPU, TPU, and distributed training
TensorFlow detects GPUs automatically. List devices with
tf.config.list_physical_devices('GPU'). Enable memory growth
(set_memory_growth(gpu, True)) to avoid grabbing all VRAM at import
time — critical when multiple processes share one server.
TPU training uses TPUStrategy and requires data in
GCS buckets with batch sizes divisible by 128. This is TensorFlow's clearest
advantage over PyTorch for teams on Google Cloud. Multi-GPU on a single machine
uses MirroredStrategy; multi-worker uses MultiWorkerMirroredStrategy
with coordinated checkpointing.
tf.function and XLA
Decorate training steps with @tf.function to trace Python into a
static graph. First call is slow (tracing); subsequent calls reuse the graph.
Enable XLA (jit_compile=True on supported ops) for additional kernel
fusion — especially helpful on TPU and for inference serving.
SavedModel, TFLite, and serving
Training produces weights; production needs a portable artifact:
- SavedModel — full model + signatures for TF Serving; export with
model.export('path/')(Keras 3) ortf.saved_model.save - TFLite — mobile/embedded; convert with
TFLiteConverter.from_keras_model, optionally apply post-training or QAT quantization to shrink size 4× and speed inference - TensorFlow Serving — gRPC/REST server loading SavedModel for datacenter batch or online inference
Always run numerical parity tests after conversion: feed identical inputs to the Keras model and the TFLite interpreter; max absolute diff should stay below 1e-5 for float32 or within expected quantization error for int8. TFLite op coverage is not universal — stick to supported layers (Conv2D, DepthwiseConv, FullyConnected, ReLU, Softmax) when mobile deployment is the goal from day one.
Worked example: Harbor Supply shelf-audit classifier
Harbor Supply runs 40 warehouse lanes. Pickers photograph shelf fronts; a model flags missing facings, wrong SKU placement, and damaged packaging. Requirements: run offline on Android scanners, under 50 ms inference, model under 5 MB.
- Data: 18,000 labeled photos (3 classes: OK, misplaced, damaged) stored as TFRecords with embedded JPEG bytes.
- Model: MobileNetV3Small pretrained on ImageNet, fine-tuned head with 3-class softmax; input 128×128 grayscale to match scanner camera.
- Training:
tf.datapipeline with random flip/brightness augmentation; Adam lr=1e-4; 15 epochs with EarlyStopping on val_accuracy; mixed precision enabled. - Metrics: 94.2% val accuracy; recall on "damaged" class weighted 2× in loss to catch safety-critical defects.
- Export: TFLite float16 quantization → 3.8 MB bundle; 28 ms median inference on Pixel 6a via NNAPI delegate.
- Deployment: APK loads
model.tflitefrom assets; no server round-trip during aisle walks.
The team chose TensorFlow over PyTorch specifically for TFLite maturity and NNAPI integration — PyTorch's ExecuTorch is improving but Android deployment docs and tooling still lag for this use case in mid-2026.
Framework decision table
| Need | TensorFlow / Keras | PyTorch | JAX |
|---|---|---|---|
| Android / iOS edge deployment | TFLite mature, NNAPI/Core ML delegates | ExecuTorch growing; ONNX bridge | Limited mobile tooling |
| Google TPU training | Native TPU strategy, best docs | XLA bridge; workable | First-class pmap |
| LLM / transformer ecosystem | KerasHub, smaller hub vs HF PyTorch | Hugging Face default | Flax + HF; research-heavy |
| Beginner-friendly high-level API | compile/fit very approachable |
Explicit loop; more control | Functional; steep curve |
| Production ML pipelines (batch + serving) | TFX end-to-end (ExampleGen → Trainer → Pusher) | Custom + MLflow/W&B | Research-first |
| Dynamic research graphs | Eager + subclassing; export harder | Eager default; best flexibility | jit retracing constraints |
Pick TensorFlow when mobile TFLite, TPU, or TFX pipeline integration is on the critical path. Pick PyTorch when LLM fine-tuning, research agility, or the Hugging Face ecosystem dominates.
Common pitfalls
- NHWC vs NCHW confusion — wrong channel order silently hurts accuracy when porting weights.
- Forgetting to normalize inputs — ImageNet pretrained models expect specific mean/std scaling.
- steps_per_epoch mismatch — infinite datasets need explicit step counts or training runs forever.
- TFLite unsupported ops — exotic layers fail at convert time; design for mobile early.
- GPU memory not released — TensorFlow holds VRAM after OOM; restart kernel or enable memory growth.
- Validation shuffle — shuffling val data makes epoch metrics noisy; shuffle train only.
- Saving only weights vs full model —
save_weightsloses architecture; prefer SavedModel ormodel.save. - tf.function retracing — passing Python scalars that change shape each call triggers expensive retraces.
- Label dtype mismatch — sparse categorical crossentropy expects integer labels, not one-hot.
- Quantization without representative dataset — int8 post-training quant needs calibration samples.
Practitioner checklist
- Pin
tensorflowversion in requirements — CUDA/cuDNN wheels are version-locked. - Run
model.summary()and a single-batch forward pass before multi-hour training. - Build
tf.datawithprefetch(AUTOTUNE)and profile input pipeline separately (tf.data.experimental.AUTOTUNE). - Log train and validation metrics every epoch; save best checkpoint by val metric, not final epoch.
- Enable mixed precision only after float32 baseline converges correctly.
- Export SavedModel and run parity inference tests before declaring training done.
- If targeting mobile, prototype TFLite conversion in week one — not after full training.
- Set random seeds (
tf.random.set_seed) and document data splits for reproducibility. - Use TensorBoard or MLflow for experiment tracking across hyperparameter sweeps.
- Monitor GPU utilization; if below 80%, the input pipeline is the bottleneck.
Key takeaways
- TensorFlow 2 defaults to eager execution with Keras as the primary modeling API.
tf.datapipelines with prefetch keep GPUs fed; TFRecords scale better than loose files.compile+fit+ callbacks cover most supervised training without a custom loop.- TFLite and TPU support are TensorFlow's strongest differentiators for edge and Google Cloud.
- Always validate SavedModel and TFLite exports with numerical parity tests before deployment.
Related reading
- PyTorch fundamentals explained — the other major deep learning framework
- Deep learning explained — neural network concepts both frameworks implement
- Gradient descent explained — how optimizers update weights each step
- Edge AI and on-device inference explained — deploying models where TFLite shines