from __future__ import annotations
import numpy as np
import pandas as pd
from sklearn.cluster import MiniBatchKMeans, AgglomerativeClustering
from sklearn.metrics import adjusted_rand_score as _inter_run_ari
from scipy.optimize import linear_sum_assignment
from .result import SPADEResult
from .steps.density import DensityEstimator
from .steps.downsample import DownsampleStep
from .steps.mst import MSTBuilder
from .steps.base import BaseStep
[docs]
class SPADE:
"""SPADE clustering with scikit-learn-compatible API.
Improved SPADE that combines density-dependent downsampling (for rare
population preservation and tree construction) with consensus
overclustering (for accurate cell assignment).
The algorithm:
1. **Density estimation** (k-NN) on all cells.
2. **Consensus clustering** over multiple runs (overcluster into
``n_micro`` microclusters, merge into ``n_clusters`` metaclusters,
align labels via Hungarian algorithm, filter low-agreement runs,
take majority vote).
3. **Density-dependent downsampling** for tree construction.
4. **MST construction** on metacluster centroids.
Parameters
----------
n_clusters : int
Number of clusters (default 50).
downsample_target : float
Fraction of cells to retain for tree construction (default 0.1).
knn : int
k for k-NN density estimation (default 5).
n_micro : int | None
Number of microclusters. ``None`` uses ``min(10 * n_clusters, n_cells // 10)``.
n_consensus : int
Number of MiniBatchKMeans runs per linkage type for consensus.
Total runs = 2 * n_consensus (ward + average). Default 10.
transform : str | None
``'arcsinh'``, ``'log'``, or ``None``.
cofactor : float
Arcsinh cofactor (default 150.0).
backend : str
Default plotting backend.
density_estimator : BaseStep | None
Custom density estimator step.
random_state : int | None
Seed for reproducibility.
"""
def __init__(
self,
n_clusters: int = 50,
downsample_target: float = 0.1,
knn: int = 5,
n_micro: int | None = None,
n_consensus: int = 10,
transform: str | None = "arcsinh",
cofactor: float = 150.0,
backend: str = "matplotlib",
density_estimator: BaseStep | None = None,
random_state: int | None = None,
) -> None:
if not 0 < downsample_target <= 1:
raise ValueError(f"downsample_target must be in (0, 1], got {downsample_target}")
self.n_clusters = n_clusters
self.downsample_target = downsample_target
self.knn = knn
self.n_micro = n_micro
self.n_consensus = n_consensus
self.transform = transform
self.cofactor = cofactor
self.backend = backend
self.random_state = random_state
self._density_estimator = density_estimator or DensityEstimator(knn=knn)
self.labels_: np.ndarray | None = None
self.result_: SPADEResult | None = None
[docs]
def fit(self, X: np.ndarray | pd.DataFrame) -> "SPADE":
"""Fit SPADE to data."""
feature_names: list[str] | None = None
if isinstance(X, pd.DataFrame):
feature_names = list(X.columns)
X = X.values
X = np.asarray(X, dtype=float)
n_cells, n_features = X.shape
if n_cells < self.n_clusters:
raise ValueError(
f"n_clusters ({self.n_clusters}) must be <= number of cells ({n_cells})."
)
X_t = self._apply_transform(X)
seed = self.random_state or 0
# Step 1: Density estimation
density_ctx = self._density_estimator.run(X_t)
density = density_ctx["density"]
# Step 2: Consensus clustering
n_micro = self.n_micro
if n_micro is None:
n_micro = min(10 * self.n_clusters, n_cells // 10)
n_micro = max(n_micro, self.n_clusters + 1)
if self.n_consensus <= 1:
labels = self._single_run_cluster(X_t, n_micro, seed)
else:
labels = self._consensus_cluster(X_t, n_micro, n_cells, seed)
# Step 3: Density-dependent downsampling (for tree construction)
down_ctx = DownsampleStep(
downsample_target=self.downsample_target,
random_state=seed,
).run(X_t, density=density)
# Step 4: MST on metacluster centroids (original space for medians)
centroids = np.array([
X_t[labels == c].mean(axis=0)
if (labels == c).sum() > 0 else np.zeros(n_features)
for c in range(self.n_clusters)
])
mst_ctx = MSTBuilder().run(X, centroids=centroids, labels_=labels)
self.labels_ = labels
self.result_ = SPADEResult(
labels_=labels,
tree_=mst_ctx["tree_"],
X_down=down_ctx["X_down"],
down_idx=down_ctx["down_idx"],
n_features=n_features,
feature_names=feature_names,
)
return self
def _single_run_cluster(
self, X_t: np.ndarray, n_micro: int, seed: int,
) -> np.ndarray:
"""Single-run two-stage clustering."""
micro = MiniBatchKMeans(
n_clusters=n_micro, random_state=seed,
batch_size=min(2048, len(X_t)), n_init=5,
)
micro_labels = micro.fit_predict(X_t)
meta = AgglomerativeClustering(
n_clusters=self.n_clusters, linkage="average",
)
meta_of_micro = meta.fit_predict(micro.cluster_centers_)
return meta_of_micro[micro_labels]
def _consensus_cluster(
self, X_t: np.ndarray, n_micro: int, n_cells: int, base_seed: int,
) -> np.ndarray:
"""Consensus clustering: multiple overclustering runs with both
ward and average linkage, aligned via Hungarian algorithm, filtered
by inter-run agreement, and combined via majority vote.
"""
# Generate candidate label arrays
runs: list[np.ndarray] = []
for r in range(self.n_consensus):
micro = MiniBatchKMeans(
n_clusters=n_micro, random_state=base_seed + r,
batch_size=min(1024, n_cells), n_init=5,
)
micro_labels = micro.fit_predict(X_t)
centroids = micro.cluster_centers_
for linkage in ("ward", "average"):
meta = AgglomerativeClustering(
n_clusters=self.n_clusters, linkage=linkage,
)
meta_of_micro = meta.fit_predict(centroids)
runs.append(meta_of_micro[micro_labels])
total_runs = len(runs)
# Pairwise agreement (ARI between runs)
ari_matrix = np.zeros((total_runs, total_runs))
for i in range(total_runs):
for j in range(i + 1, total_runs):
a = _inter_run_ari(runs[i], runs[j])
ari_matrix[i, j] = a
ari_matrix[j, i] = a
mean_agreement = ari_matrix.sum(axis=1) / (total_runs - 1)
# Keep top 60% of runs by agreement
threshold = np.percentile(mean_agreement, 40)
good_idx = np.where(mean_agreement >= threshold)[0]
# Use highest-agreement run as alignment reference
ref_idx = good_idx[mean_agreement[good_idx].argmax()]
ref = runs[ref_idx]
# Align and vote
aligned: list[np.ndarray] = []
for r in good_idx:
if r == ref_idx:
aligned.append(runs[r])
continue
# Build confusion matrix and solve assignment
conf = np.zeros((self.n_clusters, self.n_clusters), dtype=int)
for i in range(n_cells):
conf[ref[i], runs[r][i]] += 1
_, col_idx = linear_sum_assignment(-conf)
mapping = np.zeros(self.n_clusters, dtype=int)
mapping[col_idx] = np.arange(self.n_clusters)
aligned.append(mapping[runs[r]])
# Majority vote
votes = np.zeros((n_cells, self.n_clusters), dtype=np.int32)
for labels_r in aligned:
for c in range(self.n_clusters):
votes[:, c] += (labels_r == c).astype(np.int32)
return votes.argmax(axis=1)
[docs]
def fit_predict(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
"""Fit and return cluster labels for all cells."""
self.fit(X)
return self.labels_
def _apply_transform(self, X: np.ndarray) -> np.ndarray:
match self.transform:
case "arcsinh":
return np.arcsinh(X / self.cofactor)
case "log":
return np.log1p(np.clip(X, 0, None))
case None:
return X.copy()
case _:
raise ValueError(f"Unknown transform '{self.transform}'. Use 'arcsinh', 'log', or None.")