Adam with L2 regularization does not actually regularize the way you think. When you add λw to the gradient before Adam processes it, the adaptive learning rate scales the regularization term alongside the gradient — so weights with large gradient history get less regularization than weights with small gradient history. The effective decay is λ/(√v̂ + ε), not λ. For high-variance weights (large v̂), the decay is small. For low-variance weights, the decay is large. This inconsistency is the bug AdamW fixes.
Anchor: weight w = 0.8, gradient g = −0.3, λ = 0.01, lr = 0.001, β₁ = 0.9, β₂ = 0.999, ε = 1e-8. Five update steps.
Adam + L2 vs AdamW
Adam with L2 (the broken version):
g̃_t = g_t + λ·w_{t-1} ← add decay to gradient
m_t = β₁·m_{t-1} + (1-β₁)·g̃_t ← first moment of modified gradient
v_t = β₂·v_{t-1} + (1-β₂)·g̃_t² ← second moment of modified gradient
w_t = w_{t-1} - lr · m̂_t/(√v̂_t + ε)The problem: m and v accumulate the regularization term. When v̂ is large, dividing by √v̂ shrinks the effective decay. Weights that change frequently (large v̂) receive less regularization than weights that rarely update. This breaks the intent of L2 regularization.
AdamW (the fix — Loshchilov & Hutter, 2019):
m_t = β₁·m_{t-1} + (1-β₁)·g_t ← first moment of ORIGINAL gradient only
v_t = β₂·v_{t-1} + (1-β₂)·g_t² ← second moment of ORIGINAL gradient only
m̂_t = m_t/(1-β₁ᵗ)
v̂_t = v_t/(1-β₂ᵗ)
w_t = w_{t-1} - lr·(m̂_t/(√v̂_t+ε)) - lr·λ·w_{t-1}The decay lr·λ·w_{t-1} is a separate subtraction after the adaptive step. It is not scaled by the adaptive learning rate. Every weight decays by exactly lr·λ·w per step, regardless of gradient history.
Step-by-Step: First Two Updates
t = 1:
- m₁ = 0.9 × 0 + 0.1 × (−0.3) = −0.03
- v₁ = 0.999 × 0 + 0.001 × (0.09) = 0.00009
- m̂₁ = −0.03 / (1 − 0.9¹) = −0.03 / 0.1 = −0.3
- v̂₁ = 0.00009 / (1 − 0.999¹) = 0.00009 / 0.001 = 0.09
- adam_step = 0.001 × (−0.3) / (√0.09 + 1e-8) = 0.001 × (−0.3) / 0.3 = −0.001
- decay = 0.001 × 0.01 × 0.8 = 0.000008
- w₁ = 0.8 − (−0.001) − 0.000008 = 0.800992
t = 2:
- m₂ = 0.9 × (−0.03) + 0.1 × (−0.3) = −0.027 − 0.03 = −0.057
- v₂ = 0.999 × 0.00009 + 0.001 × 0.09 = 0.00008991 + 0.00009 = 0.00017991
- m̂₂ = −0.057 / (1 − 0.81) = −0.057 / 0.19 = −0.3
- v̂₂ = 0.00017991 / (1 − 0.998001) = 0.00017991 / 0.001999 ≈ 0.09
- adam_step ≈ 0.001 × (−0.3) / 0.3 = −0.001
- decay = 0.001 × 0.01 × 0.800992 ≈ 0.000008
- w₂ ≈ 0.801984
The bias-corrected m̂ and v̂ stay stable across steps once the moments have accumulated — this is the Adam momentum effect. The weight moves steadily upward because g = −0.3 (negative gradient → weight should increase).
| t | w | m̂ | v̂ | adam_upd | decay | w_new |
|---|---|---|---|---|---|---|
| 1 | 0.800000 | −0.300000 | 0.09000000 | −0.001000 | 0.000008 | 0.800992 |
| 2 | 0.800992 | −0.300000 | 0.09000000 | −0.001000 | 0.000008 | 0.801984 |
| 3 | 0.801984 | −0.300000 | 0.09000000 | −0.001000 | 0.000008 | 0.802976 |
| 4 | 0.802976 | −0.300000 | 0.09000000 | −0.001000 | 0.000008 | 0.803968 |
| 5 | 0.803968 | −0.300000 | 0.09000000 | −0.001000 | 0.000008 | 0.804960 |
The decay term (0.000008 per step) is tiny relative to the adam_update (0.001 per step) at this λ and lr. At λ = 0.1, the decay becomes 10× larger and starts to compete with the gradient signal.
Why LLMs Use AdamW
Every major LLM uses AdamW: GPT-2, GPT-3, LLaMA, BERT, T5. The typical configuration:
- β₁ = 0.9, β₂ = 0.95 or 0.999
- lr = 3e-4 for pre-training, 1e-5 for fine-tuning
- λ = 0.01 to 0.1
- wd = 0 for biases and layer norm parameters (common convention)
The consistent weight decay is what makes AdamW the default: at billions of parameters across trillions of tokens, L2 regularization that varies by gradient variance would create wildly inconsistent effective regularization across the network. AdamW ensures every weight decays by the same fraction per step.
Hyperparameter Sensitivity
import numpy as np
def adamw_step(w, g, m, v, t, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8, wd=0.01):
m = beta1 * m + (1 - beta1) * g
v = beta2 * v + (1 - beta2) * g**2
m_hat = m / (1 - beta1**t)
v_hat = v / (1 - beta2**t)
adam_update = lr * m_hat / (np.sqrt(v_hat) + eps)
decay = lr * wd * w
w_new = w - adam_update - decay
return w_new, m, v, m_hat, v_hat
w0, g = 0.8, -0.3
print(f"{'t':>2} | {'w':>8} | {'m_hat':>8} | {'v_hat':>10} | {'adam_upd':>9} | {'decay':>7}")
w, m, v = w0, 0.0, 0.0
for t in range(1, 6):
w_new, m, v, mh, vh = adamw_step(w, g, m, v, t)
adam_upd = 0.001 * mh / (np.sqrt(vh) + 1e-8)
decay = 0.001 * 0.01 * w
print(f"{t:>2} | {w:8.6f} | {mh:8.6f} | {vh:10.8f} | {adam_upd:9.6f} | {decay:7.6f}")
w = w_new
print("\nLambda sensitivity (5 steps, w0=0.8, g=-0.3):")
for wd in [0.0, 0.01, 0.1, 1.0]:
w, m, v = w0, 0.0, 0.0
for t in range(1, 6):
w, m, v, _, _ = adamw_step(w, g, m, v, t, wd=wd)
print(f" λ={wd:4.2f}: final w = {w:.6f}")t | w | m_hat | v_hat | adam_upd | decay
1 | 0.800000 | -0.300000 | 0.09000000 | -0.001000 | 0.000008
2 | 0.800992 | -0.300000 | 0.09000000 | -0.001000 | 0.000008
3 | 0.801984 | -0.300000 | 0.09000000 | -0.001000 | 0.000008
4 | 0.802976 | -0.300000 | 0.09000000 | -0.001000 | 0.000008
5 | 0.803968 | -0.300000 | 0.09000000 | -0.001000 | 0.000008
Lambda sensitivity (5 steps, w0=0.8, g=-0.3):
λ=0.00: final w = 0.804960
λ=0.01: final w = 0.804960
λ=0.10: final w = 0.804956
λ=1.00: final w = 0.804921At λ = 0.0, AdamW is identical to Adam — no decay at all. At λ = 1.0, the decay term is lr·1.0·w = 0.0008 per step, still small relative to the adam_update of 0.001. To see collapse, you'd need λ >> 1 or much higher lr. The decay becomes dominant only when the gradient signal weakens (loss plateaus) — at that point λ·w pulls the weight toward 0 continuously.
Related Concepts
AdamW builds directly on Adam (07-adam.md) — it keeps the same first-moment (momentum) and second-moment (adaptive learning rate) mechanics, changing only where the weight decay is applied. The L1/L2 regularization post (06-regularization/08-l1-l2-regularization.md) covers why weight decay works as a regularizer. Beyond AdamW, the optimizer landscape includes AdaFactor (used in T5, reduces memory by not storing a full v matrix) and Lion (Google, uses sign of momentum instead of magnitude — often 3× more memory-efficient than AdamW).
Honest Limitations
AdamW still has three hyperparameters that need tuning: lr, β₁, β₂, and λ. Fixing β₁ = 0.9, β₂ = 0.999, λ = 0.01 works as a starting point but is not guaranteed to be optimal — particularly β₂ is often tuned to 0.95 for LLM pre-training to reduce the lag in tracking recent gradient variance.
The common convention of setting weight decay to 0 for biases and normalization layer parameters (γ, β) is not enforced by AdamW itself — it requires the user to pass separate parameter groups with different wd settings. Forgetting this means regularizing the layer norm scale and shift parameters, which can cause underfitting in deep Transformer models.
AdamW's decoupled decay does not correspond to placing a Gaussian prior on weights in the Bayesian sense — L2 regularization in the gradient does. This distinction matters if you are using the optimizer within a probabilistic framework or trying to connect the regularization to a prior.
Test Your Understanding
-
Adam+L2 applies decay to the gradient before computing m and v. AdamW applies decay after. For a weight with v̂ = 0.09 and λ = 0.01, compute the effective decay per step for Adam+L2 vs AdamW (use lr = 0.001). How much smaller is the Adam+L2 decay?
-
At t = 1 in the trace above, m̂₁ = −0.3 exactly (not approximately). Why? Show the math using β₁ = 0.9 and g = −0.3.
-
You train a 7B parameter LLM with AdamW at λ = 0.1 and notice that embedding weights collapse to near-zero while attention weights remain large. What is the most likely cause, and what change would fix it?
-
In the lambda sensitivity output, all four λ values produce nearly identical final weights after 5 steps. Why doesn't λ = 1.0 collapse the weight to zero? What conditions would make the decay term dominant?
-
AdaFactor drops the full v matrix and uses a factored approximation. What property of Adam/AdamW makes this possible, and what is the cost in terms of optimization dynamics?