summaryrefslogtreecommitdiff
path: root/research/flossing/launch_hrm_multi4_perturb_ablation.sh
blob: 07fba63674a2fbc249742ad647568186b0104c51 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#!/usr/bin/env bash
set -eo pipefail

CASE="${1:?case required: h | l | joint | all}"
GPU="${2:-0}"

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"

run_case() {
  local perturb="$1"
  local suffix="$2"

  source /home/yurenh2/miniconda3/etc/profile.d/conda.sh
  conda activate rrm

  cd /home/yurenh2/rrm/hrm
  export WANDB_MODE=offline
  export CUDA_VISIBLE_DEVICES="${GPU}"

  python pretrain.py \
    data_path=/home/yurenh2/rrm/data/sudoku-extreme-1k-aug-1000 \
    +project_name='Sudoku-extreme-1k-aug-1000 ACT-torch' \
    +run_name="HierarchicalReasoningModel_ACTV1 multi4-loguniform-${suffix}-repro" \
    +checkpoint_path="checkpoints/Sudoku-extreme-1k-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 multi4-loguniform-${suffix}-repro" \
    epochs=20000 eval_interval=2000 checkpoint_every_eval=true \
    global_batch_size=768 \
    lr=0.0001 lr_min_ratio=1.0 lr_warmup_steps=2000 \
    beta1=0.9 beta2=0.95 weight_decay=1.0 \
    puzzle_emb_lr=0.0001 puzzle_emb_weight_decay=1.0 \
    +trajectory_augment=true \
    +trajectory_n=4 \
    +trajectory_noise_std=0.001 \
    +trajectory_noise_min=0.00003 \
    +trajectory_noise_max=0.003 \
    +trajectory_noise_sampling=loguniform \
    +trajectory_sigma_start=0.0 \
    +trajectory_sigma_ramp_steps=5000 \
    +trajectory_perturb="${perturb}"
}

case "${CASE}" in
  h)
    run_case h honly
    ;;
  l)
    run_case l lonly
    ;;
  joint)
    run_case joint jointnorm
    ;;
  all)
    bash "${SCRIPT_DIR}/launch_hrm_multi4_perturb_ablation.sh" h "${GPU}"
    bash "${SCRIPT_DIR}/launch_hrm_multi4_perturb_ablation.sh" l "${GPU}"
    bash "${SCRIPT_DIR}/launch_hrm_multi4_perturb_ablation.sh" joint "${GPU}"
    ;;
  *)
    echo "unknown case: ${CASE}" >&2
    exit 2
    ;;
esac