summaryrefslogtreecommitdiff
path: root/experiments/ep_baseline.py
blob: 36f97f6d89f374aa132ec5166f9e9bf22931f8ed (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
"""
Equilibrium Propagation (Scellier & Bengio 2017) for ResidualMLP on CIFAR-10.
Feedforward EP with energy-based state optimization.

Usage: python ep_baseline.py --method ep --seed 42 --gpu 0
"""
import os, sys, json, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.residual_mlp import ResidualMLP
from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation
import torchvision, torchvision.transforms as transforms


def get_cifar10(bs=128):
    tt = transforms.Compose([transforms.RandomCrop(32, 4), transforms.RandomHorizontalFlip(),
                              transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
    tv = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
    return (DataLoader(torchvision.datasets.CIFAR10('./data', True, download=True, transform=tt), bs, True, num_workers=4, pin_memory=True),
            DataLoader(torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv), bs, False, num_workers=4, pin_memory=True))


def evaluate(m, tl, dev):
    m.eval(); c, t = 0, 0
    with torch.no_grad():
        for x, y in tl:
            x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
            c += (m(x).argmax(1) == y).sum().item(); t += x.size(0)
    return c / t


def ep_energy(model, hiddens, lam=1.0):
    """
    Compute the EP energy E = 0.5 * sum_l ||h_{l+1} - h_l - F_l(h_l)||^2
    hiddens: list of L+1 tensors, [h_0, h_1, ..., h_L]
    F_l(h_l) is the residual branch output (block forward without the skip).
    lam: weight for the state consistency term (kept at 1.0).
    """
    L = model.num_blocks
    E = 0.0
    for l in range(L):
        f_l = model.blocks[l](hiddens[l])  # residual branch
        residual = hiddens[l + 1] - hiddens[l] - f_l
        E = E + 0.5 * (residual ** 2).sum(-1)  # (batch,)
    return E  # (batch,)


def ep_free_phase(model, x):
    """
    Free phase: standard forward pass. Returns hidden states h_0..h_L.
    """
    with torch.no_grad():
        _, hiddens = model(x, return_hidden=True)
    return hiddens  # list of L+1 tensors


def ep_nudged_phase(model, x, y, h_free, beta, T_nudge, alpha_nudge):
    """
    Nudged phase: minimize E(h) + beta * C(h_L, y) w.r.t. hidden states h_1..h_L.
    h_0 is fixed (output of embed layer).
    Returns list of nudged hidden states [h_0, h_1^*, ..., h_L^*].
    """
    L = model.num_blocks
    # Initialize nudged states from free phase (detached)
    h_nudged = [h.clone().detach() for h in h_free]
    # h_0 is fixed (embed output)
    h_nudged[0] = h_free[0].clone().detach()
    # Optimize h_1 .. h_L
    for i in range(1, L + 1):
        h_nudged[i].requires_grad_(True)

    params_to_opt = h_nudged[1:]
    inner_opt = optim.SGD(params_to_opt, lr=alpha_nudge)

    for _ in range(T_nudge):
        # Energy over all layers
        E = ep_energy(model, h_nudged)  # (batch,)
        # Cost at output layer: cross-entropy
        logits = model.out_head(model.out_ln(h_nudged[L]))
        C = F.cross_entropy(logits, y, reduction='none')  # (batch,)
        total = (E + beta * C).mean()
        inner_opt.zero_grad()
        total.backward()
        inner_opt.step()

    return [h.detach() for h in h_nudged]



def train_ep(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01,
             beta=0.5, T_nudge=20, alpha_nudge=0.1, random_targets: bool = False):
    L = model.num_blocks

    # Separate optimizers for different parts
    block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
    embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
    head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd)
    all_opts = block_opts + [embed_opt, head_opt]
    schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_opts]

    for ep in range(1, epochs + 1):
        model.train()
        for x, y in trl:
            x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
            if random_targets:
                y = torch.randint(0, 10, y.shape, device=dev)

            # ---- FREE PHASE ----
            # Standard forward pass to get free fixed point
            with torch.no_grad():
                _, h_free = model(x, return_hidden=True)

            # ---- NUDGED PHASE ----
            # Minimize E(h) + beta * C(h_L, y) w.r.t. hidden states
            h_nudged = ep_nudged_phase(model, x, y, h_free, beta, T_nudge, alpha_nudge)

            # ---- EP WEIGHT UPDATE ----
            # Δθ ∝ (∂E_nudged/∂θ - ∂E_free/∂θ) / beta
            # For blocks: dE/dθ_l comes from F_l(h_l) term in E
            # For embed: dE/dθ_embed comes from h_0 = embed(x) being the base state

            for o in all_opts:
                o.zero_grad()

            # Compute EP grads for residual blocks
            # E = sum_l 0.5 ||h_{l+1} - h_l - F_l(h_l)||^2
            # dE/dθ_l = - (h_{l+1} - h_l - F_l(h_l))^T * dF_l/dθ_l
            #          = - residual_l^T * dF_l/dθ_l

            for l in range(L):
                h_l_free = h_free[l].detach()
                h_lp1_free = h_free[l + 1].detach()
                h_l_nudge = h_nudged[l].detach()
                h_lp1_nudge = h_nudged[l + 1].detach()

                # Free phase: -residual_l_free dot dF_l/dtheta
                # = -(h_{l+1}^free - h_l^free - F_l(h_l^free)) dot dF_l/dtheta
                # We compute this by forward pass with grad for params
                f_l_free = model.blocks[l](h_l_free)
                res_free = h_lp1_free - h_l_free - f_l_free.detach()
                # Gradient: d/dtheta [ -0.5 * res_free^2 ] = res_free * dF/dtheta ... actually we want
                # to minimize E, so grad = dE/dtheta = -res * dF/dtheta
                # To use autograd: compute -res_free.detach() * f_l_free, sum, backward
                loss_free_l = -(res_free.detach() * f_l_free).sum()

                f_l_nudge = model.blocks[l](h_l_nudge)
                res_nudge = h_lp1_nudge - h_l_nudge - f_l_nudge.detach()
                loss_nudge_l = -(res_nudge.detach() * f_l_nudge).sum()

                # EP grad = (nudged - free) / beta  [we want d(E_nudge - E_free)/dtheta / beta]
                # Since loss_free_l = -res_free * F_l contributes dE/dtheta_free (the negative),
                # and loss_nudge_l similarly, we need:
                # grad_l = (dE_nudge/dtheta - dE_free/dtheta) / beta
                # dE/dtheta = -res * dF/dtheta => computed via backward of (res * F).sum()
                # So: ep_loss = (loss_nudge_l - loss_free_l) / beta
                ep_loss_l = (loss_nudge_l - loss_free_l) / beta
                ep_loss_l.backward()

            # Grad for embed layer:
            # h_0 = embed(x), so dE/dtheta_embed = dE/dh_0 * dh_0/dtheta_embed
            # dE/dh_0: E depends on h_0 via (h_1 - h_0 - F_0(h_0)) term
            # = -(h_1 - h_0 - F_0(h_0)) * (I + dF_0/dh_0)^T ... complex
            # Simpler: treat h_0 as part of the system and use chain rule via autograd
            h0_free = model.embed(x)  # differentiable
            # compute E contribution from layer 0 with h_0 = embed(x) fixed, block params fixed
            with torch.no_grad():
                f0_free = model.blocks[0](h0_free.detach())
                res0_free = h_free[1].detach() - h0_free.detach() - f0_free
            embed_loss_free = -(res0_free.detach() * h0_free).sum() / beta  # approximation: -res * dh0/dtheta

            h0_nudge_rg = model.embed(x)
            with torch.no_grad():
                f0_nudge = model.blocks[0](h_nudged[0].detach())
                res0_nudge = h_nudged[1].detach() - h_nudged[0].detach() - f0_nudge
            embed_loss_nudge = -(res0_nudge.detach() * h0_nudge_rg).sum() / beta

            embed_ep = (embed_loss_nudge - embed_loss_free)
            embed_ep.backward()

            # Grad for out_head + out_ln: standard BP at nudged output
            # EP: Δθ_out = ∂C_nudged/∂θ_out (since ∂E/∂θ_out = 0)
            # This is equivalent to standard BP loss at nudged hidden state
            logits_nudged = model.out_head(model.out_ln(h_nudged[L].detach()))
            head_loss = F.cross_entropy(logits_nudged, y)
            head_loss.backward()

            # Clip and step
            all_params = list(model.parameters())
            torch.nn.utils.clip_grad_norm_(all_params, 1.0)
            for o in all_opts:
                o.step()

        for s in schedulers:
            s.step()
        if ep % 20 == 0:
            print(f"  Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True)

    return model


def ep_credit_signals(model, x, y, beta, T_nudge, alpha_nudge):
    """
    Compute EP credit signals a_l^EP = (h_l^nudged - h_l^free) / beta for diagnostics.
    """
    with torch.no_grad():
        _, h_free = model(x, return_hidden=True)
    h_nudged = ep_nudged_phase(model, x, y, h_free, beta, T_nudge, alpha_nudge)
    L = model.num_blocks
    # Negate: EP nudge moves h toward lower loss, opposite to BP grad direction
    credits = [-(h_nudged[l] - h_free[l]) / beta for l in range(L)]
    return credits, h_free, h_nudged


def compute_diagnostics(model, tel, dev, beta, T_nudge, alpha_nudge):
    model.eval()
    L = model.num_blocks

    for x, y in tel:
        x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
        break

    # EP credit signals
    ep_credits, h_free, h_nudged = ep_credit_signals(model, x, y, beta, T_nudge, alpha_nudge)

    # BP gradients for comparison
    h0 = model.embed(x.detach())
    hs = [h0.clone().requires_grad_(True)]
    for bl in model.blocks:
        hs.append(hs[-1] + bl(hs[-1]))
    lo = model.out_head(model.out_ln(hs[-1]))
    loss = F.cross_entropy(lo, y)
    gs = torch.autograd.grad(loss, hs)
    bp_grads = {l: gs[l].detach() for l in range(L)}

    # Gamma: cosine sim between EP credit and BP grad
    gammas = []
    for l in range(L):
        g = cosine_similarity_batch(ep_credits[l], bp_grads[l])
        gammas.append(g)

    # rho: perturbation correlation using EP credit
    rhos = []
    with torch.no_grad():
        _, hi = model(x, return_hidden=True)

    for l in range(L):
        h_l = hi[l].detach()
        a_l = ep_credits[l].detach()

        def mk(sl):
            def f(h):
                with torch.no_grad():
                    c = h
                    for i in range(sl, L):
                        c = c + model.blocks[i](c)
                    return F.cross_entropy(model.out_head(model.out_ln(c)), y, reduction='none')
            return f

        rhos.append(perturbation_correlation(h_l, a_l, mk(l), epsilon=1e-3, M=16))

    # naive state error
    with torch.no_grad():
        _, hi2 = model(x, return_hidden=True)
        nse = ((hi2[L // 2] - hi2[-1]).norm(-1) / hi2[-1].norm(-1).clamp(min=1e-8)).mean().item()

    return {'Gamma': float(np.mean(gammas)), 'rho': float(np.mean(rhos)),
            'naive_StateErr': nse, 'gammas_per_layer': [float(g) for g in gammas],
            'rhos_per_layer': [float(r) for r in rhos]}


def main():
    p = argparse.ArgumentParser()
    p.add_argument('--method', type=str, default='ep')
    p.add_argument('--seed', type=int, required=True)
    p.add_argument('--gpu', type=int, default=0)
    p.add_argument('--output_dir', type=str, default='results/ep_baseline')
    p.add_argument('--epochs', type=int, default=100)
    p.add_argument('--beta', type=float, default=0.5, help='EP nudge strength')
    p.add_argument('--T_nudge', type=int, default=20, help='Inner optimization steps for nudged phase')
    p.add_argument('--alpha_nudge', type=float, default=0.1, help='Inner step size for nudged phase')
    p.add_argument('--lr', type=float, default=1e-3)
    p.add_argument('--wd', type=float, default=0.01)
    p.add_argument('--d_hidden', type=int, default=256)
    p.add_argument('--random_targets', action='store_true',
                   help='Replace each minibatch label with i.i.d. random class targets (codex round 36 OPTION EP).')
    args = p.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    dev = torch.device(f'cuda:{args.gpu}')
    torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)

    trl, tel = get_cifar10()
    L, d = 4, args.d_hidden
    model = ResidualMLP(3072, d, 10, L).to(dev)

    print(f"[{args.method} s={args.seed}] Training EP  beta={args.beta} T={args.T_nudge} alpha={args.alpha_nudge}", flush=True)

    model = train_ep(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd,
                     beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge,
                     random_targets=args.random_targets)

    acc = evaluate(model, tel, dev)
    diag = compute_diagnostics(model, tel, dev, args.beta, args.T_nudge, args.alpha_nudge)

    torch.save(model.state_dict(), os.path.join(args.output_dir, f'{args.method}_s{args.seed}.pt'))

    result = {'method': args.method, 'seed': args.seed, 'acc': acc,
              'Gamma': diag['Gamma'], 'rho': diag['rho'],
              'naive_StateErr': diag['naive_StateErr'],
              'gammas_per_layer': diag['gammas_per_layer'],
              'rhos_per_layer': diag['rhos_per_layer'],
              'beta': args.beta, 'T_nudge': args.T_nudge, 'alpha_nudge': args.alpha_nudge}

    with open(os.path.join(args.output_dir, f'{args.method}_s{args.seed}.json'), 'w') as f:
        json.dump(result, f, indent=2, default=float)

    print(f"[{args.method} s={args.seed}] acc={acc:.4f} Γ={diag['Gamma']:.4f} ρ={diag['rho']:.4f} nse={diag['naive_StateErr']:.4f}", flush=True)


if __name__ == '__main__':
    main()