Guide

Recurrent neural networks (RNN, LSTM, GRU) explained

Recurrent neural networks (RNNs) were the dominant architecture for sequence data for two decades — language modeling, speech recognition, machine translation, and time series forecasting all leaned on networks that carry a hidden state from one timestep to the next. Vanilla RNNs are elegant but suffer from vanishing gradients on long dependencies. LSTM (Long Short-Term Memory) and GRU (Gated Recurrent Unit) networks introduced gating mechanisms that preserve information across hundreds of steps. Today, transformers dominate large-scale NLP, but RNNs remain relevant for low-latency edge inference, streaming sensor data, and understanding how modern deep learning evolved. This guide covers the RNN recurrence equation, why gradients vanish, LSTM and GRU gate mechanics, bidirectional and stacked variants, backpropagation through time (BPTT), and a practical decision framework for RNN vs transformer vs classical forecasting models.

Why sequence models need recurrence

Feed-forward networks treat each input independently. Sequences — sentences, audio waveforms, stock prices, sensor readings — have order and context: the meaning of word five depends on words one through four. A recurrent layer processes input xt at timestep t together with the previous hidden state ht-1, producing an updated hidden state ht and optionally an output yt.

The same weight matrices are shared across timesteps (weight tying), which makes RNNs parameter-efficient: a 2-layer LSTM with 256 hidden units has the same parameter count whether the sequence is 10 or 10,000 steps long. That sharing is also what lets the network generalize patterns regardless of position — a phrase detector trained on 50-token windows can run on 500-token streams at inference, modulo memory limits.

Three common RNN modes map to different tasks:

  • Many-to-one — sequence in, single label out (sentiment classification, sequence-level anomaly score).
  • One-to-many — single seed in, sequence out (image captioning decoder, music generation).
  • Many-to-many — aligned sequence pairs (machine translation with encoder-decoder, video frame labeling).

Vanilla RNN and the hidden state

At each timestep, a vanilla RNN computes:

ht = tanh(Whh ht-1 + Wxh xt + b)

The hidden state ht is a compressed summary of everything seen so far. Outputs can be read from ht directly (yt = Why ht) or passed to downstream layers. Stacking multiple RNN layers — feeding the output sequence of layer L into layer L+1 — builds hierarchical representations similar to depth in CNNs.

The problem is vanishing and exploding gradients. During backpropagation through time, gradients flow backward through the chain of tanh derivatives multiplied by Whh. If the largest singular value of Whh is below 1, gradients shrink exponentially with sequence length; above 1, they explode. Either way, vanilla RNNs struggle to learn dependencies separated by more than ~10–20 steps — "the cat, which already ate, was hungry" requires linking subject and verb across a relative clause.

Practical mitigations for vanilla RNNs include gradient clipping (cap norm at 1–5), careful initialization (orthogonal Whh), and ReLU activations — but gated architectures solved the problem more reliably.

LSTM: the cell state and three gates

Hochreiter and Schmidhuber's LSTM (1997) adds a separate cell state ct — a conveyor belt that runs through the sequence with minimal linear transformations, plus three sigmoid gates that control information flow:

Forget gate

ft = σ(Wf [ht-1, xt] + bf). Decides what fraction of the previous cell state to discard. Values near 0 erase; near 1 preserve. This is how LSTMs learn to drop irrelevant context ("reset" when a new sentence starts).

Input gate and candidate cell

it = σ(Wi [ht-1, xt] + bi) controls how much new information to write. c̃t = tanh(Wc [ht-1, xt] + bc) is the candidate content. Updated cell: ct = ft ⊙ ct-1 + it ⊙ c̃t.

Output gate

ot = σ(Wo [ht-1, xt] + bo); ht = ot ⊙ tanh(ct). The hidden state exposed to downstream layers is a filtered view of the cell state.

Because ct updates are largely additive (not repeated matrix multiplications through tanh), gradients can flow across long spans without vanishing. LSTMs became the workhorse for speech (Deep Speech), translation (pre-transformer Google NMT), and early language models.

GRU: fewer gates, similar performance

The GRU (Cho et al., 2014) merges forget and input gates into a single update gate zt and uses a reset gate rt to control how much past hidden state influences the candidate. GRUs have roughly two-thirds the parameters of an equivalent LSTM and often train faster with comparable accuracy on medium-length sequences.

Rule of thumb: start with GRU when prototyping sequence models on tabular or short-text data; switch to LSTM if you need slightly better long-range retention on noisy sequences (financial ticks, multi-speaker audio). On modern GPU stacks with large batches, the speed difference is usually minor — pick whichever hyperparameter-tunes more cleanly on your validation set.

Bidirectional and stacked RNNs

A bidirectional RNN runs two independent RNNs — one forward, one backward — and concatenates their hidden states at each timestep. That gives every position context from both past and future, which is invaluable for tagging tasks (named entity recognition, part-of-speech) where yt depends on words after t. Bidirectional layers are not causal: you cannot use them in autoregressive decoding without cheating (seeing future tokens).

Stacked (deep) RNNs place layer k+1 on top of the full output sequence of layer k. Two to four layers is typical; beyond that, training becomes unstable without residual connections or layer normalization. Transformers replaced deep RNN stacks with self-attention partly because attention parallelizes across sequence length — RNNs are inherently sequential, limiting GPU utilization on long batches.

