~/blog

Group Normalization

Jul 3, 20268 min readBy Mohammed Vasim
deep-learningneural-networksmachine-learningrepresentation-learning

Batch Normalization computes statistics across all samples in a mini-batch. That works fine at batch size 128 where you have many samples to average over. At batch size 2 — which is common in object detection and segmentation because high-resolution images consume most GPU memory — the mean and variance estimates are computed from just 2 samples. These estimates are noisy: the sample mean can be far from the true mean, and the sample variance can vary wildly across batches. Training becomes unstable.

Group Normalization (Wu & He, 2018) breaks channels into groups and computes statistics within each group, for each sample independently. It never looks across the batch dimension at all. Batch size 1, batch size 2, batch size 128 — the statistics are identical.

Anchor: single sample (N=1), 8 channels at one spatial position: [1.2, 0.4, −0.8, 2.1, 0.9, −0.3, 1.5, −1.1]. G=2 groups of 4 channels each.


The Normalization Zoo

Four normalization methods, same 4D input tensor (N samples, C channels, H height, W width):

MethodNormalize overPer sample?Batch-independent?
Batch NormN (across samples), per CNoNo
Layer NormC, H, W (all features), per NYesYes
Group NormC/G channels per group, per NYesYes
Instance NormH, W (spatial), per N, CYesYes
Normalization Zoo — What Gets Averaged (blue) Batch Norm Layer Norm Group Norm Instance Norm average across N (per channel) average across all C,H,W (per sample) average across C/G channels (per group per sample) average across H,W (per channel per sample)

The Algorithm

For sample n, group g (channels c in group g):

  1. Partition C channels into G groups of C/G channels each
  2. Compute μ_{n,g} = mean over channels in group g
  3. Compute σ²_{n,g} = variance over channels in group g
  4. x̂ = (x − μ_{n,g}) / √(σ²_{n,g} + ε)
  5. y = γ · x̂ + β (γ and β are per-channel, not per-group)

Computing on the Anchor

Channels: [1.2, 0.4, −0.8, 2.1, 0.9, −0.3, 1.5, −1.1] G=2: Group 1 = channels 0–3 = [1.2, 0.4, −0.8, 2.1], Group 2 = channels 4–7 = [0.9, −0.3, 1.5, −1.1]

Group 1: μ₁ = (1.2 + 0.4 + (−0.8) + 2.1) / 4 = 2.9 / 4 = 0.725

σ²₁ = [(1.2−0.725)² + (0.4−0.725)² + (−0.8−0.725)² + (2.1−0.725)²] / 4 = [0.2256 + 0.1056 + 2.3256 + 1.8906] / 4 = 4.5474 / 4 = 1.1369

σ₁ = √(1.1369 + 1e-5) ≈ 1.0663

x̂ for group 1:

  • x̂₀ = (1.2 − 0.725) / 1.0663 = 0.4455
  • x̂₁ = (0.4 − 0.725) / 1.0663 = −0.3049
  • x̂₂ = (−0.8 − 0.725) / 1.0663 = −1.4302
  • x̂₃ = (2.1 − 0.725) / 1.0663 = 1.2895

Group 2: μ₂ = (0.9 + (−0.3) + 1.5 + (−1.1)) / 4 = 1.0 / 4 = 0.25

σ²₂ = [(0.9−0.25)² + (−0.3−0.25)² + (1.5−0.25)² + (−1.1−0.25)²] / 4 = [0.4225 + 0.3025 + 1.5625 + 1.8225] / 4 = 4.11 / 4 = 1.0275

x̂ for group 2:

  • x̂₄ = (0.9 − 0.25) / √1.0275 = 0.65 / 1.0137 = 0.6413
  • x̂₅ = (−0.3 − 0.25) / 1.0137 = −0.5426
  • x̂₆ = (1.5 − 0.25) / 1.0137 = 1.2330
  • x̂₇ = (−1.1 − 0.25) / 1.0137 = −1.3317
RowFormulaResult
Group 1 μ(1.2+0.4+(−0.8)+2.1)/40.725
Group 1 σ²sum of squared deviations / 41.1369
Group 1 x̂(xᵢ−0.725)/1.0663[0.446, −0.305, −1.430, 1.290]
Group 2 x̂(xᵢ−0.25)/1.0137[0.641, −0.543, 1.233, −1.332]

G=1 and G=C Special Cases

G=1: All 8 channels in one group → normalize over all 8 channels for this sample → equivalent to Layer Normalization (when there is no spatial dimension).

