summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
Diffstat (limited to 'experiments')
-rw-r--r--experiments/ep_baseline.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/experiments/ep_baseline.py b/experiments/ep_baseline.py
index 7f3d004..36f97f6 100644
--- a/experiments/ep_baseline.py
+++ b/experiments/ep_baseline.py
@@ -90,7 +90,7 @@ def ep_nudged_phase(model, x, y, h_free, beta, T_nudge, alpha_nudge):
def train_ep(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01,
- beta=0.5, T_nudge=20, alpha_nudge=0.1):
+ beta=0.5, T_nudge=20, alpha_nudge=0.1, random_targets: bool = False):
L = model.num_blocks
# Separate optimizers for different parts
@@ -104,6 +104,8 @@ def train_ep(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01,
model.train()
for x, y in trl:
x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ if random_targets:
+ y = torch.randint(0, 10, y.shape, device=dev)
# ---- FREE PHASE ----
# Standard forward pass to get free fixed point
@@ -281,6 +283,8 @@ def main():
p.add_argument('--lr', type=float, default=1e-3)
p.add_argument('--wd', type=float, default=0.01)
p.add_argument('--d_hidden', type=int, default=256)
+ p.add_argument('--random_targets', action='store_true',
+ help='Replace each minibatch label with i.i.d. random class targets (codex round 36 OPTION EP).')
args = p.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
@@ -294,7 +298,8 @@ def main():
print(f"[{args.method} s={args.seed}] Training EP beta={args.beta} T={args.T_nudge} alpha={args.alpha_nudge}", flush=True)
model = train_ep(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd,
- beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge)
+ beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge,
+ random_targets=args.random_targets)
acc = evaluate(model, tel, dev)
diag = compute_diagnostics(model, tel, dev, args.beta, args.T_nudge, args.alpha_nudge)