summaryrefslogtreecommitdiff
path: root/src/data.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/data.py')
-rw-r--r--src/data.py189
1 files changed, 189 insertions, 0 deletions
diff --git a/src/data.py b/src/data.py
new file mode 100644
index 0000000..6e80285
--- /dev/null
+++ b/src/data.py
@@ -0,0 +1,189 @@
+"""Data loading and preprocessing for Graph-GrAPE experiments."""
+
+import torch
+import torch.nn.functional as F
+from torch_geometric.datasets import Planetoid
+import torch_geometric.transforms as T
+
+
+def spmm(A, B):
+ """Sparse matrix @ dense matrix."""
+ return torch.sparse.mm(A, B)
+
+
+def build_normalized_adj(edge_index, num_nodes):
+ """Build  = D̃^{-1/2} à D̃^{-1/2} with self-loops, as sparse tensor."""
+ row, col = edge_index
+ # Add self-loops: Ã = A + I
+ self_loops = torch.arange(num_nodes, device=edge_index.device)
+ row = torch.cat([row, self_loops])
+ col = torch.cat([col, self_loops])
+
+ # Degree
+ deg = torch.zeros(num_nodes, device=edge_index.device)
+ deg.scatter_add_(0, row, torch.ones_like(row, dtype=torch.float))
+
+ # D̃^{-1/2}
+ deg_inv_sqrt = deg.pow(-0.5)
+ deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0.0
+
+ # Edge weights: d_i^{-1/2} * d_j^{-1/2}
+ values = deg_inv_sqrt[row] * deg_inv_sqrt[col]
+
+ A_hat = torch.sparse_coo_tensor(
+ torch.stack([row, col]), values, (num_nodes, num_nodes)
+ ).coalesce()
+
+ return A_hat
+
+
+def precompute_traces(A_hat, max_power=4):
+ """Precompute tr(Â^k) for k=0..max_power.
+
+ For Cora-sized graphs, we use exact computation via sparse powers.
+ For larger graphs, switch to Hutchinson estimator.
+ """
+ N = A_hat.size(0)
+ traces = {0: torch.tensor(float(N), device=A_hat.device)}
+
+ # tr(Â) = sum of diagonal entries
+ indices = A_hat.indices()
+ values = A_hat.values()
+ diag_mask = indices[0] == indices[1]
+ traces[1] = values[diag_mask].sum()
+
+ # tr(Â^2) = ||Â||_F^2 = sum of squared entries
+ traces[2] = (values ** 2).sum()
+
+ # For higher powers, use Hutchinson estimator: tr(M) ≈ (1/K) Σ z^T M z
+ if max_power >= 3:
+ num_probes = 100
+ for power in range(3, max_power + 1):
+ est = 0.0
+ for _ in range(num_probes):
+ z = torch.randn(N, 1, device=A_hat.device)
+ Az = z
+ for _ in range(power):
+ Az = spmm(A_hat, Az)
+ est += (z * Az).sum().item()
+ traces[power] = torch.tensor(est / num_probes, device=A_hat.device)
+
+ return traces
+
+
+def subsample_train_mask(data, label_rate, seed=0):
+ """Create a train mask with `label_rate` fraction of total nodes as labels.
+
+ Ensures at least 1 node per class.
+ """
+ y = data['y']
+ N = data['num_nodes']
+ C = data['num_classes']
+ n_per_class = max(1, int(label_rate * N / C))
+
+ rng = torch.Generator()
+ rng.manual_seed(seed)
+
+ mask = torch.zeros(N, dtype=torch.bool, device=y.device)
+ for c in range(C):
+ idx_c = (y == c).nonzero(as_tuple=True)[0]
+ perm = torch.randperm(len(idx_c), generator=rng)
+ selected = idx_c[perm[:n_per_class]]
+ mask[selected] = True
+
+ return mask
+
+
+def build_row_normalized_adj(edge_index, num_nodes):
+ """Build D⁻¹Ã (row-normalized) and its transpose, as sparse tensors."""
+ row, col = edge_index
+ self_loops = torch.arange(num_nodes, device=edge_index.device)
+ row = torch.cat([row, self_loops])
+ col = torch.cat([col, self_loops])
+
+ deg = torch.zeros(num_nodes, device=edge_index.device)
+ deg.scatter_add_(0, row, torch.ones_like(row, dtype=torch.float))
+ deg_inv = deg.pow(-1)
+ deg_inv[deg_inv == float('inf')] = 0.0
+
+ # D⁻¹Ã: normalize by row (target node degree)
+ vals = deg_inv[row]
+ A_row = torch.sparse_coo_tensor(
+ torch.stack([row, col]), vals, (num_nodes, num_nodes)
+ ).coalesce()
+
+ # Transpose: Ã D⁻¹ (normalize by column / source node degree)
+ vals_T = deg_inv[col]
+ A_row_T = torch.sparse_coo_tensor(
+ torch.stack([row, col]), vals_T, (num_nodes, num_nodes)
+ ).coalesce()
+
+ return A_row, A_row_T
+
+
+def load_amazon(name, root='./data', device='cuda:0', train_ratio=0.1, val_ratio=0.1, seed=0):
+ """Load Amazon Photo or Computers with random split."""
+ from torch_geometric.datasets import Amazon
+ dataset = Amazon(root=root, name=name)
+ data = dataset[0]
+ N = data.num_nodes
+ C = dataset.num_classes
+
+ # Random split: train_ratio per class, val_ratio per class, rest = test
+ rng = torch.Generator().manual_seed(seed)
+ train_mask = torch.zeros(N, dtype=torch.bool)
+ val_mask = torch.zeros(N, dtype=torch.bool)
+ test_mask = torch.zeros(N, dtype=torch.bool)
+ for c in range(C):
+ idx = (data.y == c).nonzero(as_tuple=True)[0]
+ perm = torch.randperm(len(idx), generator=rng)
+ n_train = max(1, int(train_ratio * len(idx)))
+ n_val = max(1, int(val_ratio * len(idx)))
+ train_mask[idx[perm[:n_train]]] = True
+ val_mask[idx[perm[n_train:n_train + n_val]]] = True
+ test_mask[idx[perm[n_train + n_val:]]] = True
+
+ A_hat = build_normalized_adj(data.edge_index, N)
+ A_row, A_row_T = build_row_normalized_adj(data.edge_index, N)
+ traces = precompute_traces(A_hat, max_power=4)
+
+ return {
+ 'X': data.x.to(device),
+ 'y': data.y.to(device),
+ 'A_hat': A_hat.to(device),
+ 'A_row': A_row.to(device),
+ 'A_row_T': A_row_T.to(device),
+ 'train_mask': train_mask.to(device),
+ 'val_mask': val_mask.to(device),
+ 'test_mask': test_mask.to(device),
+ 'num_nodes': N,
+ 'num_features': data.x.shape[1],
+ 'num_classes': C,
+ 'traces': {k: v.to(device) for k, v in traces.items()},
+ }
+
+
+def load_dataset(name, root='./data', device='cuda:3'):
+ """Load Planetoid dataset and precompute graph structures."""
+ dataset = Planetoid(root=root, name=name, transform=T.NormalizeFeatures())
+ data = dataset[0]
+
+ A_hat = build_normalized_adj(data.edge_index, data.num_nodes)
+ A_row, A_row_T = build_row_normalized_adj(data.edge_index, data.num_nodes)
+ traces = precompute_traces(A_hat, max_power=4)
+
+ result = {
+ 'X': data.x.to(device),
+ 'y': data.y.to(device),
+ 'A_hat': A_hat.to(device),
+ 'A_row': A_row.to(device),
+ 'A_row_T': A_row_T.to(device),
+ 'train_mask': data.train_mask.to(device),
+ 'val_mask': data.val_mask.to(device),
+ 'test_mask': data.test_mask.to(device),
+ 'num_nodes': data.num_nodes,
+ 'num_features': dataset.num_features,
+ 'num_classes': dataset.num_classes,
+ 'traces': {k: v.to(device) for k, v in traces.items()},
+ }
+ return result