μ = (1.2+0.4−0.8+2.1+0.9−0.3+1.5−1.1)/8 = 3.9/8 = 0.4875 σ² = variance over all 8 values

G=C=8: Each channel is its own group of 1 → each channel normalized by its own single value → Instance Normalization (when used with spatial dimensions).

With one spatial position, a single value has zero variance and normalization is undefined. Instance Norm only makes sense with spatial dimensions (H×W > 1).

GEquivalent toPer-group size
1Layer NormC = 8
2Group Norm4 channels
8Instance Norm1 channel

GN vs BN at Small Batch Sizes

Validation Error vs Batch Size (ImageNet, Wu & He 2018) high low 1 2 4 8 32 batch size → GN BN BN: noisy stats at batch=1,2

At batch size 2, BN's error is ~10 points higher than GN on ImageNet. At batch size 32, they converge. GN's error is flat across all batch sizes because its statistics never use the batch dimension.


Code

python
import numpy as np

def group_norm(x, G, gamma, beta, eps=1e-5):
    N, C = x.shape
    assert C % G == 0
    x_grouped = x.reshape(N, G, C // G)
    mu = x_grouped.mean(axis=2, keepdims=True)
    var = x_grouped.var(axis=2, keepdims=True)
    x_hat = (x_grouped - mu) / np.sqrt(var + eps)
    x_hat = x_hat.reshape(N, C)
    return gamma * x_hat + beta, mu.squeeze(), var.squeeze()

x = np.array([[1.2, 0.4, -0.8, 2.1, 0.9, -0.3, 1.5, -1.1]])
gamma = np.ones(8)
beta = np.zeros(8)
y, mu, var = group_norm(x, G=2, gamma=gamma, beta=beta)
print("Input:      ", x[0].round(4))
print("Group means:", mu.round(4))
print("Group vars: ", var.round(4))
print("Output y:   ", y[0].round(4))
text
Input:       [ 1.2   0.4  -0.8   2.1   0.9  -0.3   1.5  -1.1]
Group means: [0.725  0.25 ]
Group vars:  [1.1369 1.0275]
Output y:    [ 0.4455 -0.3049 -1.4302  1.2895  0.6413 -0.5426  1.233  -1.3317]

Group Normalization sits between Batch Normalization (04-batch-normalization.md) and Layer Normalization (05-layer-normalization.md) in the normalization family. BN is optimal at large batch sizes; LN is optimal when batch statistics are unavailable; GN gives stable statistics at any batch size by partitioning channels. Mask R-CNN, Detectron2, and most modern detection backbones use GN with G=32 as the default.

Honest Limitations

The number of groups G is a hyperparameter. G=32 is the default from the original paper but assumes C is divisible by 32. With unusual channel counts (C=48, C=96), you need to choose G carefully or channels will be split unevenly. Wrong G can create groups that mix semantically unrelated channels.

At large batch sizes (≥ 32), BN benefits from highly optimized CUDA kernels — fused batch-norm operations on modern GPUs are significantly faster than GN's per-group statistics. GN's advantage is in accuracy at small batches, not speed.

GN cannot exploit batch-level statistics. In some tasks — contrastive learning, for instance — the diversity of samples within a batch is a signal that BN uses implicitly (negative pairs in the batch create variance that BN tracks). GN sees each sample independently and cannot benefit from this.


Test Your Understanding

  1. Compute Group 2 μ and σ² for the anchor x = [0.9, −0.3, 1.5, −1.1] using the formulas. Verify against the code output.

  2. If G=1, show that Group Normalization reduces to Layer Normalization on the anchor. Compute the mean and variance over all 8 channels and compare to what LN would produce.

  3. A ResNet-50 normally uses batch size 256 for training. You switch to batch size 2 for a medical imaging task with large images. You observe training divergence with BN. After switching to GN with G=32, training stabilizes. Explain mechanistically why BN diverges at batch=2 while GN does not.

  4. Instance Normalization (G=C) normalizes each channel independently. With only one spatial position (as in our 1D anchor), Instance Norm cannot compute variance. Why? And what does this imply about when Instance Norm should be used?

  5. A model uses GN with G=4 on 64-channel feature maps. You double the channel count to 128. Should you also change G? What is the trade-off between larger and smaller G in terms of statistics quality and channel grouping semantics?

Comments (0)

No comments yet. Be the first to comment!

Leave a comment