Skip to content

Plotting

Matplotlib backend

densitree.plot.matplotlib.plot_tree(result, color_by=None, size_by='count')

Draw the SPADE MST as a static matplotlib figure.

Parameters:

Name Type Description Default
result SPADEResult
required
color_by int | str | None

Feature index or name to color nodes by median expression.

None
size_by str

'count' scales node area by cell count; anything else → uniform.

'count'

Returns:

Type Description
Figure
Source code in densitree/plot/matplotlib.py
def plot_tree(
    result: "SPADEResult",
    color_by: int | str | None = None,
    size_by: str = "count",
) -> plt.Figure:
    """Draw the SPADE MST as a static matplotlib figure.

    Parameters
    ----------
    result : SPADEResult
    color_by : int | str | None
        Feature index or name to color nodes by median expression.
    size_by : str
        ``'count'`` scales node area by cell count; anything else → uniform.

    Returns
    -------
    matplotlib.figure.Figure
    """
    G = result.tree_
    pos = nx.spring_layout(G, seed=42, weight="weight")

    # Node sizes
    if size_by == "count":
        sizes = np.array([G.nodes[n].get("size", 1) for n in G.nodes])
        max_size = sizes.max() or 1
        node_sizes = (sizes / max_size * 800 + 100).tolist()
    else:
        node_sizes = [300] * G.number_of_nodes()

    # Node colors
    feature_idx = _resolve_feature(color_by, result.feature_names)
    if feature_idx is not None:
        values = np.array([
            G.nodes[n]["median"][feature_idx]
            for n in G.nodes
        ])
        norm = plt.Normalize(vmin=values.min(), vmax=values.max())
        node_colors = [cm.viridis(norm(v)) for v in values]
    else:
        node_colors = ["steelblue"] * G.number_of_nodes()

    fig, ax = plt.subplots(figsize=(8, 6))
    nx.draw_networkx(
        G, pos=pos, ax=ax,
        node_size=node_sizes,
        node_color=node_colors,
        edge_color="gray",
        with_labels=True,
        font_size=8,
    )

    if feature_idx is not None:
        sm = cm.ScalarMappable(cmap=cm.viridis, norm=norm)
        sm.set_array([])
        label = result.feature_names[feature_idx] if result.feature_names else str(feature_idx)
        fig.colorbar(sm, ax=ax, label=f"Median {label}")

    ax.set_title("SPADE Tree")
    ax.axis("off")
    return fig

Plotly backend

densitree.plot.plotly.plot_tree(result, color_by=None, size_by='count')

Draw the SPADE MST as an interactive plotly figure.

Parameters:

Name Type Description Default
result SPADEResult
required
color_by int | str | None

Feature index or name to color nodes by median expression.

None
size_by str

'count' scales node size by cell count; anything else → uniform.

'count'

Returns:

Type Description
Figure
Source code in densitree/plot/plotly.py
def plot_tree(
    result: "SPADEResult",
    color_by: int | str | None = None,
    size_by: str = "count",
):
    """Draw the SPADE MST as an interactive plotly figure.

    Parameters
    ----------
    result : SPADEResult
    color_by : int | str | None
        Feature index or name to color nodes by median expression.
    size_by : str
        ``'count'`` scales node size by cell count; anything else → uniform.

    Returns
    -------
    plotly.graph_objects.Figure
    """
    import plotly.graph_objects as go

    G = result.tree_
    pos = nx.spring_layout(G, seed=42, weight="weight")

    feature_idx = _resolve_feature(color_by, result.feature_names)

    # Edge traces
    edge_x, edge_y = [], []
    for u, v in G.edges:
        x0, y0 = pos[u]
        x1, y1 = pos[v]
        edge_x += [x0, x1, None]
        edge_y += [y0, y1, None]

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        mode="lines",
        line=dict(width=1, color="lightgray"),
        hoverinfo="none",
    )

    # Node traces
    node_x = [pos[n][0] for n in G.nodes]
    node_y = [pos[n][1] for n in G.nodes]

    if size_by == "count":
        counts = np.array([G.nodes[n].get("size", 1) for n in G.nodes], dtype=float)
        max_count = counts.max() or 1
        node_sizes = (counts / max_count * 40 + 10).tolist()
    else:
        node_sizes = [20] * G.number_of_nodes()

    if feature_idx is not None:
        node_colors = [G.nodes[n]["median"][feature_idx] for n in G.nodes]
        colorscale = "Viridis"
        label = result.feature_names[feature_idx] if result.feature_names else str(feature_idx)
        colorbar = dict(title=f"Median {label}")
    else:
        node_colors = ["steelblue"] * G.number_of_nodes()
        colorscale = None
        colorbar = None

    hover_text = [
        f"Cluster {n}<br>Size: {G.nodes[n].get('size', '?')}"
        for n in G.nodes
    ]

    marker_kwargs: dict = dict(
        size=node_sizes,
        color=node_colors,
        line=dict(width=1, color="white"),
    )
    if colorscale:
        marker_kwargs["colorscale"] = colorscale
        marker_kwargs["showscale"] = True
        marker_kwargs["colorbar"] = colorbar

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode="markers+text",
        text=[str(n) for n in G.nodes],
        textposition="top center",
        hovertext=hover_text,
        hoverinfo="text",
        marker=marker_kwargs,
    )

    fig = go.Figure(
        data=[edge_trace, node_trace],
        layout=go.Layout(
            title="SPADE Tree",
            showlegend=False,
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            plot_bgcolor="white",
        ),
    )
    return fig