~/blog

Batch Normalization

Jul 3, 20268 min readBy Mohammed Vasim
deep-learningneural-networksmachine-learningrepresentation-learning

Training a deep network without normalization means that every weight update in layer 3 shifts the distribution of inputs to layer 4, which shifts what layer 4 needs to learn, which shifts layer 5's inputs — on and on through every layer simultaneously. This is called internal covariate shift. The deeper the network, the more these shifts compound. The practical consequence: training slows down because each layer must constantly re-adapt to a moving distribution, and you're forced to use small learning rates to prevent divergence.

Batch Normalization (Ioffe & Szegedy, 2015) fixes this by normalizing each layer's pre-activation values to zero mean and unit variance across the mini-batch — then letting the network optionally undo that normalization with learnable parameters. After BN, each layer sees a stable distribution regardless of how upstream weights changed.

Anchor: mini-batch of 4 samples, 3 neurons per layer (pre-activation values):

text
X = [[2.0, 0.5, -1.0],
     [1.5, 0.8, -0.5],
     [2.5, 0.2, -1.5],
     [1.0, 1.0,  0.0]]

γ = 1.0, β = 0.0 (initialized).


Internal Covariate Shift

Imagine neuron 1 in layer 5 receives inputs ranging from [−2, 2] at epoch 1. After a few gradient steps, neuron 3 in layer 4 shifts its output distribution — now layer 5 neuron 1 sees inputs from [−5, 8]. Its weights were calibrated for [−2, 2]. It must now re-learn.

Without BN — Shifting Distribution With BN — Stable Distribution Distribution shifts each epoch Stable — centered, unit variance

The Algorithm

For each feature (neuron) j, across all m samples in the mini-batch:

Step 1 — Batch mean: μ_j = (1/m) Σᵢ xᵢⱼ

Step 2 — Batch variance: σ²_j = (1/m) Σᵢ (xᵢⱼ − μ_j)²

Step 3 — Normalize: x̂ᵢⱼ = (xᵢⱼ − μ_j) / √(σ²_j + ε) where ε = 1e-5

Step 4 — Scale and shift: yᵢⱼ = γ_j · x̂ᵢⱼ + β_j


Computing on the Anchor

Neuron 1 (column 0: values 2.0, 1.5, 2.5, 1.0):

μ₁ = (2.0 + 1.5 + 2.5 + 1.0) / 4 = 7.0 / 4 = 1.75

σ²₁ = [(2.0−1.75)² + (1.5−1.75)² + (2.5−1.75)² + (1.0−1.75)²] / 4 = [0.0625 + 0.0625 + 0.5625 + 0.5625] / 4 = 1.25 / 4 = 0.3125

x̂ for neuron 1:

  • x̂₁₁ = (2.0 − 1.75) / √(0.3125 + 1e-5) = 0.25 / 0.5590 = 0.4472
  • x̂₂₁ = (1.5 − 1.75) / 0.5590 = −0.25 / 0.5590 = −0.4472
  • x̂₃₁ = (2.5 − 1.75) / 0.5590 = 0.75 / 0.5590 = 1.3416
  • x̂₄₁ = (1.0 − 1.75) / 0.5590 = −0.75 / 0.5590 = −1.3416

With γ₁ = 1, β₁ = 0: y = x̂ (identity at initialization).

StepFormulaNeuron 1 result
μ(2.0+1.5+2.5+1.0)/41.75
σ²[(0.25²+0.25²+0.75²+0.75²)]/40.3125
x̂ (4 samples)(xᵢ − 1.75)/0.5590[0.447, −0.447, 1.342, −1.342]
y (γ=1, β=0)1.0 · x̂ + 0.0[0.447, −0.447, 1.342, −1.342]

All 3 neurons:

μ = [1.75, 0.625, −0.75] σ² = [0.3125, 0.0869, 0.3125]


Learnable Parameters γ and β

At initialization (γ=1, β=0), BN is a pure normalization — zero mean, unit variance. But this might hurt the network: maybe the downstream activation needs inputs in range [−2, 2], not [−1, 1].

γ and β are trained by backpropagation. If γ → σ_B and β → μ_B, the output y ≈ x — BN undoes itself. This sounds counterproductive, but it means BN can never permanently hurt expressive power. The network learns to use as much (or as little) normalization as the loss requires.


Training vs Inference

During training, BN uses the current batch's μ_B and σ²_B. During inference, you might have a single sample — no batch to compute statistics from.

The solution: track running averages during training:

μ_running = momentum · μ_running + (1 − momentum) · μ_B

With momentum = 0.9, starting from μ_running = 0 and three training batches each with μ_B = [1.75, 0.625, −0.75]:

  • After step 1: μ_running = 0.9 × 0 + 0.1 × 1.75 = 0.175
  • After step 2: μ_running = 0.9 × 0.175 + 0.1 × 1.75 = 0.1575 + 0.175 = 0.3325
  • After step 3: μ_running = 0.9 × 0.3325 + 0.1 × 1.75 = 0.2993 + 0.175 = 0.4743

