diff options
Diffstat (limited to 'scripts/compute_accuracy.py')
| -rw-r--r-- | scripts/compute_accuracy.py | 82 |
1 files changed, 34 insertions, 48 deletions
diff --git a/scripts/compute_accuracy.py b/scripts/compute_accuracy.py index 4aeb6a2..13cf50a 100644 --- a/scripts/compute_accuracy.py +++ b/scripts/compute_accuracy.py @@ -1,76 +1,64 @@ -"""Compute AMI, ARI, NMI for all (network, method) pairs.""" +"""Compute AMI, ARI, NMI for all (network, method) pairs. + +Uses the official network_evaluation scripts from +https://github.com/illinois-or-research-analytics/network_evaluation +""" -import argparse import sys import os sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -import numpy as np +# Add network_evaluation to path +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, os.path.join(BASE_DIR, "network_evaluation", "commdet_acc")) + import pandas as pd -from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score, normalized_mutual_info_score +from compute_cd_accuracy import clustering_accuracy from config import NETWORKS, METHODS, RESULTS_DIR -from load_data import load_edge_list, load_communities - - -def align_labels(gt_com, est_com, edge_path): - """Align ground truth and estimated labels over the full node set from edges. - Nodes missing from a clustering get unique singleton community IDs.""" - edge_df = load_edge_list(edge_path) - all_nodes = sorted(set( - pd.unique(edge_df[["src", "tgt"]].values.ravel("K")) - )) - - gt_labels = [] - est_labels = [] - # For nodes not in GT or EST, assign unique singleton IDs - gt_next = max((int(v) for v in gt_com.values() if v.lstrip('-').isdigit()), default=0) + 1 - est_next = max((int(v) for v in est_com.values() if v.lstrip('-').isdigit()), default=0) + 1 - - for node in all_nodes: - if node in gt_com: - gt_labels.append(gt_com[node]) - else: - gt_labels.append(f"gt_singleton_{gt_next}") - gt_next += 1 - - if node in est_com: - est_labels.append(est_com[node]) - else: - est_labels.append(f"est_singleton_{est_next}") - est_next += 1 - return gt_labels, est_labels - -def compute_accuracy(network_name, method_name): - """Compute AMI, ARI, NMI for a single (network, method) pair.""" +def compute_single_accuracy(network_name, method_name): + """Compute accuracy for a single (network, method) pair using official scripts.""" net = NETWORKS[network_name] - gt_com = load_communities(net["com_gt_tsv"]) - est_path = os.path.join(RESULTS_DIR, network_name, method_name, "com.tsv") if not os.path.exists(est_path): print(f" WARNING: {est_path} not found, skipping") return None - est_com = load_communities(est_path) - gt_labels, est_labels = align_labels(gt_com, est_com, net["edge_tsv"]) + out_prefix = os.path.join(RESULTS_DIR, "accuracy", f"{network_name}_{method_name}") - ami = adjusted_mutual_info_score(gt_labels, est_labels, average_method="arithmetic") - ari = adjusted_rand_score(gt_labels, est_labels) - nmi = normalized_mutual_info_score(gt_labels, est_labels, average_method="arithmetic") + clustering_accuracy( + input_edgelist=net["edge_tsv"], + groundtruth_clustering=net["com_gt_tsv"], + estimated_clustering=est_path, + output_prefix=out_prefix, + num_processors=1, + local=False, + overwrite=True, + ) - return {"ami": ami, "ari": ari, "nmi": nmi} + # Read back the results + result = {} + for metric in ["ami", "ari", "nmi", "node_coverage"]: + fpath = f"{out_prefix}.{metric}" + if os.path.exists(fpath): + with open(fpath) as f: + result[metric] = float(f.read().strip()) + return result def compute_all_accuracy(): """Compute accuracy for all (network, method) pairs and save CSV.""" + out_dir = os.path.join(RESULTS_DIR, "accuracy") + os.makedirs(out_dir, exist_ok=True) + rows = [] for net_name in NETWORKS: for method in METHODS: m_name = method["name"] print(f"Computing accuracy: {net_name} / {m_name}") - result = compute_accuracy(net_name, m_name) + result = compute_single_accuracy(net_name, m_name) if result is not None: rows.append({ "network": net_name, @@ -79,8 +67,6 @@ def compute_all_accuracy(): }) df = pd.DataFrame(rows) - out_dir = os.path.join(RESULTS_DIR, "accuracy") - os.makedirs(out_dir, exist_ok=True) out_path = os.path.join(out_dir, "accuracy_table.csv") df.to_csv(out_path, index=False) print(f"\nAccuracy table saved to {out_path}") |
