summaryrefslogtreecommitdiff
path: root/protocol/examples
diff options
context:
space:
mode:
Diffstat (limited to 'protocol/examples')
-rw-r--r--protocol/examples/temporal_diagnostic_evolution.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/protocol/examples/temporal_diagnostic_evolution.py b/protocol/examples/temporal_diagnostic_evolution.py
index 35cf720..af0da84 100644
--- a/protocol/examples/temporal_diagnostic_evolution.py
+++ b/protocol/examples/temporal_diagnostic_evolution.py
@@ -64,7 +64,7 @@ def main():
import argparse
p = argparse.ArgumentParser()
p.add_argument("--seed", type=int, default=42)
- p.add_argument("--arch", type=str, default="resmlp", choices=["resmlp", "vit"])
+ p.add_argument("--arch", type=str, default="resmlp", choices=["resmlp", "vit", "no_outln"])
args = p.parse_args()
if args.arch == "resmlp":
snapshot_path = os.path.join(
@@ -72,12 +72,18 @@ def main():
)
h_key = "hidden_norms"
g_key = "bp_grad_norms_per_sample_med"
- else:
+ elif args.arch == "vit":
snapshot_path = os.path.join(
REPO_ROOT, f"results/snapshot_vit_v1/snapshot_vit_s{args.seed}.json"
)
h_key = "hidden_norms_cls"
g_key = "bp_grad_per_sample_l2_med"
+ else: # no_outln (synthetic studentnet without terminal LN)
+ snapshot_path = os.path.join(
+ REPO_ROOT, f"results/snapshot_no_outln_v1/snapshot_noLN_s{args.seed}.json"
+ )
+ h_key = "hidden_norms"
+ g_key = "bp_grad_per_sample_l2_med"
if not os.path.exists(snapshot_path):
print(f"snapshot not found: {snapshot_path}")
return