Skip to content

SPADE

densitree.spade.SPADE(n_clusters=50, downsample_target=0.1, knn=5, n_micro=None, n_consensus=10, transform='arcsinh', cofactor=150.0, backend='matplotlib', density_estimator=None, random_state=None)

SPADE clustering with scikit-learn-compatible API.

Improved SPADE that combines density-dependent downsampling (for rare population preservation and tree construction) with consensus overclustering (for accurate cell assignment).

The algorithm:

  1. Density estimation (k-NN) on all cells.
  2. Consensus clustering over multiple runs: a. Overcluster all cells into n_micro microclusters (MiniBatchKMeans). b. Merge microclusters into n_clusters metaclusters using both ward and average linkage agglomerative clustering. c. Align labels across runs (Hungarian algorithm) and take majority vote. d. Filter out low-agreement runs before voting.
  3. Density-dependent downsampling for tree construction.
  4. MST construction on metacluster centroids.

Parameters:

Name Type Description Default
n_clusters int

Number of clusters (default 50).

50
downsample_target float

Fraction of cells to retain for tree construction (default 0.1).

0.1
knn int

k for k-NN density estimation (default 5).

5
n_micro int | None

Number of microclusters. None uses min(10 * n_clusters, n_cells // 10).

None
n_consensus int

Number of MiniBatchKMeans runs per linkage type for consensus. Total runs = 2 * n_consensus (ward + average). Default 10.

10
transform str | None

'arcsinh', 'log', or None.

'arcsinh'
cofactor float

Arcsinh cofactor (default 150.0).

150.0
backend str

Default plotting backend.

'matplotlib'
density_estimator BaseStep | None

Custom density estimator step.

None
random_state int | None

Seed for reproducibility.

None
Source code in densitree/spade.py
def __init__(
    self,
    n_clusters: int = 50,
    downsample_target: float = 0.1,
    knn: int = 5,
    n_micro: int | None = None,
    n_consensus: int = 10,
    transform: str | None = "arcsinh",
    cofactor: float = 150.0,
    backend: str = "matplotlib",
    density_estimator: BaseStep | None = None,
    random_state: int | None = None,
) -> None:
    if not 0 < downsample_target <= 1:
        raise ValueError(f"downsample_target must be in (0, 1], got {downsample_target}")
    self.n_clusters = n_clusters
    self.downsample_target = downsample_target
    self.knn = knn
    self.n_micro = n_micro
    self.n_consensus = n_consensus
    self.transform = transform
    self.cofactor = cofactor
    self.backend = backend
    self.random_state = random_state
    self._density_estimator = density_estimator or DensityEstimator(knn=knn)

    self.labels_: np.ndarray | None = None
    self.result_: SPADEResult | None = None

fit(X)

Fit SPADE to data.

Source code in densitree/spade.py
def fit(self, X: np.ndarray | pd.DataFrame) -> "SPADE":
    """Fit SPADE to data."""
    feature_names: list[str] | None = None
    if isinstance(X, pd.DataFrame):
        feature_names = list(X.columns)
        X = X.values

    X = np.asarray(X, dtype=float)
    n_cells, n_features = X.shape

    if n_cells < self.n_clusters:
        raise ValueError(
            f"n_clusters ({self.n_clusters}) must be <= number of cells ({n_cells})."
        )

    X_t = self._apply_transform(X)
    seed = self.random_state or 0

    # Step 1: Density estimation
    density_ctx = self._density_estimator.run(X_t)
    density = density_ctx["density"]

    # Step 2: Consensus clustering
    n_micro = self.n_micro
    if n_micro is None:
        n_micro = min(10 * self.n_clusters, n_cells // 10)
        n_micro = max(n_micro, self.n_clusters + 1)

    if self.n_consensus <= 1:
        labels = self._single_run_cluster(X_t, n_micro, seed)
    else:
        labels = self._consensus_cluster(X_t, n_micro, n_cells, seed)

    # Step 3: Density-dependent downsampling (for tree construction)
    down_ctx = DownsampleStep(
        downsample_target=self.downsample_target,
        random_state=seed,
    ).run(X_t, density=density)

    # Step 4: MST on metacluster centroids (original space for medians)
    centroids = np.array([
        X_t[labels == c].mean(axis=0)
        if (labels == c).sum() > 0 else np.zeros(n_features)
        for c in range(self.n_clusters)
    ])
    mst_ctx = MSTBuilder().run(X, centroids=centroids, labels_=labels)

    self.labels_ = labels
    self.result_ = SPADEResult(
        labels_=labels,
        tree_=mst_ctx["tree_"],
        X_down=down_ctx["X_down"],
        down_idx=down_ctx["down_idx"],
        n_features=n_features,
        feature_names=feature_names,
    )
    return self

fit_predict(X)

Fit and return cluster labels for all cells.

Source code in densitree/spade.py
def fit_predict(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
    """Fit and return cluster labels for all cells."""
    self.fit(X)
    return self.labels_