summaryrefslogtreecommitdiff
path: root/collaborativeagents/training/train_sft_lf.sbatch
blob: 7f090b16ef69a6aed8cba974e0fef538f5401abc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/bin/bash
#SBATCH --job-name=sft_lf
#SBATCH --account=bfqt-delta-gpu
#SBATCH --partition=gpuH200x8
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=32
#SBATCH --gres=gpu:4
#SBATCH --mem=256G
#SBATCH --time=8:00:00
#SBATCH --output=/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/sft_lf_%j.out
#SBATCH --error=/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/sft_lf_%j.err

echo "=== SFT Training with LLaMA-Factory (H200) ==="
date
nvidia-smi --query-gpu=index,name,memory.total --format=csv

cd /projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training
source /u/yurenh2/miniforge3/etc/profile.d/conda.sh
conda activate eval
export HF_HOME=/projects/bfqt/users/yurenh2/hf_cache/huggingface
export WANDB_MODE=offline

echo "Model: /projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/models/llama-3.1-8b-instruct"
echo "Data: training_data/sft_training_data.json"
echo "Output: outputs/sft_reflection_lf"

# Check training data
echo "Training data size: $(stat -c%s training_data/sft_training_data.json) bytes"
echo "Training examples: $(python3 -c 'import json; print(len(json.load(open(\"training_data/sft_training_data.json\"))))')"

echo ""
echo "Starting LLaMA-Factory SFT training with DeepSpeed ZeRO-3..."

# Run with LLaMA-Factory CLI + DeepSpeed ZeRO-3 for memory efficiency
FORCE_TORCHRUN=1 llamafactory-cli train llama_factory_config.yaml

echo ""
echo "SFT Training complete!"
date