group-wbl/.venv/lib/python3.13/site-packages/sklearn/tree/tests/test_fenwick.py

52 lines
1.7 KiB
Python
Raw Permalink Normal View History

2026-01-09 09:48:03 +08:00
import numpy as np
from sklearn.tree._utils import PytestWeightedFenwickTree
def test_cython_weighted_fenwick_tree(global_random_seed):
"""
Test Cython's weighted Fenwick tree implementation
"""
rng = np.random.default_rng(global_random_seed)
n = 100
indices = rng.permutation(n)
y = rng.normal(size=n)
w = rng.integers(0, 4, size=n)
y_included_so_far = np.zeros_like(y)
w_included_so_far = np.zeros_like(w)
tree = PytestWeightedFenwickTree(n)
tree.py_reset(n)
for i in range(n):
idx = indices[i]
tree.py_add(idx, y[idx], w[idx])
y_included_so_far[idx] = y[idx]
w_included_so_far[idx] = w[idx]
target = rng.uniform(0, w_included_so_far.sum())
t_idx_low, t_idx, cw, cwy = tree.py_search(target)
# check the aggregates are consistent with the returned idx
assert np.isclose(cw, np.sum(w_included_so_far[:t_idx]))
assert np.isclose(
cwy, np.sum(w_included_so_far[:t_idx] * y_included_so_far[:t_idx])
)
# check if the cumulative weight is less than or equal to the target
# depending on t_idx_low and t_idx
if t_idx_low == t_idx:
assert cw < target
else:
assert cw == target
# check that if we add the next non-null weight, we are above the target:
next_weights = w_included_so_far[t_idx:][w_included_so_far[t_idx:] > 0]
if next_weights.size > 0:
assert cw + next_weights[0] > target
# and not below the target for `t_idx_low`:
next_weights = w_included_so_far[t_idx_low:][w_included_so_far[t_idx_low:] > 0]
if next_weights.size > 0:
assert cw + next_weights[0] >= target