summaryrefslogtreecommitdiff
path: root/research/flossing/launch_trajectory_sampling_long.sh
blob: 4a036505f2b471e17e41491fe712d480f603b120 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env bash
set -eo pipefail

CASE="${1:?case required}"
GPU="${2:-0}"

source /home/yurenh2/miniconda3/etc/profile.d/conda.sh
conda activate rrm
cd /home/yurenh2/rrm/research/flossing

HRM_ROOT="/home/yurenh2/rrm/hrm/checkpoints/Sudoku-extreme-1k-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 righteous-python"
TRM_ROOT="/home/yurenh2/rrm/trm/checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_mlp_t_sudoku_singleGPU"

run_one() {
  local tag="$1"
  shift
  echo "[$(date -Is)] START ${tag} on GPU ${GPU}"
  CUDA_VISIBLE_DEVICES="${GPU}" python step9_trajectory_perturb_train.py "$@" \
    > "${tag}.log" 2>&1
  echo "[$(date -Is)] DONE ${tag}"
}

case "${CASE}" in
  hrm_baseline50k)
    run_one step9_E_hrm_baseline_parallel_fixed_26040_50k \
      --model hrm \
      --ckpt-root "${HRM_ROOT}" \
      --ckpt-name step_26040 \
      --mode baseline_clean \
      --rollout-impl parallel_fixed \
      --train-steps 50000 \
      --batch-size 8 \
      --lr 1e-5 \
      --seed 42 \
      --eval-every 2500 \
      --eval-n 1024 \
      --eval-batch-size 32 \
      --save-best \
      --save-final \
      --save-every-eval \
      --save-train-state \
      --out step9_E_hrm_baseline_parallel_fixed_26040_50k.json
    ;;

  hrm_multi4_loguniform50k)
    run_one step9_F_hrm_multi4_loguniform_ramp_26040_50k \
      --model hrm \
      --ckpt-root "${HRM_ROOT}" \
      --ckpt-name step_26040 \
      --mode multi_perturbed_ce \
      --rollout-impl parallel_fixed \
      --n-trajectories 4 \
      --train-steps 50000 \
      --batch-size 8 \
      --lr 1e-5 \
      --noise-std 0.001 \
      --noise-min 0.00003 \
      --noise-max 0.003 \
      --noise-sampling loguniform \
      --sigma-start 0 \
      --sigma-ramp-steps 5000 \
      --perturb both \
      --seed 42 \
      --eval-every 2500 \
      --eval-n 1024 \
      --eval-batch-size 32 \
      --save-best \
      --save-final \
      --save-every-eval \
      --save-train-state \
      --out step9_F_hrm_multi4_loguniform_ramp_26040_50k.json
    ;;

  trm_baseline50k)
    run_one step9_G_trm_baseline_parallel_fixed_26041_batch4_50k \
      --model trm \
      --ckpt-root "${TRM_ROOT}" \
      --ckpt-name step_26041 \
      --mode baseline_clean \
      --rollout-impl parallel_fixed \
      --train-steps 50000 \
      --batch-size 4 \
      --lr 1e-5 \
      --seed 42 \
      --eval-every 2500 \
      --eval-n 1024 \
      --eval-batch-size 32 \
      --save-best \
      --save-final \
      --save-every-eval \
      --save-train-state \
      --out step9_G_trm_baseline_parallel_fixed_26041_batch4_50k.json
    ;;

  trm_multi4_loguniform50k)
    run_one step9_H_trm_multi4_loguniform_ramp_26041_batch4_50k \
      --model trm \
      --ckpt-root "${TRM_ROOT}" \
      --ckpt-name step_26041 \
      --mode multi_perturbed_ce \
      --rollout-impl parallel_fixed \
      --n-trajectories 4 \
      --train-steps 50000 \
      --batch-size 4 \
      --lr 1e-5 \
      --noise-std 0.001 \
      --noise-min 0.00003 \
      --noise-max 0.003 \
      --noise-sampling loguniform \
      --sigma-start 0 \
      --sigma-ramp-steps 5000 \
      --perturb both \
      --seed 42 \
      --eval-every 2500 \
      --eval-n 1024 \
      --eval-batch-size 32 \
      --save-best \
      --save-final \
      --save-every-eval \
      --save-train-state \
      --out step9_H_trm_multi4_loguniform_ramp_26041_batch4_50k.json
    ;;

  *)
    echo "unknown case: ${CASE}" >&2
    exit 2
    ;;
esac