Training: BPTT, truncation, and teacher forcing

Backpropagation through time (BPTT) unrolls the RNN across T timesteps and applies standard backprop on the unrolled graph. Memory grows linearly with T, so practitioners use truncated BPTT — unroll only k steps (e.g. 35–70), detach the hidden state, and continue forward. This approximates full BPTT and works well when dependencies shorter than k dominate.

Teacher forcing feeds the ground-truth previous token as input during training instead of the model's own prediction. It speeds convergence for sequence generation but creates exposure bias: at inference the model sees its own (possibly wrong) outputs. Scheduled sampling — gradually mixing model predictions into training inputs — mitigates this for decoders.

Regularization matches other deep networks: dropout on recurrent connections (variational dropout locks the same mask across timesteps), weight decay, and early stopping on validation loss. Normalize input features for time series (z-score per channel) and use walk-forward validation — random train/test splits leak future information into past windows.

RNN vs CNN vs transformer

Architecture Strengths Weaknesses Typical use today
Vanilla RNN / GRU / LSTM Low memory per step, streaming-friendly, small models Sequential training, limited parallelization, weaker on very long context Edge IoT, on-device keystroke biometrics, simple forecasting baselines
1D CNN on sequences Parallelizable, local pattern detection Fixed receptive field per layer; dilated convs needed for long range Audio classification, ECG anomaly, as transformer patch embedders
Transformer Global attention, scales with data and compute, SOTA NLP/vision O(n²) attention cost, needs large data, KV cache memory at inference LLMs, speech (Whisper), long-document QA, most greenfield NLP
ARIMA / Prophet / XGBoost Interpretable, fast on small tabular series, no GPU required Weak on multivariate nonlinear dynamics, unstructured text Ops forecasting, KPI dashboards, baseline before neural models

Where RNNs still earn their place

Transformers won the benchmark war, but RNNs are not obsolete:

  • Streaming inference — process one timestep at a time with constant memory; transformers need the full context window (or a growing KV cache).
  • Tiny models on microcontrollers — a GRU with 32 units fits where a 7B-parameter transformer cannot.
  • Encoder bottlenecks in hybrid systems — compress a long sensor stream into a fixed vector before a downstream classifier.
  • Teaching and debugging — RNNs make the notion of "memory across time" explicit; understanding gates clarifies why attention was invented.
  • Legacy production pipelines — many speech and fraud systems shipped LSTMs before 2020; migration cost may exceed marginal accuracy gains.

For new NLP projects, default to pretrained transformers and fine-tune. Reach for RNNs when latency, memory, or true streaming constraints rule out attention-based models.

Common mistakes

  • Not shuffling sequences correctly — for independent series, shuffle series IDs, not individual timesteps within a series.
  • Using bidirectional RNNs for forecasting — future data leaks into past predictions; use unidirectional or causal convolutions.
  • Ignoring input scale — unnormalized features make gate saturations likely; tanh and sigmoid saturate at extremes.
  • Evaluating on shuffled time splits — always validate with chronologically held-out windows.
  • Expecting RNNs to beat transformers on raw text at scale — pre-trained context windows and subword tokenization dominate modern NLP.
  • Forgetting hidden-state initialization — carry state across batch boundaries only when sequences are truly continuous; otherwise zero-init.

Decision table: which sequence model?

Your situation Recommended starting point
Greenfield NLP, >10k labeled examples, GPU available Pretrained transformer (BERT/GPT-class) + fine-tune or prompt
Multivariate time series, <5k rows, need interpretability Gradient boosting on lag features, then try GRU if nonlinear gaps remain
Real-time sensor stream, <1 ms per step on CPU Small GRU or 1D CNN; quantize to INT8
Sequence labeling (NER, POS) with medium text Bidirectional LSTM-CRF or transformer token classifier
Autoregressive generation (chars, MIDI) Transformer decoder if compute allows; else stacked GRU with scheduled sampling
Very long documents (>8k tokens) without fine-tune budget Chunk + retrieve (RAG) with transformer embeddings, not vanilla RNN

Production checklist

  • Confirm task is truly sequential — tabular IID rows may not need RNNs at all.
  • Normalize inputs; store scaler params for inference parity.
  • Use walk-forward or grouped splits; never leak future timesteps into training.
  • Clip gradients (max norm 1–5) and monitor for NaN activations early in training.
  • Pick hidden size and layers via validation loss, not parameter count alone.
  • Benchmark against a naive baseline (last-value, moving average, logistic regression).
  • For deployment, export with fixed sequence chunking or stateful inference API.
  • Log hidden-state norm distributions in production to catch drift.
  • Document max trained sequence length; behavior beyond it is extrapolation.
  • Re-evaluate against transformers when latency budget or hardware allows.

Key takeaways

  • RNNs maintain a hidden state across timesteps with shared weights.
  • Vanilla RNNs suffer vanishing gradients; LSTM/GRU gates preserve long-range information.
  • Bidirectional layers use future context — great for tagging, wrong for causal forecasting.
  • BPTT trains unrolled graphs; truncated BPTT trades exactness for memory.
  • Transformers dominate large-scale NLP, but RNNs remain valuable for streaming, edge, and legacy pipelines.

Related reading