From 4715dd0ead4884baf1c0209bb19ec004fc7bef1f Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Fri, 3 Apr 2026 11:18:21 -0500 Subject: Add checkpoint saving to ep_synthetic.py Co-Authored-By: Claude Opus 4.6 (1M context) --- experiments/ep_synthetic.py | 1 + 1 file changed, 1 insertion(+) (limited to 'experiments') diff --git a/experiments/ep_synthetic.py b/experiments/ep_synthetic.py index a2f24df..3685f40 100644 --- a/experiments/ep_synthetic.py +++ b/experiments/ep_synthetic.py @@ -149,6 +149,7 @@ def main(): diag=compute_diagnostics(model,teacher,dev,d,C,L) result={'method':'ep','alpha':args.alpha,'depth':L,'seed':args.seed, 'acc':diag['acc'],'Gamma':diag['Gamma'],'rho':diag['rho']} + torch.save(model.state_dict(),os.path.join(args.output_dir,f'ep_a{args.alpha}_L{L}_s{args.seed}.pt')) out=os.path.join(args.output_dir,f'ep_a{args.alpha}_L{L}_s{args.seed}.json') with open(out,'w') as f:json.dump(result,f,indent=2,default=float) print(f" acc={diag['acc']:.4f} Gamma={diag['Gamma']:.4f} rho={diag['rho']:.4f}",flush=True) -- cgit v1.2.3