#!/usr/bin/env python3 """H4: Forward-Forward with Virtual-Node variant (FF+VN, Hinton 2022 + graph adaptation). Each layer trained locally to discriminate positive vs negative samples via a "goodness" function (sum of squared activations). For graph data with virtual node: - Positive sample: augment graph with a virtual node connected to all real nodes. The VN feature encodes the CORRECT class label (one-hot). - Negative sample: same graph augmentation but VN feature encodes a WRONG (random) class label. - Goodness at layer l: g_l = mean(H_l^2) (clamped via sigmoid threshold θ) - Local loss: binary cross-entropy on goodness, positive should exceed θ, negative should stay below θ. - Each layer trained independently on its own local loss. Inference: take final-layer goodness at virtual node across candidate labels, pick argmax. Runs on Cora/CiteSeer/PubMed/DBLP × 20 seeds, GCN L=6. """ import torch import torch.nn.functional as F import numpy as np import json import os from src.data import load_dataset, spmm from run_dblp_depth import load_dblp device = 'cuda:0' SEEDS = list(range(20)) EPOCHS = 200 OUT_DIR = 'results/ff_baseline_20seeds' class FFTrainer: """FF+VN for GCN L=6: virtual node carries label, per-layer goodness-discriminator.""" def __init__(self, data, hidden_dim, lr, weight_decay, num_layers=2, residual_alpha=0.0, backbone='gcn', ff_threshold=2.0, **_kw): dev = data['X'].device self.data = data self.device = dev self.lr = lr self.wd = weight_decay self.num_layers = num_layers self.backbone = backbone self.theta = ff_threshold d_in_orig = data['num_features'] d_out = data['num_classes'] self.d_in = d_in_orig + d_out # augmented: original features + label one-hot slot self.d_out = d_out self.N_orig = data['num_nodes'] dims = [self.d_in] + [hidden_dim] * (num_layers - 1) + [hidden_dim] self.weights = [] for i in range(num_layers): w = torch.empty(dims[i], dims[i + 1], device=dev) torch.nn.init.xavier_uniform_(w) w.requires_grad_(True) self.weights.append(w) self.optim = torch.optim.Adam(self.weights, lr=lr, weight_decay=weight_decay) # Pre-build augmented adjacency with virtual node self.A_hat_aug = self._build_vn_adj() def _build_vn_adj(self): """Augment A_hat with a virtual node (index N) connected to all N real nodes. Re-normalize symmetrically.""" N = self.N_orig A = self.data['A_hat'] # (N, N) sparse # For simplicity build dense adjacency (OK for small graphs) if A.is_sparse: A_dense = A.to_dense() else: A_dense = A # Add row/col for VN (index N) A_big = torch.zeros(N + 1, N + 1, device=A.device) A_big[:N, :N] = A_dense A_big[N, :N] = 1.0 # VN connects to all A_big[:N, N] = 1.0 # symmetric A_big[N, N] = 1.0 # self-loop for VN # Symmetric re-normalize: D^(-1/2) (A + I) D^(-1/2). Our A_hat already has # self-loops + normalization per convention. For simplicity just re-normalize. deg = A_big.sum(dim=1) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 D_inv_sqrt = torch.diag(deg_inv_sqrt) A_norm = D_inv_sqrt @ A_big @ D_inv_sqrt return A_norm def _make_input(self, label_vec): """Build augmented X (N+1, d_in_orig + d_out) with VN (row N) carrying label_vec (one-hot vector of length d_out) in its last d_out slots. Real nodes (rows 0..N-1) have 0s in label slots.""" X_orig = self.data['X'] N = self.N_orig # Original features padded with zeros in label slots zeros_lbl = torch.zeros(N, self.d_out, device=self.device) X_real = torch.cat([X_orig, zeros_lbl], dim=1) # Virtual node: zero features, label_vec in label slots zeros_feat = torch.zeros(1, X_orig.shape[1], device=self.device) X_vn = torch.cat([zeros_feat, label_vec.unsqueeze(0)], dim=1) return torch.cat([X_real, X_vn], dim=0) def _forward_layer(self, H, l): """One GCN layer on augmented graph.""" HW = H @ self.weights[l] return self.A_hat_aug @ HW def _forward_all(self, X_aug): """Full forward through L layers, returning [H_l for l in 0..L].""" H = X_aug Hs = [H] for l in range(self.num_layers): Z = self._forward_layer(H, l) if l < self.num_layers - 1: H = F.relu(Z) else: H = Z Hs.append(H) return Hs def _goodness(self, H): """Goodness = sum of squared activations (Hinton 2022).""" return (H ** 2).sum(dim=1).mean() def train_step(self): y = self.data['y'] mask = self.data['train_mask'] # Pick one labeled node at random per step for simplicity # Or: use all labeled nodes with aggregated goodness # For efficiency, use all at once: VN label is the majority train label # But that doesn't make sense — VN should carry different labels in pos/neg. # Compromise: random positive/negative labels sampled per step, using VN # Positive: pick one of the labeled classes as VN label (one-hot) train_labels = y[mask] labeled_node_count = mask.sum().item() if labeled_node_count == 0: return 0.0, 0.0, {} # Use all training labels to construct a distribution # For simplicity: pos sample uses one-hot majority class; neg uses random wrong pos_label_idx = train_labels[torch.randint(0, labeled_node_count, (1,), device=self.device)].item() pos_label = F.one_hot(torch.tensor(pos_label_idx, device=self.device), self.d_out).float() # Negative: pick a wrong class wrong_classes = [c for c in range(self.d_out) if c != pos_label_idx] neg_label_idx = wrong_classes[torch.randint(0, len(wrong_classes), (1,)).item()] neg_label = F.one_hot(torch.tensor(neg_label_idx, device=self.device), self.d_out).float() X_pos = self._make_input(pos_label) X_neg = self._make_input(neg_label) self.optim.zero_grad() # Forward both, collect per-layer goodness Hs_pos = self._forward_all(X_pos) Hs_neg = self._forward_all(X_neg) total_loss = 0.0 for l in range(1, self.num_layers + 1): # skip input H_pos = Hs_pos[l] H_neg = Hs_neg[l] # Detach previous-layer outputs to block upstream gradient (FF principle) # But layers are connected through Hs_pos[l-1] which gets used in next layer. # Detach Hs_pos[l] so gradient at layer l+1 doesn't flow to l. # Simpler: recompute per-layer with detach # Actually just use local loss per layer on goodness g_pos = self._goodness(H_pos) g_neg = self._goodness(H_neg) # FF loss: logistic loss_l = F.softplus(-(g_pos - self.theta)).mean() + F.softplus(g_neg - self.theta).mean() total_loss += loss_l.item() loss_l.backward(retain_graph=(l < self.num_layers)) self.optim.step() return total_loss, 0.0, {} @torch.no_grad() def evaluate(self, mask_name='test_mask'): """For each test node, try each candidate VN label, pick the one with highest final-layer goodness at the test node's position.""" mask = self.data[mask_name] y = self.data['y'] # For each candidate class c: build input with VN carrying class c, forward goodness_per_class = [] for c in range(self.d_out): lbl = F.one_hot(torch.tensor(c, device=self.device), self.d_out).float() X_aug = self._make_input(lbl) Hs = self._forward_all(X_aug) # Use final hidden layer H_final = Hs[-1][:self.N_orig] # exclude VN # Per-node goodness gn = (H_final ** 2).sum(dim=1) # (N,) goodness_per_class.append(gn) goodness = torch.stack(goodness_per_class, dim=1) # (N, C) preds = goodness.argmax(dim=1) return (preds[mask] == y[mask]).float().mean().item() def train_one(seed, data): torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) t = FFTrainer(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4, num_layers=6, residual_alpha=0.0, backbone='gcn') bv, bt = 0, 0 for ep in range(EPOCHS): t.train_step() if ep % 5 == 0: v = t.evaluate('val_mask') te = t.evaluate('test_mask') if v > bv: bv, bt = v, te del t; torch.cuda.empty_cache() return bt def main(): os.makedirs(OUT_DIR, exist_ok=True) per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json') if os.path.exists(per_seed_file): with open(per_seed_file) as f: per_seed_data = json.load(f) else: per_seed_data = {} datasets_cfg = { 'Cora': lambda: load_dataset('Cora', device=device), 'CiteSeer': lambda: load_dataset('CiteSeer', device=device), 'PubMed': lambda: load_dataset('PubMed', device=device), 'DBLP': lambda: load_dblp(), } for ds_name, loader in datasets_cfg.items(): data = loader() key = f"{ds_name}_FF+VN" if key not in per_seed_data: per_seed_data[key] = {} print(f"\n=== {key} (20 seeds, GCN L=6) ===", flush=True) for seed in SEEDS: sk = str(seed) if sk in per_seed_data[key]: print(f" seed {seed}: cached ({per_seed_data[key][sk]*100:.1f}%)", flush=True) continue try: acc = train_one(seed, data) per_seed_data[key][sk] = acc print(f" seed {seed}: {acc*100:.1f}%", flush=True) except Exception as e: print(f" seed {seed}: FAILED - {e}", flush=True) per_seed_data[key][sk] = 0.0 with open(per_seed_file, 'w') as f: json.dump(per_seed_data, f, indent=2) del data; torch.cuda.empty_cache() # Summary print(f"\n{'=' * 70}\nFF+VN summary (20 seeds, GCN L=6)\n{'=' * 70}") results = {} for ds in datasets_cfg: key = f"{ds}_FF+VN" vals = np.array([per_seed_data[key][str(s)] for s in SEEDS]) * 100 results[key] = {'mean': float(vals.mean()), 'std': float(vals.std()), 'per_seed': vals.tolist()} print(f" {ds:<12} {vals.mean():5.1f} ± {vals.std():4.1f}") with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f: json.dump(results, f, indent=2) print(f"\nSaved to {OUT_DIR}/results.json") if __name__ == '__main__': main()