Tutorial: Comparing Conditions¶
A common use case for SPADE is comparing cell populations between experimental conditions (e.g., healthy vs. disease, pre vs. post treatment).
Strategy¶
- Fit SPADE on the combined dataset from all conditions
- Split cluster assignments by condition and compare population sizes
- Visualize differences on the tree
This ensures cells from different conditions are assigned to the same clusters, making comparison meaningful.
Example: Two conditions¶
import numpy as np
import pandas as pd
from densitree import SPADE
# Simulate two conditions with different rare population abundances
rng = np.random.default_rng(0)
# Shared populations
common_a = rng.normal(loc=[0, 0, 5, 3], scale=0.5, size=(2000, 4))
common_b = rng.normal(loc=[4, 4, 1, 6], scale=0.5, size=(2000, 4))
# Rare population: present in disease, nearly absent in healthy
rare_healthy = rng.normal(loc=[2, 8, 2, 2], scale=0.3, size=(20, 4))
rare_disease = rng.normal(loc=[2, 8, 2, 2], scale=0.3, size=(200, 4))
X_healthy = np.vstack([common_a, common_b, rare_healthy])
X_disease = np.vstack([common_a, common_b, rare_disease])
# Combine
X_combined = np.vstack([X_healthy, X_disease])
condition = np.array(
["healthy"] * len(X_healthy) + ["disease"] * len(X_disease)
)
print(f"Combined: {len(X_combined)} cells")
print(f" Healthy: {(condition == 'healthy').sum()}")
print(f" Disease: {(condition == 'disease').sum()}")
Fit SPADE on combined data¶
spade = SPADE(
n_clusters=15,
downsample_target=0.15,
transform=None, # data is already on a reasonable scale
random_state=42,
)
spade.fit(X_combined)
Compare cluster composition¶
labels = spade.labels_
# Count cells per cluster per condition
comparison = []
for cluster_id in range(15):
mask = labels == cluster_id
n_healthy = ((condition == "healthy") & mask).sum()
n_disease = ((condition == "disease") & mask).sum()
total = mask.sum()
# Fold change (disease / healthy), handling zeros
if n_healthy > 0:
fold_change = (n_disease / (condition == "disease").sum()) / \
(n_healthy / (condition == "healthy").sum())
else:
fold_change = float("inf")
comparison.append({
"cluster": cluster_id,
"healthy": n_healthy,
"disease": n_disease,
"total": total,
"fold_change": fold_change,
})
df_comp = pd.DataFrame(comparison).set_index("cluster")
print(df_comp.sort_values("fold_change", ascending=False))
Visualize on the tree¶
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import networkx as nx
tree = spade.result_.tree_
pos = nx.spring_layout(tree, seed=42, weight="weight")
# Color nodes by fold change
fold_changes = df_comp["fold_change"].values
# Cap infinite fold changes for visualization
fc_capped = np.clip(fold_changes, 0.1, 10)
log_fc = np.log2(fc_capped)
norm = plt.Normalize(vmin=-3, vmax=3)
colors = [cm.RdBu_r(norm(v)) for v in log_fc]
sizes = [tree.nodes[n].get("size", 1) for n in tree.nodes]
max_size = max(sizes) or 1
node_sizes = [s / max_size * 800 + 100 for s in sizes]
fig, ax = plt.subplots(figsize=(10, 8))
nx.draw_networkx(
tree, pos=pos, ax=ax,
node_size=node_sizes,
node_color=colors,
edge_color="gray",
with_labels=True,
font_size=8,
)
sm = cm.ScalarMappable(cmap=cm.RdBu_r, norm=norm)
sm.set_array([])
fig.colorbar(sm, ax=ax, label="log2(fold change disease/healthy)")
ax.set_title("SPADE Tree: Disease vs Healthy")
ax.axis("off")
fig.savefig("condition_comparison.png", dpi=150, bbox_inches="tight")
Blue nodes are enriched in healthy, red nodes in disease. The rare population cluster should appear as a bright red node.

Statistical testing (advanced)¶
For rigorous comparison with biological replicates, use a per-cluster test:
from scipy.stats import fisher_exact
for cluster_id in range(15):
mask = labels == cluster_id
n_h = ((condition == "healthy") & mask).sum()
n_d = ((condition == "disease") & mask).sum()
n_h_other = (condition == "healthy").sum() - n_h
n_d_other = (condition == "disease").sum() - n_d
table = [[n_h, n_d], [n_h_other, n_d_other]]
odds_ratio, p_value = fisher_exact(table)
if p_value < 0.05:
direction = "enriched in disease" if odds_ratio < 1 else "enriched in healthy"
print(f"Cluster {cluster_id}: p={p_value:.4f}, OR={odds_ratio:.2f} ({direction})")