blob: 248e9a280ac411f9cc5b2559f070adf89b665216 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
|
#!/usr/bin/env bash
# 验证 HRM 和 TRM 都能在共享数据集上启动训练循环
# 不跑全量训练,每边跑 90s 检查崩没崩
set -euo pipefail
REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
source "$(conda info --base)/etc/profile.d/conda.sh"
conda activate rrm
if [[ ! -d "$REPO_ROOT/data/sudoku-extreme-1k-aug-1000" ]]; then
echo "datasets not built; run scripts/build_datasets.sh sudoku first"; exit 1
fi
echo "=== HRM smoke ==="
cd "$REPO_ROOT/hrm"
( timeout 90 env WANDB_MODE=offline OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES=0 \
python pretrain.py \
data_path="$REPO_ROOT/data/sudoku-extreme-1k-aug-1000" \
epochs=1 eval_interval=1 global_batch_size=64 \
lr=7e-5 puzzle_emb_lr=7e-5 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
> /tmp/hrm_smoke.log 2>&1 ) || true
grep -iE "error|traceback" /tmp/hrm_smoke.log && { echo "HRM smoke FAILED"; tail -30 /tmp/hrm_smoke.log; exit 1; } || echo "HRM smoke OK"
echo "=== TRM smoke ==="
cd "$REPO_ROOT/trm"
( timeout 120 env WANDB_MODE=offline OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES=0 \
python pretrain.py \
arch=trm \
data_paths="[$REPO_ROOT/data/sudoku-extreme-1k-aug-1000]" \
evaluators="[]" \
epochs=1 eval_interval=1 global_batch_size=64 \
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
arch.mlp_t=True arch.pos_encodings=none \
arch.L_layers=2 arch.H_cycles=3 arch.L_cycles=6 \
+run_name=smoke_test ema=True \
> /tmp/trm_smoke.log 2>&1 ) || true
grep -iE "error|traceback" /tmp/trm_smoke.log && { echo "TRM smoke FAILED"; tail -30 /tmp/trm_smoke.log; exit 1; } || echo "TRM smoke OK"
echo "==> all smoke tests passed"
|