summaryrefslogtreecommitdiff
path: root/research/flossing/maze_package/launch_maze_trm.sh
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/maze_package/launch_maze_trm.sh')
-rwxr-xr-xresearch/flossing/maze_package/launch_maze_trm.sh40
1 files changed, 40 insertions, 0 deletions
diff --git a/research/flossing/maze_package/launch_maze_trm.sh b/research/flossing/maze_package/launch_maze_trm.sh
new file mode 100755
index 0000000..093bb1e
--- /dev/null
+++ b/research/flossing/maze_package/launch_maze_trm.sh
@@ -0,0 +1,40 @@
+#!/usr/bin/env bash
+# TRM Maze-Hard 30x30 official recipe, adapted for 2 GPUs. Run on dedicated training cards.
+# Usage: bash launch_maze_trm.sh [NGPU] [GBS]
+# 2x A6000 (48G): bash launch_maze_trm.sh 2 384
+# 2x A5000 (24G): bash launch_maze_trm.sh 2 192 (drop to 128 if OOM)
+# 1x card: bash launch_maze_trm.sh 1 128
+set -eo pipefail
+
+NGPU="${1:-2}"
+GBS="${2:-384}"
+RUN_NAME="pretrain_att_maze30x30_${NGPU}gpu_gbs${GBS}"
+
+source /home/yurenh2/miniconda3/etc/profile.d/conda.sh
+conda activate rrm
+cd /home/yurenh2/rrm/trm
+export WANDB_MODE=offline
+
+COMMON_ARGS=(
+ arch=trm
+ "data_paths=[/home/yurenh2/rrm/data/maze-30x30-hard-1k]"
+ "evaluators=[]"
+ epochs=50000 eval_interval=5000
+ lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
+ global_batch_size="${GBS}"
+ arch.L_layers=2 arch.H_cycles=3 arch.L_cycles=4
+ +run_name="${RUN_NAME}" ema=True
+ +checkpoint_every_eval=true
+)
+
+LOG="/home/yurenh2/rrm/research/flossing/maze_${RUN_NAME}.log"
+
+if [[ "${NGPU}" -gt 1 ]]; then
+ nohup torchrun --nproc-per-node "${NGPU}" --rdzv_backend=c10d --rdzv_endpoint=localhost:0 \
+ --nnodes=1 pretrain.py "${COMMON_ARGS[@]}" > "${LOG}" 2>&1 &
+else
+ nohup python pretrain.py "${COMMON_ARGS[@]}" > "${LOG}" 2>&1 &
+fi
+echo "launched ${RUN_NAME} (pid $!), log: ${LOG}"
+echo "checkpoints -> trm/checkpoints/maze-30x30-hard-1k.../${RUN_NAME}/ (one per 5000 epochs)"
+echo "monitor: tail -f ${LOG} | grep -E 'accuracy|exact'"