summaryrefslogtreecommitdiff
path: root/scripts/run_trm_sudoku.sh
blob: ae6db211ebd30f67b86d67c98297cf0101cc4c0a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#!/usr/bin/env bash
# 启动 TRM Sudoku 1k 训练 (TRM 官方 pretrain_mlp_t_sudoku 配置)
# 单 GPU L40S 48GB 约 18h; A6000 应该接近
set -euo pipefail
REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
source "$(conda info --base)/etc/profile.d/conda.sh"
conda activate rrm

cd "$REPO_ROOT/trm"
OMP_NUM_THREADS=${OMP_NUM_THREADS:-8} \
WANDB_MODE=${WANDB_MODE:-online} \
python pretrain.py \
  arch=trm \
  data_paths="[$REPO_ROOT/data/sudoku-extreme-1k-aug-1000]" \
  evaluators="[]" \
  epochs=50000 eval_interval=5000 \
  lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
  arch.mlp_t=True arch.pos_encodings=none \
  arch.L_layers=2 \
  arch.H_cycles=3 arch.L_cycles=6 \
  +run_name=pretrain_mlp_t_sudoku ema=True "$@"