"""Compute AMI, ARI, NMI for all (network, method) pairs.""" import argparse import sys import os sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import numpy as np import pandas as pd from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score, normalized_mutual_info_score from config import NETWORKS, METHODS, RESULTS_DIR from load_data import load_edge_list, load_communities def align_labels(gt_com, est_com, edge_path): """Align ground truth and estimated labels over the full node set from edges. Nodes missing from a clustering get unique singleton community IDs.""" edge_df = load_edge_list(edge_path) all_nodes = sorted(set( pd.unique(edge_df[["src", "tgt"]].values.ravel("K")) )) gt_labels = [] est_labels = [] # For nodes not in GT or EST, assign unique singleton IDs gt_next = max((int(v) for v in gt_com.values() if v.lstrip('-').isdigit()), default=0) + 1 est_next = max((int(v) for v in est_com.values() if v.lstrip('-').isdigit()), default=0) + 1 for node in all_nodes: if node in gt_com: gt_labels.append(gt_com[node]) else: gt_labels.append(f"gt_singleton_{gt_next}") gt_next += 1 if node in est_com: est_labels.append(est_com[node]) else: est_labels.append(f"est_singleton_{est_next}") est_next += 1 return gt_labels, est_labels def compute_accuracy(network_name, method_name): """Compute AMI, ARI, NMI for a single (network, method) pair.""" net = NETWORKS[network_name] gt_com = load_communities(net["com_gt_tsv"]) est_path = os.path.join(RESULTS_DIR, network_name, method_name, "com.tsv") if not os.path.exists(est_path): print(f" WARNING: {est_path} not found, skipping") return None est_com = load_communities(est_path) gt_labels, est_labels = align_labels(gt_com, est_com, net["edge_tsv"]) ami = adjusted_mutual_info_score(gt_labels, est_labels, average_method="arithmetic") ari = adjusted_rand_score(gt_labels, est_labels) nmi = normalized_mutual_info_score(gt_labels, est_labels, average_method="arithmetic") return {"ami": ami, "ari": ari, "nmi": nmi} def compute_all_accuracy(): """Compute accuracy for all (network, method) pairs and save CSV.""" rows = [] for net_name in NETWORKS: for method in METHODS: m_name = method["name"] print(f"Computing accuracy: {net_name} / {m_name}") result = compute_accuracy(net_name, m_name) if result is not None: rows.append({ "network": net_name, "method": m_name, **result, }) df = pd.DataFrame(rows) out_dir = os.path.join(RESULTS_DIR, "accuracy") os.makedirs(out_dir, exist_ok=True) out_path = os.path.join(out_dir, "accuracy_table.csv") df.to_csv(out_path, index=False) print(f"\nAccuracy table saved to {out_path}") print(df.to_string(index=False)) return df if __name__ == "__main__": compute_all_accuracy()