summaryrefslogtreecommitdiff
path: root/experiments/ep_baseline.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/ep_baseline.py')
-rw-r--r--experiments/ep_baseline.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/experiments/ep_baseline.py b/experiments/ep_baseline.py
index e2e9074..de7d853 100644
--- a/experiments/ep_baseline.py
+++ b/experiments/ep_baseline.py
@@ -279,6 +279,7 @@ def main():
p.add_argument('--alpha_nudge', type=float, default=0.1, help='Inner step size for nudged phase')
p.add_argument('--lr', type=float, default=1e-3)
p.add_argument('--wd', type=float, default=0.01)
+ p.add_argument('--d_hidden', type=int, default=256)
args = p.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
@@ -286,7 +287,7 @@ def main():
torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
trl, tel = get_cifar10()
- L, d = 4, 256
+ L, d = 4, args.d_hidden
model = ResidualMLP(3072, d, 10, L).to(dev)
print(f"[{args.method} s={args.seed}] Training EP beta={args.beta} T={args.T_nudge} alpha={args.alpha_nudge}", flush=True)