~/blog

KL Divergence

Jul 3, 20267 min readBy Mohammed Vasim
deep-learningneural-networksmachine-learningrepresentation-learning

When a model outputs a probability distribution, cross-entropy measures how surprised you are on average when you use the model's distribution to encode events from the true distribution. But cross-entropy mixes two things together: the irreducible surprise in the true distribution itself, and the extra surprise from using the wrong distribution. KL divergence isolates the extra part.

H(P,Q) = H(P) + KL(P‖Q)

H(P) is entropy — how much surprise is inherent in the true distribution regardless of what model you use. H(P,Q) is cross-entropy — the average surprise when you use model Q to encode samples from P. KL(P‖Q) is the difference: how much extra surprise you incur by approximating P with Q. If your model is perfect (Q=P), KL=0 and cross-entropy equals entropy.

Anchor: two probability distributions over 4 weather classes.

python
classes = ["sunny", "cloudy", "rainy", "stormy"]
P (true)  = [0.5, 0.3, 0.15, 0.05]
Q (model) = [0.4, 0.35, 0.2, 0.05]

The Formula

KL(P‖Q) = Σ P(x) · log(P(x) / Q(x))

Equivalently: Σ P(x) · (log P(x) − log Q(x))

KL ≥ 0 always (Gibbs' inequality). KL = 0 if and only if P = Q. This is not a metric — it is asymmetric: KL(P‖Q) ≠ KL(Q‖P) in general.


Trace Table

classP(x)Q(x)P(x)/Q(x)log(P/Q)P·log(P/Q)
sunny0.500.401.2500+0.2231+0.1116
cloudy0.300.350.8571−0.1542−0.0463
rainy0.150.200.7500−0.2877−0.0432
stormy0.050.051.00000.00000.0000
KL(P‖Q)= 0.0221

The model assigns too little probability to "sunny" (0.4 vs 0.5), which is the most frequent class — this gives a positive contribution. The slight over-assignment to "cloudy" and "rainy" reduces KL a bit, but the net is positive: 0.0221 nats of extra surprise.


H(P), H(P,Q), and KL

Entropy H(P): H(P) = −Σ P·log P = −(0.5·log0.5 + 0.3·log0.3 + 0.15·log0.15 + 0.05·log0.05) = −(0.5·(−0.6931) + 0.3·(−1.2040) + 0.15·(−1.8971) + 0.05·(−2.9957)) = −(−0.3466 − 0.3612 − 0.2846 − 0.1498) = 1.1422

Cross-entropy H(P,Q): H(P,Q) = −Σ P·log Q = −(0.5·log0.4 + 0.3·log0.35 + 0.15·log0.2 + 0.05·log0.05) = −(0.5·(−0.9163) + 0.3·(−1.0498) + 0.15·(−1.6094) + 0.05·(−2.9957)) = −(−0.4581 − 0.3149 − 0.2414 − 0.1498) = 1.1643

Verification: H(P,Q) − H(P) = 1.1643 − 1.1422 = 0.0221 = KL(P‖Q)

This shows why minimizing cross-entropy loss in classification is exactly the same as minimizing KL divergence between the label distribution and the model output — the entropy H(P) is a constant that doesn't depend on the model.


Forward vs Reverse KL

KL(Q‖P) on anchor:

classQ(x)P(x)Q(x)/P(x)log(Q/P)Q·log(Q/P)
sunny0.400.500.8000−0.2231−0.0893
cloudy0.350.301.1667+0.1542+0.0540
rainy0.200.151.3333+0.2877+0.0575
stormy0.050.051.00000.00000.0000
KL(Q‖P)= 0.0222

On this anchor they are nearly identical (0.0221 vs 0.0222) because the distributions are close. In general they differ substantially:

Forward KL(P‖Q) Reverse KL(Q‖P) P (wide, bimodal) Q fits mean — covers both modes but imprecisely Q fits one mode — misses other mode entirely

Forward KL(P‖Q): P is in the numerator. The term P(x)·log(P(x)/Q(x)) blows up if Q(x)→0 where P(x)>0. So Q is forced to cover everywhere P has probability mass — even in low-density regions. This makes Q spread out, mean-seeking.

Reverse KL(Q‖P): Q is in the numerator. The term Q(x)·log(Q(x)/P(x)) blows up if P(x)→0 where Q(x)>0. So Q avoids placing mass where P has none. Q collapses to a single mode of P that it can model well, ignoring other modes. Mode-seeking.


Where KL Divergence Is Used

Variational Autoencoders (VAE): Loss = reconstruction_loss + β · KL(q(z|x) ‖ p(z))

P = p(z) = N(0,I) — standard normal prior. Q = q(z|x) — encoder's learned posterior. The KL term regularizes the latent space toward the prior — without it, the encoder would memorize exact codes rather than learning a continuous latent distribution.

Knowledge Distillation: Loss = KL(teacher_softmax ‖ student_softmax) + α · CE(y, student)

P = teacher's soft probabilities. Q = student's soft probabilities. The student is trained to match not just the hard label but the full soft distribution — information in the teacher's probability assigned to wrong classes ("dark knowledge") helps the student generalize.

RLHF / PPO: PPO penalty = KL(π_new ‖ π_ref)

P = π_ref — the reference policy (SFT model). Q = π_new — the policy being trained. Penalizing large KL prevents the policy from drifting too far from the reference during reward optimization, avoiding reward hacking and degenerate behaviors.


Code

python
import numpy as np

P = np.array([0.5, 0.3, 0.15, 0.05])
Q = np.array([0.4, 0.35, 0.2, 0.05])
classes = ["sunny", "cloudy", "rainy", "stormy"]

def kl_divergence(p, q, eps=1e-10):
    return np.sum(p * np.log((p + eps) / (q + eps)))

def entropy(p, eps=1e-10):
    return -np.sum(p * np.log(p + eps))

def cross_entropy(p, q, eps=1e-10):
    return -np.sum(p * np.log(q + eps))

print(f"{'Class':>8} | {'P':>6} | {'Q':>6} | {'P/Q':>8} | {'log(P/Q)':>10} | {'P*log(P/Q)':>12}")
for c, p, q in zip(classes, P, Q):
    ratio = p / q
    print(f"{c:>8} | {p:>6.2f} | {q:>6.2f} | {ratio:>8.4f} | {np.log(ratio):>10.4f} | {p*np.log(ratio):>12.4f}")

kl_fwd = kl_divergence(P, Q)
kl_rev = kl_divergence(Q, P)
h_p = entropy(P)
h_pq = cross_entropy(P, Q)

print(f"\nH(P)         = {h_p:.4f}")
print(f"H(P,Q)       = {h_pq:.4f}")
print(f"KL(P‖Q)      = {kl_fwd:.4f}  [forward]")
print(f"KL(Q‖P)      = {kl_rev:.4f}  [reverse]")
print(f"H(P,Q)-H(P)  = {h_pq - h_p:.4f}  (should equal KL(P‖Q))")
text
Class |      P |      Q |     P/Q |   log(P/Q) |   P*log(P/Q)
   sunny |   0.50 |   0.40 |  1.2500 |     0.2231 |       0.1116
  cloudy |   0.30 |   0.35 |  0.8571 |    -0.1542 |      -0.0463
   rainy |   0.15 |   0.20 |  0.7500 |    -0.2877 |      -0.0431
  stormy |   0.05 |   0.05 |  1.0000 |     0.0000 |       0.0000

H(P)         = 1.1422
H(P,Q)       = 1.1643
KL(P‖Q)      = 0.0221  [forward]
KL(Q‖P)      = 0.0222  [reverse]
H(P,Q)-H(P)  = 0.0221  (should equal KL(P‖Q))

KL divergence connects directly to cross-entropy loss (03-classification-losses.md) — minimizing CE with fixed true labels is identical to minimizing KL. The softmax function (03-activations/07-softmax.md) is typically applied before computing KL divergence to convert logits to a probability distribution. VAE, knowledge distillation, and PPO are three of the most important real applications of KL in modern deep learning.

Honest Limitations

KL(P‖Q) is undefined when Q(x)=0 for any class where P(x)>0. In practice, a small ε (1e-10) is added to prevent log(0), but this is a numerical hack. If the model confidently assigns zero probability to a class that appears in the true distribution, the KL divergence should be infinite — the model has made an infinitely bad prediction. Adding ε hides this failure.

KL divergence is not symmetric — KL(P‖Q) ≠ KL(Q‖P) in general. It is not a metric in the mathematical sense. When you need a true symmetric distance between distributions, Jensen-Shannon divergence (JSD = (KL(P‖M) + KL(Q‖M))/2 where M=(P+Q)/2) is a better choice. JSD is bounded in [0, log 2] and symmetric.

When P and Q have disjoint support (no overlap in their support sets), KL(P‖Q) = ∞. This is a fundamental problem for training GANs with KL divergence — early generator outputs may have zero overlap with the real data distribution, making the gradient undefined. Wasserstein distance (earth mover's distance) was introduced to solve exactly this problem.


Test Your Understanding

  1. Compute KL(P‖Q) for P=[0.7, 0.3] and Q=[0.5, 0.5]. Show each term. Why is KL positive even though Q seems like a "reasonable" model?

  2. The trace table shows that "cloudy" and "rainy" contribute negative values to KL(P‖Q) because Q assigns more probability than P. How can individual terms in a KL sum be negative if KL ≥ 0 overall?

  3. In VAE training, the KL term KL(q(z|x)‖p(z)) often collapses to zero early in training (all latent codes collapse to the prior). What does this mean for the encoder, and what does it imply about the reconstruction loss?

  4. A knowledge distillation student is trained to match teacher probabilities [0.95, 0.04, 0.01] but outputs [0.90, 0.05, 0.05]. Compute KL(teacher‖student) for this 3-class example. Which class contributes most and why?

  5. PPO uses KL(π_new‖π_ref) as a penalty rather than KL(π_ref‖π_new). Given the mode-seeking vs mean-seeking interpretation, why is this the right choice? What would happen if you used the reverse direction?

Comments (0)

No comments yet. Be the first to comment!

Leave a comment