summaryrefslogtreecommitdiff
path: root/research/flossing/launch_engelken_paper_faithful_trm_queue.sh
blob: f0c979dc42c13460ffd5324a501b68d67bea3e99 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env bash
set -euo pipefail

ROOT="/home/yurenh2/rrm"
FLOSS_DIR="${ROOT}/research/flossing"
PY="/home/yurenh2/miniconda3/envs/rrm/bin/python"
CKPT_ROOT="${ROOT}/trm/checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_mlp_t_sudoku_official_gbs768_repro"
OUT_DIR="${FLOSS_DIR}/engelken_paper_faithful"
mkdir -p "${OUT_DIR}"

wait_for_pid() {
  local pid="${1:-}"
  if [[ -z "${pid}" || "${pid}" == "0" ]]; then
    return 0
  fi
  while kill -0 "${pid}" 2>/dev/null; do
    sleep 60
  done
}

common_args=(
  --model trm
  --ckpt-root "${CKPT_ROOT}"
  --ckpt-name __random__
  --init-seed 123
  --train-steps 10000
  --batch-size 8
  --train-lr 1e-4
  --floss-lr 1e-4
  --floss-steps 500
  --floss-mode engelken_l2
  --lambda-star 0
  --k-lyap 4
  --lyap-act-steps 4
  --seed 42
  --eval-every 1000
  --eval-n 1000
  --eval-batch-size 64
  --floss-log-every 10
  --train-puzzle-emb
  --puzzle-emb-lr 1e-4
  --puzzle-emb-weight-decay 1.0
  --kl-beta 0
)

run_case() {
  local gpu="$1"
  local name="$2"
  local schedule="$3"
  local extra_floss_steps="${4:-500}"
  cd "${ROOT}"
  CUDA_VISIBLE_DEVICES="${gpu}" PYTHONUNBUFFERED=1 "${PY}" research/flossing/step7_interfloss.py \
    "${common_args[@]}" \
    --floss-steps "${extra_floss_steps}" \
    --interfloss-at "${schedule}" \
    --out "${OUT_DIR}/${name}.json" \
    > "${OUT_DIR}/${name}.log" 2>&1
}

base_wait_pid="$(cat "${FLOSS_DIR}/ptrm_official_gbs768_base58590_k100_d64_sigma03_Lonly_n1000_seed0.pid" 2>/dev/null || true)"
multi_wait_pid="$(cat "${FLOSS_DIR}/ptrm_official_gbs768_multi4_35805_k100_d64_sigma03_Lonly_n1000_seed0.pid" 2>/dev/null || true)"

(
  wait_for_pid "${base_wait_pid}"
  run_case 2 "trm_sudoku_seed123_baseline_nofloss_b8_10k" "" 0
  run_case 2 "trm_sudoku_seed123_pre_interfloss_0_500_b8_k4_10k" "0,500" 500
) &
echo $! > "${OUT_DIR}/gpu2_queue.pid"

(
  wait_for_pid "${multi_wait_pid}"
  run_case 3 "trm_sudoku_seed123_prefloss_0_b8_k4_10k" "0" 500
) &
echo $! > "${OUT_DIR}/gpu3_queue.pid"

echo "queued Engelken-faithful TRM runs"
echo "gpu2 queue pid: $(cat "${OUT_DIR}/gpu2_queue.pid")"
echo "gpu3 queue pid: $(cat "${OUT_DIR}/gpu3_queue.pid")"