← View series: machine learning
~/blog
KD-Tree and Ball Tree: KNN Optimization
Brute-force KNN computes the distance from every query to every training sample — distance computations per prediction. At one million training samples and 1000 features, that's multiplications per query. The KD-tree and Ball tree reduce this by organizing training points spatially so large portions of the dataset can be skipped without computing distances at all.
Anchor dataset: 7 points from the house dataset, normalized for clean spatial construction.
import numpy as np
X = np.array([
[1.5, 2], # A
[2.0, 2], # B
[2.5, 3], # C
[3.5, 3], # D
[4.0, 4], # E
[4.5, 4], # F
[5.0, 5], # G
])
q = np.array([3.0, 3]) # query pointThe Problem: Brute Force Is O(n)
For each query, brute-force KNN:
- Computes distance to all training samples
- Sorts by distance:
- Returns the top
At training samples and features: multiplications per query. At 1ns per multiplication: 1 second per prediction. For a production API handling 1000 queries/second, this is impossible.
The key insight: if we know a query is close to region R₁, we can skip all points in region R₂ that is far from the query — without computing individual distances to points in R₂.
Building the KD-Tree
A KD-tree partitions the space with axis-aligned hyperplanes. The construction:
- At depth : split on dimension (cycling through dimensions)
- Split at the median along the chosen dimension
- Points left of median → left subtree; points right → right subtree
- Recurse until leaf nodes have leaf_size points
Building on 7 points, sorted by (first feature):
- Sorted: A(1.5), B(2.0), C(2.5), D(3.5), E(4.0), F(4.5), G(5.0)
- Median index = 3 → D=(3.5, 3) is the root. Split: → left, → right
Left subtree {A, B, C}, split on (next dimension):
- Sorted by : A(2), B(2), C(3). Median = B=(2.0, 2). Split: → left, → right
- Left child: A=(1.5, 2) — leaf. Right child: C=(2.5, 3) — leaf.
Right subtree {E, F, G}, split on :
- Sorted by : E(4), F(4), G(5). Median = F=(4.5, 4). Split: → left, → right
- Left child: E=(4.0, 4) — leaf. Right child: G=(5.0, 5) — leaf.
D (3.5, 3) ← split on x₁ = 3.5
/ \
B (2.0, 2) F (4.5, 4) ← split on x₂
/ \ / \
A(1.5,2) C(2.5,3) E(4.0,4) G(5.0,5)
<line x1="235" y1="56" x2="130" y2="106" stroke="#94a3b8" stroke-width="1.5"/>
<line x1="285" y1="56" x2="390" y2="106" stroke="#94a3b8" stroke-width="1.5"/>
<text x="160" y="82" font-size="8" fill="#64748b">x₁ ≤ 3.5</text>
<text x="335" y="82" font-size="8" fill="#64748b">x₁ > 3.5</text>
<rect x="60" y="106" width="140" height="46" rx="6" fill="#dcfce7" stroke="#22c55e" stroke-width="2"/>
<text x="130" y="126" text-anchor="middle" font-size="11" font-weight="bold" fill="#15803d">B = (2.0, 2)</text>
<text x="130" y="144" text-anchor="middle" font-size="9" fill="#15803d">split on x₂ = 2</text>
<rect x="320" y="106" width="140" height="46" rx="6" fill="#dcfce7" stroke="#22c55e" stroke-width="2"/>
<text x="390" y="126" text-anchor="middle" font-size="11" font-weight="bold" fill="#15803d">F = (4.5, 4)</text>
<text x="390" y="144" text-anchor="middle" font-size="9" fill="#15803d">split on x₂ = 4</text>
<line x1="100" y1="152" x2="55" y2="182" stroke="#94a3b8" stroke-width="1.5"/>
<line x1="160" y1="152" x2="205" y2="182" stroke="#94a3b8" stroke-width="1.5"/>
<line x1="360" y1="152" x2="325" y2="182" stroke="#94a3b8" stroke-width="1.5"/>
<line x1="420" y1="152" x2="465" y2="182" stroke="#94a3b8" stroke-width="1.5"/>
<rect x="18" y="182" width="76" height="32" rx="4" fill="#f1f5f9" stroke="#94a3b8" stroke-width="1.5"/>
<text x="56" y="202" text-anchor="middle" font-size="10" fill="#334155">A (1.5, 2)</text>
<rect x="170" y="182" width="76" height="32" rx="4" fill="#f1f5f9" stroke="#94a3b8" stroke-width="1.5"/>
<text x="208" y="202" text-anchor="middle" font-size="10" fill="#334155">C (2.5, 3)</text>
<rect x="290" y="182" width="70" height="32" rx="4" fill="#f1f5f9" stroke="#94a3b8" stroke-width="1.5"/>
<text x="325" y="202" text-anchor="middle" font-size="10" fill="#334155">E (4.0, 4)</text>
<rect x="430" y="182" width="70" height="32" rx="4" fill="#f1f5f9" stroke="#94a3b8" stroke-width="1.5"/>
<text x="465" y="202" text-anchor="middle" font-size="10" fill="#334155">G (5.0, 5)</text>
<line x1="260" y1="15" x2="260" y2="225" stroke="#3b82f6" stroke-width="2"/>
<text x="265" y="30" font-size="9" fill="#3b82f6" font-weight="bold">D: x₁=3.5</text>
<line x1="50" y1="155" x2="260" y2="155" stroke="#22c55e" stroke-width="1.5"/>
<text x="55" y="148" font-size="9" fill="#22c55e">B: x₂=2</text>
<line x1="260" y1="85" x2="475" y2="85" stroke="#22c55e" stroke-width="1.5"/>
<text x="265" y="78" font-size="9" fill="#22c55e">F: x₂=4</text>
<text x="100" y="200" font-size="8" fill="#64748b" text-anchor="middle">Region A</text>
<text x="100" y="90" font-size="8" fill="#64748b" text-anchor="middle">Region C</text>
<text x="360" y="160" font-size="8" fill="#64748b" text-anchor="middle">Region E</text>
<text x="360" y="50" font-size="8" fill="#64748b" text-anchor="middle">Region G</text>
<text x="155" y="130" font-size="9" fill="#64748b" text-anchor="middle">Region B (node)</text>
<text x="360" y="120" font-size="9" fill="#64748b" text-anchor="middle">Region F (node)</text>
<circle cx="80" cy="155" r="6" fill="#3b82f6"/>
<text x="82" y="148" font-size="8" fill="#334155">A</text>
<circle cx="122" cy="155" r="6" fill="#3b82f6"/>
<text x="124" y="148" font-size="8" fill="#334155">B</text>
<circle cx="160" cy="105" r="6" fill="#3b82f6"/>
<text x="162" y="98" font-size="8" fill="#334155">C</text>
<circle cx="260" cy="105" r="7" fill="#ef4444"/>
<text x="262" y="98" font-size="8" fill="#334155">D</text>
<circle cx="295" cy="85" r="6" fill="#ef4444"/>
<text x="297" y="78" font-size="8" fill="#334155">E</text>
<circle cx="335" cy="85" r="6" fill="#ef4444"/>
<text x="337" y="78" font-size="8" fill="#334155">F</text>
<circle cx="420" cy="50" r="6" fill="#ef4444"/>
<text x="422" y="43" font-size="8" fill="#334155">G</text>
<polygon points="205,100 212,113 198,113" fill="#f59e0b" stroke="#d97706" stroke-width="1.5"/>
<text x="185" y="127" font-size="9" fill="#d97706">q=(3,3)</text>
The blue vertical line is the root split (D). Green horizontal lines are the left (B) and right (F) subtree splits. Six rectangular regions emerge — each leaf's bounding box.
KD-Tree Search for Nearest Neighbor of q=(3.0, 3)
Phase 1: Descend to find the leaf candidate
- At root D=(3.5, 3): → go left
- At B=(2.0, 2): → go right
- Reach leaf C=(2.5, 3): compute . Best so far: C at 0.5.
Phase 2: Backtrack and prune
-
Back at B=(2.0, 2): . B is not better. Check left child (A)?
- A is in the region . The closest point in that region to is at the boundary : .
- PRUNE: the entire left subtree of B is more than 0.5 away. Skip A.
-
Back at D=(3.5, 3): . Tie with C. Update best = 0.5 (D ties C).
-
Check right subtree of D (F's subtree). Distance to the split plane () from : . Cannot prune — the boundary itself is at distance 0.5, so there might be a point exactly on it. Descend into F.
- At F=(4.5, 4): . F is not better. Prune both children of F (E and G) since F's distance already exceeds the best.
Result: Nearest neighbor is C=(2.5, 3) (and D=(3.5, 3) ties it). 3 distance computations — instead of 7.
The Pruning Condition
When backtracking at a node with split value on dimension , prune the unexplored child if:
The distance from the query to the split hyperplane is a lower bound on the distance to any point in the unexplored subtree. If even the closest possible point in that subtree (on the boundary) is farther than the current best, skip the entire subtree.
Ball Tree — Spherical Partitioning
KD-trees use axis-aligned boxes. Ball trees partition using nested hyperspheres (balls). Each node stores a center and radius such that all points in the subtree satisfy .
Pruning condition: prune a ball at center , radius , if:
The quantity is the minimum possible distance from to any point in the ball. If this minimum exceeds the best distance found so far, the entire ball is skipped.
Ball trees degrade less in high dimensions than KD-trees because the pruning condition is a direct distance comparison — it doesn't rely on axis-aligned bounding boxes, which become loose in high dimensions.
Speed Benchmark
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import make_classification
import time
X_large, y_large = make_classification(n_samples=100000, n_features=20, random_state=42)
q_large = X_large[:100]
for algo in ['brute', 'kd_tree', 'ball_tree']:
knn = KNeighborsClassifier(n_neighbors=5, algorithm=algo)
knn.fit(X_large, y_large)
start = time.time()
knn.predict(q_large)
elapsed = time.time() - start
print(f"{algo:12s}: {elapsed*1000:.2f}ms for 100 queries")brute : 450.23ms for 100 queries
kd_tree : 12.45ms for 100 queries
ball_tree : 8.71ms for 100 queries
~37× speedup for KD-tree, ~52× for Ball tree at , .
The Curse of Dimensionality — When KD-Tree Breaks Down
from sklearn.datasets import make_classification
for d in [5, 10, 20, 50, 100]:
X_hd, y_hd = make_classification(n_samples=10000, n_features=d, random_state=42)
knn_kd = KNeighborsClassifier(n_neighbors=5, algorithm='kd_tree')
knn_bf = KNeighborsClassifier(n_neighbors=5, algorithm='brute')
knn_kd.fit(X_hd, y_hd)
knn_bf.fit(X_hd, y_hd)
q = X_hd[:50]
t_kd = time.time(); knn_kd.predict(q); t_kd = time.time()-t_kd
t_bf = time.time(); knn_bf.predict(q); t_bf = time.time()-t_bf
print(f"d={d:>4}: KD-tree={t_kd*1000:.1f}ms, Brute={t_bf*1000:.1f}ms, speedup={t_bf/t_kd:.1f}x")d= 5: KD-tree= 0.4ms, Brute= 8.1ms, speedup=20.3x
d= 10: KD-tree= 1.2ms, Brute= 8.3ms, speedup= 6.9x
d= 20: KD-tree= 3.8ms, Brute= 8.9ms, speedup= 2.3x
d= 50: KD-tree= 7.1ms, Brute= 9.1ms, speedup= 1.3x
d= 100: KD-tree= 9.0ms, Brute= 9.4ms, speedup= 1.0x
At : KD-tree gives essentially zero speedup. Why? In high dimensions, all pairwise distances concentrate around the same value — the difference between the nearest and farthest neighbor vanishes. With no clear "near" vs "far" distinction, the pruning condition almost never fires; the search degenerates to visiting all nodes.
Comparison Table
| Method | Build Time | Query Time (low ) | Query Time (high ) | Memory |
|---|---|---|---|---|
| Brute Force | ||||
| KD-Tree | ideal | Degrades to for | ||
| Ball Tree | ideal | Better than KD-tree for |
sklearn algorithm='auto' picks brute force when or ; KD-tree for small ; Ball tree otherwise.
Test Your Understanding
-
The KD-tree search used 3 distance computations to find the nearest neighbor of among 7 points. In the worst case, a KD-tree still visits all nodes. Describe a query point configuration where the KD-tree degenerates to brute force.
-
At D's right subtree, the distance to the split plane was best distance. The condition "" was not satisfied, so the subtree was not pruned. If you slightly move the query to , does the right subtree get pruned? Compute the distance to the split plane.
-
Ball tree pruning condition: . This assumes the triangle inequality holds: . Which distance metrics satisfy the triangle inequality? (Euclidean, Manhattan, Hamming, Cosine similarity?)
-
At , KD-tree gives no speedup over brute force. A common fix is to reduce dimensions with PCA before building the KD-tree. If PCA reduces 100 features to 10 components while retaining 95% variance, what speedup would you expect for KD-tree on the reduced data?
-
sklearn's
KNeighborsClassifier(algorithm='auto')selects the algorithm automatically. If you have samples and features, which algorithm wouldautolikely select, and why is brute force sometimes preferred over tree methods for small datasets?