blob: f0c979dc42c13460ffd5324a501b68d67bea3e99 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
|
#!/usr/bin/env bash
set -euo pipefail
ROOT="/home/yurenh2/rrm"
FLOSS_DIR="${ROOT}/research/flossing"
PY="/home/yurenh2/miniconda3/envs/rrm/bin/python"
CKPT_ROOT="${ROOT}/trm/checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_mlp_t_sudoku_official_gbs768_repro"
OUT_DIR="${FLOSS_DIR}/engelken_paper_faithful"
mkdir -p "${OUT_DIR}"
wait_for_pid() {
local pid="${1:-}"
if [[ -z "${pid}" || "${pid}" == "0" ]]; then
return 0
fi
while kill -0 "${pid}" 2>/dev/null; do
sleep 60
done
}
common_args=(
--model trm
--ckpt-root "${CKPT_ROOT}"
--ckpt-name __random__
--init-seed 123
--train-steps 10000
--batch-size 8
--train-lr 1e-4
--floss-lr 1e-4
--floss-steps 500
--floss-mode engelken_l2
--lambda-star 0
--k-lyap 4
--lyap-act-steps 4
--seed 42
--eval-every 1000
--eval-n 1000
--eval-batch-size 64
--floss-log-every 10
--train-puzzle-emb
--puzzle-emb-lr 1e-4
--puzzle-emb-weight-decay 1.0
--kl-beta 0
)
run_case() {
local gpu="$1"
local name="$2"
local schedule="$3"
local extra_floss_steps="${4:-500}"
cd "${ROOT}"
CUDA_VISIBLE_DEVICES="${gpu}" PYTHONUNBUFFERED=1 "${PY}" research/flossing/step7_interfloss.py \
"${common_args[@]}" \
--floss-steps "${extra_floss_steps}" \
--interfloss-at "${schedule}" \
--out "${OUT_DIR}/${name}.json" \
> "${OUT_DIR}/${name}.log" 2>&1
}
base_wait_pid="$(cat "${FLOSS_DIR}/ptrm_official_gbs768_base58590_k100_d64_sigma03_Lonly_n1000_seed0.pid" 2>/dev/null || true)"
multi_wait_pid="$(cat "${FLOSS_DIR}/ptrm_official_gbs768_multi4_35805_k100_d64_sigma03_Lonly_n1000_seed0.pid" 2>/dev/null || true)"
(
wait_for_pid "${base_wait_pid}"
run_case 2 "trm_sudoku_seed123_baseline_nofloss_b8_10k" "" 0
run_case 2 "trm_sudoku_seed123_pre_interfloss_0_500_b8_k4_10k" "0,500" 500
) &
echo $! > "${OUT_DIR}/gpu2_queue.pid"
(
wait_for_pid "${multi_wait_pid}"
run_case 3 "trm_sudoku_seed123_prefloss_0_b8_k4_10k" "0" 500
) &
echo $! > "${OUT_DIR}/gpu3_queue.pid"
echo "queued Engelken-faithful TRM runs"
echo "gpu2 queue pid: $(cat "${OUT_DIR}/gpu2_queue.pid")"
echo "gpu3 queue pid: $(cat "${OUT_DIR}/gpu3_queue.pid")"
|