summaryrefslogtreecommitdiff
path: root/diag/datasets.py
blob: cbd236ed3a9d5fae6e85697be6ad9c798d716a19 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Synthetic graph datasets for the 1-WL diagnosis."""
import numpy as np
import networkx as nx


def _nx_to_edge_index(G):
    G = nx.convert_node_labels_to_integers(G)
    n = G.number_of_nodes()
    if G.number_of_edges() == 0:
        return n, np.zeros((2, 0), dtype=np.int64)
    e = np.array(list(G.edges()), dtype=np.int64).T
    ei = np.concatenate([e, e[::-1]], axis=1)   # undirected -> both directions
    return n, ei


def circulant(N, offsets):
    G = nx.Graph()
    G.add_nodes_from(range(N))
    for i in range(N):
        for s in offsets:
            G.add_edge(i, (i + s) % N)
    return G


CSL_SKIPS = [2, 3, 4, 5, 6, 9, 11, 12, 13, 16]   # 10 classes, N=41 (Murphy et al. 2019)


def build_csl(n_per_class=15, N=41, seed=0):
    """Circular Skip Links: all graphs 4-regular -> 1-WL collapses to one color (pure H2 anchor)."""
    rng = np.random.default_rng(seed)
    data = []
    for cls, s in enumerate(CSL_SKIPS):
        for _ in range(n_per_class):
            G = circulant(N, [1, s])
            perm = rng.permutation(N)
            G = nx.relabel_nodes(G, {i: int(perm[i]) for i in range(N)})
            n, ei = _nx_to_edge_index(G)
            data.append({'n': n, 'edge_index': ei, 'y': cls})
    return data


def build_triangle_count(n_graphs=600, n_nodes=20, kind='regular', deg=3, p=0.2, seed=0):
    """Graph-level triangle-count regression. 1-WL cannot count triangles -> measurable H2 floor."""
    rng = np.random.default_rng(seed)
    data, tries = [], 0
    while len(data) < n_graphs and tries < n_graphs * 30:
        tries += 1
        sd = int(rng.integers(1 << 30))
        try:
            G = (nx.random_regular_graph(deg, n_nodes, seed=sd) if kind == 'regular'
                 else nx.gnp_random_graph(n_nodes, p, seed=sd))
        except Exception:
            continue
        tri = sum(nx.triangles(G).values()) // 3
        n, ei = _nx_to_edge_index(G)
        data.append({'n': n, 'edge_index': ei, 'y': float(tri)})
    return data


def canonical_pairs():
    """Graphs for instrument self-test (known 1-WL outcomes)."""
    pairs = [('C6', nx.cycle_graph(6)),
             ('2C3', nx.disjoint_union(nx.cycle_graph(3), nx.cycle_graph(3))),
             ('P4', nx.path_graph(4)),
             ('K1,3', nx.star_graph(3))]
    out = {}
    for name, G in pairs:
        n, ei = _nx_to_edge_index(G)
        out[name] = {'n': n, 'edge_index': ei, 'tri': sum(nx.triangles(G).values()) // 3}
    return out