summaryrefslogtreecommitdiff
path: root/research/flossing/launch_trm_official_gbs768_multi4.sh
blob: 5bd9575ce5e0ba4b8075c3ec9c997d081bb5e04e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#!/usr/bin/env bash
set -eo pipefail

GPU="${1:-0}"

source /home/yurenh2/miniconda3/etc/profile.d/conda.sh
conda activate rrm

cd /home/yurenh2/rrm/trm
export WANDB_MODE=offline
export CUDA_VISIBLE_DEVICES="${GPU}"

python pretrain.py \
  arch=trm \
  data_paths='[/home/yurenh2/rrm/data/sudoku-extreme-1k-aug-1000]' \
  data_paths_test='[]' \
  evaluators='[]' \
  +project_name='Sudoku-extreme-1k-aug-1000-ACT-torch' \
  +run_name='pretrain_mlp_t_sudoku_official_gbs768_multi4_loguniform_repro' \
  +checkpoint_path='checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_mlp_t_sudoku_official_gbs768_multi4_loguniform_repro' \
  +load_checkpoint=null \
  epochs=50000 eval_interval=2500 min_eval_interval=0 checkpoint_every_eval=true \
  global_batch_size=768 \
  lr=0.0001 lr_min_ratio=1.0 lr_warmup_steps=2000 \
  beta1=0.9 beta2=0.95 weight_decay=1.0 \
  puzzle_emb_lr=0.0001 puzzle_emb_weight_decay=1.0 \
  ema=true ema_rate=0.999 freeze_weights=false \
  arch.mlp_t=true arch.pos_encodings=none arch.L_layers=2 arch.H_cycles=3 arch.L_cycles=6 \
  +trajectory_augment=true \
  +trajectory_n=4 \
  +trajectory_noise_std=0.001 \
  +trajectory_noise_min=0.00003 \
  +trajectory_noise_max=0.003 \
  +trajectory_noise_sampling=loguniform \
  +trajectory_sigma_start=0.0 \
  +trajectory_sigma_ramp_steps=5000 \
  +trajectory_perturb=both