Back to blog
← View series: machine learning

~/blog

Decision Tree Pruning: Pre-Pruning and Post-Pruning

Jun 26, 20267 min readBy Mohammed Vasim
Machine LearningAIData Science

An unconstrained decision tree will grow until every leaf is pure — one sample per leaf if necessary. On training data this achieves 100% accuracy. On test data it fails badly: the tree has memorized noise rather than learned a generalizable pattern. Pruning limits tree growth to prevent this.

Anchor dataset: Breast Cancer Wisconsin — 30 features, binary classification (malignant/benign), 569 samples.

python
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np

data = load_breast_cancer()
X, y = data.data, data.target

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

Why Trees Overfit Without Constraints

python
dt_full = DecisionTreeClassifier(random_state=42)
dt_full.fit(X_train, y_train)

print(f"Max depth:  {dt_full.get_depth()}")
print(f"Leaf count: {dt_full.get_n_leaves()}")
print(f"Train accuracy: {dt_full.score(X_train, y_train):.4f}")
print(f"Test accuracy:  {dt_full.score(X_test, y_test):.4f}")
Max depth: 7 Leaf count: 43 Train accuracy: 1.0000 Test accuracy: 0.9298

43 leaves for 455 training samples — roughly 10 samples per leaf on average. Some leaves contain 1–2 samples that happen to land there by quirks of the training data. Train accuracy is 100% (memorized) but test accuracy is 92.98% — a 7% gap from overfitting.

Pre-Pruning Strategy 1: max_depth

Limit how deep the tree can grow before stopping.

python
print(f"{'depth':>6} | {'train':>8} | {'test':>8} | {'leaves':>7}")
for d in [1, 2, 3, 4, 5, 6, None]:
    dt = DecisionTreeClassifier(max_depth=d, random_state=42)
    dt.fit(X_train, y_train)
    tr = dt.score(X_train, y_train)
    te = dt.score(X_test, y_test)
    lv = dt.get_n_leaves()
    print(f"{str(d):>6} | {tr:>8.4f} | {te:>8.4f} | {lv:>7}")
depth | train | test | leaves 1 | 0.8967 | 0.8860 | 2 ← underfit 2 | 0.9385 | 0.9298 | 4 3 | 0.9560 | 0.9561 | 7 4 | 0.9714 | 0.9649 | 13 ← sweet spot 5 | 0.9824 | 0.9561 | 22 6 | 0.9978 | 0.9386 | 37 None | 1.0000 | 0.9298 | 43 ← overfit max_depth Accuracy <text x="68" y="213" font-size="8" fill="#64748b">1</text> <text x="128" y="213" font-size="8" fill="#64748b">2</text> <text x="188" y="213" font-size="8" fill="#64748b">3</text> <text x="248" y="213" font-size="8" fill="#64748b">4</text> <text x="308" y="213" font-size="8" fill="#64748b">5</text> <text x="368" y="213" font-size="8" fill="#64748b">6</text> <text x="415" y="213" font-size="8" fill="#64748b">None</text> <text x="48" y="200" text-anchor="end" font-size="8" fill="#64748b">0.88</text> <text x="48" y="155" text-anchor="end" font-size="8" fill="#64748b">0.93</text> <text x="48" y="110" text-anchor="end" font-size="8" fill="#64748b">0.97</text> <text x="48" y="55" text-anchor="end" font-size="8" fill="#64748b">1.00</text> <polyline points="68,181 128,158 188,143 248,124 308,112 368,33 435,18" fill="none" stroke="#f59e0b" stroke-width="2" stroke-dasharray="5,3"/> <text x="370" y="25" font-size="9" fill="#f59e0b">train</text> <polyline points="68,181 128,158 188,143 248,133 308,143 368,165 435,158" fill="none" stroke="#3b82f6" stroke-width="2.5"/> <text x="370" y="155" font-size="9" fill="#3b82f6">test</text> <line x1="248" y1="15" x2="248" y2="200" stroke="#22c55e" stroke-width="1.5" stroke-dasharray="4,2"/> <text x="250" y="30" font-size="8" fill="#22c55e">depth=4</text> <text x="250" y="42" font-size="8" fill="#22c55e">sweet spot</text> <circle cx="248" cy="133" r="5" fill="#22c55e"/>

Train accuracy (orange dashed) rises monotonically to 1.0. Test accuracy (blue) peaks at depth=4 (96.5%) then falls as the tree starts memorizing training quirks. The gap between train and test widens after depth=4.

Pre-Pruning Strategy 2: min_samples_split

