summaryrefslogtreecommitdiff
path: root/scripts/generate_plots.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-24 08:40:49 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-24 08:40:49 +0000
commit8f63cf9f41bbdb8d55cd4679872d2b4ae2129324 (patch)
treeab5c95888849e854f2346db856c7edece7c8b8a7 /scripts/generate_plots.py
EC-SBM community detection analysis: full pipeline and writeup
Implement community detection on 3 EC-SBM networks (polblogs, topology, internet_as) using 5 methods (Leiden-Mod, Leiden-CPM at 0.1 and 0.01, Infomap, graph-tool SBM). Compute AMI/ARI/NMI accuracy, cluster statistics, and generate figures and LaTeX report.
Diffstat (limited to 'scripts/generate_plots.py')
-rw-r--r--scripts/generate_plots.py324
1 files changed, 324 insertions, 0 deletions
diff --git a/scripts/generate_plots.py b/scripts/generate_plots.py
new file mode 100644
index 0000000..be5db9f
--- /dev/null
+++ b/scripts/generate_plots.py
@@ -0,0 +1,324 @@
+"""Generate all figures and LaTeX tables for the EC-SBM analysis."""
+
+import sys
+import os
+import json
+sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+
+import numpy as np
+import pandas as pd
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+from matplotlib.colors import Normalize
+from matplotlib.cm import ScalarMappable
+
+from config import NETWORKS, METHODS, RESULTS_DIR
+
+FIGURES_DIR = os.path.join(RESULTS_DIR, "figures")
+os.makedirs(FIGURES_DIR, exist_ok=True)
+
+METHOD_NAMES = [m["name"] for m in METHODS]
+METHOD_LABELS = {
+ "leiden_mod": "Leiden-Mod",
+ "leiden_cpm_01": "Leiden-CPM(0.1)",
+ "leiden_cpm_001": "Leiden-CPM(0.01)",
+ "infomap": "Infomap",
+ "graphtool_sbm": "graph-tool SBM",
+}
+NET_LABELS = {
+ "polblogs": "polblogs",
+ "topology": "topology",
+ "internet_as": "internet\\_as",
+}
+
+
+def plot_accuracy_heatmap():
+ """Create a heatmap of accuracy metrics (network x method)."""
+ acc_path = os.path.join(RESULTS_DIR, "accuracy", "accuracy_table.csv")
+ if not os.path.exists(acc_path):
+ print("No accuracy table found, skipping heatmap")
+ return
+ df = pd.read_csv(acc_path)
+
+ for metric in ["ami", "ari", "nmi"]:
+ fig, ax = plt.subplots(figsize=(8, 3.5))
+ pivot = df.pivot(index="network", columns="method", values=metric)
+ pivot = pivot.reindex(index=list(NETWORKS.keys()), columns=METHOD_NAMES)
+
+ im = ax.imshow(pivot.values, cmap="YlOrRd", aspect="auto",
+ vmin=0, vmax=1)
+ ax.set_xticks(range(len(METHOD_NAMES)))
+ ax.set_xticklabels([METHOD_LABELS.get(m, m) for m in METHOD_NAMES],
+ rotation=30, ha="right", fontsize=9)
+ ax.set_yticks(range(len(NETWORKS)))
+ ax.set_yticklabels(list(NETWORKS.keys()), fontsize=10)
+
+ for i in range(pivot.shape[0]):
+ for j in range(pivot.shape[1]):
+ val = pivot.values[i, j]
+ if not np.isnan(val):
+ ax.text(j, i, f"{val:.3f}", ha="center", va="center",
+ fontsize=9, color="black" if val < 0.6 else "white")
+
+ plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
+ ax.set_title(f"{metric.upper()} Accuracy", fontsize=12)
+ plt.tight_layout()
+ plt.savefig(os.path.join(FIGURES_DIR, f"heatmap_{metric}.pdf"),
+ bbox_inches="tight")
+ plt.close()
+ print(f" Saved heatmap_{metric}.pdf")
+
+
+def plot_cluster_size_distributions():
+ """Histogram of cluster sizes per network/method."""
+ stats_dir = os.path.join(RESULTS_DIR, "stats")
+
+ for net_name in NETWORKS:
+ all_labels = ["ground_truth"] + METHOD_NAMES
+ fig, axes = plt.subplots(2, 3, figsize=(14, 8))
+ axes = axes.flatten()
+
+ for idx, label in enumerate(all_labels):
+ ax = axes[idx]
+ det_path = os.path.join(stats_dir, net_name, label, "cluster_details.json")
+ if not os.path.exists(det_path):
+ ax.set_title(METHOD_LABELS.get(label, label))
+ ax.text(0.5, 0.5, "No data", ha="center", va="center",
+ transform=ax.transAxes)
+ continue
+
+ with open(det_path) as f:
+ details = json.load(f)
+ sizes = [d["n"] for d in details]
+
+ if sizes:
+ ax.hist(sizes, bins=min(50, max(10, len(set(sizes)))),
+ edgecolor="black", alpha=0.7, color="steelblue")
+ ax.set_title(METHOD_LABELS.get(label, label), fontsize=10)
+ ax.set_xlabel("Cluster size")
+ ax.set_ylabel("Count")
+ if sizes and max(sizes) > 100:
+ ax.set_xscale("log")
+
+ # Remove extra subplot if any
+ for idx in range(len(all_labels), len(axes)):
+ fig.delaxes(axes[idx])
+
+ fig.suptitle(f"Cluster Size Distribution — {net_name}", fontsize=13)
+ plt.tight_layout()
+ plt.savefig(os.path.join(FIGURES_DIR, f"cluster_sizes_{net_name}.pdf"),
+ bbox_inches="tight")
+ plt.close()
+ print(f" Saved cluster_sizes_{net_name}.pdf")
+
+
+def plot_edge_density_boxplots():
+ """Boxplots of edge density across methods for each network."""
+ stats_dir = os.path.join(RESULTS_DIR, "stats")
+
+ for net_name in NETWORKS:
+ all_labels = ["ground_truth"] + METHOD_NAMES
+ data = []
+ labels = []
+
+ for label in all_labels:
+ det_path = os.path.join(stats_dir, net_name, label, "cluster_details.json")
+ if not os.path.exists(det_path):
+ continue
+ with open(det_path) as f:
+ details = json.load(f)
+ densities = [d["edge_density"] for d in details]
+ if densities:
+ data.append(densities)
+ labels.append(METHOD_LABELS.get(label, label))
+
+ if not data:
+ continue
+
+ fig, ax = plt.subplots(figsize=(9, 4))
+ bp = ax.boxplot(data, tick_labels=labels, patch_artist=True, showfliers=False)
+ for patch in bp["boxes"]:
+ patch.set_facecolor("lightblue")
+ ax.set_ylabel("Edge Density")
+ ax.set_title(f"Edge Density Distribution — {net_name}")
+ plt.xticks(rotation=20, ha="right")
+ plt.tight_layout()
+ plt.savefig(os.path.join(FIGURES_DIR, f"edge_density_{net_name}.pdf"),
+ bbox_inches="tight")
+ plt.close()
+ print(f" Saved edge_density_{net_name}.pdf")
+
+
+def plot_mixing_parameter_comparison():
+ """Bar chart of mean mixing parameter per method/network."""
+ stats_path = os.path.join(RESULTS_DIR, "stats", "cluster_stats_summary.csv")
+ if not os.path.exists(stats_path):
+ print("No stats summary found, skipping mixing param plot")
+ return
+
+ df = pd.read_csv(stats_path)
+
+ fig, ax = plt.subplots(figsize=(10, 4.5))
+ net_names = list(NETWORKS.keys())
+ all_methods = ["ground_truth"] + METHOD_NAMES
+ x = np.arange(len(net_names))
+ width = 0.13
+ offsets = np.arange(len(all_methods)) - len(all_methods) / 2 + 0.5
+
+ colors = plt.cm.Set2(np.linspace(0, 1, len(all_methods)))
+
+ for i, method in enumerate(all_methods):
+ vals = []
+ for net in net_names:
+ row = df[(df["network"] == net) & (df["method"] == method)]
+ vals.append(row["mean_mixing_param"].values[0] if len(row) > 0 else 0)
+ ax.bar(x + offsets[i] * width, vals, width, label=METHOD_LABELS.get(method, method),
+ color=colors[i])
+
+ ax.set_xticks(x)
+ ax.set_xticklabels(net_names)
+ ax.set_ylabel("Mean Mixing Parameter")
+ ax.set_title("Mean Mixing Parameter by Network and Method")
+ ax.legend(fontsize=7, ncol=2)
+ plt.tight_layout()
+ plt.savefig(os.path.join(FIGURES_DIR, "mixing_parameter.pdf"), bbox_inches="tight")
+ plt.close()
+ print(" Saved mixing_parameter.pdf")
+
+
+def plot_node_coverage_comparison():
+ """Bar chart of node coverage per method/network."""
+ stats_path = os.path.join(RESULTS_DIR, "stats", "cluster_stats_summary.csv")
+ if not os.path.exists(stats_path):
+ return
+
+ df = pd.read_csv(stats_path)
+
+ fig, ax = plt.subplots(figsize=(10, 4.5))
+ net_names = list(NETWORKS.keys())
+ all_methods = ["ground_truth"] + METHOD_NAMES
+ x = np.arange(len(net_names))
+ width = 0.13
+ offsets = np.arange(len(all_methods)) - len(all_methods) / 2 + 0.5
+ colors = plt.cm.Set2(np.linspace(0, 1, len(all_methods)))
+
+ for i, method in enumerate(all_methods):
+ vals = []
+ for net in net_names:
+ row = df[(df["network"] == net) & (df["method"] == method)]
+ vals.append(row["node_coverage"].values[0] if len(row) > 0 else 0)
+ ax.bar(x + offsets[i] * width, vals, width, label=METHOD_LABELS.get(method, method),
+ color=colors[i])
+
+ ax.set_xticks(x)
+ ax.set_xticklabels(net_names)
+ ax.set_ylabel("Node Coverage")
+ ax.set_title("Node Coverage by Network and Method")
+ ax.legend(fontsize=7, ncol=2)
+ ax.set_ylim(0, 1.05)
+ plt.tight_layout()
+ plt.savefig(os.path.join(FIGURES_DIR, "node_coverage.pdf"), bbox_inches="tight")
+ plt.close()
+ print(" Saved node_coverage.pdf")
+
+
+def generate_latex_accuracy_table():
+ """Generate a LaTeX accuracy table."""
+ acc_path = os.path.join(RESULTS_DIR, "accuracy", "accuracy_table.csv")
+ if not os.path.exists(acc_path):
+ return
+
+ df = pd.read_csv(acc_path)
+ lines = []
+ lines.append(r"\begin{table}[htbp]")
+ lines.append(r"\centering")
+ lines.append(r"\caption{Community detection accuracy (AMI, ARI, NMI) on EC-SBM networks.}")
+ lines.append(r"\label{tab:accuracy}")
+ lines.append(r"\begin{tabular}{llrrr}")
+ lines.append(r"\toprule")
+ lines.append(r"Network & Method & AMI & ARI & NMI \\")
+ lines.append(r"\midrule")
+
+ for net_name in NETWORKS:
+ first = True
+ for _, row in df[df["network"] == net_name].iterrows():
+ net_disp = net_name if first else ""
+ m_label = METHOD_LABELS.get(row["method"], row["method"])
+ lines.append(
+ f"{net_disp} & {m_label} & {row['ami']:.4f} & {row['ari']:.4f} & {row['nmi']:.4f} \\\\"
+ )
+ first = False
+ lines.append(r"\midrule")
+
+ lines[-1] = r"\bottomrule"
+ lines.append(r"\end{tabular}")
+ lines.append(r"\end{table}")
+
+ out_path = os.path.join(FIGURES_DIR, "accuracy_table.tex")
+ with open(out_path, "w") as f:
+ f.write("\n".join(lines))
+ print(f" Saved accuracy_table.tex")
+
+
+def generate_latex_stats_table():
+ """Generate a LaTeX cluster stats table."""
+ stats_path = os.path.join(RESULTS_DIR, "stats", "cluster_stats_summary.csv")
+ if not os.path.exists(stats_path):
+ return
+
+ df = pd.read_csv(stats_path)
+ lines = []
+ lines.append(r"\begin{table}[htbp]")
+ lines.append(r"\centering")
+ lines.append(r"\caption{Cluster statistics summary for each network and method.}")
+ lines.append(r"\label{tab:cluster_stats}")
+ lines.append(r"\footnotesize")
+ lines.append(r"\begin{tabular}{llrrrrrr}")
+ lines.append(r"\toprule")
+ lines.append(r"Network & Method & \#Clusters & Node Cov. & Mean Size & Mean Density & Mean Cond. & Mean Mix. \\")
+ lines.append(r"\midrule")
+
+ for net_name in NETWORKS:
+ first = True
+ for _, row in df[df["network"] == net_name].iterrows():
+ net_disp = net_name if first else ""
+ m_label = METHOD_LABELS.get(row["method"], row["method"])
+ lines.append(
+ f"{net_disp} & {m_label} & {int(row['n_clusters_non_singleton'])} & "
+ f"{row['node_coverage']:.3f} & {row['mean_cluster_size']:.1f} & "
+ f"{row['mean_edge_density']:.3f} & {row['mean_conductance']:.3f} & "
+ f"{row['mean_mixing_param']:.3f} \\\\"
+ )
+ first = False
+ lines.append(r"\midrule")
+
+ lines[-1] = r"\bottomrule"
+ lines.append(r"\end{tabular}")
+ lines.append(r"\end{table}")
+
+ out_path = os.path.join(FIGURES_DIR, "cluster_stats_table.tex")
+ with open(out_path, "w") as f:
+ f.write("\n".join(lines))
+ print(f" Saved cluster_stats_table.tex")
+
+
+def generate_all():
+ print("Generating accuracy heatmaps...")
+ plot_accuracy_heatmap()
+ print("Generating cluster size distributions...")
+ plot_cluster_size_distributions()
+ print("Generating edge density boxplots...")
+ plot_edge_density_boxplots()
+ print("Generating mixing parameter comparison...")
+ plot_mixing_parameter_comparison()
+ print("Generating node coverage comparison...")
+ plot_node_coverage_comparison()
+ print("Generating LaTeX tables...")
+ generate_latex_accuracy_table()
+ generate_latex_stats_table()
+ print("All plots and tables generated.")
+
+
+if __name__ == "__main__":
+ generate_all()