← View series: machine learning
~/blog
Decision Tree Pruning: Pre-Pruning and Post-Pruning
An unconstrained decision tree will grow until every leaf is pure — one sample per leaf if necessary. On training data this achieves 100% accuracy. On test data it fails badly: the tree has memorized noise rather than learned a generalizable pattern. Pruning limits tree growth to prevent this.
Anchor dataset: Breast Cancer Wisconsin — 30 features, binary classification (malignant/benign), 569 samples.
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
data = load_breast_cancer()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)Why Trees Overfit Without Constraints
dt_full = DecisionTreeClassifier(random_state=42)
dt_full.fit(X_train, y_train)
print(f"Max depth: {dt_full.get_depth()}")
print(f"Leaf count: {dt_full.get_n_leaves()}")
print(f"Train accuracy: {dt_full.score(X_train, y_train):.4f}")
print(f"Test accuracy: {dt_full.score(X_test, y_test):.4f}")Max depth: 7
Leaf count: 43
Train accuracy: 1.0000
Test accuracy: 0.9298
43 leaves for 455 training samples — roughly 10 samples per leaf on average. Some leaves contain 1–2 samples that happen to land there by quirks of the training data. Train accuracy is 100% (memorized) but test accuracy is 92.98% — a 7% gap from overfitting.
Pre-Pruning Strategy 1: max_depth
Limit how deep the tree can grow before stopping.
print(f"{'depth':>6} | {'train':>8} | {'test':>8} | {'leaves':>7}")
for d in [1, 2, 3, 4, 5, 6, None]:
dt = DecisionTreeClassifier(max_depth=d, random_state=42)
dt.fit(X_train, y_train)
tr = dt.score(X_train, y_train)
te = dt.score(X_test, y_test)
lv = dt.get_n_leaves()
print(f"{str(d):>6} | {tr:>8.4f} | {te:>8.4f} | {lv:>7}") depth | train | test | leaves
1 | 0.8967 | 0.8860 | 2 ← underfit
2 | 0.9385 | 0.9298 | 4
3 | 0.9560 | 0.9561 | 7
4 | 0.9714 | 0.9649 | 13 ← sweet spot
5 | 0.9824 | 0.9561 | 22
6 | 0.9978 | 0.9386 | 37
None | 1.0000 | 0.9298 | 43 ← overfit
<text x="68" y="213" font-size="8" fill="#64748b">1</text>
<text x="128" y="213" font-size="8" fill="#64748b">2</text>
<text x="188" y="213" font-size="8" fill="#64748b">3</text>
<text x="248" y="213" font-size="8" fill="#64748b">4</text>
<text x="308" y="213" font-size="8" fill="#64748b">5</text>
<text x="368" y="213" font-size="8" fill="#64748b">6</text>
<text x="415" y="213" font-size="8" fill="#64748b">None</text>
<text x="48" y="200" text-anchor="end" font-size="8" fill="#64748b">0.88</text>
<text x="48" y="155" text-anchor="end" font-size="8" fill="#64748b">0.93</text>
<text x="48" y="110" text-anchor="end" font-size="8" fill="#64748b">0.97</text>
<text x="48" y="55" text-anchor="end" font-size="8" fill="#64748b">1.00</text>
<polyline points="68,181 128,158 188,143 248,124 308,112 368,33 435,18" fill="none" stroke="#f59e0b" stroke-width="2" stroke-dasharray="5,3"/>
<text x="370" y="25" font-size="9" fill="#f59e0b">train</text>
<polyline points="68,181 128,158 188,143 248,133 308,143 368,165 435,158" fill="none" stroke="#3b82f6" stroke-width="2.5"/>
<text x="370" y="155" font-size="9" fill="#3b82f6">test</text>
<line x1="248" y1="15" x2="248" y2="200" stroke="#22c55e" stroke-width="1.5" stroke-dasharray="4,2"/>
<text x="250" y="30" font-size="8" fill="#22c55e">depth=4</text>
<text x="250" y="42" font-size="8" fill="#22c55e">sweet spot</text>
<circle cx="248" cy="133" r="5" fill="#22c55e"/>
Train accuracy (orange dashed) rises monotonically to 1.0. Test accuracy (blue) peaks at depth=4 (96.5%) then falls as the tree starts memorizing training quirks. The gap between train and test widens after depth=4.
Pre-Pruning Strategy 2: min_samples_split
A node is only split if it contains at least min_samples_split samples. This prevents creating splits on very small groups.
print(f"{'min_split':>10} | {'train':>8} | {'test':>8} | {'leaves':>7}")
for ms in [2, 5, 10, 20, 30, 50, 100]:
dt = DecisionTreeClassifier(min_samples_split=ms, random_state=42)
dt.fit(X_train, y_train)
print(f"{ms:>10} | {dt.score(X_train, y_train):>8.4f} | {dt.score(X_test, y_test):>8.4f} | {dt.get_n_leaves():>7}") min_split | train | test | leaves
2 | 1.0000 | 0.9298 | 43
5 | 1.0000 | 0.9298 | 43
10 | 0.9978 | 0.9386 | 37
20 | 0.9956 | 0.9561 | 27
30 | 0.9824 | 0.9561 | 22
50 | 0.9736 | 0.9649 | 17
100 | 0.9429 | 0.9298 | 9
Increasing min_samples_split from 2 to 50 improves test accuracy from 92.98% to 96.49% by preventing splits on small, noisy groups. Beyond 100, underfitting sets in.
Pre-Pruning Strategy 3: min_samples_leaf
A leaf must contain at least min_samples_leaf samples. This is stricter than min_samples_split — it ensures both children of any split have enough samples.
for ml in [1, 2, 5, 10, 20]:
dt = DecisionTreeClassifier(min_samples_leaf=ml, random_state=42)
dt.fit(X_train, y_train)
print(f"min_samples_leaf={ml:>3}: leaves={dt.get_n_leaves():>3}, test={dt.score(X_test, y_test):.4f}")min_samples_leaf= 1: leaves= 43, test=0.9298
min_samples_leaf= 2: leaves= 32, test=0.9386
min_samples_leaf= 5: leaves= 19, test=0.9474
min_samples_leaf= 10: leaves= 12, test=0.9561
min_samples_leaf= 20: leaves= 7, test=0.9649
Each increase in min_samples_leaf removes small leaves and improves test accuracy up to a point.
Finding the Best Combination with GridSearchCV
from sklearn.model_selection import GridSearchCV
param_grid = {
'max_depth': [3, 4, 5, 6, None],
'min_samples_split': [2, 5, 10, 20],
'min_samples_leaf': [1, 5, 10],
}
gs = GridSearchCV(
DecisionTreeClassifier(random_state=42),
param_grid, cv=10, scoring='accuracy', n_jobs=-1
)
gs.fit(X_train, y_train)
print(f"Best params: {gs.best_params_}")
print(f"Best CV accuracy: {gs.best_score_:.4f}")
print(f"Test accuracy: {gs.best_estimator_.score(X_test, y_test):.4f}")Best params: {'max_depth': 4, 'min_samples_leaf': 5, 'min_samples_split': 10}
Best CV accuracy: 0.9647
Test accuracy: 0.9737
The best pre-pruned tree: depth=4, min_samples_leaf=5, min_samples_split=10. This matches intuition from the individual sweeps.
Post-Pruning: Cost Complexity Pruning (CCP)
Pre-pruning stops growth early. Post-pruning grows the full tree first, then removes subtrees that don't justify their complexity.
CART's cost complexity pruning adds a penalty per leaf. The effective tree is the one that minimizes:
where is the total leaf impurity and is the number of leaves. As increases, leaves become more expensive — subtrees get pruned until only the root remains.
dt_full = DecisionTreeClassifier(random_state=42)
path = dt_full.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas
print(f"Number of alpha values: {len(ccp_alphas)}")
print(f"Alpha range: [{ccp_alphas[0]:.6f}, {ccp_alphas[-1]:.4f}]")Number of alpha values: 44
Alpha range: [0.000000, 0.4997]
clfs, train_scores, test_scores, leaf_counts = [], [], [], []
for alpha in ccp_alphas:
clf = DecisionTreeClassifier(random_state=42, ccp_alpha=alpha)
clf.fit(X_train, y_train)
clfs.append(clf)
train_scores.append(clf.score(X_train, y_train))
test_scores.append(clf.score(X_test, y_test))
leaf_counts.append(clf.get_n_leaves())
best_idx = np.argmax(test_scores)
print(f"Best alpha: {ccp_alphas[best_idx]:.6f}")
print(f"Best test accuracy: {test_scores[best_idx]:.4f}")
print(f"Leaves at best alpha: {leaf_counts[best_idx]}")Best alpha: 0.005000
Best test accuracy: 0.9737
Leaves at best alpha: 12
At : 12 leaves, 97.37% test accuracy — identical to the pre-pruned GridSearch result. The optimal tree has 12 leaves whether found by pre-pruning or CCP.
<rect x="10" y="18" width="185" height="165" fill="#f8fafc" stroke="#e2e8f0" stroke-width="1"/>
<rect x="215" y="18" width="185" height="165" fill="#f8fafc" stroke="#e2e8f0" stroke-width="1"/>
<rect x="430" y="18" width="185" height="165" fill="#f8fafc" stroke="#e2e8f0" stroke-width="1"/>
<polyline points="10,25 30,25 30,35 50,35 50,50 70,50 70,75 90,75 90,100 110,100 110,130 130,130 130,150 150,150 150,165 185,165 185,183" fill="none" stroke="#3b82f6" stroke-width="2"/>
<text x="12" y="30" font-size="8" fill="#64748b">43</text>
<text x="12" y="183" font-size="8" fill="#64748b">1</text>
<text x="12" y="175" font-size="8" fill="#64748b">α→</text>
<polyline points="215,183 230,165 240,140 255,100 265,88 275,80 285,82 295,90 310,100 330,115 360,140 400,165" fill="none" stroke="#3b82f6" stroke-width="2"/>
<polyline points="215,160 230,148 240,130 255,100 265,85 275,78 285,76 295,80 310,90 330,108 360,135 400,160" fill="none" stroke="#f59e0b" stroke-width="2" stroke-dasharray="4,2"/>
<line x1="275" y1="18" x2="275" y2="183" stroke="#22c55e" stroke-width="1.5" stroke-dasharray="3,2"/>
<text x="277" y="35" font-size="8" fill="#22c55e">α=0.005</text>
<text x="218" y="175" font-size="8" fill="#3b82f6">test</text>
<text x="218" y="185" font-size="8" fill="#f59e0b">train</text>
<polyline points="430,183 455,165 470,140 490,100 510,90 530,88 545,92 560,108 575,140 610,165" fill="none" stroke="#3b82f6" stroke-width="2"/>
<circle cx="530" cy="88" r="5" fill="#22c55e"/>
<text x="532" y="85" font-size="8" fill="#22c55e">12 leaves</text>
<text x="432" y="175" font-size="8" fill="#64748b">1 leaf</text>
<text x="575" y="175" font-size="8" fill="#64748b">43 leaves</text>
Left: leaf count decreases as increases. Center: test accuracy (blue) peaks at , train (orange) decreases monotonically. Right: test accuracy peaks at 12 leaves, then drops at both extremes.
Pre-Pruning vs Post-Pruning
| Aspect | Pre-Pruning | Post-Pruning (CCP) |
|---|---|---|
| When applied | During tree building | After full tree is built |
| Parameters | max_depth, min_samples_* | ccp_alpha |
| Computational cost | Cheap — stops early | Expensive — builds full tree first |
| Risk | May stop splitting too early at depth limit | More principled — uses actual tree structure |
| Best test accuracy (this data) | 0.9737 | 0.9737 |
Pruning Parameter Summary
| Parameter | Small value | Large value | Typical range |
|---|---|---|---|
max_depth | Deep (overfit) | Shallow (underfit) | 3–10 |
min_samples_split | Many tiny nodes | Fewer splits | 10–20 |
min_samples_leaf | Tiny leaves (overfit) | Large leaves (underfit) | 5–20 |
ccp_alpha | No pruning | Heavily pruned | Find via CV |
Test Your Understanding
-
The fully grown tree has 43 leaves for 455 training samples (~10 samples/leaf). If a leaf contains exactly 2 samples with different classes (1 Yes, 1 No), what is its entropy? What class does it predict, and what is its local training accuracy?
-
At
max_depth=4: train=97.14%, test=96.49%. Atmax_depth=5: train=98.24%, test=95.61%. The test accuracy drops from depth 4 to 5 despite more training accuracy. Describe what changes in the tree between depth=4 and depth=5 that could explain this. -
CCP starts with the full tree (alpha=0) and increases alpha. At each alpha value, which subtrees are pruned first — the ones at the top of the tree or the leaves at the bottom? Why?
-
Both pre-pruning (GridSearchCV) and post-pruning (CCP) found 97.37% test accuracy with approximately 12 leaves. Does this guarantee they found the same tree? What would you compare to check?
-
You have 10,000 training samples and want to choose between
min_samples_leaf=10andmin_samples_leaf=100. The dataset has a rare class that appears in only 2% of samples (200 examples). How does this class imbalance affect your choice ofmin_samples_leaf?