Source code for densitree.steps.mst

from __future__ import annotations
import numpy as np
import networkx as nx
from scipy.spatial.distance import cdist
from scipy.sparse.csgraph import minimum_spanning_tree
from scipy.sparse import csr_matrix
from .base import BaseStep


[docs] class MSTBuilder(BaseStep): """Build a minimum spanning tree connecting cluster centroids. Each node in the resulting networkx.Graph represents one cluster. Node attributes: - ``size``: number of cells assigned to that cluster - ``median``: per-feature median of cells in that cluster (ndarray) Edge weights are Euclidean distances between centroids. """
[docs] def run( self, data: np.ndarray, *, centroids: np.ndarray, labels_: np.ndarray, **ctx, ) -> dict: n_clusters = len(centroids) # Pairwise distances between centroids dist_matrix = cdist(centroids, centroids, metric="euclidean") # Compute MST on the full distance graph sparse = csr_matrix(dist_matrix) mst_sparse = minimum_spanning_tree(sparse) mst_array = mst_sparse.toarray() # Build networkx graph G = nx.Graph() G.add_nodes_from(range(n_clusters)) for i in range(n_clusters): for j in range(i + 1, n_clusters): w = mst_array[i, j] or mst_array[j, i] if w > 0: G.add_edge(i, j, weight=float(w)) # Add node attributes for cluster_id in range(n_clusters): mask = labels_ == cluster_id G.nodes[cluster_id]["size"] = int(mask.sum()) if mask.sum() > 0 and data is not None: G.nodes[cluster_id]["median"] = np.median(data[mask], axis=0) else: G.nodes[cluster_id]["median"] = centroids[cluster_id] return {"tree_": G}