Source code for densitree.plot.matplotlib
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.cm as cm
if TYPE_CHECKING:
from densitree.result import SPADEResult
[docs]
def plot_tree(
result: "SPADEResult",
color_by: int | str | None = None,
size_by: str = "count",
) -> plt.Figure:
"""Draw the SPADE MST as a static matplotlib figure.
Parameters
----------
result : SPADEResult
color_by : int | str | None
Feature index or name to color nodes by median expression.
size_by : str
``'count'`` scales node area by cell count; anything else → uniform.
Returns
-------
matplotlib.figure.Figure
"""
G = result.tree_
pos = nx.spring_layout(G, seed=42, weight="weight")
# Node sizes
if size_by == "count":
sizes = np.array([G.nodes[n].get("size", 1) for n in G.nodes])
max_size = sizes.max() or 1
node_sizes = (sizes / max_size * 800 + 100).tolist()
else:
node_sizes = [300] * G.number_of_nodes()
# Node colors
feature_idx = _resolve_feature(color_by, result.feature_names)
if feature_idx is not None:
values = np.array([
G.nodes[n]["median"][feature_idx]
for n in G.nodes
])
norm = plt.Normalize(vmin=values.min(), vmax=values.max())
node_colors = [cm.viridis(norm(v)) for v in values]
else:
node_colors = ["steelblue"] * G.number_of_nodes()
fig, ax = plt.subplots(figsize=(8, 6))
nx.draw_networkx(
G, pos=pos, ax=ax,
node_size=node_sizes,
node_color=node_colors,
edge_color="gray",
with_labels=True,
font_size=8,
)
if feature_idx is not None:
sm = cm.ScalarMappable(cmap=cm.viridis, norm=norm)
sm.set_array([])
label = result.feature_names[feature_idx] if result.feature_names else str(feature_idx)
fig.colorbar(sm, ax=ax, label=f"Median {label}")
ax.set_title("SPADE Tree")
ax.axis("off")
return fig
def _resolve_feature(color_by: int | str | None, feature_names: list[str] | None) -> int | None:
if color_by is None:
return None
if isinstance(color_by, int):
return color_by
if feature_names and color_by in feature_names:
return feature_names.index(color_by)
raise ValueError(f"Feature '{color_by}' not found in feature_names={feature_names}")