"""Synthetic graph datasets for the 1-WL diagnosis.""" import numpy as np import networkx as nx def _nx_to_edge_index(G): G = nx.convert_node_labels_to_integers(G) n = G.number_of_nodes() if G.number_of_edges() == 0: return n, np.zeros((2, 0), dtype=np.int64) e = np.array(list(G.edges()), dtype=np.int64).T ei = np.concatenate([e, e[::-1]], axis=1) # undirected -> both directions return n, ei def circulant(N, offsets): G = nx.Graph() G.add_nodes_from(range(N)) for i in range(N): for s in offsets: G.add_edge(i, (i + s) % N) return G CSL_SKIPS = [2, 3, 4, 5, 6, 9, 11, 12, 13, 16] # 10 classes, N=41 (Murphy et al. 2019) def build_csl(n_per_class=15, N=41, seed=0): """Circular Skip Links: all graphs 4-regular -> 1-WL collapses to one color (pure H2 anchor).""" rng = np.random.default_rng(seed) data = [] for cls, s in enumerate(CSL_SKIPS): for _ in range(n_per_class): G = circulant(N, [1, s]) perm = rng.permutation(N) G = nx.relabel_nodes(G, {i: int(perm[i]) for i in range(N)}) n, ei = _nx_to_edge_index(G) data.append({'n': n, 'edge_index': ei, 'y': cls}) return data def build_triangle_count(n_graphs=600, n_nodes=20, kind='regular', deg=3, p=0.2, seed=0): """Graph-level triangle-count regression. 1-WL cannot count triangles -> measurable H2 floor.""" rng = np.random.default_rng(seed) data, tries = [], 0 while len(data) < n_graphs and tries < n_graphs * 30: tries += 1 sd = int(rng.integers(1 << 30)) try: G = (nx.random_regular_graph(deg, n_nodes, seed=sd) if kind == 'regular' else nx.gnp_random_graph(n_nodes, p, seed=sd)) except Exception: continue tri = sum(nx.triangles(G).values()) // 3 n, ei = _nx_to_edge_index(G) data.append({'n': n, 'edge_index': ei, 'y': float(tri)}) return data def canonical_pairs(): """Graphs for instrument self-test (known 1-WL outcomes).""" pairs = [('C6', nx.cycle_graph(6)), ('2C3', nx.disjoint_union(nx.cycle_graph(3), nx.cycle_graph(3))), ('P4', nx.path_graph(4)), ('K1,3', nx.star_graph(3))] out = {} for name, G in pairs: n, ei = _nx_to_edge_index(G) out[name] = {'n': n, 'edge_index': ei, 'tri': sum(nx.triangles(G).values()) // 3} return out