diff options
Diffstat (limited to 'scripts')
| -rw-r--r-- | scripts/compute_accuracy.py | 82 | ||||
| -rw-r--r-- | scripts/compute_stats.py | 30 | ||||
| -rw-r--r-- | scripts/generate_plots.py | 95 |
3 files changed, 154 insertions, 53 deletions
diff --git a/scripts/compute_accuracy.py b/scripts/compute_accuracy.py index 4aeb6a2..13cf50a 100644 --- a/scripts/compute_accuracy.py +++ b/scripts/compute_accuracy.py @@ -1,76 +1,64 @@ -"""Compute AMI, ARI, NMI for all (network, method) pairs.""" +"""Compute AMI, ARI, NMI for all (network, method) pairs. + +Uses the official network_evaluation scripts from +https://github.com/illinois-or-research-analytics/network_evaluation +""" -import argparse import sys import os sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -import numpy as np +# Add network_evaluation to path +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, os.path.join(BASE_DIR, "network_evaluation", "commdet_acc")) + import pandas as pd -from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score, normalized_mutual_info_score +from compute_cd_accuracy import clustering_accuracy from config import NETWORKS, METHODS, RESULTS_DIR -from load_data import load_edge_list, load_communities - - -def align_labels(gt_com, est_com, edge_path): - """Align ground truth and estimated labels over the full node set from edges. - Nodes missing from a clustering get unique singleton community IDs.""" - edge_df = load_edge_list(edge_path) - all_nodes = sorted(set( - pd.unique(edge_df[["src", "tgt"]].values.ravel("K")) - )) - - gt_labels = [] - est_labels = [] - # For nodes not in GT or EST, assign unique singleton IDs - gt_next = max((int(v) for v in gt_com.values() if v.lstrip('-').isdigit()), default=0) + 1 - est_next = max((int(v) for v in est_com.values() if v.lstrip('-').isdigit()), default=0) + 1 - - for node in all_nodes: - if node in gt_com: - gt_labels.append(gt_com[node]) - else: - gt_labels.append(f"gt_singleton_{gt_next}") - gt_next += 1 - - if node in est_com: - est_labels.append(est_com[node]) - else: - est_labels.append(f"est_singleton_{est_next}") - est_next += 1 - return gt_labels, est_labels - -def compute_accuracy(network_name, method_name): - """Compute AMI, ARI, NMI for a single (network, method) pair.""" +def compute_single_accuracy(network_name, method_name): + """Compute accuracy for a single (network, method) pair using official scripts.""" net = NETWORKS[network_name] - gt_com = load_communities(net["com_gt_tsv"]) - est_path = os.path.join(RESULTS_DIR, network_name, method_name, "com.tsv") if not os.path.exists(est_path): print(f" WARNING: {est_path} not found, skipping") return None - est_com = load_communities(est_path) - gt_labels, est_labels = align_labels(gt_com, est_com, net["edge_tsv"]) + out_prefix = os.path.join(RESULTS_DIR, "accuracy", f"{network_name}_{method_name}") - ami = adjusted_mutual_info_score(gt_labels, est_labels, average_method="arithmetic") - ari = adjusted_rand_score(gt_labels, est_labels) - nmi = normalized_mutual_info_score(gt_labels, est_labels, average_method="arithmetic") + clustering_accuracy( + input_edgelist=net["edge_tsv"], + groundtruth_clustering=net["com_gt_tsv"], + estimated_clustering=est_path, + output_prefix=out_prefix, + num_processors=1, + local=False, + overwrite=True, + ) - return {"ami": ami, "ari": ari, "nmi": nmi} + # Read back the results + result = {} + for metric in ["ami", "ari", "nmi", "node_coverage"]: + fpath = f"{out_prefix}.{metric}" + if os.path.exists(fpath): + with open(fpath) as f: + result[metric] = float(f.read().strip()) + return result def compute_all_accuracy(): """Compute accuracy for all (network, method) pairs and save CSV.""" + out_dir = os.path.join(RESULTS_DIR, "accuracy") + os.makedirs(out_dir, exist_ok=True) + rows = [] for net_name in NETWORKS: for method in METHODS: m_name = method["name"] print(f"Computing accuracy: {net_name} / {m_name}") - result = compute_accuracy(net_name, m_name) + result = compute_single_accuracy(net_name, m_name) if result is not None: rows.append({ "network": net_name, @@ -79,8 +67,6 @@ def compute_all_accuracy(): }) df = pd.DataFrame(rows) - out_dir = os.path.join(RESULTS_DIR, "accuracy") - os.makedirs(out_dir, exist_ok=True) out_path = os.path.join(out_dir, "accuracy_table.csv") df.to_csv(out_path, index=False) print(f"\nAccuracy table saved to {out_path}") diff --git a/scripts/compute_stats.py b/scripts/compute_stats.py index 2e88252..6f21f63 100644 --- a/scripts/compute_stats.py +++ b/scripts/compute_stats.py @@ -8,6 +8,7 @@ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import numpy as np import pandas as pd +import igraph as ig from config import NETWORKS, METHODS, RESULTS_DIR from load_data import load_edge_list, load_communities @@ -50,7 +51,8 @@ def compute_cluster_stats(network_name, method_name, com_path): # Per-cluster stats cluster_details = [] - for com_id, nodes in sorted(non_singleton_coms.items()): + total_clusters = len(non_singleton_coms) + for ci, (com_id, nodes) in enumerate(sorted(non_singleton_coms.items())): n = len(nodes) # Internal edges m_internal = 0 @@ -67,6 +69,21 @@ def compute_cluster_stats(network_name, method_name, com_path): degree_density = m_internal / n if n > 0 else 0.0 conductance = c_boundary / (2 * m_internal + c_boundary) if (2 * m_internal + c_boundary) > 0 else 0.0 + # Minimum edge cut via igraph + mincut = 0 + if n >= 2 and m_internal >= 1: + node_list = sorted(nodes) + local_map = {nd: i for i, nd in enumerate(node_list)} + edges = [] + for nd in node_list: + for nbr in neighbors.get(nd, set()): + if nbr in nodes and local_map[nd] < local_map[nbr]: + edges.append((local_map[nd], local_map[nbr])) + sg = ig.Graph(n=n, edges=edges, directed=False) + mincut = sg.mincut().value + + mincut_over_log10n = mincut / np.log10(n) if n > 1 else 0.0 + cluster_details.append({ "com_id": com_id, "n": n, @@ -75,8 +92,13 @@ def compute_cluster_stats(network_name, method_name, com_path): "edge_density": edge_density, "degree_density": degree_density, "conductance": conductance, + "mincut": int(mincut), + "mincut_over_log10n": mincut_over_log10n, }) + if (ci + 1) % 500 == 0: + print(f" ... {ci+1}/{total_clusters} clusters processed") + # Per-node mixing parameter mixing_params = [] for node in all_nodes: @@ -107,6 +129,12 @@ def compute_cluster_stats(network_name, method_name, com_path): "median_edge_density": np.median([d["edge_density"] for d in cluster_details]) if cluster_details else 0, "mean_conductance": np.mean([d["conductance"] for d in cluster_details]) if cluster_details else 0, "mean_degree_density": np.mean([d["degree_density"] for d in cluster_details]) if cluster_details else 0, + "mean_mincut": np.mean([d["mincut"] for d in cluster_details]) if cluster_details else 0, + "median_mincut": np.median([d["mincut"] for d in cluster_details]) if cluster_details else 0, + "mean_mincut_over_log10n": np.mean([d["mincut_over_log10n"] for d in cluster_details]) if cluster_details else 0, + "n_connected": sum(1 for d in cluster_details if d["mincut"] > 0), + "n_disconnected": sum(1 for d in cluster_details if d["mincut"] == 0), + "n_wellconnected": sum(1 for d in cluster_details if d["mincut"] > np.log10(d["n"])), } return summary, cluster_details, mixing_params diff --git a/scripts/generate_plots.py b/scripts/generate_plots.py index be5db9f..d844f12 100644 --- a/scripts/generate_plots.py +++ b/scripts/generate_plots.py @@ -223,6 +223,86 @@ def plot_node_coverage_comparison(): print(" Saved node_coverage.pdf") +def plot_edge_connectivity_boxplots(): + """Boxplots of mincut/log10(n) across methods for each network.""" + stats_dir = os.path.join(RESULTS_DIR, "stats") + + for net_name in NETWORKS: + all_labels = ["ground_truth"] + METHOD_NAMES + data = [] + labels = [] + + for label in all_labels: + det_path = os.path.join(stats_dir, net_name, label, "cluster_details.json") + if not os.path.exists(det_path): + continue + with open(det_path) as f: + details = json.load(f) + vals = [d["mincut_over_log10n"] for d in details if "mincut_over_log10n" in d] + if vals: + data.append(vals) + labels.append(METHOD_LABELS.get(label, label)) + + if not data: + continue + + fig, ax = plt.subplots(figsize=(9, 4)) + bp = ax.boxplot(data, tick_labels=labels, patch_artist=True, showfliers=False) + for patch in bp["boxes"]: + patch.set_facecolor("lightyellow") + ax.axhline(y=1.0, color="red", linestyle="--", linewidth=0.8, label="well-connected threshold") + ax.set_ylabel("Min Edge Cut / log$_{10}$(n)") + ax.set_title(f"Edge Connectivity — {net_name}") + ax.legend(fontsize=8) + plt.xticks(rotation=20, ha="right") + plt.tight_layout() + plt.savefig(os.path.join(FIGURES_DIR, f"edge_connectivity_{net_name}.pdf"), + bbox_inches="tight") + plt.close() + print(f" Saved edge_connectivity_{net_name}.pdf") + + +def plot_wellconnected_bar(): + """Bar chart of fraction well-connected clusters per method/network.""" + stats_path = os.path.join(RESULTS_DIR, "stats", "cluster_stats_summary.csv") + if not os.path.exists(stats_path): + return + + df = pd.read_csv(stats_path) + + fig, ax = plt.subplots(figsize=(10, 4.5)) + net_names = list(NETWORKS.keys()) + all_methods = ["ground_truth"] + METHOD_NAMES + x = np.arange(len(net_names)) + width = 0.13 + offsets = np.arange(len(all_methods)) - len(all_methods) / 2 + 0.5 + colors = plt.cm.Set2(np.linspace(0, 1, len(all_methods))) + + for i, method in enumerate(all_methods): + vals = [] + for net in net_names: + row = df[(df["network"] == net) & (df["method"] == method)] + if len(row) > 0: + nc = row["n_clusters_non_singleton"].values[0] + nwc = row["n_wellconnected"].values[0] + vals.append(nwc / nc if nc > 0 else 0) + else: + vals.append(0) + ax.bar(x + offsets[i] * width, vals, width, label=METHOD_LABELS.get(method, method), + color=colors[i]) + + ax.set_xticks(x) + ax.set_xticklabels(net_names) + ax.set_ylabel("Fraction Well-Connected") + ax.set_title("Fraction of Well-Connected Clusters (mincut > log$_{10}$(n))") + ax.legend(fontsize=7, ncol=2) + ax.set_ylim(0, 1.05) + plt.tight_layout() + plt.savefig(os.path.join(FIGURES_DIR, "wellconnected.pdf"), bbox_inches="tight") + plt.close() + print(" Saved wellconnected.pdf") + + def generate_latex_accuracy_table(): """Generate a LaTeX accuracy table.""" acc_path = os.path.join(RESULTS_DIR, "accuracy", "accuracy_table.csv") @@ -274,9 +354,9 @@ def generate_latex_stats_table(): lines.append(r"\caption{Cluster statistics summary for each network and method.}") lines.append(r"\label{tab:cluster_stats}") lines.append(r"\footnotesize") - lines.append(r"\begin{tabular}{llrrrrrr}") + lines.append(r"\begin{tabular}{llrrrrrrr}") lines.append(r"\toprule") - lines.append(r"Network & Method & \#Clusters & Node Cov. & Mean Size & Mean Density & Mean Cond. & Mean Mix. \\") + lines.append(r"Network & Method & \#Clust. & Node Cov. & Mean Size & Mean Dens. & Mean Cond. & Mean Mix. & \%WC \\") lines.append(r"\midrule") for net_name in NETWORKS: @@ -284,11 +364,14 @@ def generate_latex_stats_table(): for _, row in df[df["network"] == net_name].iterrows(): net_disp = net_name if first else "" m_label = METHOD_LABELS.get(row["method"], row["method"]) + nc = int(row['n_clusters_non_singleton']) + nwc = int(row['n_wellconnected']) if 'n_wellconnected' in row and not pd.isna(row.get('n_wellconnected', np.nan)) else 0 + pct_wc = 100 * nwc / nc if nc > 0 else 0.0 lines.append( - f"{net_disp} & {m_label} & {int(row['n_clusters_non_singleton'])} & " + f"{net_disp} & {m_label} & {nc} & " f"{row['node_coverage']:.3f} & {row['mean_cluster_size']:.1f} & " f"{row['mean_edge_density']:.3f} & {row['mean_conductance']:.3f} & " - f"{row['mean_mixing_param']:.3f} \\\\" + f"{row['mean_mixing_param']:.3f} & {pct_wc:.0f}\\% \\\\" ) first = False lines.append(r"\midrule") @@ -310,6 +393,10 @@ def generate_all(): plot_cluster_size_distributions() print("Generating edge density boxplots...") plot_edge_density_boxplots() + print("Generating edge connectivity boxplots...") + plot_edge_connectivity_boxplots() + print("Generating well-connected fraction bar chart...") + plot_wellconnected_bar() print("Generating mixing parameter comparison...") plot_mixing_parameter_comparison() print("Generating node coverage comparison...") |
