summaryrefslogtreecommitdiff
path: root/scripts/load_data.py
blob: 5a0362bfe2ee7d51c81192ef2d5ab64af8e9b8a6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""Shared data loading utilities for EC-SBM analysis."""

import pandas as pd
import numpy as np


def load_edge_list(path):
    """Load a tab-separated edge list (no header, two columns: node1, node2)."""
    df = pd.read_csv(path, sep="\t", header=None, names=["src", "tgt"],
                      dtype=str, comment="#")
    return df


def load_communities(path):
    """Load a tab-separated community file (no header: node, community).
    Returns dict {node_str: community_str}."""
    node2com = {}
    with open(path, "r") as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith("#"):
                continue
            parts = line.split("\t")
            if len(parts) >= 2:
                node2com[parts[0]] = parts[1]
    return node2com


def build_igraph(edge_df):
    """Build an igraph Graph from an edge DataFrame.
    Returns (graph, name_to_idx, idx_to_name)."""
    import igraph as ig

    all_nodes = pd.unique(edge_df[["src", "tgt"]].values.ravel("K"))
    name_to_idx = {name: i for i, name in enumerate(all_nodes)}
    idx_to_name = {i: name for name, i in name_to_idx.items()}

    src_ids = edge_df["src"].map(name_to_idx).values
    tgt_ids = edge_df["tgt"].map(name_to_idx).values

    n = len(all_nodes)
    g = ig.Graph(n=n, edges=list(zip(src_ids, tgt_ids)), directed=False)
    g.simplify()  # remove multi-edges and self-loops
    return g, name_to_idx, idx_to_name


def build_graphtool_graph(edge_df):
    """Build a graph-tool Graph from an edge DataFrame.
    Returns (graph, name_to_idx, idx_to_name)."""
    import graph_tool.all as gt

    all_nodes = pd.unique(edge_df[["src", "tgt"]].values.ravel("K"))
    name_to_idx = {name: i for i, name in enumerate(all_nodes)}
    idx_to_name = {i: name for name, i in name_to_idx.items()}

    src_ids = edge_df["src"].map(name_to_idx).values.astype(np.int64)
    tgt_ids = edge_df["tgt"].map(name_to_idx).values.astype(np.int64)

    n = len(all_nodes)
    g = gt.Graph(directed=False)
    g.add_vertex(n)
    g.add_edge_list(np.column_stack([src_ids, tgt_ids]))
    gt.remove_parallel_edges(g)
    gt.remove_self_loops(g)
    return g, name_to_idx, idx_to_name


def save_communities(node2com, path):
    """Save community assignments as TSV (node\tcommunity)."""
    import os
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w") as f:
        for node, com in sorted(node2com.items(), key=lambda x: x[0]):
            f.write(f"{node}\t{com}\n")