From 8f63cf9f41bbdb8d55cd4679872d2b4ae2129324 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 24 Feb 2026 08:40:49 +0000 Subject: EC-SBM community detection analysis: full pipeline and writeup Implement community detection on 3 EC-SBM networks (polblogs, topology, internet_as) using 5 methods (Leiden-Mod, Leiden-CPM at 0.1 and 0.01, Infomap, graph-tool SBM). Compute AMI/ARI/NMI accuracy, cluster statistics, and generate figures and LaTeX report. --- scripts/compute_accuracy.py | 92 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 scripts/compute_accuracy.py (limited to 'scripts/compute_accuracy.py') diff --git a/scripts/compute_accuracy.py b/scripts/compute_accuracy.py new file mode 100644 index 0000000..4aeb6a2 --- /dev/null +++ b/scripts/compute_accuracy.py @@ -0,0 +1,92 @@ +"""Compute AMI, ARI, NMI for all (network, method) pairs.""" + +import argparse +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +import numpy as np +import pandas as pd +from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score, normalized_mutual_info_score + +from config import NETWORKS, METHODS, RESULTS_DIR +from load_data import load_edge_list, load_communities + + +def align_labels(gt_com, est_com, edge_path): + """Align ground truth and estimated labels over the full node set from edges. + Nodes missing from a clustering get unique singleton community IDs.""" + edge_df = load_edge_list(edge_path) + all_nodes = sorted(set( + pd.unique(edge_df[["src", "tgt"]].values.ravel("K")) + )) + + gt_labels = [] + est_labels = [] + # For nodes not in GT or EST, assign unique singleton IDs + gt_next = max((int(v) for v in gt_com.values() if v.lstrip('-').isdigit()), default=0) + 1 + est_next = max((int(v) for v in est_com.values() if v.lstrip('-').isdigit()), default=0) + 1 + + for node in all_nodes: + if node in gt_com: + gt_labels.append(gt_com[node]) + else: + gt_labels.append(f"gt_singleton_{gt_next}") + gt_next += 1 + + if node in est_com: + est_labels.append(est_com[node]) + else: + est_labels.append(f"est_singleton_{est_next}") + est_next += 1 + + return gt_labels, est_labels + + +def compute_accuracy(network_name, method_name): + """Compute AMI, ARI, NMI for a single (network, method) pair.""" + net = NETWORKS[network_name] + gt_com = load_communities(net["com_gt_tsv"]) + + est_path = os.path.join(RESULTS_DIR, network_name, method_name, "com.tsv") + if not os.path.exists(est_path): + print(f" WARNING: {est_path} not found, skipping") + return None + + est_com = load_communities(est_path) + gt_labels, est_labels = align_labels(gt_com, est_com, net["edge_tsv"]) + + ami = adjusted_mutual_info_score(gt_labels, est_labels, average_method="arithmetic") + ari = adjusted_rand_score(gt_labels, est_labels) + nmi = normalized_mutual_info_score(gt_labels, est_labels, average_method="arithmetic") + + return {"ami": ami, "ari": ari, "nmi": nmi} + + +def compute_all_accuracy(): + """Compute accuracy for all (network, method) pairs and save CSV.""" + rows = [] + for net_name in NETWORKS: + for method in METHODS: + m_name = method["name"] + print(f"Computing accuracy: {net_name} / {m_name}") + result = compute_accuracy(net_name, m_name) + if result is not None: + rows.append({ + "network": net_name, + "method": m_name, + **result, + }) + + df = pd.DataFrame(rows) + out_dir = os.path.join(RESULTS_DIR, "accuracy") + os.makedirs(out_dir, exist_ok=True) + out_path = os.path.join(out_dir, "accuracy_table.csv") + df.to_csv(out_path, index=False) + print(f"\nAccuracy table saved to {out_path}") + print(df.to_string(index=False)) + return df + + +if __name__ == "__main__": + compute_all_accuracy() -- cgit v1.2.3