~/blog

Layer Normalization

Jul 3, 20267 min readBy Mohammed Vasim
deep-learningneural-networksmachine-learningrepresentation-learning

Batch Normalization normalizes across samples in a batch — it averages over the N-dimension. That design breaks the moment you process a single sample: at inference with batch size 1, there is nothing to average over. It also breaks for variable-length sequences: a batch mixing a 10-token sequence with a 50-token sequence would compute statistics that mix positions with fundamentally different content. Transformers need to work at batch size 1 during autoregressive decoding, and they process variable-length sequences constantly. So Batch Normalization was never an option.

Layer Normalization (Ba et al., 2016) rotates the normalization axis: instead of normalizing across the batch for each feature, it normalizes across features for each sample. Every single sample has its own μ and σ, computed from its own feature vector. Batch size becomes irrelevant.

Anchor: single token representation x = [2.0, 0.5, −1.0, 1.5] (4-dimensional hidden state). γ = 1.0, β = 0.0.


BN vs LN — What Gets Normalized

Batch NormLayer Norm
Normalize overN samples, per feature jd features, per sample i
Statisticsμ_j, σ²_j across batchμᵢ, σ²ᵢ across features
Requires batchYesNo
Works at batch size 1NoYes
Works with variable-length sequencesNoYes
Common inCNNsTransformers, RNNs
Batch Norm — normalize columns Layer Norm — normalize rows features→ sample 1 sample 2 sample 3 sample 4 ↑ column normalized features→ s1 s2 s3 s4 ← row normalized (sample 1 only)

The Algorithm

For a single sample x ∈ ℝᵈ (one token, one sequence position, one data point):

Step 1 — Mean across features: μ = (1/d) Σᵢ xᵢ = (2.0 + 0.5 + (−1.0) + 1.5) / 4 = 3.0 / 4 = 0.75

Step 2 — Variance across features: σ² = (1/d) Σᵢ (xᵢ − μ)² = [(2.0−0.75)² + (0.5−0.75)² + (−1.0−0.75)² + (1.5−0.75)²] / 4 = [1.5625 + 0.0625 + 3.0625 + 0.5625] / 4 = 5.25 / 4 = 1.3125

Step 3 — Normalize: σ = √(1.3125 + 1e-5) ≈ 1.1456

  • x̂₁ = (2.0 − 0.75) / 1.1456 = 1.25 / 1.1456 = 1.0911
  • x̂₂ = (0.5 − 0.75) / 1.1456 = −0.25 / 1.1456 = −0.2182
  • x̂₃ = (−1.0 − 0.75) / 1.1456 = −1.75 / 1.1456 = −1.5275
  • x̂₄ = (1.5 − 0.75) / 1.1456 = 0.75 / 1.1456 = 0.6547

Step 4 — Scale and shift: yᵢ = γᵢ · x̂ᵢ + βᵢ = 1.0 · x̂ᵢ + 0.0 = x̂ᵢ (at initialization)

StepFormulaResult
μ(2.0+0.5+(−1.0)+1.5)/40.75
σ²[(1.25²+0.25²+1.75²+0.75²)]/41.3125
(xᵢ−0.75)/1.1456[1.091, −0.218, −1.527, 0.655]
y (γ=1, β=0)1.0·x̂+0.0[1.091, −0.218, −1.527, 0.655]

Placement in Transformers

Every Transformer layer applies LN twice: once before attention, once before the FFN. This is called Pre-LN (modern placement, introduced after the original paper):

Pre-LN: x + Attention(LN(x))

Post-LN (original Vaswani 2017): LN(x + Attention(x))

Post-LN (original 2017) Pre-LN (modern) x (input) Attention(x) x + Attention(x) LN(x + Attn) LN after residual — less stable at depth x (input) LN(x) Attention(LN(x)) x + Attn(LN(x)) LN before residual — stable at depth

