1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
|
"""
Codex round 19's #4 control experiment: train BP with the same
λ ‖f_l(h_l)‖² penalty that's used in the DFA penalty rescue.
If BP + penalty still clears the frozen baseline by a wide margin
(e.g., ~25 pp like normal BP):
→ the penalty itself is not the reason penalized DFA's depth
utilization is capped at +1.4 pp; the cap is intrinsic to DFA's
random-feedback credit signal quality
→ mode 2 (intrinsic credit quality) is real
If BP + penalty drops to ~+1.4 pp margin too:
→ the penalty is the reason for the cap, not credit quality
→ mode 2 might be a regularization artifact, not a real failure mode
→ would need to walk back walk-back #7 (back to "one unified mode")
Run:
CUDA_VISIBLE_DEVICES=2 python experiments/bp_with_penalty_control.py --seed 42 --epochs 30 --lam 1e-2
"""
import os
import sys
import argparse
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.residual_mlp import ResidualMLP
def get_loaders(batch_size=128):
tv_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
tv = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
return (
DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
)
def evaluate(model, loader, dev):
model.eval()
n = c = 0
with torch.no_grad():
for x, y in loader:
x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
preds = model(x).argmax(-1)
c += (preds == y).sum().item()
n += x.size(0)
return c / n
def train_bp_with_penalty(model, train_loader, test_loader, dev, epochs, lr, wd, lam):
"""End-to-end BP training with `lam * sum_l ||f_l(h_l)||^2` added to the
cross-entropy loss. The penalty is applied to the residual branch outputs
of every block."""
opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
log = []
for ep in range(1, epochs + 1):
model.train()
for x, y in train_loader:
x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
# Forward, capturing per-block residual outputs
h = model.embed(x)
penalty = torch.zeros((), device=dev)
for block in model.blocks:
f = block(h)
penalty = penalty + (f ** 2).sum(-1).mean()
h = h + f
logits = model.out_head(model.out_ln(h))
ce = F.cross_entropy(logits, y)
loss = ce + lam * penalty
opt.zero_grad()
loss.backward()
opt.step()
sch.step()
if ep % 5 == 0 or ep == 1 or ep == epochs:
acc = evaluate(model, test_loader, dev)
log.append({"epoch": ep, "test_acc": acc})
print(f" ep {ep}: test_acc={acc:.4f}", flush=True)
return log
def main():
p = argparse.ArgumentParser()
p.add_argument("--seed", type=int, default=42)
p.add_argument("--epochs", type=int, default=30)
p.add_argument("--lr", type=float, default=1e-3)
p.add_argument("--wd", type=float, default=0.01)
p.add_argument("--lam", type=float, default=1e-2)
p.add_argument("--output_dir", type=str, default="results/bp_with_penalty")
args = p.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
dev = torch.device("cuda:0")
print(f"BP + ‖f‖² penalty: seed={args.seed}, lam={args.lam}, epochs={args.epochs}", flush=True)
train_loader, test_loader = get_loaders(batch_size=128)
torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
m = ResidualMLP(3072, 256, 10, 4).to(dev)
log = train_bp_with_penalty(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, args.lam)
final_acc = evaluate(m, test_loader, dev)
print(f"\nFINAL test acc: {final_acc:.4f}", flush=True)
print(f"Compare to:")
print(f" BP-trainable (3-seed mean): 0.609")
print(f" Penalized DFA lam=1e-2: 0.363")
print(f" DFA-shallow: 0.349")
margin = (final_acc - 0.349) * 100
print(f"\nMargin vs DFA-shallow baseline: {margin:+.2f} pp")
if margin > 25:
print(" → BP+penalty still clears shallow by >25 pp")
print(" → mode 2 (intrinsic random-feedback alignment) is REAL")
print(" → walk-back #7 confirmed: two distinct failure modes")
elif margin < 5:
print(" → BP+penalty drops to a tiny margin like penalized DFA")
print(" → the penalty itself capped depth utilization")
print(" → mode 2 might be a regularization artifact")
print(" → consider walking back walk-back #7")
else:
print(" → BP+penalty intermediate; partial capacity loss + residual mode 2")
out = {"config": vars(args), "final_acc": final_acc, "log": log, "margin_pp": margin}
out_path = os.path.join(args.output_dir, f"bp_pen_lam{args.lam}_s{args.seed}.json")
with open(out_path, "w") as f:
json.dump(out, f, indent=2)
print(f"Saved {out_path}")
if __name__ == "__main__":
main()
|