~/blog

Exploding Gradient Problem

Jul 1, 20267 min readBy Mohammed Vasim
deep-learningneural-networksmachine-learningrepresentation-learning

The vanishing gradient problem (section 3, post 01) described what happens when weights are too small: gradients shrink layer by layer until the first layers receive no signal. The exploding gradient problem is the mirror image: weights are too large, and gradients grow exponentially through the network until they overflow — weights become NaN or ±∞ and training collapses.

Both problems arise from the same multiplication structure of backpropagation. When backpropagation multiplies gradient by weight at every layer, the compound effect of many multiplications either shrinks everything (|w| < 1) or amplifies everything (|w| > 1).

Anchor: 5-layer network with all weights initialized to 2.0. Input x=1.0.


Why It Happens — Numerical Example

Forward pass (no activation function, for clarity):

a₁ = 2.0 × 1 = 2

a₂ = 2.0 × 2 = 4

a₃ = 2.0 × 4 = 8

a₄ = 2.0 × 8 = 16

a₅ = 2.0 × 16 = 32

Output = 32.

Backward pass (assume δ₅ = 1.0 at the output):

δ₄ = w × δ₅ = 2.0 × 1.0 = 2.0

δ₃ = w × δ₄ = 2.0 × 2.0 = 4.0

δ₂ = w × δ₃ = 2.0 × 4.0 = 8.0

δ₁ = w × δ₂ = 2.0 × 8.0 = 16.0

∂L/∂W₁ = δ₁ × x = 16.0 × 1.0 = 16.0

With a learning rate of 0.1, the weight update for W₁ is 0.1 × 16 = 1.6 — changing the weight by 80% of its initial value in one step.

Extrapolation to 20 layers: gradient at layer 1 = 2¹⁹ = 524,288. Weight update = 52,428. The weight changes by orders of magnitude in a single step. The network diverges.

Exploding Gradient — Gradient Grows Layer-by-Layer (w=2.0) δ at each layer L5 δ=1 L4 δ=2 L3 δ=4 L2 δ=8 L1 δ=16 20 layers: δ₁ = 2¹⁹ = 524,288 → NaN weights ← gradient grows as we move from output toward input

When Exploding Gradients Occur

Deep feedforward networks with large initialization. Any network deeper than ~10 layers with weights sampled from N(0, 1) (not He or Xavier) will encounter this.

RNNs and LSTMs. The most notorious victim. Backpropagation through time (BPTT) unrolls the recurrent connections across time steps — a sequence of length 100 is equivalent to a 100-layer network. With the same weight matrix applied at every step, gradient growth is w¹⁰⁰. RNNs for long sequences cannot be trained without gradient clipping.

Without gradient clipping or weight regularization. Both are standard in RNN training. Feedforward networks with proper initialization (He/Xavier) are less affected.


Diagnosing Exploding Gradients

Symptom 1: Training loss suddenly goes to NaN or ∞ after a few epochs. The model was learning, then weights became so large that outputs overflow.

Symptom 2: Weights grow to very large magnitudes in the first 2–3 epochs. |W| increases monotonically rather than oscillating around a solution.

Symptom 3: Gradient norm grows consistently across epochs rather than stabilizing. Monitor gradient norm per batch: ||∇J||₂ = √Σ(∂J/∂wᵢ)².

Symptom 4: Check gradient norm directly in code:

python
total_norm = 0
for p in model.parameters():
    total_norm += p.grad.data.norm(2)**2
total_norm = total_norm**0.5
print(f"Gradient norm: {total_norm:.4f}")

If gradient norm grows from ~1 to ~1000 over 10 epochs, exploding gradients are the cause.


Solution 1 — Gradient Clipping

Algorithm: if the gradient norm exceeds a threshold, rescale the gradient to that threshold while preserving direction.

if ||g|| > threshold: g ← g × (threshold / ||g||)

Example: g = [16, 12, −8], threshold = 5.

||g|| = √(16² + 12² + 8²) = √(256 + 144 + 64) = √464 = 21.54

g_clipped = [16, 12, −8] × (5 / 21.54) = [16 × 0.232, 12 × 0.232, −8 × 0.232]

= [3.72, 2.79, −1.86]

The gradient direction is preserved exactly. Only the magnitude is reduced. The optimizer still moves in the right direction, just not as far.

PyTorch: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Common max_norm values: 1.0 (tight), 5.0 (moderate), 10.0 (loose). For RNNs: 1.0 is standard.


Solution 2 — Weight Initialization (Preview)

Initializing weights from N(0, √(2/n)) (He) or N(0, √(2/(n_in+n_out))) (Xavier) keeps the initial gradient magnitudes in a stable range. The derivation is in the next post. When weights start at the right scale, the compound multiplication through layers stays near 1 — neither shrinking to 0 nor growing to ∞.


Solution 3 — Batch Normalization (Preview)

