Source code for densitree.steps.downsample

from __future__ import annotations
import numpy as np
from .base import BaseStep


[docs] class DownsampleStep(BaseStep): """Density-normalized downsampling. Cells in dense regions are sampled with lower probability so that rare populations (low density) are preserved after downsampling. Inclusion probability for cell i: ``p_i = min(1, target_count * w_i / sum(w))`` where ``w_i = 1 / density_i``. """ def __init__(self, downsample_target: float = 0.05, random_state: int | None = None) -> None: if not 0 < downsample_target <= 1: raise ValueError(f"downsample_target must be in (0, 1], got {downsample_target}") self.downsample_target = downsample_target self.random_state = random_state
[docs] def run(self, data: np.ndarray, *, density: np.ndarray, **ctx) -> dict: rng = np.random.default_rng(self.random_state) n = len(data) target_count = max(1, int(n * self.downsample_target)) weights = 1.0 / density probs = np.minimum(1.0, target_count * weights / weights.sum()) mask = rng.random(n) < probs down_idx = np.where(mask)[0] return {"X_down": data[down_idx], "down_idx": down_idx}