summaryrefslogtreecommitdiff
path: root/diag/datasets.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/datasets.py')
-rw-r--r--diag/datasets.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/diag/datasets.py b/diag/datasets.py
new file mode 100644
index 0000000..cbd236e
--- /dev/null
+++ b/diag/datasets.py
@@ -0,0 +1,70 @@
+"""Synthetic graph datasets for the 1-WL diagnosis."""
+import numpy as np
+import networkx as nx
+
+
+def _nx_to_edge_index(G):
+ G = nx.convert_node_labels_to_integers(G)
+ n = G.number_of_nodes()
+ if G.number_of_edges() == 0:
+ return n, np.zeros((2, 0), dtype=np.int64)
+ e = np.array(list(G.edges()), dtype=np.int64).T
+ ei = np.concatenate([e, e[::-1]], axis=1) # undirected -> both directions
+ return n, ei
+
+
+def circulant(N, offsets):
+ G = nx.Graph()
+ G.add_nodes_from(range(N))
+ for i in range(N):
+ for s in offsets:
+ G.add_edge(i, (i + s) % N)
+ return G
+
+
+CSL_SKIPS = [2, 3, 4, 5, 6, 9, 11, 12, 13, 16] # 10 classes, N=41 (Murphy et al. 2019)
+
+
+def build_csl(n_per_class=15, N=41, seed=0):
+ """Circular Skip Links: all graphs 4-regular -> 1-WL collapses to one color (pure H2 anchor)."""
+ rng = np.random.default_rng(seed)
+ data = []
+ for cls, s in enumerate(CSL_SKIPS):
+ for _ in range(n_per_class):
+ G = circulant(N, [1, s])
+ perm = rng.permutation(N)
+ G = nx.relabel_nodes(G, {i: int(perm[i]) for i in range(N)})
+ n, ei = _nx_to_edge_index(G)
+ data.append({'n': n, 'edge_index': ei, 'y': cls})
+ return data
+
+
+def build_triangle_count(n_graphs=600, n_nodes=20, kind='regular', deg=3, p=0.2, seed=0):
+ """Graph-level triangle-count regression. 1-WL cannot count triangles -> measurable H2 floor."""
+ rng = np.random.default_rng(seed)
+ data, tries = [], 0
+ while len(data) < n_graphs and tries < n_graphs * 30:
+ tries += 1
+ sd = int(rng.integers(1 << 30))
+ try:
+ G = (nx.random_regular_graph(deg, n_nodes, seed=sd) if kind == 'regular'
+ else nx.gnp_random_graph(n_nodes, p, seed=sd))
+ except Exception:
+ continue
+ tri = sum(nx.triangles(G).values()) // 3
+ n, ei = _nx_to_edge_index(G)
+ data.append({'n': n, 'edge_index': ei, 'y': float(tri)})
+ return data
+
+
+def canonical_pairs():
+ """Graphs for instrument self-test (known 1-WL outcomes)."""
+ pairs = [('C6', nx.cycle_graph(6)),
+ ('2C3', nx.disjoint_union(nx.cycle_graph(3), nx.cycle_graph(3))),
+ ('P4', nx.path_graph(4)),
+ ('K1,3', nx.star_graph(3))]
+ out = {}
+ for name, G in pairs:
+ n, ei = _nx_to_edge_index(G)
+ out[name] = {'n': n, 'edge_index': ei, 'tri': sum(nx.triangles(G).values()) // 3}
+ return out