Source code for densitree.steps.cluster

from __future__ import annotations
import numpy as np
from sklearn.cluster import AgglomerativeClustering, MiniBatchKMeans
from .base import BaseStep


[docs] class ClusterStep(BaseStep): """Two-stage clustering on the downsampled cell set. Stage 1: Overcluster into ``n_micro`` microclusters using MiniBatchKMeans (fast, sees all downsampled cells, captures fine structure). Stage 2: Merge microclusters into ``n_clusters`` metaclusters using agglomerative clustering on the microcluster centroids. This approach produces much better cluster boundaries than single-stage agglomerative clustering because MiniBatchKMeans scales linearly and produces stable microclusters, agglomerative merging on centroids is fast, and upsampling to fine-grained microcluster centroids dramatically improves cell assignment accuracy. Returns micro-level and meta-level labels plus centroids for both. """ def __init__( self, n_clusters: int = 50, n_micro: int | None = None, linkage: str = "average", ) -> None: self.n_clusters = n_clusters self.n_micro = n_micro self.linkage = linkage
[docs] def run(self, data: np.ndarray, *, X_down: np.ndarray, **ctx) -> dict: n_down = len(X_down) # Determine number of microclusters n_micro = self.n_micro if n_micro is None: # Heuristic: 10x n_clusters, capped by data size n_micro = min(10 * self.n_clusters, n_down // 5) n_micro = max(n_micro, self.n_clusters) if n_micro <= self.n_clusters or n_down <= n_micro: # Fall back to single-stage agglomerative return self._single_stage(X_down) # Stage 1: overclustering with MiniBatchKMeans micro_model = MiniBatchKMeans( n_clusters=n_micro, random_state=0, batch_size=min(1024, n_down), n_init=3, ) micro_labels = micro_model.fit_predict(X_down) micro_centroids = micro_model.cluster_centers_ # Stage 2: merge microclusters into metaclusters meta_model = AgglomerativeClustering( n_clusters=self.n_clusters, linkage=self.linkage, ) meta_of_micro = meta_model.fit_predict(micro_centroids) # Map downsampled cells: micro -> meta meta_labels = meta_of_micro[micro_labels] # Compute metacluster centroids from downsampled data meta_centroids = np.array([ X_down[meta_labels == i].mean(axis=0) for i in range(self.n_clusters) ]) return { "cluster_labels_down": meta_labels, "centroids": meta_centroids, "micro_centroids": micro_centroids, "micro_to_meta": meta_of_micro, }
def _single_stage(self, X_down: np.ndarray) -> dict: model = AgglomerativeClustering( n_clusters=self.n_clusters, linkage="ward", ) labels = model.fit_predict(X_down) centroids = np.array([ X_down[labels == i].mean(axis=0) for i in range(self.n_clusters) ]) return { "cluster_labels_down": labels, "centroids": centroids, "micro_centroids": centroids, "micro_to_meta": np.arange(self.n_clusters), }