summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 21:15:35 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 21:15:35 -0500
commit6a47ae80aeed672f24b8ceb12f79f92d86a6bb2f (patch)
treef3ce1049b31506d319be0b85f071cf1e976a2189 /experiments
parent045385c975e7f6b64678edf8835e771614352f9f (diff)
Add --d_hidden arg to ep_baseline.py for d=512 support
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/ep_baseline.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/experiments/ep_baseline.py b/experiments/ep_baseline.py
index e2e9074..de7d853 100644
--- a/experiments/ep_baseline.py
+++ b/experiments/ep_baseline.py
@@ -279,6 +279,7 @@ def main():
p.add_argument('--alpha_nudge', type=float, default=0.1, help='Inner step size for nudged phase')
p.add_argument('--lr', type=float, default=1e-3)
p.add_argument('--wd', type=float, default=0.01)
+ p.add_argument('--d_hidden', type=int, default=256)
args = p.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
@@ -286,7 +287,7 @@ def main():
torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
trl, tel = get_cifar10()
- L, d = 4, 256
+ L, d = 4, args.d_hidden
model = ResidualMLP(3072, d, 10, L).to(dev)
print(f"[{args.method} s={args.seed}] Training EP beta={args.beta} T={args.T_nudge} alpha={args.alpha_nudge}", flush=True)