summaryrefslogtreecommitdiff
path: root/research/flossing/launch_engelken_paper_faithful_trm_queue.sh
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/launch_engelken_paper_faithful_trm_queue.sh')
-rwxr-xr-xresearch/flossing/launch_engelken_paper_faithful_trm_queue.sh78
1 files changed, 78 insertions, 0 deletions
diff --git a/research/flossing/launch_engelken_paper_faithful_trm_queue.sh b/research/flossing/launch_engelken_paper_faithful_trm_queue.sh
new file mode 100755
index 0000000..f0c979d
--- /dev/null
+++ b/research/flossing/launch_engelken_paper_faithful_trm_queue.sh
@@ -0,0 +1,78 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT="/home/yurenh2/rrm"
+FLOSS_DIR="${ROOT}/research/flossing"
+PY="/home/yurenh2/miniconda3/envs/rrm/bin/python"
+CKPT_ROOT="${ROOT}/trm/checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_mlp_t_sudoku_official_gbs768_repro"
+OUT_DIR="${FLOSS_DIR}/engelken_paper_faithful"
+mkdir -p "${OUT_DIR}"
+
+wait_for_pid() {
+ local pid="${1:-}"
+ if [[ -z "${pid}" || "${pid}" == "0" ]]; then
+ return 0
+ fi
+ while kill -0 "${pid}" 2>/dev/null; do
+ sleep 60
+ done
+}
+
+common_args=(
+ --model trm
+ --ckpt-root "${CKPT_ROOT}"
+ --ckpt-name __random__
+ --init-seed 123
+ --train-steps 10000
+ --batch-size 8
+ --train-lr 1e-4
+ --floss-lr 1e-4
+ --floss-steps 500
+ --floss-mode engelken_l2
+ --lambda-star 0
+ --k-lyap 4
+ --lyap-act-steps 4
+ --seed 42
+ --eval-every 1000
+ --eval-n 1000
+ --eval-batch-size 64
+ --floss-log-every 10
+ --train-puzzle-emb
+ --puzzle-emb-lr 1e-4
+ --puzzle-emb-weight-decay 1.0
+ --kl-beta 0
+)
+
+run_case() {
+ local gpu="$1"
+ local name="$2"
+ local schedule="$3"
+ local extra_floss_steps="${4:-500}"
+ cd "${ROOT}"
+ CUDA_VISIBLE_DEVICES="${gpu}" PYTHONUNBUFFERED=1 "${PY}" research/flossing/step7_interfloss.py \
+ "${common_args[@]}" \
+ --floss-steps "${extra_floss_steps}" \
+ --interfloss-at "${schedule}" \
+ --out "${OUT_DIR}/${name}.json" \
+ > "${OUT_DIR}/${name}.log" 2>&1
+}
+
+base_wait_pid="$(cat "${FLOSS_DIR}/ptrm_official_gbs768_base58590_k100_d64_sigma03_Lonly_n1000_seed0.pid" 2>/dev/null || true)"
+multi_wait_pid="$(cat "${FLOSS_DIR}/ptrm_official_gbs768_multi4_35805_k100_d64_sigma03_Lonly_n1000_seed0.pid" 2>/dev/null || true)"
+
+(
+ wait_for_pid "${base_wait_pid}"
+ run_case 2 "trm_sudoku_seed123_baseline_nofloss_b8_10k" "" 0
+ run_case 2 "trm_sudoku_seed123_pre_interfloss_0_500_b8_k4_10k" "0,500" 500
+) &
+echo $! > "${OUT_DIR}/gpu2_queue.pid"
+
+(
+ wait_for_pid "${multi_wait_pid}"
+ run_case 3 "trm_sudoku_seed123_prefloss_0_b8_k4_10k" "0" 500
+) &
+echo $! > "${OUT_DIR}/gpu3_queue.pid"
+
+echo "queued Engelken-faithful TRM runs"
+echo "gpu2 queue pid: $(cat "${OUT_DIR}/gpu2_queue.pid")"
+echo "gpu3 queue pid: $(cat "${OUT_DIR}/gpu3_queue.pid")"