diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-24 08:40:49 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-24 08:40:49 +0000 |
| commit | 8f63cf9f41bbdb8d55cd4679872d2b4ae2129324 (patch) | |
| tree | ab5c95888849e854f2346db856c7edece7c8b8a7 /scripts/generate_plots.py | |
EC-SBM community detection analysis: full pipeline and writeup
Implement community detection on 3 EC-SBM networks (polblogs, topology,
internet_as) using 5 methods (Leiden-Mod, Leiden-CPM at 0.1 and 0.01,
Infomap, graph-tool SBM). Compute AMI/ARI/NMI accuracy, cluster statistics,
and generate figures and LaTeX report.
Diffstat (limited to 'scripts/generate_plots.py')
| -rw-r--r-- | scripts/generate_plots.py | 324 |
1 files changed, 324 insertions, 0 deletions
diff --git a/scripts/generate_plots.py b/scripts/generate_plots.py new file mode 100644 index 0000000..be5db9f --- /dev/null +++ b/scripts/generate_plots.py @@ -0,0 +1,324 @@ +"""Generate all figures and LaTeX tables for the EC-SBM analysis.""" + +import sys +import os +import json +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +import numpy as np +import pandas as pd +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from matplotlib.colors import Normalize +from matplotlib.cm import ScalarMappable + +from config import NETWORKS, METHODS, RESULTS_DIR + +FIGURES_DIR = os.path.join(RESULTS_DIR, "figures") +os.makedirs(FIGURES_DIR, exist_ok=True) + +METHOD_NAMES = [m["name"] for m in METHODS] +METHOD_LABELS = { + "leiden_mod": "Leiden-Mod", + "leiden_cpm_01": "Leiden-CPM(0.1)", + "leiden_cpm_001": "Leiden-CPM(0.01)", + "infomap": "Infomap", + "graphtool_sbm": "graph-tool SBM", +} +NET_LABELS = { + "polblogs": "polblogs", + "topology": "topology", + "internet_as": "internet\\_as", +} + + +def plot_accuracy_heatmap(): + """Create a heatmap of accuracy metrics (network x method).""" + acc_path = os.path.join(RESULTS_DIR, "accuracy", "accuracy_table.csv") + if not os.path.exists(acc_path): + print("No accuracy table found, skipping heatmap") + return + df = pd.read_csv(acc_path) + + for metric in ["ami", "ari", "nmi"]: + fig, ax = plt.subplots(figsize=(8, 3.5)) + pivot = df.pivot(index="network", columns="method", values=metric) + pivot = pivot.reindex(index=list(NETWORKS.keys()), columns=METHOD_NAMES) + + im = ax.imshow(pivot.values, cmap="YlOrRd", aspect="auto", + vmin=0, vmax=1) + ax.set_xticks(range(len(METHOD_NAMES))) + ax.set_xticklabels([METHOD_LABELS.get(m, m) for m in METHOD_NAMES], + rotation=30, ha="right", fontsize=9) + ax.set_yticks(range(len(NETWORKS))) + ax.set_yticklabels(list(NETWORKS.keys()), fontsize=10) + + for i in range(pivot.shape[0]): + for j in range(pivot.shape[1]): + val = pivot.values[i, j] + if not np.isnan(val): + ax.text(j, i, f"{val:.3f}", ha="center", va="center", + fontsize=9, color="black" if val < 0.6 else "white") + + plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + ax.set_title(f"{metric.upper()} Accuracy", fontsize=12) + plt.tight_layout() + plt.savefig(os.path.join(FIGURES_DIR, f"heatmap_{metric}.pdf"), + bbox_inches="tight") + plt.close() + print(f" Saved heatmap_{metric}.pdf") + + +def plot_cluster_size_distributions(): + """Histogram of cluster sizes per network/method.""" + stats_dir = os.path.join(RESULTS_DIR, "stats") + + for net_name in NETWORKS: + all_labels = ["ground_truth"] + METHOD_NAMES + fig, axes = plt.subplots(2, 3, figsize=(14, 8)) + axes = axes.flatten() + + for idx, label in enumerate(all_labels): + ax = axes[idx] + det_path = os.path.join(stats_dir, net_name, label, "cluster_details.json") + if not os.path.exists(det_path): + ax.set_title(METHOD_LABELS.get(label, label)) + ax.text(0.5, 0.5, "No data", ha="center", va="center", + transform=ax.transAxes) + continue + + with open(det_path) as f: + details = json.load(f) + sizes = [d["n"] for d in details] + + if sizes: + ax.hist(sizes, bins=min(50, max(10, len(set(sizes)))), + edgecolor="black", alpha=0.7, color="steelblue") + ax.set_title(METHOD_LABELS.get(label, label), fontsize=10) + ax.set_xlabel("Cluster size") + ax.set_ylabel("Count") + if sizes and max(sizes) > 100: + ax.set_xscale("log") + + # Remove extra subplot if any + for idx in range(len(all_labels), len(axes)): + fig.delaxes(axes[idx]) + + fig.suptitle(f"Cluster Size Distribution — {net_name}", fontsize=13) + plt.tight_layout() + plt.savefig(os.path.join(FIGURES_DIR, f"cluster_sizes_{net_name}.pdf"), + bbox_inches="tight") + plt.close() + print(f" Saved cluster_sizes_{net_name}.pdf") + + +def plot_edge_density_boxplots(): + """Boxplots of edge density across methods for each network.""" + stats_dir = os.path.join(RESULTS_DIR, "stats") + + for net_name in NETWORKS: + all_labels = ["ground_truth"] + METHOD_NAMES + data = [] + labels = [] + + for label in all_labels: + det_path = os.path.join(stats_dir, net_name, label, "cluster_details.json") + if not os.path.exists(det_path): + continue + with open(det_path) as f: + details = json.load(f) + densities = [d["edge_density"] for d in details] + if densities: + data.append(densities) + labels.append(METHOD_LABELS.get(label, label)) + + if not data: + continue + + fig, ax = plt.subplots(figsize=(9, 4)) + bp = ax.boxplot(data, tick_labels=labels, patch_artist=True, showfliers=False) + for patch in bp["boxes"]: + patch.set_facecolor("lightblue") + ax.set_ylabel("Edge Density") + ax.set_title(f"Edge Density Distribution — {net_name}") + plt.xticks(rotation=20, ha="right") + plt.tight_layout() + plt.savefig(os.path.join(FIGURES_DIR, f"edge_density_{net_name}.pdf"), + bbox_inches="tight") + plt.close() + print(f" Saved edge_density_{net_name}.pdf") + + +def plot_mixing_parameter_comparison(): + """Bar chart of mean mixing parameter per method/network.""" + stats_path = os.path.join(RESULTS_DIR, "stats", "cluster_stats_summary.csv") + if not os.path.exists(stats_path): + print("No stats summary found, skipping mixing param plot") + return + + df = pd.read_csv(stats_path) + + fig, ax = plt.subplots(figsize=(10, 4.5)) + net_names = list(NETWORKS.keys()) + all_methods = ["ground_truth"] + METHOD_NAMES + x = np.arange(len(net_names)) + width = 0.13 + offsets = np.arange(len(all_methods)) - len(all_methods) / 2 + 0.5 + + colors = plt.cm.Set2(np.linspace(0, 1, len(all_methods))) + + for i, method in enumerate(all_methods): + vals = [] + for net in net_names: + row = df[(df["network"] == net) & (df["method"] == method)] + vals.append(row["mean_mixing_param"].values[0] if len(row) > 0 else 0) + ax.bar(x + offsets[i] * width, vals, width, label=METHOD_LABELS.get(method, method), + color=colors[i]) + + ax.set_xticks(x) + ax.set_xticklabels(net_names) + ax.set_ylabel("Mean Mixing Parameter") + ax.set_title("Mean Mixing Parameter by Network and Method") + ax.legend(fontsize=7, ncol=2) + plt.tight_layout() + plt.savefig(os.path.join(FIGURES_DIR, "mixing_parameter.pdf"), bbox_inches="tight") + plt.close() + print(" Saved mixing_parameter.pdf") + + +def plot_node_coverage_comparison(): + """Bar chart of node coverage per method/network.""" + stats_path = os.path.join(RESULTS_DIR, "stats", "cluster_stats_summary.csv") + if not os.path.exists(stats_path): + return + + df = pd.read_csv(stats_path) + + fig, ax = plt.subplots(figsize=(10, 4.5)) + net_names = list(NETWORKS.keys()) + all_methods = ["ground_truth"] + METHOD_NAMES + x = np.arange(len(net_names)) + width = 0.13 + offsets = np.arange(len(all_methods)) - len(all_methods) / 2 + 0.5 + colors = plt.cm.Set2(np.linspace(0, 1, len(all_methods))) + + for i, method in enumerate(all_methods): + vals = [] + for net in net_names: + row = df[(df["network"] == net) & (df["method"] == method)] + vals.append(row["node_coverage"].values[0] if len(row) > 0 else 0) + ax.bar(x + offsets[i] * width, vals, width, label=METHOD_LABELS.get(method, method), + color=colors[i]) + + ax.set_xticks(x) + ax.set_xticklabels(net_names) + ax.set_ylabel("Node Coverage") + ax.set_title("Node Coverage by Network and Method") + ax.legend(fontsize=7, ncol=2) + ax.set_ylim(0, 1.05) + plt.tight_layout() + plt.savefig(os.path.join(FIGURES_DIR, "node_coverage.pdf"), bbox_inches="tight") + plt.close() + print(" Saved node_coverage.pdf") + + +def generate_latex_accuracy_table(): + """Generate a LaTeX accuracy table.""" + acc_path = os.path.join(RESULTS_DIR, "accuracy", "accuracy_table.csv") + if not os.path.exists(acc_path): + return + + df = pd.read_csv(acc_path) + lines = [] + lines.append(r"\begin{table}[htbp]") + lines.append(r"\centering") + lines.append(r"\caption{Community detection accuracy (AMI, ARI, NMI) on EC-SBM networks.}") + lines.append(r"\label{tab:accuracy}") + lines.append(r"\begin{tabular}{llrrr}") + lines.append(r"\toprule") + lines.append(r"Network & Method & AMI & ARI & NMI \\") + lines.append(r"\midrule") + + for net_name in NETWORKS: + first = True + for _, row in df[df["network"] == net_name].iterrows(): + net_disp = net_name if first else "" + m_label = METHOD_LABELS.get(row["method"], row["method"]) + lines.append( + f"{net_disp} & {m_label} & {row['ami']:.4f} & {row['ari']:.4f} & {row['nmi']:.4f} \\\\" + ) + first = False + lines.append(r"\midrule") + + lines[-1] = r"\bottomrule" + lines.append(r"\end{tabular}") + lines.append(r"\end{table}") + + out_path = os.path.join(FIGURES_DIR, "accuracy_table.tex") + with open(out_path, "w") as f: + f.write("\n".join(lines)) + print(f" Saved accuracy_table.tex") + + +def generate_latex_stats_table(): + """Generate a LaTeX cluster stats table.""" + stats_path = os.path.join(RESULTS_DIR, "stats", "cluster_stats_summary.csv") + if not os.path.exists(stats_path): + return + + df = pd.read_csv(stats_path) + lines = [] + lines.append(r"\begin{table}[htbp]") + lines.append(r"\centering") + lines.append(r"\caption{Cluster statistics summary for each network and method.}") + lines.append(r"\label{tab:cluster_stats}") + lines.append(r"\footnotesize") + lines.append(r"\begin{tabular}{llrrrrrr}") + lines.append(r"\toprule") + lines.append(r"Network & Method & \#Clusters & Node Cov. & Mean Size & Mean Density & Mean Cond. & Mean Mix. \\") + lines.append(r"\midrule") + + for net_name in NETWORKS: + first = True + for _, row in df[df["network"] == net_name].iterrows(): + net_disp = net_name if first else "" + m_label = METHOD_LABELS.get(row["method"], row["method"]) + lines.append( + f"{net_disp} & {m_label} & {int(row['n_clusters_non_singleton'])} & " + f"{row['node_coverage']:.3f} & {row['mean_cluster_size']:.1f} & " + f"{row['mean_edge_density']:.3f} & {row['mean_conductance']:.3f} & " + f"{row['mean_mixing_param']:.3f} \\\\" + ) + first = False + lines.append(r"\midrule") + + lines[-1] = r"\bottomrule" + lines.append(r"\end{tabular}") + lines.append(r"\end{table}") + + out_path = os.path.join(FIGURES_DIR, "cluster_stats_table.tex") + with open(out_path, "w") as f: + f.write("\n".join(lines)) + print(f" Saved cluster_stats_table.tex") + + +def generate_all(): + print("Generating accuracy heatmaps...") + plot_accuracy_heatmap() + print("Generating cluster size distributions...") + plot_cluster_size_distributions() + print("Generating edge density boxplots...") + plot_edge_density_boxplots() + print("Generating mixing parameter comparison...") + plot_mixing_parameter_comparison() + print("Generating node coverage comparison...") + plot_node_coverage_comparison() + print("Generating LaTeX tables...") + generate_latex_accuracy_table() + generate_latex_stats_table() + print("All plots and tables generated.") + + +if __name__ == "__main__": + generate_all() |
