diff options
Diffstat (limited to 'experiments/ep_synthetic.py')
| -rw-r--r-- | experiments/ep_synthetic.py | 1 |
1 files changed, 1 insertions, 0 deletions
diff --git a/experiments/ep_synthetic.py b/experiments/ep_synthetic.py index a2f24df..3685f40 100644 --- a/experiments/ep_synthetic.py +++ b/experiments/ep_synthetic.py @@ -149,6 +149,7 @@ def main(): diag=compute_diagnostics(model,teacher,dev,d,C,L) result={'method':'ep','alpha':args.alpha,'depth':L,'seed':args.seed, 'acc':diag['acc'],'Gamma':diag['Gamma'],'rho':diag['rho']} + torch.save(model.state_dict(),os.path.join(args.output_dir,f'ep_a{args.alpha}_L{L}_s{args.seed}.pt')) out=os.path.join(args.output_dir,f'ep_a{args.alpha}_L{L}_s{args.seed}.json') with open(out,'w') as f:json.dump(result,f,indent=2,default=float) print(f" acc={diag['acc']:.4f} Gamma={diag['Gamma']:.4f} rho={diag['rho']:.4f}",flush=True) |
