summaryrefslogtreecommitdiff
path: root/scripts/compute_accuracy.py
blob: 13cf50a1abc9a7ea2030482f43a530fc2e1f9969 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Compute AMI, ARI, NMI for all (network, method) pairs.

Uses the official network_evaluation scripts from
https://github.com/illinois-or-research-analytics/network_evaluation
"""

import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

# Add network_evaluation to path
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.join(BASE_DIR, "network_evaluation", "commdet_acc"))

import pandas as pd
from compute_cd_accuracy import clustering_accuracy

from config import NETWORKS, METHODS, RESULTS_DIR


def compute_single_accuracy(network_name, method_name):
    """Compute accuracy for a single (network, method) pair using official scripts."""
    net = NETWORKS[network_name]
    est_path = os.path.join(RESULTS_DIR, network_name, method_name, "com.tsv")
    if not os.path.exists(est_path):
        print(f"  WARNING: {est_path} not found, skipping")
        return None

    out_prefix = os.path.join(RESULTS_DIR, "accuracy", f"{network_name}_{method_name}")

    clustering_accuracy(
        input_edgelist=net["edge_tsv"],
        groundtruth_clustering=net["com_gt_tsv"],
        estimated_clustering=est_path,
        output_prefix=out_prefix,
        num_processors=1,
        local=False,
        overwrite=True,
    )

    # Read back the results
    result = {}
    for metric in ["ami", "ari", "nmi", "node_coverage"]:
        fpath = f"{out_prefix}.{metric}"
        if os.path.exists(fpath):
            with open(fpath) as f:
                result[metric] = float(f.read().strip())
    return result


def compute_all_accuracy():
    """Compute accuracy for all (network, method) pairs and save CSV."""
    out_dir = os.path.join(RESULTS_DIR, "accuracy")
    os.makedirs(out_dir, exist_ok=True)

    rows = []
    for net_name in NETWORKS:
        for method in METHODS:
            m_name = method["name"]
            print(f"Computing accuracy: {net_name} / {m_name}")
            result = compute_single_accuracy(net_name, m_name)
            if result is not None:
                rows.append({
                    "network": net_name,
                    "method": m_name,
                    **result,
                })

    df = pd.DataFrame(rows)
    out_path = os.path.join(out_dir, "accuracy_table.csv")
    df.to_csv(out_path, index=False)
    print(f"\nAccuracy table saved to {out_path}")
    print(df.to_string(index=False))
    return df


if __name__ == "__main__":
    compute_all_accuracy()