Batch normalization normalizes activations between layers to have mean≈0 and standard deviation≈1 after every layer. If activations can't grow exponentially, gradients can't grow exponentially either. Batch norm is primarily used for training stability, but preventing exploding gradients is one of its effects.


Code

python
import numpy as np

def forward_backward(w_init, n_layers=5, x=1.0):
    a = x
    activations = [a]
    for _ in range(n_layers):
        a = w_init * a  # linear (no activation, for illustration)
        activations.append(a)

    delta = 1.0
    grads = []
    for _ in range(n_layers):
        delta = w_init * delta
        grads.append(delta)
    return list(reversed(grads))

for w in [0.5, 1.0, 2.0]:
    grads = forward_backward(w)
    print(f"w={w}: gradients at each layer = {[round(g, 4) for g in grads]}")

# Gradient clipping
def clip_grad(g, max_norm):
    norm = np.linalg.norm(g)
    return g * (max_norm / norm) if norm > max_norm else g

g = np.array([16.0, 12.0, -8.0])
print(f"\nOriginal gradient: {g}")
print(f"Gradient norm:     {np.linalg.norm(g):.2f}")
print(f"Clipped (max=5):   {clip_grad(g, 5.0).round(3)}")
text
w=0.5: gradients at each layer = [0.0312, 0.0625, 0.125, 0.25, 0.5]
w=1.0: gradients at each layer = [1.0, 1.0, 1.0, 1.0, 1.0]
w=2.0: gradients at each layer = [16.0, 8.0, 4.0, 2.0, 1.0]

Original gradient: [16.  12.  -8.]
Gradient norm:     21.54
Clipped (max=5):   [3.714 2.786 -1.857]

w=0.5: gradient shrinks toward the input layer (vanishing gradient). w=1.0: gradient stays constant (perfect weight initialization). w=2.0: gradient grows 16× toward the input layer (exploding gradient). The crossover point is exactly |w|=1.


Where this builds from: The vanishing gradient problem (section 3, post 01) established that the product of activation derivatives through layers either decays or grows exponentially. Exploding gradient is the same mechanism with |w| > 1 instead of |w| < 1.

Where this leads: Weight initialization (next post) shows how to choose the initial scale of weights to avoid both problems. Gradient clipping is standard in all RNN and LSTM training. Batch normalization (an advanced technique) prevents both vanishing and exploding gradients by normalizing activations between layers.


Honest Limitations

Gradient clipping does not solve the root cause. It prevents divergence but does not ensure the gradient is accurate — a clipped gradient is a smaller gradient, not a better one. In RNNs with very long sequences (500+ timesteps), clipping prevents NaN but the gradient still loses meaningful signal from distant timesteps.

The threshold for gradient clipping requires tuning. A threshold too small clips frequently and the optimizer takes tiny steps. A threshold too large fails to prevent occasional large spikes. Common values (1.0, 5.0) are starting points; the right value depends on the gradient scale of your specific task.

Batch normalization adds its own complications. The normalization statistics (mean and variance) are computed per-batch during training and as running averages at test time. On small batches (< 8) or in recurrent settings (where the sequence acts as the batch), batch norm estimates can be noisy or incorrect. Layer normalization is preferred for RNNs and transformers.


Test Your Understanding

  1. With w=2.0 and 5 layers, the gradient at layer 1 is 16. With w=3.0 and 5 layers, what is the gradient at layer 1? With w=0.9 and 5 layers, what is the gradient at layer 1? At what weight value does the gradient stay exactly constant across all layers?

  2. Gradient clipping preserves direction but reduces magnitude. For g=[3, 4] and threshold=2, compute ||g||, the clipping scale factor, and g_clipped. Verify that ||g_clipped|| ≤ threshold. If two separate gradients g₁=[3,4] and g₂=[6,8] are both clipped to max_norm=2, are their clipped versions in the same or different directions?

  3. RNNs unroll through time steps. A sequence of length 100 with the same weight matrix W applied at each step means the gradient at timestep 1 is proportional to W¹⁰⁰. If W is a scalar with value 1.01 (just slightly above 1), compute the gradient magnitude at timestep 1. If W=0.99, compute it. What does this reveal about why RNNs struggle with long-range dependencies even with "safe" weight values?

  4. Gradient clipping is typically applied to the global gradient norm (across all model parameters), not per-parameter. Suppose a network has two weight matrices W₁ and W₂ where ∂L/∂W₁=[100, 0] and ∂L/∂W₂=[0, 0.1]. Compute the global norm. After clipping to max_norm=1, what is the effective gradient for each weight? Is this desirable?

  5. You observe that training loss decreases for 3 epochs, then suddenly spikes to NaN. Gradient clipping is not applied. List three diagnostics you would run (specific metrics to compute and at what granularity) to confirm exploding gradients are the cause, and state what threshold value you would expect to see in each diagnostic.

Comments (0)

No comments yet. Be the first to comment!

Leave a comment