summaryrefslogtreecommitdiff
path: root/scripts/load_data.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/load_data.py')
-rw-r--r--scripts/load_data.py74
1 files changed, 74 insertions, 0 deletions
diff --git a/scripts/load_data.py b/scripts/load_data.py
new file mode 100644
index 0000000..5a0362b
--- /dev/null
+++ b/scripts/load_data.py
@@ -0,0 +1,74 @@
+"""Shared data loading utilities for EC-SBM analysis."""
+
+import pandas as pd
+import numpy as np
+
+
+def load_edge_list(path):
+ """Load a tab-separated edge list (no header, two columns: node1, node2)."""
+ df = pd.read_csv(path, sep="\t", header=None, names=["src", "tgt"],
+ dtype=str, comment="#")
+ return df
+
+
+def load_communities(path):
+ """Load a tab-separated community file (no header: node, community).
+ Returns dict {node_str: community_str}."""
+ node2com = {}
+ with open(path, "r") as f:
+ for line in f:
+ line = line.strip()
+ if not line or line.startswith("#"):
+ continue
+ parts = line.split("\t")
+ if len(parts) >= 2:
+ node2com[parts[0]] = parts[1]
+ return node2com
+
+
+def build_igraph(edge_df):
+ """Build an igraph Graph from an edge DataFrame.
+ Returns (graph, name_to_idx, idx_to_name)."""
+ import igraph as ig
+
+ all_nodes = pd.unique(edge_df[["src", "tgt"]].values.ravel("K"))
+ name_to_idx = {name: i for i, name in enumerate(all_nodes)}
+ idx_to_name = {i: name for name, i in name_to_idx.items()}
+
+ src_ids = edge_df["src"].map(name_to_idx).values
+ tgt_ids = edge_df["tgt"].map(name_to_idx).values
+
+ n = len(all_nodes)
+ g = ig.Graph(n=n, edges=list(zip(src_ids, tgt_ids)), directed=False)
+ g.simplify() # remove multi-edges and self-loops
+ return g, name_to_idx, idx_to_name
+
+
+def build_graphtool_graph(edge_df):
+ """Build a graph-tool Graph from an edge DataFrame.
+ Returns (graph, name_to_idx, idx_to_name)."""
+ import graph_tool.all as gt
+
+ all_nodes = pd.unique(edge_df[["src", "tgt"]].values.ravel("K"))
+ name_to_idx = {name: i for i, name in enumerate(all_nodes)}
+ idx_to_name = {i: name for name, i in name_to_idx.items()}
+
+ src_ids = edge_df["src"].map(name_to_idx).values.astype(np.int64)
+ tgt_ids = edge_df["tgt"].map(name_to_idx).values.astype(np.int64)
+
+ n = len(all_nodes)
+ g = gt.Graph(directed=False)
+ g.add_vertex(n)
+ g.add_edge_list(np.column_stack([src_ids, tgt_ids]))
+ gt.remove_parallel_edges(g)
+ gt.remove_self_loops(g)
+ return g, name_to_idx, idx_to_name
+
+
+def save_communities(node2com, path):
+ """Save community assignments as TSV (node\tcommunity)."""
+ import os
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ with open(path, "w") as f:
+ for node, com in sorted(node2com.items(), key=lambda x: x[0]):
+ f.write(f"{node}\t{com}\n")