summaryrefslogtreecommitdiff
path: root/research/flossing/maze_package/launch_maze_trm.sh
blob: 093bb1eb440a70b39efda8bf90d23323bfac1f49 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env bash
# TRM Maze-Hard 30x30 official recipe, adapted for 2 GPUs. Run on dedicated training cards.
# Usage: bash launch_maze_trm.sh [NGPU] [GBS]
#   2x A6000 (48G):  bash launch_maze_trm.sh 2 384
#   2x A5000 (24G):  bash launch_maze_trm.sh 2 192     (drop to 128 if OOM)
#   1x card:         bash launch_maze_trm.sh 1 128
set -eo pipefail

NGPU="${1:-2}"
GBS="${2:-384}"
RUN_NAME="pretrain_att_maze30x30_${NGPU}gpu_gbs${GBS}"

source /home/yurenh2/miniconda3/etc/profile.d/conda.sh
conda activate rrm
cd /home/yurenh2/rrm/trm
export WANDB_MODE=offline

COMMON_ARGS=(
  arch=trm
  "data_paths=[/home/yurenh2/rrm/data/maze-30x30-hard-1k]"
  "evaluators=[]"
  epochs=50000 eval_interval=5000
  lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
  global_batch_size="${GBS}"
  arch.L_layers=2 arch.H_cycles=3 arch.L_cycles=4
  +run_name="${RUN_NAME}" ema=True
  +checkpoint_every_eval=true
)

LOG="/home/yurenh2/rrm/research/flossing/maze_${RUN_NAME}.log"

if [[ "${NGPU}" -gt 1 ]]; then
  nohup torchrun --nproc-per-node "${NGPU}" --rdzv_backend=c10d --rdzv_endpoint=localhost:0 \
    --nnodes=1 pretrain.py "${COMMON_ARGS[@]}" > "${LOG}" 2>&1 &
else
  nohup python pretrain.py "${COMMON_ARGS[@]}" > "${LOG}" 2>&1 &
fi
echo "launched ${RUN_NAME} (pid $!), log: ${LOG}"
echo "checkpoints -> trm/checkpoints/maze-30x30-hard-1k.../${RUN_NAME}/  (one per 5000 epochs)"
echo "monitor:  tail -f ${LOG}   |  grep -E 'accuracy|exact'"