"""Generate all figures and LaTeX tables for the EC-SBM analysis.""" import sys import os import json sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import numpy as np import pandas as pd import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.colors import Normalize from matplotlib.cm import ScalarMappable from config import NETWORKS, METHODS, RESULTS_DIR FIGURES_DIR = os.path.join(RESULTS_DIR, "figures") os.makedirs(FIGURES_DIR, exist_ok=True) METHOD_NAMES = [m["name"] for m in METHODS] METHOD_LABELS = { "leiden_mod": "Leiden-Mod", "leiden_cpm_01": "Leiden-CPM(0.1)", "leiden_cpm_001": "Leiden-CPM(0.01)", "infomap": "Infomap", "graphtool_sbm": "graph-tool SBM", } NET_LABELS = { "polblogs": "polblogs", "topology": "topology", "internet_as": "internet\\_as", } def plot_accuracy_heatmap(): """Create a heatmap of accuracy metrics (network x method).""" acc_path = os.path.join(RESULTS_DIR, "accuracy", "accuracy_table.csv") if not os.path.exists(acc_path): print("No accuracy table found, skipping heatmap") return df = pd.read_csv(acc_path) for metric in ["ami", "ari", "nmi"]: fig, ax = plt.subplots(figsize=(8, 3.5)) pivot = df.pivot(index="network", columns="method", values=metric) pivot = pivot.reindex(index=list(NETWORKS.keys()), columns=METHOD_NAMES) im = ax.imshow(pivot.values, cmap="YlOrRd", aspect="auto", vmin=0, vmax=1) ax.set_xticks(range(len(METHOD_NAMES))) ax.set_xticklabels([METHOD_LABELS.get(m, m) for m in METHOD_NAMES], rotation=30, ha="right", fontsize=9) ax.set_yticks(range(len(NETWORKS))) ax.set_yticklabels(list(NETWORKS.keys()), fontsize=10) for i in range(pivot.shape[0]): for j in range(pivot.shape[1]): val = pivot.values[i, j] if not np.isnan(val): ax.text(j, i, f"{val:.3f}", ha="center", va="center", fontsize=9, color="black" if val < 0.6 else "white") plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) ax.set_title(f"{metric.upper()} Accuracy", fontsize=12) plt.tight_layout() plt.savefig(os.path.join(FIGURES_DIR, f"heatmap_{metric}.pdf"), bbox_inches="tight") plt.close() print(f" Saved heatmap_{metric}.pdf") def plot_cluster_size_distributions(): """Histogram of cluster sizes per network/method.""" stats_dir = os.path.join(RESULTS_DIR, "stats") for net_name in NETWORKS: all_labels = ["ground_truth"] + METHOD_NAMES fig, axes = plt.subplots(2, 3, figsize=(14, 8)) axes = axes.flatten() for idx, label in enumerate(all_labels): ax = axes[idx] det_path = os.path.join(stats_dir, net_name, label, "cluster_details.json") if not os.path.exists(det_path): ax.set_title(METHOD_LABELS.get(label, label)) ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes) continue with open(det_path) as f: details = json.load(f) sizes = [d["n"] for d in details] if sizes: ax.hist(sizes, bins=min(50, max(10, len(set(sizes)))), edgecolor="black", alpha=0.7, color="steelblue") ax.set_title(METHOD_LABELS.get(label, label), fontsize=10) ax.set_xlabel("Cluster size") ax.set_ylabel("Count") if sizes and max(sizes) > 100: ax.set_xscale("log") # Remove extra subplot if any for idx in range(len(all_labels), len(axes)): fig.delaxes(axes[idx]) fig.suptitle(f"Cluster Size Distribution — {net_name}", fontsize=13) plt.tight_layout() plt.savefig(os.path.join(FIGURES_DIR, f"cluster_sizes_{net_name}.pdf"), bbox_inches="tight") plt.close() print(f" Saved cluster_sizes_{net_name}.pdf") def plot_edge_density_boxplots(): """Boxplots of edge density across methods for each network.""" stats_dir = os.path.join(RESULTS_DIR, "stats") for net_name in NETWORKS: all_labels = ["ground_truth"] + METHOD_NAMES data = [] labels = [] for label in all_labels: det_path = os.path.join(stats_dir, net_name, label, "cluster_details.json") if not os.path.exists(det_path): continue with open(det_path) as f: details = json.load(f) densities = [d["edge_density"] for d in details] if densities: data.append(densities) labels.append(METHOD_LABELS.get(label, label)) if not data: continue fig, ax = plt.subplots(figsize=(9, 4)) bp = ax.boxplot(data, tick_labels=labels, patch_artist=True, showfliers=False) for patch in bp["boxes"]: patch.set_facecolor("lightblue") ax.set_ylabel("Edge Density") ax.set_title(f"Edge Density Distribution — {net_name}") plt.xticks(rotation=20, ha="right") plt.tight_layout() plt.savefig(os.path.join(FIGURES_DIR, f"edge_density_{net_name}.pdf"), bbox_inches="tight") plt.close() print(f" Saved edge_density_{net_name}.pdf") def plot_mixing_parameter_comparison(): """Bar chart of mean mixing parameter per method/network.""" stats_path = os.path.join(RESULTS_DIR, "stats", "cluster_stats_summary.csv") if not os.path.exists(stats_path): print("No stats summary found, skipping mixing param plot") return df = pd.read_csv(stats_path) fig, ax = plt.subplots(figsize=(10, 4.5)) net_names = list(NETWORKS.keys()) all_methods = ["ground_truth"] + METHOD_NAMES x = np.arange(len(net_names)) width = 0.13 offsets = np.arange(len(all_methods)) - len(all_methods) / 2 + 0.5 colors = plt.cm.Set2(np.linspace(0, 1, len(all_methods))) for i, method in enumerate(all_methods): vals = [] for net in net_names: row = df[(df["network"] == net) & (df["method"] == method)] vals.append(row["mean_mixing_param"].values[0] if len(row) > 0 else 0) ax.bar(x + offsets[i] * width, vals, width, label=METHOD_LABELS.get(method, method), color=colors[i]) ax.set_xticks(x) ax.set_xticklabels(net_names) ax.set_ylabel("Mean Mixing Parameter") ax.set_title("Mean Mixing Parameter by Network and Method") ax.legend(fontsize=7, ncol=2) plt.tight_layout() plt.savefig(os.path.join(FIGURES_DIR, "mixing_parameter.pdf"), bbox_inches="tight") plt.close() print(" Saved mixing_parameter.pdf") def plot_node_coverage_comparison(): """Bar chart of node coverage per method/network.""" stats_path = os.path.join(RESULTS_DIR, "stats", "cluster_stats_summary.csv") if not os.path.exists(stats_path): return df = pd.read_csv(stats_path) fig, ax = plt.subplots(figsize=(10, 4.5)) net_names = list(NETWORKS.keys()) all_methods = ["ground_truth"] + METHOD_NAMES x = np.arange(len(net_names)) width = 0.13 offsets = np.arange(len(all_methods)) - len(all_methods) / 2 + 0.5 colors = plt.cm.Set2(np.linspace(0, 1, len(all_methods))) for i, method in enumerate(all_methods): vals = [] for net in net_names: row = df[(df["network"] == net) & (df["method"] == method)] vals.append(row["node_coverage"].values[0] if len(row) > 0 else 0) ax.bar(x + offsets[i] * width, vals, width, label=METHOD_LABELS.get(method, method), color=colors[i]) ax.set_xticks(x) ax.set_xticklabels(net_names) ax.set_ylabel("Node Coverage") ax.set_title("Node Coverage by Network and Method") ax.legend(fontsize=7, ncol=2) ax.set_ylim(0, 1.05) plt.tight_layout() plt.savefig(os.path.join(FIGURES_DIR, "node_coverage.pdf"), bbox_inches="tight") plt.close() print(" Saved node_coverage.pdf") def plot_edge_connectivity_boxplots(): """Boxplots of mincut/log10(n) across methods for each network.""" stats_dir = os.path.join(RESULTS_DIR, "stats") for net_name in NETWORKS: all_labels = ["ground_truth"] + METHOD_NAMES data = [] labels = [] for label in all_labels: det_path = os.path.join(stats_dir, net_name, label, "cluster_details.json") if not os.path.exists(det_path): continue with open(det_path) as f: details = json.load(f) vals = [d["mincut_over_log10n"] for d in details if "mincut_over_log10n" in d] if vals: data.append(vals) labels.append(METHOD_LABELS.get(label, label)) if not data: continue fig, ax = plt.subplots(figsize=(9, 4)) bp = ax.boxplot(data, tick_labels=labels, patch_artist=True, showfliers=False) for patch in bp["boxes"]: patch.set_facecolor("lightyellow") ax.axhline(y=1.0, color="red", linestyle="--", linewidth=0.8, label="well-connected threshold") ax.set_ylabel("Min Edge Cut / log$_{10}$(n)") ax.set_title(f"Edge Connectivity — {net_name}") ax.legend(fontsize=8) plt.xticks(rotation=20, ha="right") plt.tight_layout() plt.savefig(os.path.join(FIGURES_DIR, f"edge_connectivity_{net_name}.pdf"), bbox_inches="tight") plt.close() print(f" Saved edge_connectivity_{net_name}.pdf") def plot_wellconnected_bar(): """Bar chart of fraction well-connected clusters per method/network.""" stats_path = os.path.join(RESULTS_DIR, "stats", "cluster_stats_summary.csv") if not os.path.exists(stats_path): return df = pd.read_csv(stats_path) fig, ax = plt.subplots(figsize=(10, 4.5)) net_names = list(NETWORKS.keys()) all_methods = ["ground_truth"] + METHOD_NAMES x = np.arange(len(net_names)) width = 0.13 offsets = np.arange(len(all_methods)) - len(all_methods) / 2 + 0.5 colors = plt.cm.Set2(np.linspace(0, 1, len(all_methods))) for i, method in enumerate(all_methods): vals = [] for net in net_names: row = df[(df["network"] == net) & (df["method"] == method)] if len(row) > 0: nc = row["n_clusters_non_singleton"].values[0] nwc = row["n_wellconnected"].values[0] vals.append(nwc / nc if nc > 0 else 0) else: vals.append(0) ax.bar(x + offsets[i] * width, vals, width, label=METHOD_LABELS.get(method, method), color=colors[i]) ax.set_xticks(x) ax.set_xticklabels(net_names) ax.set_ylabel("Fraction Well-Connected") ax.set_title("Fraction of Well-Connected Clusters (mincut > log$_{10}$(n))") ax.legend(fontsize=7, ncol=2) ax.set_ylim(0, 1.05) plt.tight_layout() plt.savefig(os.path.join(FIGURES_DIR, "wellconnected.pdf"), bbox_inches="tight") plt.close() print(" Saved wellconnected.pdf") def generate_latex_accuracy_table(): """Generate a LaTeX accuracy table.""" acc_path = os.path.join(RESULTS_DIR, "accuracy", "accuracy_table.csv") if not os.path.exists(acc_path): return df = pd.read_csv(acc_path) lines = [] lines.append(r"\begin{table}[htbp]") lines.append(r"\centering") lines.append(r"\caption{Community detection accuracy (AMI, ARI, NMI) on EC-SBM networks.}") lines.append(r"\label{tab:accuracy}") lines.append(r"\begin{tabular}{llrrr}") lines.append(r"\toprule") lines.append(r"Network & Method & AMI & ARI & NMI \\") lines.append(r"\midrule") for net_name in NETWORKS: first = True for _, row in df[df["network"] == net_name].iterrows(): net_disp = net_name if first else "" m_label = METHOD_LABELS.get(row["method"], row["method"]) lines.append( f"{net_disp} & {m_label} & {row['ami']:.4f} & {row['ari']:.4f} & {row['nmi']:.4f} \\\\" ) first = False lines.append(r"\midrule") lines[-1] = r"\bottomrule" lines.append(r"\end{tabular}") lines.append(r"\end{table}") out_path = os.path.join(FIGURES_DIR, "accuracy_table.tex") with open(out_path, "w") as f: f.write("\n".join(lines)) print(f" Saved accuracy_table.tex") def generate_latex_stats_table(): """Generate a LaTeX cluster stats table.""" stats_path = os.path.join(RESULTS_DIR, "stats", "cluster_stats_summary.csv") if not os.path.exists(stats_path): return df = pd.read_csv(stats_path) lines = [] lines.append(r"\begin{table}[htbp]") lines.append(r"\centering") lines.append(r"\caption{Cluster statistics summary for each network and method.}") lines.append(r"\label{tab:cluster_stats}") lines.append(r"\footnotesize") lines.append(r"\begin{tabular}{llrrrrrrr}") lines.append(r"\toprule") lines.append(r"Network & Method & \#Clust. & Node Cov. & Mean Size & Mean Dens. & Mean Cond. & Mean Mix. & \%WC \\") lines.append(r"\midrule") for net_name in NETWORKS: first = True for _, row in df[df["network"] == net_name].iterrows(): net_disp = net_name if first else "" m_label = METHOD_LABELS.get(row["method"], row["method"]) nc = int(row['n_clusters_non_singleton']) nwc = int(row['n_wellconnected']) if 'n_wellconnected' in row and not pd.isna(row.get('n_wellconnected', np.nan)) else 0 pct_wc = 100 * nwc / nc if nc > 0 else 0.0 lines.append( f"{net_disp} & {m_label} & {nc} & " f"{row['node_coverage']:.3f} & {row['mean_cluster_size']:.1f} & " f"{row['mean_edge_density']:.3f} & {row['mean_conductance']:.3f} & " f"{row['mean_mixing_param']:.3f} & {pct_wc:.0f}\\% \\\\" ) first = False lines.append(r"\midrule") lines[-1] = r"\bottomrule" lines.append(r"\end{tabular}") lines.append(r"\end{table}") out_path = os.path.join(FIGURES_DIR, "cluster_stats_table.tex") with open(out_path, "w") as f: f.write("\n".join(lines)) print(f" Saved cluster_stats_table.tex") def generate_all(): print("Generating accuracy heatmaps...") plot_accuracy_heatmap() print("Generating cluster size distributions...") plot_cluster_size_distributions() print("Generating edge density boxplots...") plot_edge_density_boxplots() print("Generating edge connectivity boxplots...") plot_edge_connectivity_boxplots() print("Generating well-connected fraction bar chart...") plot_wellconnected_bar() print("Generating mixing parameter comparison...") plot_mixing_parameter_comparison() print("Generating node coverage comparison...") plot_node_coverage_comparison() print("Generating LaTeX tables...") generate_latex_accuracy_table() generate_latex_stats_table() print("All plots and tables generated.") if __name__ == "__main__": generate_all()