summaryrefslogtreecommitdiff
path: root/scripts/run_graphtool_sbm.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/run_graphtool_sbm.py')
-rw-r--r--scripts/run_graphtool_sbm.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/scripts/run_graphtool_sbm.py b/scripts/run_graphtool_sbm.py
new file mode 100644
index 0000000..f860e89
--- /dev/null
+++ b/scripts/run_graphtool_sbm.py
@@ -0,0 +1,47 @@
+"""Run graph-tool SBM inference for community detection."""
+
+import argparse
+import sys
+import os
+sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+
+import graph_tool.all as gt
+import numpy as np
+
+from config import NETWORKS, RESULTS_DIR
+from load_data import load_edge_list, build_graphtool_graph, save_communities
+
+
+def run_graphtool_sbm(network_name):
+ net = NETWORKS[network_name]
+ edge_df = load_edge_list(net["edge_tsv"])
+ g, name_to_idx, idx_to_name = build_graphtool_graph(edge_df)
+
+ print(f" Graph: {g.num_vertices()} nodes, {g.num_edges()} edges")
+
+ # Use minimize_blockmodel_dl for flat (non-nested) SBM
+ np.random.seed(42)
+ gt.seed_rng(42)
+ state = gt.minimize_blockmodel_dl(g)
+
+ # Extract block assignments
+ blocks = state.get_blocks()
+ n_blocks = len(set(blocks.a))
+ print(f" Found {n_blocks} blocks")
+
+ node2com = {}
+ for v in g.vertices():
+ node2com[idx_to_name[int(v)]] = str(blocks[v])
+
+ out_path = os.path.join(RESULTS_DIR, network_name, "graphtool_sbm", "com.tsv")
+ save_communities(node2com, out_path)
+ print(f" Saved to {out_path}")
+ return node2com
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--network", required=True)
+ args = parser.parse_args()
+ print(f"Running graph-tool SBM on {args.network}...")
+ run_graphtool_sbm(args.network)