summaryrefslogtreecommitdiff
path: root/experiments/run_cafo_baseline.py
blob: 3d8c2d73eea1b058e673d8cd1bde5c2e51299058 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#!/usr/bin/env python3
"""H3: CaFo+CE (Cascaded Forward Learning with Top-Down Feedback, Park et al. 2023).

Greedy layer-wise training for GCN L=6:
  - Each hidden layer l has an auxiliary classifier W_aux_l: hidden → num_classes
  - Forward through all layers with .detach() between layers (blocks upstream gradient)
  - Per-layer CE loss on labeled nodes via auxiliary classifier
  - Output layer uses standard cross-entropy
  - No global backprop — each W_l only sees its local loss

Tests CaFo on Cora/CiteSeer/PubMed/DBLP × 20 seeds, GCN L=6.
"""

import torch
import torch.nn.functional as F
import numpy as np
import json
import os
from src.data import load_dataset, spmm
from run_dblp_depth import load_dblp

device = 'cuda:0'
SEEDS = list(range(20))
EPOCHS = 200
OUT_DIR = 'results/cafo_baseline_20seeds'


class CaFoTrainer:
    """CaFo+CE: greedy layer-wise training with per-layer CE loss."""

    def __init__(self, data, hidden_dim, lr, weight_decay,
                 num_layers=2, residual_alpha=0.0, backbone='gcn', **_kw):
        dev = data['X'].device
        self.data = data
        self.device = dev
        self.lr = lr
        self.wd = weight_decay
        self.num_layers = num_layers
        self.residual_alpha = residual_alpha
        self.backbone = backbone
        self._training = True

        d_in = data['num_features']
        d_out = data['num_classes']
        self.d_out = d_out

        dims = [d_in] + [hidden_dim] * (num_layers - 1) + [d_out]
        # Main layer weights — autograd Parameters
        self.weights = []
        for i in range(num_layers):
            w = torch.empty(dims[i], dims[i + 1], device=dev)
            torch.nn.init.xavier_uniform_(w)
            w.requires_grad_(True)
            self.weights.append(w)

        # Auxiliary classifier per hidden layer: hidden_dim -> d_out
        self.W_aux = []
        for i in range(num_layers - 1):
            w_aux = torch.empty(hidden_dim, d_out, device=dev)
            torch.nn.init.xavier_uniform_(w_aux)
            w_aux.requires_grad_(True)
            self.W_aux.append(w_aux)

        params = self.weights + self.W_aux
        self.optim = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)

    def _gcn_conv(self, H, W):
        """GCN conv: A_hat @ (H @ W)."""
        return spmm(self.data['A_hat'], H @ W)

    def train_step(self):
        X = self.data['X']
        y = self.data['y']
        mask = self.data['train_mask']

        self.optim.zero_grad()

        H = X
        total_loss = 0.0
        for l in range(self.num_layers):
            if l > 0:
                H = H.detach()  # block grad flow upstream

            Z = self._gcn_conv(H, self.weights[l])

            if l < self.num_layers - 1:
                H_new = F.relu(Z)
                # Auxiliary classifier (projects hidden to classes)
                Z_aux = H_new @ self.W_aux[l]
                loss_l = F.cross_entropy(Z_aux[mask], y[mask])
                loss_l.backward()
                total_loss += loss_l.item()
                H = H_new
            else:
                # Output layer: standard CE
                loss_final = F.cross_entropy(Z[mask], y[mask])
                loss_final.backward()
                total_loss += loss_final.item()

        self.optim.step()

        with torch.no_grad():
            Z_out = self._forward_full_detached()
            acc = (Z_out[mask].argmax(1) == y[mask]).float().mean().item()
        return total_loss, acc, {}

    def _forward_full_detached(self):
        """Full forward pass with no_grad for evaluation."""
        X = self.data['X']
        H = X
        for l in range(self.num_layers):
            Z = self._gcn_conv(H, self.weights[l].detach())
            if l < self.num_layers - 1:
                H = F.relu(Z)
        return Z

    @torch.no_grad()
    def evaluate(self, mask_name='test_mask'):
        self._training = False
        Z = self._forward_full_detached()
        self._training = True
        mask = self.data[mask_name]
        return (Z[mask].argmax(1) == self.data['y'][mask]).float().mean().item()


def train_one(seed, data, num_layers=6):
    torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
    t = CaFoTrainer(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4,
                    num_layers=num_layers, residual_alpha=0.0, backbone='gcn')
    bv, bt = 0, 0
    for ep in range(EPOCHS):
        t.train_step()
        if ep % 5 == 0:
            v = t.evaluate('val_mask')
            te = t.evaluate('test_mask')
            if v > bv: bv, bt = v, te
    del t; torch.cuda.empty_cache()
    return bt


def main():
    os.makedirs(OUT_DIR, exist_ok=True)
    per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json')
    if os.path.exists(per_seed_file):
        with open(per_seed_file) as f:
            per_seed_data = json.load(f)
    else:
        per_seed_data = {}

    datasets_cfg = {
        'Cora': lambda: load_dataset('Cora', device=device),
        'CiteSeer': lambda: load_dataset('CiteSeer', device=device),
        'PubMed': lambda: load_dataset('PubMed', device=device),
        'DBLP': lambda: load_dblp(),
    }

    for ds_name, loader in datasets_cfg.items():
        data = loader()
        key = f"{ds_name}_CaFo+CE"
        if key not in per_seed_data:
            per_seed_data[key] = {}

        print(f"\n=== {key} (20 seeds, GCN L=6) ===", flush=True)
        for seed in SEEDS:
            sk = str(seed)
            if sk in per_seed_data[key]:
                print(f"  seed {seed}: cached ({per_seed_data[key][sk]*100:.1f}%)", flush=True)
                continue
            try:
                acc = train_one(seed, data)
                per_seed_data[key][sk] = acc
                print(f"  seed {seed}: {acc*100:.1f}%", flush=True)
            except Exception as e:
                print(f"  seed {seed}: FAILED - {e}", flush=True)
                per_seed_data[key][sk] = 0.0

            with open(per_seed_file, 'w') as f:
                json.dump(per_seed_data, f, indent=2)

        del data; torch.cuda.empty_cache()

    # Summary
    print(f"\n{'=' * 70}\nCaFo+CE summary (20 seeds, GCN L=6)\n{'=' * 70}")
    results = {}
    for ds in datasets_cfg:
        key = f"{ds}_CaFo+CE"
        vals = np.array([per_seed_data[key][str(s)] for s in SEEDS]) * 100
        results[key] = {'mean': float(vals.mean()), 'std': float(vals.std()),
                         'per_seed': vals.tolist()}
        print(f"  {ds:<12} {vals.mean():5.1f} ± {vals.std():4.1f}")

    with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f:
        json.dump(results, f, indent=2)
    print(f"\nSaved to {OUT_DIR}/results.json")


if __name__ == '__main__':
    main()