diff options
Diffstat (limited to 'diag/datasets.py')
| -rw-r--r-- | diag/datasets.py | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/diag/datasets.py b/diag/datasets.py new file mode 100644 index 0000000..cbd236e --- /dev/null +++ b/diag/datasets.py @@ -0,0 +1,70 @@ +"""Synthetic graph datasets for the 1-WL diagnosis.""" +import numpy as np +import networkx as nx + + +def _nx_to_edge_index(G): + G = nx.convert_node_labels_to_integers(G) + n = G.number_of_nodes() + if G.number_of_edges() == 0: + return n, np.zeros((2, 0), dtype=np.int64) + e = np.array(list(G.edges()), dtype=np.int64).T + ei = np.concatenate([e, e[::-1]], axis=1) # undirected -> both directions + return n, ei + + +def circulant(N, offsets): + G = nx.Graph() + G.add_nodes_from(range(N)) + for i in range(N): + for s in offsets: + G.add_edge(i, (i + s) % N) + return G + + +CSL_SKIPS = [2, 3, 4, 5, 6, 9, 11, 12, 13, 16] # 10 classes, N=41 (Murphy et al. 2019) + + +def build_csl(n_per_class=15, N=41, seed=0): + """Circular Skip Links: all graphs 4-regular -> 1-WL collapses to one color (pure H2 anchor).""" + rng = np.random.default_rng(seed) + data = [] + for cls, s in enumerate(CSL_SKIPS): + for _ in range(n_per_class): + G = circulant(N, [1, s]) + perm = rng.permutation(N) + G = nx.relabel_nodes(G, {i: int(perm[i]) for i in range(N)}) + n, ei = _nx_to_edge_index(G) + data.append({'n': n, 'edge_index': ei, 'y': cls}) + return data + + +def build_triangle_count(n_graphs=600, n_nodes=20, kind='regular', deg=3, p=0.2, seed=0): + """Graph-level triangle-count regression. 1-WL cannot count triangles -> measurable H2 floor.""" + rng = np.random.default_rng(seed) + data, tries = [], 0 + while len(data) < n_graphs and tries < n_graphs * 30: + tries += 1 + sd = int(rng.integers(1 << 30)) + try: + G = (nx.random_regular_graph(deg, n_nodes, seed=sd) if kind == 'regular' + else nx.gnp_random_graph(n_nodes, p, seed=sd)) + except Exception: + continue + tri = sum(nx.triangles(G).values()) // 3 + n, ei = _nx_to_edge_index(G) + data.append({'n': n, 'edge_index': ei, 'y': float(tri)}) + return data + + +def canonical_pairs(): + """Graphs for instrument self-test (known 1-WL outcomes).""" + pairs = [('C6', nx.cycle_graph(6)), + ('2C3', nx.disjoint_union(nx.cycle_graph(3), nx.cycle_graph(3))), + ('P4', nx.path_graph(4)), + ('K1,3', nx.star_graph(3))] + out = {} + for name, G in pairs: + n, ei = _nx_to_edge_index(G) + out[name] = {'n': n, 'edge_index': ei, 'tri': sum(nx.triangles(G).values()) // 3} + return out |
