diff options
| author | blackhao <13851610112@163.com> | 2025-06-25 23:53:15 -0700 |
|---|---|---|
| committer | blackhao <13851610112@163.com> | 2025-06-25 23:53:15 -0700 |
| commit | 0a8f3fb353d1b95cdef5bf1f0baa666b6f590ab0 (patch) | |
| tree | 1a08db7c740ebca82b4b66c876506de761f43276 /scripts | |
| parent | b2d2d05021de3aba1257fdeb69088a82c65a457f (diff) | |
gee init
Diffstat (limited to 'scripts')
| -rwxr-xr-x | scripts/evaluate_gee.sh | 58 | ||||
| -rwxr-xr-x | scripts/quick_test_gee.sh | 45 | ||||
| -rwxr-xr-x | scripts/train_one_shot_gee.sh | 56 |
3 files changed, 159 insertions, 0 deletions
diff --git a/scripts/evaluate_gee.sh b/scripts/evaluate_gee.sh new file mode 100755 index 0000000..46f6055 --- /dev/null +++ b/scripts/evaluate_gee.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +# GEE评估脚本 +# 使用方法: bash scripts/evaluate_gee.sh + +echo "开始GEE评估..." + +# 设置环境变量 +export CUDA_VISIBLE_DEVICES=0 + +# 模型路径(请根据实际情况修改) +BASE_MODEL_PATH="/volume/pt-train/models/Qwen2.5-Math-7B" +GEE_MODEL_PATH="checkpoints/Qwen2.5-Math-7B/one_shot_gee/final" + +# 检查模型路径 +if [ ! -d "$BASE_MODEL_PATH" ]; then + echo "错误: 基础模型路径不存在: $BASE_MODEL_PATH" + exit 1 +fi + +if [ ! -d "$GEE_MODEL_PATH" ]; then + echo "错误: GEE模型路径不存在: $GEE_MODEL_PATH" + echo "请先运行训练脚本" + exit 1 +fi + +echo "基础模型: $BASE_MODEL_PATH" +echo "GEE模型: $GEE_MODEL_PATH" + +# 运行评估 +python -c " +import sys +sys.path.append('.') +from evaluation.gee_evaluator import GEEEvaluator + +# 创建评估器 +evaluator = GEEEvaluator('$BASE_MODEL_PATH') + +# 生成测试数据 +test_data = evaluator.create_winogender_style_data(num_samples=100) + +# 定义要比较的模型 +model_paths = { + 'Base': '$BASE_MODEL_PATH', + 'GEE': '$GEE_MODEL_PATH' +} + +# 比较模型 +results = evaluator.compare_models(model_paths, test_data) + +# 打印摘要 +evaluator.print_summary(results) + +# 绘制结果 +evaluator.plot_results(results, 'gee_evaluation_results.png') +" + +echo "评估完成!结果已保存到 gee_evaluation_results.png"
\ No newline at end of file diff --git a/scripts/quick_test_gee.sh b/scripts/quick_test_gee.sh new file mode 100755 index 0000000..43763df --- /dev/null +++ b/scripts/quick_test_gee.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# GEE快速测试脚本 +# 使用合成数据进行快速验证,无需真实模型 + +echo "开始GEE快速测试..." + +# 设置环境变量 +export CUDA_VISIBLE_DEVICES=0 + +echo "运行组件测试..." +python test_gee_components.py + +if [ $? -eq 0 ]; then + echo "✓ 组件测试通过" +else + echo "✗ 组件测试失败" + exit 1 +fi + +echo "" +echo "运行快速训练测试(使用合成数据)..." +accelerate launch train_gee.py \ + --model_name Qwen2.5-Math-7B \ + --model_path /volume/pt-train/models/Qwen2.5-Math-7B \ + --use_test_data \ + --effective_batch 8 \ + --micro_batch_size 2 \ + --max_steps 5 \ + --lambda_weight 3.0 \ + --log_steps 1 \ + --save_steps 5 \ + --run_name quick_test_gee \ + --wandb_project one-shot-gee + +if [ $? -eq 0 ]; then + echo "✓ 快速训练测试通过" +else + echo "✗ 快速训练测试失败" + exit 1 +fi + +echo "" +echo "所有快速测试通过!✓" +echo "现在可以运行完整的训练和评估脚本"
\ No newline at end of file diff --git a/scripts/train_one_shot_gee.sh b/scripts/train_one_shot_gee.sh new file mode 100755 index 0000000..9a5fa85 --- /dev/null +++ b/scripts/train_one_shot_gee.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +# One-shot GEE训练脚本 +# 使用方法: bash scripts/train_one_shot_gee.sh + +echo "开始One-shot GEE训练..." + +# 设置环境变量 +export CUDA_VISIBLE_DEVICES=0 +export NCCL_TIMEOUT=2700 +export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=2700 + +# 训练参数 +MODEL_NAME="Qwen2.5-Math-7B" +MODEL_PATH="/volume/pt-train/models/Qwen2.5-Math-7B" # 请根据实际路径修改 +TRAIN_DATA="dataset/1shot_rlvr/pi1_r1280.parquet" +RUN_NAME="one_shot_gee_lambda3" +WANDB_PROJECT="one-shot-gee" + +# 检查模型路径 +if [ ! -d "$MODEL_PATH" ]; then + echo "错误: 模型路径不存在: $MODEL_PATH" + echo "请修改脚本中的MODEL_PATH变量" + exit 1 +fi + +# 检查训练数据 +if [ ! -f "$TRAIN_DATA" ]; then + echo "错误: 训练数据文件不存在: $TRAIN_DATA" + echo "请检查数据文件路径" + exit 1 +fi + +echo "模型路径: $MODEL_PATH" +echo "训练数据: $TRAIN_DATA" +echo "运行名称: $RUN_NAME" + +# 开始训练 +accelerate launch train_gee.py \ + --model_name $MODEL_NAME \ + --model_path $MODEL_PATH \ + --train_data $TRAIN_DATA \ + --effective_batch 64 \ + --micro_batch_size 2 \ + --temperature 0.5 \ + --learning_rate 2e-5 \ + --max_steps 50 \ + --lambda_weight 3.0 \ + --auto_anneal \ + --balance_dataset \ + --log_steps 1 \ + --save_steps 1 \ + --run_name $RUN_NAME \ + --wandb_project $WANDB_PROJECT + +echo "训练完成!"
\ No newline at end of file |
