summaryrefslogtreecommitdiff
path: root/experiments/ep_synthetic.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/ep_synthetic.py')
-rw-r--r--experiments/ep_synthetic.py1
1 files changed, 1 insertions, 0 deletions
diff --git a/experiments/ep_synthetic.py b/experiments/ep_synthetic.py
index a2f24df..3685f40 100644
--- a/experiments/ep_synthetic.py
+++ b/experiments/ep_synthetic.py
@@ -149,6 +149,7 @@ def main():
diag=compute_diagnostics(model,teacher,dev,d,C,L)
result={'method':'ep','alpha':args.alpha,'depth':L,'seed':args.seed,
'acc':diag['acc'],'Gamma':diag['Gamma'],'rho':diag['rho']}
+ torch.save(model.state_dict(),os.path.join(args.output_dir,f'ep_a{args.alpha}_L{L}_s{args.seed}.pt'))
out=os.path.join(args.output_dir,f'ep_a{args.alpha}_L{L}_s{args.seed}.json')
with open(out,'w') as f:json.dump(result,f,indent=2,default=float)
print(f" acc={diag['acc']:.4f} Gamma={diag['Gamma']:.4f} rho={diag['rho']:.4f}",flush=True)