At inference, this running mean is used instead of the batch mean.

BN: Training vs Inference Path x (input) Training: μ_B, σ²_B from batch Inference: μ_run, σ²_run γ · x̂ + β → output

Where BN Is Placed

Original placement (Ioffe & Szegedy): Conv → BN → ReLU

BN normalizes before the activation. This prevents sigmoid and tanh from saturating — after normalization the inputs are near zero, where these activations have strong gradients.

Alternative placement: Conv → ReLU → BN

Some architectures use this, particularly when batch statistics on post-activation values are more informative. In practice, pre-activation BN is the dominant convention.

Original (recommended) Alternative Conv BN ReLU Inputs to ReLU are near-zero Strong gradient region Conv ReLU BN Normalizes post-activation Less common

Gradient Flow Through BN

BN conditions gradients. The chain rule through BN gives:

∂L/∂xᵢⱼ = (γ_j / (m · σ_j)) · [m · ∂L/∂ŷᵢⱼ − Σₖ ∂L/∂ŷₖⱼ − x̂ᵢⱼ · Σₖ ∂L/∂ŷₖⱼ · x̂ₖⱼ]

The key term is γ_j / σ_j. If σ_j is large (high-variance layer), gradients are scaled down. If σ_j is small, gradients are amplified. BN is a self-correcting system: large activations produce large σ, which reduces gradient magnitude, which stabilizes training.


Code

python
import numpy as np

def batch_norm(X, gamma, beta, eps=1e-5):
    mu = X.mean(axis=0)
    var = X.var(axis=0)
    x_hat = (X - mu) / np.sqrt(var + eps)
    y = gamma * x_hat + beta
    return y, mu, var, x_hat

X = np.array([[2.0, 0.5, -1.0],
              [1.5, 0.8, -0.5],
              [2.5, 0.2, -1.5],
              [1.0, 1.0,  0.0]])
gamma = np.ones(3)
beta = np.zeros(3)
y, mu, var, x_hat = batch_norm(X, gamma, beta)
print("Batch mean:    ", mu.round(4))
print("Batch var:     ", var.round(4))
print("Normalized x̂:\n", x_hat.round(4))
print("Output y:\n", y.round(4))
text
Batch mean:     [1.75   0.625  -0.75 ]
Batch var:      [0.3125 0.0869  0.3125]
Normalized x̂:
 [[ 0.4472  -0.4226  -0.4472]
  [-0.4472   0.6339   0.4472]
  [ 1.3416  -1.4791  -1.3416]
  [-1.3416   1.2678   1.3416]]
Output y:
 [[ 0.4472  -0.4226  -0.4472]
  [-0.4472   0.6339   0.4472]
  [ 1.3416  -1.4791  -1.3416]
  [-1.3416   1.2678   1.3416]]

BN addresses the same instability that exploding gradients (01-exploding-gradient.md) and dropout (03-dropout.md) each tackle from different angles — exploding-gradient methods clip the symptom, BN removes a cause. Layer Normalization (05-layer-normalization.md) replaces the batch-dimension statistics with feature-dimension statistics, which is what makes Transformers work at batch size 1 during autoregressive decoding.

Honest Limitations

BN statistics become unreliable with batch size below 8 — μ_B and σ²_B computed from 2 or 4 samples are high-variance estimates that make training noisy. Object detection and segmentation models that must use batch size 1–4 due to memory constraints regularly outperform BN-based baselines with Group Normalization instead.

BN is incompatible with variable-length sequences. Each position in a sequence of length 12 would have different statistics than positions in a sequence of length 24 within the same batch. Mixing them under a shared μ_B produces meaningless statistics. Transformers use Layer Normalization for this reason.

BN requires storing and updating running μ and σ² for every layer and every feature during training — at scale this is a non-trivial memory cost. It also introduces a dependency between samples in the same mini-batch, which breaks certain training scenarios (contrastive learning, where you need gradients to flow without mixing batch statistics from positive and negative pairs).


Test Your Understanding

  1. Compute σ² for neuron 2 (column 1: values 0.5, 0.8, 0.2, 1.0) using the anchor data. Then compute x̂ for the first sample in that neuron.

  2. If γ₁ = 2.0 and β₁ = −0.5, what are the outputs y for neuron 1 using the x̂ values computed above?

  3. A model is trained with batch size 128 but deployed with batch size 1. Without running statistics, BN cannot compute batch mean and variance. Why can't you just use the training batch statistics at inference time?

  4. You observe that after adding BN, training loss decreases faster but validation loss is slightly higher than without BN. Propose an explanation involving the running statistics estimate.

  5. BN is placed between a convolutional layer and ReLU. The batch variance σ² for one feature channel approaches zero during training. What happens to (a) the normalized values x̂, (b) the gradients through this channel, and (c) what should you do about it?

Comments (0)

No comments yet. Be the first to comment!

Leave a comment