diff options
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/ep_baseline.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/experiments/ep_baseline.py b/experiments/ep_baseline.py index 7f3d004..36f97f6 100644 --- a/experiments/ep_baseline.py +++ b/experiments/ep_baseline.py @@ -90,7 +90,7 @@ def ep_nudged_phase(model, x, y, h_free, beta, T_nudge, alpha_nudge): def train_ep(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, - beta=0.5, T_nudge=20, alpha_nudge=0.1): + beta=0.5, T_nudge=20, alpha_nudge=0.1, random_targets: bool = False): L = model.num_blocks # Separate optimizers for different parts @@ -104,6 +104,8 @@ def train_ep(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, model.train() for x, y in trl: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) + if random_targets: + y = torch.randint(0, 10, y.shape, device=dev) # ---- FREE PHASE ---- # Standard forward pass to get free fixed point @@ -281,6 +283,8 @@ def main(): p.add_argument('--lr', type=float, default=1e-3) p.add_argument('--wd', type=float, default=0.01) p.add_argument('--d_hidden', type=int, default=256) + p.add_argument('--random_targets', action='store_true', + help='Replace each minibatch label with i.i.d. random class targets (codex round 36 OPTION EP).') args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) @@ -294,7 +298,8 @@ def main(): print(f"[{args.method} s={args.seed}] Training EP beta={args.beta} T={args.T_nudge} alpha={args.alpha_nudge}", flush=True) model = train_ep(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd, - beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge) + beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge, + random_targets=args.random_targets) acc = evaluate(model, tel, dev) diag = compute_diagnostics(model, tel, dev, args.beta, args.T_nudge, args.alpha_nudge) |