A node is only split if it contains at least min_samples_split samples. This prevents creating splits on very small groups.

python
print(f"{'min_split':>10} | {'train':>8} | {'test':>8} | {'leaves':>7}")
for ms in [2, 5, 10, 20, 30, 50, 100]:
    dt = DecisionTreeClassifier(min_samples_split=ms, random_state=42)
    dt.fit(X_train, y_train)
    print(f"{ms:>10} | {dt.score(X_train, y_train):>8.4f} | {dt.score(X_test, y_test):>8.4f} | {dt.get_n_leaves():>7}")
min_split | train | test | leaves 2 | 1.0000 | 0.9298 | 43 5 | 1.0000 | 0.9298 | 43 10 | 0.9978 | 0.9386 | 37 20 | 0.9956 | 0.9561 | 27 30 | 0.9824 | 0.9561 | 22 50 | 0.9736 | 0.9649 | 17 100 | 0.9429 | 0.9298 | 9

Increasing min_samples_split from 2 to 50 improves test accuracy from 92.98% to 96.49% by preventing splits on small, noisy groups. Beyond 100, underfitting sets in.

Pre-Pruning Strategy 3: min_samples_leaf

A leaf must contain at least min_samples_leaf samples. This is stricter than min_samples_split — it ensures both children of any split have enough samples.

python
for ml in [1, 2, 5, 10, 20]:
    dt = DecisionTreeClassifier(min_samples_leaf=ml, random_state=42)
    dt.fit(X_train, y_train)
    print(f"min_samples_leaf={ml:>3}: leaves={dt.get_n_leaves():>3}, test={dt.score(X_test, y_test):.4f}")
min_samples_leaf= 1: leaves= 43, test=0.9298 min_samples_leaf= 2: leaves= 32, test=0.9386 min_samples_leaf= 5: leaves= 19, test=0.9474 min_samples_leaf= 10: leaves= 12, test=0.9561 min_samples_leaf= 20: leaves= 7, test=0.9649

Each increase in min_samples_leaf removes small leaves and improves test accuracy up to a point.

Finding the Best Combination with GridSearchCV

python
from sklearn.model_selection import GridSearchCV

param_grid = {
    'max_depth': [3, 4, 5, 6, None],
    'min_samples_split': [2, 5, 10, 20],
    'min_samples_leaf': [1, 5, 10],
}
gs = GridSearchCV(
    DecisionTreeClassifier(random_state=42),
    param_grid, cv=10, scoring='accuracy', n_jobs=-1
)
gs.fit(X_train, y_train)
print(f"Best params: {gs.best_params_}")
print(f"Best CV accuracy: {gs.best_score_:.4f}")
print(f"Test accuracy: {gs.best_estimator_.score(X_test, y_test):.4f}")
Best params: {'max_depth': 4, 'min_samples_leaf': 5, 'min_samples_split': 10} Best CV accuracy: 0.9647 Test accuracy: 0.9737

The best pre-pruned tree: depth=4, min_samples_leaf=5, min_samples_split=10. This matches intuition from the individual sweeps.

Post-Pruning: Cost Complexity Pruning (CCP)

Pre-pruning stops growth early. Post-pruning grows the full tree first, then removes subtrees that don't justify their complexity.

CART's cost complexity pruning adds a penalty per leaf. The effective tree is the one that minimizes:

where is the total leaf impurity and is the number of leaves. As increases, leaves become more expensive — subtrees get pruned until only the root remains.

python
dt_full = DecisionTreeClassifier(random_state=42)
path = dt_full.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas

print(f"Number of alpha values: {len(ccp_alphas)}")
print(f"Alpha range: [{ccp_alphas[0]:.6f}, {ccp_alphas[-1]:.4f}]")
Number of alpha values: 44 Alpha range: [0.000000, 0.4997]
python
clfs, train_scores, test_scores, leaf_counts = [], [], [], []
for alpha in ccp_alphas:
    clf = DecisionTreeClassifier(random_state=42, ccp_alpha=alpha)
    clf.fit(X_train, y_train)
    clfs.append(clf)
    train_scores.append(clf.score(X_train, y_train))
    test_scores.append(clf.score(X_test, y_test))
    leaf_counts.append(clf.get_n_leaves())

best_idx = np.argmax(test_scores)
print(f"Best alpha: {ccp_alphas[best_idx]:.6f}")
print(f"Best test accuracy: {test_scores[best_idx]:.4f}")
print(f"Leaves at best alpha: {leaf_counts[best_idx]}")
Best alpha: 0.005000 Best test accuracy: 0.9737 Leaves at best alpha: 12

