"""Compute AMI, ARI, NMI for all (network, method) pairs. Uses the official network_evaluation scripts from https://github.com/illinois-or-research-analytics/network_evaluation """ import sys import os sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) # Add network_evaluation to path BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.join(BASE_DIR, "network_evaluation", "commdet_acc")) import pandas as pd from compute_cd_accuracy import clustering_accuracy from config import NETWORKS, METHODS, RESULTS_DIR def compute_single_accuracy(network_name, method_name): """Compute accuracy for a single (network, method) pair using official scripts.""" net = NETWORKS[network_name] est_path = os.path.join(RESULTS_DIR, network_name, method_name, "com.tsv") if not os.path.exists(est_path): print(f" WARNING: {est_path} not found, skipping") return None out_prefix = os.path.join(RESULTS_DIR, "accuracy", f"{network_name}_{method_name}") clustering_accuracy( input_edgelist=net["edge_tsv"], groundtruth_clustering=net["com_gt_tsv"], estimated_clustering=est_path, output_prefix=out_prefix, num_processors=1, local=False, overwrite=True, ) # Read back the results result = {} for metric in ["ami", "ari", "nmi", "node_coverage"]: fpath = f"{out_prefix}.{metric}" if os.path.exists(fpath): with open(fpath) as f: result[metric] = float(f.read().strip()) return result def compute_all_accuracy(): """Compute accuracy for all (network, method) pairs and save CSV.""" out_dir = os.path.join(RESULTS_DIR, "accuracy") os.makedirs(out_dir, exist_ok=True) rows = [] for net_name in NETWORKS: for method in METHODS: m_name = method["name"] print(f"Computing accuracy: {net_name} / {m_name}") result = compute_single_accuracy(net_name, m_name) if result is not None: rows.append({ "network": net_name, "method": m_name, **result, }) df = pd.DataFrame(rows) out_path = os.path.join(out_dir, "accuracy_table.csv") df.to_csv(out_path, index=False) print(f"\nAccuracy table saved to {out_path}") print(df.to_string(index=False)) return df if __name__ == "__main__": compute_all_accuracy()