Source code for densitree.result

from __future__ import annotations
from dataclasses import dataclass, field

import numpy as np
import pandas as pd
import networkx as nx


[docs] @dataclass class SPADEResult: """Rich output object from a SPADE run. Parameters ---------- labels_ : ndarray[int], shape (n_cells,) Cluster assignment for every original cell. tree_ : networkx.Graph MST connecting cluster centroids. Each node has ``size`` (int) and ``median`` (ndarray) attributes. X_down : ndarray, shape (n_down, n_features) Downsampled cells used for clustering. down_idx : ndarray[int], shape (n_down,) Indices into the original array for the downsampled cells. n_features : int Number of features in the input data. feature_names : list[str] or None Feature names. Auto-generated if ``None``. """ labels_: np.ndarray tree_: nx.Graph X_down: np.ndarray down_idx: np.ndarray n_features: int feature_names: list[str] | None = None _cluster_stats: pd.DataFrame | None = field(default=None, repr=False, compare=False) def __post_init__(self) -> None: if self.feature_names is None: self.feature_names = [f"feature_{i}" for i in range(self.n_features)] @property def cluster_stats_(self) -> pd.DataFrame: if self._cluster_stats is None: self._cluster_stats = self._build_stats() return self._cluster_stats def _build_stats(self) -> pd.DataFrame: rows = [] for node in sorted(self.tree_.nodes): attrs = self.tree_.nodes[node] row: dict = {"cluster": node, "size": attrs.get("size", 0)} median = attrs.get("median", np.full(self.n_features, np.nan)) for fname, val in zip(self.feature_names, median): row[f"median_{fname}"] = float(val) rows.append(row) return pd.DataFrame(rows).set_index("cluster")
[docs] def plot_tree( self, color_by: int | str | None = None, size_by: str = "count", backend: str = "matplotlib", ): """Visualize the SPADE tree. Parameters ---------- color_by: Feature index (int) or name (str) to color nodes by median expression. ``None`` colors all nodes the same. size_by: ``'count'`` scales node size by cell count. Any other value uses uniform size. backend: ``'matplotlib'`` for static plots, ``'plotly'`` for interactive. """ if backend == "matplotlib": from densitree.plot.matplotlib import plot_tree as _plot elif backend == "plotly": try: from densitree.plot.plotly import plot_tree as _plot except ImportError as e: raise ImportError( "plotly is required for backend='plotly'. " "Install it with: pip install plotly" ) from e else: raise ValueError(f"Unknown backend '{backend}'. Use 'matplotlib' or 'plotly'.") return _plot(self, color_by=color_by, size_by=size_by)