diff options
Diffstat (limited to 'protocol/examples')
| -rw-r--r-- | protocol/examples/temporal_diagnostic_evolution.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/protocol/examples/temporal_diagnostic_evolution.py b/protocol/examples/temporal_diagnostic_evolution.py index 35cf720..af0da84 100644 --- a/protocol/examples/temporal_diagnostic_evolution.py +++ b/protocol/examples/temporal_diagnostic_evolution.py @@ -64,7 +64,7 @@ def main(): import argparse p = argparse.ArgumentParser() p.add_argument("--seed", type=int, default=42) - p.add_argument("--arch", type=str, default="resmlp", choices=["resmlp", "vit"]) + p.add_argument("--arch", type=str, default="resmlp", choices=["resmlp", "vit", "no_outln"]) args = p.parse_args() if args.arch == "resmlp": snapshot_path = os.path.join( @@ -72,12 +72,18 @@ def main(): ) h_key = "hidden_norms" g_key = "bp_grad_norms_per_sample_med" - else: + elif args.arch == "vit": snapshot_path = os.path.join( REPO_ROOT, f"results/snapshot_vit_v1/snapshot_vit_s{args.seed}.json" ) h_key = "hidden_norms_cls" g_key = "bp_grad_per_sample_l2_med" + else: # no_outln (synthetic studentnet without terminal LN) + snapshot_path = os.path.join( + REPO_ROOT, f"results/snapshot_no_outln_v1/snapshot_noLN_s{args.seed}.json" + ) + h_key = "hidden_norms" + g_key = "bp_grad_per_sample_l2_med" if not os.path.exists(snapshot_path): print(f"snapshot not found: {snapshot_path}") return |
