~/blog
LSTM Training Process
Every gate covered so far — forget, input, candidate memory, output — has weights that need to be learned, and learning them means running backpropagation through the same unrolled structure that made vanilla RNNs fail. The mechanics of training an LSTM are exactly BPTT (backpropagation through time), the same procedure introduced for vanilla RNNs — what changes is why it now works.
Anchor: 10 days of stock prices, training a model to predict each next day from the one before.
prices = [100,102,105,103,108,110,107,112,115,113]
# target: predict day 11BPTT (Backpropagation Through Time)
Training unrolls the LSTM into T copies of the same cell — one per timestep — sharing the same weight matrices (Wf, Wi, WC, Wo). A forward pass runs through all T steps, caching every gate value along the way. The loss is computed (at the final timestep, or accumulated across timesteps), and backprop then flows backward through all T unrolled copies, accumulating gradients into the shared weights at every step it passes through.
Why LSTM BPTT Doesn't Vanish
The key gradient path through the cell state, established in post 04: ∂Cₜ/∂Cₜ₋₁ = fₜ — a single multiplication by the forget gate's value, not a product of a weight matrix and a tanh derivative like the vanilla RNN's ∂hₜ/∂hₜ₋₁ = Wₕ × tanh'(zₜ).
Because fₜ is learned and can sit close to 1 when long-range memory matters, the cumulative product across many steps decays far more slowly. Compare directly, both over 5 timesteps:
- LSTM cell-state path, fₜ ≈ 0.9: 0.9⁵ = 0.5905
- Vanilla RNN, factor ≈ 0.7: 0.7⁵ = 0.1681
The LSTM's gradient after 5 steps retains roughly 59% of its original magnitude; the vanilla RNN's retains under 17%. Over longer sequences the gap compounds dramatically — this is the mechanism, not just a claim, behind why LSTMs handle long-range dependencies that vanilla RNNs can't.
Gradient to a Weight — Chain Rule Expansion
The cell-state path explains why gradients survive across timesteps, but training actually needs ∂L/∂Wf — the gradient with respect to a real weight — not just ∂Cₜ/∂Cₜ₋₁. Chain rule expands that all the way from the loss down to Wf:
∂L/∂Wf = ∂L/∂h × ∂h/∂C × ∂C/∂f × ∂f/∂zf × ∂zf/∂Wf
Using the anchor's simplified 1-dim LSTM at t=2 (day 3 of the sequence, seed=42 weights, predicting normalized price y=-0.9455 from h=0.0079):
- ∂L/∂h = 2(pred − y) = 2(0.0079 − (−0.9455)) = 1.9068 (MSE derivative)
- ∂h/∂C = o × (1 − tanh(C)²) = 0.4903 × (1 − 0.0161²) ≈ 0.4900 → ∂L/∂C = 1.9068 × 0.4900 = 0.9347
- ∂C/∂f = Cₜ₋₁ = 0.0204 (the forget gate's derivative is just the previous cell state) → ∂L/∂f = 0.9347 × 0.0204 = 0.0191
- ∂f/∂zf = f × (1 − f) = 0.5019 × (1 − 0.5019) = 0.2500 → ∂L/∂zf = 0.0191 × 0.2500 = 0.0048
- ∂zf/∂Wf = inp = [h₂, x₂] = [0.0098, −0.5253] → ∂L/∂Wf = 0.0048 × [0.0098, −0.5253] = [0.0000467, −0.002510]
Weight update with lr=0.01: Wf_new = Wf − lr × ∂L/∂Wf = [0.049671, −0.013826] − 0.01×[0.0000467, −0.002510] = [0.049671, −0.013801]
The second component of Wf shifts from −0.013826 to −0.013801 — a small nudge in the direction that reduces this timestep's error, exactly what gradient descent is supposed to do.
This is the trace of a single weight in a single timestep. In practice, gradients from every timestep accumulate into the same shared Wf before the update — the training loop below repeats this trace T times per epoch.
| Phase | Formula | Values substituted | Result |
|---|---|---|---|
| Output error | ∂L/∂h = 2(pred − y) | 2(0.0079 − (−0.9455)) | 1.9068 |
| Through output gate | ∂L/∂C = ∂L/∂h × o(1−tanh²C) | 1.9068 × 0.4903×(1−0.0161²) | 0.9347 |
| Through forget gate value | ∂L/∂f = ∂L/∂C × Cₜ₋₁ | 0.9347 × 0.0204 | 0.0191 |
| Through sigmoid derivative | ∂L/∂zf = ∂L/∂f × f(1−f) | 0.0191 × 0.5019×(1−0.5019) | 0.0048 |
| Into the weight | ∂L/∂Wf = ∂L/∂zf × inp | 0.0048 × [0.0098, −0.5253] | [4.7e-5, −2.51e-3] |
Training Loop
- Initialize h₀ = 0, C₀ = 0
- For each timestep t = 1 to T: run the LSTM cell forward, caching all gate values (fₜ, iₜ, C̃ₜ, oₜ, Cₜ, hₜ) — they're needed for the backward pass
- Compute the loss at timestep T (single-value prediction) or accumulated across all timesteps (sequence prediction)
- BPTT: compute gradients backward from timestep T to timestep 1, accumulating into the shared weight matrices
- Update all weights (Wf, Wi, WC, Wo and their biases) with an optimizer — Adam is the standard choice
- Repeat for every sequence in the mini-batch, then average gradients before the weight update
Truncated BPTT (Practical)
Full BPTT across T=1000 timesteps requires storing all 1000 sets of cached activations simultaneously — memory cost that scales linearly with sequence length and becomes prohibitive for very long sequences. The practical fix is truncated BPTT: process k₁ steps forward, backpropagate through only the most recent k₂ steps, then detach the hidden and cell state from the computation graph before continuing.
Concretely, on a 10-step sequence: process all 10 steps forward, but only backpropagate through the last 5 — the earlier 5 steps still contributed to the forward computation (and therefore to what h and C carry forward), but their gradient contribution is dropped rather than tracked. This is exactly what PyTorch's .detach() does — it lets the forward computation continue using a tensor's value while cutting off the backward graph at that point.
Loss for Sequence Prediction
Predict a single value (e.g., day 11's price only): loss = MSE(prediction_at_T, target) — one scalar loss from the final timestep's output.
Predict the full sequence (e.g., every day's next-day price): loss = Σₜ MSE(prediction_t, target_t) — sum (or mean) of per-timestep squared errors across all T steps, giving the network a training signal at every position, not just the last one.
On the anchor, single-value prediction trains only on day 11's error; full-sequence prediction trains on all 9 next-day predictions (day 2 through day 10) simultaneously, which is what the code below actually computes.
Code
import numpy as np
# Simplified 1-dim LSTM training demo
def sigmoid(z): return 1/(1+np.exp(-z))
def sigmoid_d(z): s=sigmoid(z); return s*(1-s)
prices = np.array([100,102,105,103,108,110,107,112,115,113], dtype=float)
prices_n = (prices - prices.mean()) / prices.std() # normalize
# Build X (9 inputs) and y (9 targets: predict next)
X, y = prices_n[:-1], prices_n[1:]
# Single-dim LSTM (for illustration — not a full implementation)
np.random.seed(42)
lr = 0.01
# Weights (4 gates × 2 params: [h, x])
W = np.random.randn(4, 2) * 0.1 # [Wf, Wi, WC, Wo] rows
b = np.zeros(4)
losses = []
for epoch in range(3):
h, C = 0.0, 0.0
loss = 0.0
for t in range(len(X)):
inp = np.array([h, X[t]])
f, i = sigmoid(W[0]@inp+b[0]), sigmoid(W[1]@inp+b[1])
C_tilde = np.tanh(W[2]@inp+b[2])
C = f*C + i*C_tilde
o = sigmoid(W[3]@inp+b[3])
h = o*np.tanh(C)
pred = h # output is hidden state (simplified)
loss += (pred - y[t])**2
losses.append(loss/len(X))
print(f"Epoch {epoch+1}: MSE = {loss/len(X):.4f}")Epoch 1: MSE = 0.8451
Epoch 2: MSE = 0.8451
Epoch 3: MSE = 0.8451This loop is a forward-only illustration — the weights W are never updated because no backward pass is implemented, so the loss is identical every epoch. A real training loop would compute gradients via BPTT (steps 3–4 above) and apply an optimizer update after every epoch, which is what would actually change the loss over time.
Related Concepts
Where this builds from: All four gates (posts 03–05) supply the forward computation that gets unrolled here. BPTT itself was introduced for vanilla RNNs in post 01 — this post shows why the same backward procedure behaves fundamentally differently once the cell state's additive update replaces the RNN's multiplicative one. Adam, the standard optimizer for LSTM training, is covered in the optimizers section of this series.
Where this leads: LSTM variants (post 07) modify this same trained structure — peephole connections, coupled forget/input gates — while GRU (post 08) simplifies the whole cell to reduce the number of weights that need this training loop applied to them.
Honest Limitations
Full BPTT on sequences longer than roughly T=1000 timesteps requires storing all T sets of cached gate activations in memory simultaneously — memory cost that grows linearly with sequence length, often exceeding available GPU memory well before compute becomes the bottleneck. Truncated BPTT is the standard practical fix, at the cost of gradients that only ever see a limited window of recent history.
Models trained with teacher forcing (always fed the true previous value during training) suffer from exposure bias at inference — at inference the model feeds its own (possibly imperfect) predictions back in as the next input, a distribution the model never saw during training, so errors can compound in ways that training loss never captured.
LSTM training is inherently slower than Transformer training on the same data, independent of model size — each timestep's forward pass depends on the previous timestep's hidden and cell state, so the T steps of a sequence cannot be computed in parallel on a GPU. A Transformer's attention mechanism has no such recurrence, so it processes all T positions simultaneously; on long sequences this makes Transformer training substantially faster even when parameter counts are comparable.
Test Your Understanding
-
Why does the cell-state gradient path (∂Cₜ/∂Cₜ₋₁ = fₜ) not suffer the same 0.7¹⁰⁰ ≈ 0 collapse that the vanilla RNN's hidden-state path does, even though both are recurrent structures unrolled over T timesteps?
-
Using the comparison in this post (0.9⁵ = 0.5905 for LSTM vs 0.7⁵ = 0.1681 for vanilla RNN), extrapolate both to T=20 timesteps. Which architecture's gradient is closer to fully vanished at that point?
-
In the anchor's training loop, MSE stays exactly 0.8451 across all 3 epochs. Identify specifically which step from the 6-step training loop described earlier is missing from the code, and explain why its absence causes the loss to stay constant.
-
A team training an LSTM on 5,000-timestep sensor sequences runs out of GPU memory during backpropagation. Using truncated BPTT with k₁=100, k₂=20, describe what happens to the gradient signal for a sensor event that occurred 80 steps before the current truncation window's start.
-
A model trained entirely with teacher forcing achieves very low training loss but performs poorly at inference time, where its own predictions compound into increasingly wrong inputs over a long generated sequence. What does this suggest about a mismatch between the training procedure's loss and the metric that actually matters at deployment?