diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-24 08:40:49 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-24 08:40:49 +0000 |
| commit | 8f63cf9f41bbdb8d55cd4679872d2b4ae2129324 (patch) | |
| tree | ab5c95888849e854f2346db856c7edece7c8b8a7 /scripts/run_all.py | |
EC-SBM community detection analysis: full pipeline and writeup
Implement community detection on 3 EC-SBM networks (polblogs, topology,
internet_as) using 5 methods (Leiden-Mod, Leiden-CPM at 0.1 and 0.01,
Infomap, graph-tool SBM). Compute AMI/ARI/NMI accuracy, cluster statistics,
and generate figures and LaTeX report.
Diffstat (limited to 'scripts/run_all.py')
| -rw-r--r-- | scripts/run_all.py | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/scripts/run_all.py b/scripts/run_all.py new file mode 100644 index 0000000..c98b9a2 --- /dev/null +++ b/scripts/run_all.py @@ -0,0 +1,67 @@ +"""Master orchestration script: run all methods, compute accuracy and stats, generate plots.""" + +import sys +import os +import time +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from config import NETWORKS, METHODS + + +def main(): + start = time.time() + + print("=" * 60) + print("EC-SBM Community Detection Analysis Pipeline") + print("=" * 60) + + # Step 1: Run community detection methods + for net_name in NETWORKS: + for method in METHODS: + m_name = method["name"] + m_type = method["type"] + + print(f"\n{'='*60}") + print(f"Running {m_name} on {net_name}") + print(f"{'='*60}") + + if m_type == "leiden": + from run_leiden import run_leiden + run_leiden(net_name, m_name, method["quality"], + method.get("resolution")) + elif m_type == "infomap": + from run_infomap import run_infomap + run_infomap(net_name) + elif m_type == "graphtool_sbm": + from run_graphtool_sbm import run_graphtool_sbm + run_graphtool_sbm(net_name) + + # Step 2: Compute accuracy + print(f"\n{'='*60}") + print("Computing accuracy metrics") + print(f"{'='*60}") + from compute_accuracy import compute_all_accuracy + compute_all_accuracy() + + # Step 3: Compute cluster stats + print(f"\n{'='*60}") + print("Computing cluster statistics") + print(f"{'='*60}") + from compute_stats import compute_all_stats + compute_all_stats() + + # Step 4: Generate plots and tables + print(f"\n{'='*60}") + print("Generating plots and LaTeX tables") + print(f"{'='*60}") + from generate_plots import generate_all + generate_all() + + elapsed = time.time() - start + print(f"\n{'='*60}") + print(f"Pipeline complete in {elapsed:.1f}s") + print(f"{'='*60}") + + +if __name__ == "__main__": + main() |
