#!/usr/bin/env python3 """H3: CaFo+CE (Cascaded Forward Learning with Top-Down Feedback, Park et al. 2023). Greedy layer-wise training for GCN L=6: - Each hidden layer l has an auxiliary classifier W_aux_l: hidden → num_classes - Forward through all layers with .detach() between layers (blocks upstream gradient) - Per-layer CE loss on labeled nodes via auxiliary classifier - Output layer uses standard cross-entropy - No global backprop — each W_l only sees its local loss Tests CaFo on Cora/CiteSeer/PubMed/DBLP × 20 seeds, GCN L=6. """ import torch import torch.nn.functional as F import numpy as np import json import os from src.data import load_dataset, spmm from run_dblp_depth import load_dblp device = 'cuda:0' SEEDS = list(range(20)) EPOCHS = 200 OUT_DIR = 'results/cafo_baseline_20seeds' class CaFoTrainer: """CaFo+CE: greedy layer-wise training with per-layer CE loss.""" def __init__(self, data, hidden_dim, lr, weight_decay, num_layers=2, residual_alpha=0.0, backbone='gcn', **_kw): dev = data['X'].device self.data = data self.device = dev self.lr = lr self.wd = weight_decay self.num_layers = num_layers self.residual_alpha = residual_alpha self.backbone = backbone self._training = True d_in = data['num_features'] d_out = data['num_classes'] self.d_out = d_out dims = [d_in] + [hidden_dim] * (num_layers - 1) + [d_out] # Main layer weights — autograd Parameters self.weights = [] for i in range(num_layers): w = torch.empty(dims[i], dims[i + 1], device=dev) torch.nn.init.xavier_uniform_(w) w.requires_grad_(True) self.weights.append(w) # Auxiliary classifier per hidden layer: hidden_dim -> d_out self.W_aux = [] for i in range(num_layers - 1): w_aux = torch.empty(hidden_dim, d_out, device=dev) torch.nn.init.xavier_uniform_(w_aux) w_aux.requires_grad_(True) self.W_aux.append(w_aux) params = self.weights + self.W_aux self.optim = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay) def _gcn_conv(self, H, W): """GCN conv: A_hat @ (H @ W).""" return spmm(self.data['A_hat'], H @ W) def train_step(self): X = self.data['X'] y = self.data['y'] mask = self.data['train_mask'] self.optim.zero_grad() H = X total_loss = 0.0 for l in range(self.num_layers): if l > 0: H = H.detach() # block grad flow upstream Z = self._gcn_conv(H, self.weights[l]) if l < self.num_layers - 1: H_new = F.relu(Z) # Auxiliary classifier (projects hidden to classes) Z_aux = H_new @ self.W_aux[l] loss_l = F.cross_entropy(Z_aux[mask], y[mask]) loss_l.backward() total_loss += loss_l.item() H = H_new else: # Output layer: standard CE loss_final = F.cross_entropy(Z[mask], y[mask]) loss_final.backward() total_loss += loss_final.item() self.optim.step() with torch.no_grad(): Z_out = self._forward_full_detached() acc = (Z_out[mask].argmax(1) == y[mask]).float().mean().item() return total_loss, acc, {} def _forward_full_detached(self): """Full forward pass with no_grad for evaluation.""" X = self.data['X'] H = X for l in range(self.num_layers): Z = self._gcn_conv(H, self.weights[l].detach()) if l < self.num_layers - 1: H = F.relu(Z) return Z @torch.no_grad() def evaluate(self, mask_name='test_mask'): self._training = False Z = self._forward_full_detached() self._training = True mask = self.data[mask_name] return (Z[mask].argmax(1) == self.data['y'][mask]).float().mean().item() def train_one(seed, data, num_layers=6): torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) t = CaFoTrainer(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4, num_layers=num_layers, residual_alpha=0.0, backbone='gcn') bv, bt = 0, 0 for ep in range(EPOCHS): t.train_step() if ep % 5 == 0: v = t.evaluate('val_mask') te = t.evaluate('test_mask') if v > bv: bv, bt = v, te del t; torch.cuda.empty_cache() return bt def main(): os.makedirs(OUT_DIR, exist_ok=True) per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json') if os.path.exists(per_seed_file): with open(per_seed_file) as f: per_seed_data = json.load(f) else: per_seed_data = {} datasets_cfg = { 'Cora': lambda: load_dataset('Cora', device=device), 'CiteSeer': lambda: load_dataset('CiteSeer', device=device), 'PubMed': lambda: load_dataset('PubMed', device=device), 'DBLP': lambda: load_dblp(), } for ds_name, loader in datasets_cfg.items(): data = loader() key = f"{ds_name}_CaFo+CE" if key not in per_seed_data: per_seed_data[key] = {} print(f"\n=== {key} (20 seeds, GCN L=6) ===", flush=True) for seed in SEEDS: sk = str(seed) if sk in per_seed_data[key]: print(f" seed {seed}: cached ({per_seed_data[key][sk]*100:.1f}%)", flush=True) continue try: acc = train_one(seed, data) per_seed_data[key][sk] = acc print(f" seed {seed}: {acc*100:.1f}%", flush=True) except Exception as e: print(f" seed {seed}: FAILED - {e}", flush=True) per_seed_data[key][sk] = 0.0 with open(per_seed_file, 'w') as f: json.dump(per_seed_data, f, indent=2) del data; torch.cuda.empty_cache() # Summary print(f"\n{'=' * 70}\nCaFo+CE summary (20 seeds, GCN L=6)\n{'=' * 70}") results = {} for ds in datasets_cfg: key = f"{ds}_CaFo+CE" vals = np.array([per_seed_data[key][str(s)] for s in SEEDS]) * 100 results[key] = {'mean': float(vals.mean()), 'std': float(vals.std()), 'per_seed': vals.tolist()} print(f" {ds:<12} {vals.mean():5.1f} ± {vals.std():4.1f}") with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f: json.dump(results, f, indent=2) print(f"\nSaved to {OUT_DIR}/results.json") if __name__ == '__main__': main()