summaryrefslogtreecommitdiff
path: root/research/flossing/run_checkpoint_evolution.sh
blob: c5949dd5812b550f008ac2232a6df2e87a53f88e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#!/usr/bin/env bash
# (b) For each early checkpoint, run the diagnostic with same sample pool & seed.
set -euo pipefail
REPO=/home/yurenh2/rrm/research/flossing
CKPT_ROOT="/home/yurenh2/rrm/hrm/checkpoints/Sudoku-extreme-1k-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 righteous-python"
source "$(conda info --base)/etc/profile.d/conda.sh"
conda activate rrm
cd "$REPO"

N=1024
K=8
mkdir -p ckpt_evolution

# Skip 26040 since we already have it. The training did eval at steps:
# 2604, 5208, 7812, 10416, 13020, 15624, 18228, 20832, 23436, 26040
# pick a representative subset
CKPTS=(step_2604 step_7812 step_13020 step_18228 step_20832)

for ckpt in "${CKPTS[@]}"; do
  echo "==> $ckpt"
  for shard in 0 1 2; do
    LOG=ckpt_evolution/${ckpt}_shard${shard}.log
    OUT=ckpt_evolution/${ckpt}_shard${shard}.npz
    if [[ -f "$OUT" ]]; then echo "skip $OUT"; continue; fi
    nohup env CUDA_VISIBLE_DEVICES=$shard python diagnose_hrm.py \
      --ckpt-root "$CKPT_ROOT" --ckpt-name $ckpt \
      --n-samples $N --num-shards 3 --shard-id $shard \
      --batch-size 64 --k-lyap $K \
      --out "$OUT" > "$LOG" 2>&1 &
  done
  # Wait for all shards of THIS checkpoint to finish before moving to next
  wait
  echo "<== $ckpt done"
done

# Final merge per checkpoint
python - <<'PY'
import numpy as np, glob, os
out_dir = "/home/yurenh2/rrm/research/flossing/ckpt_evolution"
ckpts = ["step_2604","step_7812","step_13020","step_18228","step_20832","step_26040"]
for c in ckpts:
    if c == "step_26040":
        # use existing merged 8k as proxy (subsample to 1024 with same seed for fair comparison?)
        # Actually run on same 1024 to compare apples-to-apples; we'll do this separately
        continue
    files = sorted(glob.glob(f"{out_dir}/{c}_shard*.npz"))
    if not files:
        print(f"missing {c}"); continue
    m = {}
    for f in files:
        d = np.load(f)
        for k in d.files: m.setdefault(k, []).append(d[k])
    for k in list(m.keys()): m[k] = np.concatenate(m[k], 0)
    out = f"{out_dir}/{c}.npz"
    np.savez_compressed(out, **m)
    print(f"{c}: N={len(m['exact_correct'])} acc={m['exact_correct'].mean():.4f} "
          f"λ_max(s)={m['lyap_spec'][m['exact_correct']>0.5,0].mean():.3f} "
          f"λ_max(f)={m['lyap_spec'][m['exact_correct']<0.5,0].mean():.3f}")
PY