Post-LN was the original design but training deep Post-LN Transformers requires careful warm-up schedules. Pre-LN is now standard because the residual stream x passes through un-normalized, preserving the gradient highway through the skip connection.


Gradient Flow

The gradient through LN:

∂L/∂xᵢ = (γᵢ/σ) · [∂L/∂ŷᵢ − (1/d)Σⱼ ∂L/∂ŷⱼ − x̂ᵢ·(1/d)Σⱼ ∂L/∂ŷⱼ·x̂ⱼ]

The key factor is γᵢ/σ. When σ is large (spread-out features), gradients are scaled down. When σ is small, they are amplified. The mean-subtraction terms (the Σ terms) make each position's gradient depend on all positions — LN couples gradients across features within the same sample.


Code

python
import numpy as np

def layer_norm(x, gamma, beta, eps=1e-5):
    mu = x.mean()
    var = x.var()
    x_hat = (x - mu) / np.sqrt(var + eps)
    y = gamma * x_hat + beta
    return y, mu, var, x_hat

x = np.array([2.0, 0.5, -1.0, 1.5])
gamma = np.ones(4)
beta = np.zeros(4)
y, mu, var, x_hat = layer_norm(x, gamma, beta)
print(f"Input x:    {x}")
print(f"Mean μ:     {mu:.4f}")
print(f"Variance σ²:{var:.4f}")
print(f"Normalized: {x_hat.round(4)}")
print(f"Output y:   {y.round(4)}")
text
Input x:    [ 2.   0.5 -1.   1.5]
Mean μ:     0.7500
Variance σ²:1.3125
Normalized: [ 1.0911 -0.2182 -1.5275  0.6547]
Output y:   [ 1.0911 -0.2182 -1.5275  0.6547]

LN solves the batch-size problem that Batch Normalization (04-batch-normalization.md) cannot. Both share the same γ/β parameterization and the same two-step normalize-then-rescale structure. RMSNorm (06-rmsnorm.md) removes the mean-subtraction step from LN — it normalizes only by the root-mean-square of the features, dropping the centering. LLaMA, Mistral, and Gemma use RMSNorm because it is slightly faster and performs comparably.

Honest Limitations

LN does not benefit from cross-sample statistics. With d = 4 features (as in the anchor), the per-sample mean and variance are estimated from just 4 values — unreliable in the same way that BN is unreliable at batch size 2. In practice, Transformers use d = 512 to 4096 hidden dimensions, where the statistics are stable. For very small hidden dimensions (d < 32), LN can be noisy.

LN adds 2d parameters per layer (γ and β, one per feature dimension). In a 96-layer, d=12288 GPT-4 scale model, that is 96 × 2 × 12288 ≈ 2.4M parameters just for normalization — small relative to billions of total parameters, but worth knowing.

LN normalizes across features within each token. It does not capture spatial statistics across positions — that is why CNNs do better with BN (which captures spatial feature statistics across the batch). Using LN in a CNN degrades performance relative to BN because spatial coherence is lost.


Test Your Understanding

  1. Using the anchor x = [2.0, 0.5, −1.0, 1.5], compute x̂₂ (the normalized value for feature index 1, value 0.5) step-by-step.

  2. You set γ₃ = 0 for the third feature dimension in every LN layer of a Transformer. What happens to (a) the output at that dimension and (b) the gradient flowing back through that dimension?

  3. A model trained with Post-LN requires a 4000-step learning rate warm-up but the Pre-LN version converges without it. Using the gradient flow formula, explain mechanistically why Post-LN needs the warm-up.

  4. LN is computed over d features for each sample independently. If you increase d from 512 to 4096 while keeping everything else equal, how does the variance estimate σ² change in reliability? What does this imply for very small language models?

  5. You are converting a BN-based image classifier to use LN so it can process single images at inference without stored running statistics. After the switch, validation accuracy drops slightly. What is the most likely cause and how would you investigate?

Comments (0)

No comments yet. Be the first to comment!

Leave a comment