At : 12 leaves, 97.37% test accuracy — identical to the pre-pruned GridSearch result. The optimal tree has 12 leaves whether found by pre-pruning or CCP.

Leaves vs ccp_alpha Accuracy vs ccp_alpha Accuracy vs Leaves <rect x="10" y="18" width="185" height="165" fill="#f8fafc" stroke="#e2e8f0" stroke-width="1"/> <rect x="215" y="18" width="185" height="165" fill="#f8fafc" stroke="#e2e8f0" stroke-width="1"/> <rect x="430" y="18" width="185" height="165" fill="#f8fafc" stroke="#e2e8f0" stroke-width="1"/> <polyline points="10,25 30,25 30,35 50,35 50,50 70,50 70,75 90,75 90,100 110,100 110,130 130,130 130,150 150,150 150,165 185,165 185,183" fill="none" stroke="#3b82f6" stroke-width="2"/> <text x="12" y="30" font-size="8" fill="#64748b">43</text> <text x="12" y="183" font-size="8" fill="#64748b">1</text> <text x="12" y="175" font-size="8" fill="#64748b">α→</text> <polyline points="215,183 230,165 240,140 255,100 265,88 275,80 285,82 295,90 310,100 330,115 360,140 400,165" fill="none" stroke="#3b82f6" stroke-width="2"/> <polyline points="215,160 230,148 240,130 255,100 265,85 275,78 285,76 295,80 310,90 330,108 360,135 400,160" fill="none" stroke="#f59e0b" stroke-width="2" stroke-dasharray="4,2"/> <line x1="275" y1="18" x2="275" y2="183" stroke="#22c55e" stroke-width="1.5" stroke-dasharray="3,2"/> <text x="277" y="35" font-size="8" fill="#22c55e">α=0.005</text> <text x="218" y="175" font-size="8" fill="#3b82f6">test</text> <text x="218" y="185" font-size="8" fill="#f59e0b">train</text> <polyline points="430,183 455,165 470,140 490,100 510,90 530,88 545,92 560,108 575,140 610,165" fill="none" stroke="#3b82f6" stroke-width="2"/> <circle cx="530" cy="88" r="5" fill="#22c55e"/> <text x="532" y="85" font-size="8" fill="#22c55e">12 leaves</text> <text x="432" y="175" font-size="8" fill="#64748b">1 leaf</text> <text x="575" y="175" font-size="8" fill="#64748b">43 leaves</text>

Left: leaf count decreases as increases. Center: test accuracy (blue) peaks at , train (orange) decreases monotonically. Right: test accuracy peaks at 12 leaves, then drops at both extremes.

Pre-Pruning vs Post-Pruning

AspectPre-PruningPost-Pruning (CCP)
When appliedDuring tree buildingAfter full tree is built
Parametersmax_depth, min_samples_*ccp_alpha
Computational costCheap — stops earlyExpensive — builds full tree first
RiskMay stop splitting too early at depth limitMore principled — uses actual tree structure
Best test accuracy (this data)0.97370.9737

Pruning Parameter Summary

ParameterSmall valueLarge valueTypical range
max_depthDeep (overfit)Shallow (underfit)3–10
min_samples_splitMany tiny nodesFewer splits10–20
min_samples_leafTiny leaves (overfit)Large leaves (underfit)5–20
ccp_alphaNo pruningHeavily prunedFind via CV

Test Your Understanding

  1. The fully grown tree has 43 leaves for 455 training samples (~10 samples/leaf). If a leaf contains exactly 2 samples with different classes (1 Yes, 1 No), what is its entropy? What class does it predict, and what is its local training accuracy?

  2. At max_depth=4: train=97.14%, test=96.49%. At max_depth=5: train=98.24%, test=95.61%. The test accuracy drops from depth 4 to 5 despite more training accuracy. Describe what changes in the tree between depth=4 and depth=5 that could explain this.

  3. CCP starts with the full tree (alpha=0) and increases alpha. At each alpha value, which subtrees are pruned first — the ones at the top of the tree or the leaves at the bottom? Why?

  4. Both pre-pruning (GridSearchCV) and post-pruning (CCP) found 97.37% test accuracy with approximately 12 leaves. Does this guarantee they found the same tree? What would you compare to check?

  5. You have 10,000 training samples and want to choose between min_samples_leaf=10 and min_samples_leaf=100. The dataset has a rare class that appears in only 2% of samples (200 examples). How does this class imbalance affect your choice of min_samples_leaf?

Comments (0)

No comments yet. Be the first to comment!

Leave a comment