summaryrefslogtreecommitdiff
path: root/research/flossing/launch_multi4_parallel_repro_config.sh
blob: d618cfe511632ec281d8dd39184ec4612671dc0f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env bash
set -eo pipefail

CASE="${1:?case required}"
GPU="${2:-0}"

source /home/yurenh2/miniconda3/etc/profile.d/conda.sh
conda activate rrm

case "${CASE}" in
  hrm)
    cd /home/yurenh2/rrm/hrm
    export WANDB_MODE=offline
    export CUDA_VISIBLE_DEVICES="${GPU}"
    python pretrain.py \
      data_path=/home/yurenh2/rrm/data/sudoku-extreme-1k-aug-1000 \
      +project_name='Sudoku-extreme-1k-aug-1000 ACT-torch' \
      +run_name='HierarchicalReasoningModel_ACTV1 multi4-loguniform-parallel-repro' \
      +checkpoint_path='checkpoints/Sudoku-extreme-1k-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 multi4-loguniform-parallel-repro' \
      epochs=20000 eval_interval=2000 checkpoint_every_eval=true \
      global_batch_size=768 \
      lr=0.0001 lr_min_ratio=1.0 lr_warmup_steps=2000 \
      beta1=0.9 beta2=0.95 weight_decay=1.0 \
      puzzle_emb_lr=0.0001 puzzle_emb_weight_decay=1.0 \
      +trajectory_augment=true \
      +trajectory_parallel=true \
      +trajectory_n=4 \
      +trajectory_noise_std=0.001 \
      +trajectory_noise_min=0.00003 \
      +trajectory_noise_max=0.003 \
      +trajectory_noise_sampling=loguniform \
      +trajectory_sigma_start=0.0 \
      +trajectory_sigma_ramp_steps=5000 \
      +trajectory_perturb=both
    ;;
  trm)
    cd /home/yurenh2/rrm/trm
    export WANDB_MODE=offline
    export CUDA_VISIBLE_DEVICES="${GPU}"
    python pretrain.py \
      data_paths='[/home/yurenh2/rrm/data/sudoku-extreme-1k-aug-1000]' \
      data_paths_test='[]' \
      evaluators='[]' \
      +project_name='Sudoku-extreme-1k-aug-1000-ACT-torch' \
      +run_name='pretrain_mlp_t_sudoku_multi4_loguniform_parallel_repro' \
      +checkpoint_path='checkpoints/Sudoku-extreme-1k-aug-1000-ACT-torch/pretrain_mlp_t_sudoku_multi4_loguniform_parallel_repro' \
      +load_checkpoint=null \
      epochs=50000 eval_interval=5000 min_eval_interval=0 checkpoint_every_eval=true \
      global_batch_size=192 \
      lr=0.0001 lr_min_ratio=1.0 lr_warmup_steps=2000 \
      beta1=0.9 beta2=0.95 weight_decay=1.0 \
      puzzle_emb_lr=0.0001 puzzle_emb_weight_decay=1.0 \
      ema=true ema_rate=0.999 freeze_weights=false \
      arch.mlp_t=true arch.pos_encodings=none arch.puzzle_emb_len=16 arch.no_ACT_continue=true \
      +trajectory_augment=true \
      +trajectory_parallel=true \
      +trajectory_n=4 \
      +trajectory_noise_std=0.001 \
      +trajectory_noise_min=0.00003 \
      +trajectory_noise_max=0.003 \
      +trajectory_noise_sampling=loguniform \
      +trajectory_sigma_start=0.0 \
      +trajectory_sigma_ramp_steps=5000 \
      +trajectory_perturb=both
    ;;
  *)
    echo "unknown case: ${CASE}" >&2
    exit 2
    ;;
esac