summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorblackhao <13851610112@163.com>2025-06-25 23:53:15 -0700
committerblackhao <13851610112@163.com>2025-06-25 23:53:15 -0700
commit0a8f3fb353d1b95cdef5bf1f0baa666b6f590ab0 (patch)
tree1a08db7c740ebca82b4b66c876506de761f43276 /scripts
parentb2d2d05021de3aba1257fdeb69088a82c65a457f (diff)
gee init
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/evaluate_gee.sh58
-rwxr-xr-xscripts/quick_test_gee.sh45
-rwxr-xr-xscripts/train_one_shot_gee.sh56
3 files changed, 159 insertions, 0 deletions
diff --git a/scripts/evaluate_gee.sh b/scripts/evaluate_gee.sh
new file mode 100755
index 0000000..46f6055
--- /dev/null
+++ b/scripts/evaluate_gee.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+
+# GEE评估脚本
+# 使用方法: bash scripts/evaluate_gee.sh
+
+echo "开始GEE评估..."
+
+# 设置环境变量
+export CUDA_VISIBLE_DEVICES=0
+
+# 模型路径(请根据实际情况修改)
+BASE_MODEL_PATH="/volume/pt-train/models/Qwen2.5-Math-7B"
+GEE_MODEL_PATH="checkpoints/Qwen2.5-Math-7B/one_shot_gee/final"
+
+# 检查模型路径
+if [ ! -d "$BASE_MODEL_PATH" ]; then
+ echo "错误: 基础模型路径不存在: $BASE_MODEL_PATH"
+ exit 1
+fi
+
+if [ ! -d "$GEE_MODEL_PATH" ]; then
+ echo "错误: GEE模型路径不存在: $GEE_MODEL_PATH"
+ echo "请先运行训练脚本"
+ exit 1
+fi
+
+echo "基础模型: $BASE_MODEL_PATH"
+echo "GEE模型: $GEE_MODEL_PATH"
+
+# 运行评估
+python -c "
+import sys
+sys.path.append('.')
+from evaluation.gee_evaluator import GEEEvaluator
+
+# 创建评估器
+evaluator = GEEEvaluator('$BASE_MODEL_PATH')
+
+# 生成测试数据
+test_data = evaluator.create_winogender_style_data(num_samples=100)
+
+# 定义要比较的模型
+model_paths = {
+ 'Base': '$BASE_MODEL_PATH',
+ 'GEE': '$GEE_MODEL_PATH'
+}
+
+# 比较模型
+results = evaluator.compare_models(model_paths, test_data)
+
+# 打印摘要
+evaluator.print_summary(results)
+
+# 绘制结果
+evaluator.plot_results(results, 'gee_evaluation_results.png')
+"
+
+echo "评估完成!结果已保存到 gee_evaluation_results.png" \ No newline at end of file
diff --git a/scripts/quick_test_gee.sh b/scripts/quick_test_gee.sh
new file mode 100755
index 0000000..43763df
--- /dev/null
+++ b/scripts/quick_test_gee.sh
@@ -0,0 +1,45 @@
+#!/bin/bash
+
+# GEE快速测试脚本
+# 使用合成数据进行快速验证,无需真实模型
+
+echo "开始GEE快速测试..."
+
+# 设置环境变量
+export CUDA_VISIBLE_DEVICES=0
+
+echo "运行组件测试..."
+python test_gee_components.py
+
+if [ $? -eq 0 ]; then
+ echo "✓ 组件测试通过"
+else
+ echo "✗ 组件测试失败"
+ exit 1
+fi
+
+echo ""
+echo "运行快速训练测试(使用合成数据)..."
+accelerate launch train_gee.py \
+ --model_name Qwen2.5-Math-7B \
+ --model_path /volume/pt-train/models/Qwen2.5-Math-7B \
+ --use_test_data \
+ --effective_batch 8 \
+ --micro_batch_size 2 \
+ --max_steps 5 \
+ --lambda_weight 3.0 \
+ --log_steps 1 \
+ --save_steps 5 \
+ --run_name quick_test_gee \
+ --wandb_project one-shot-gee
+
+if [ $? -eq 0 ]; then
+ echo "✓ 快速训练测试通过"
+else
+ echo "✗ 快速训练测试失败"
+ exit 1
+fi
+
+echo ""
+echo "所有快速测试通过!✓"
+echo "现在可以运行完整的训练和评估脚本" \ No newline at end of file
diff --git a/scripts/train_one_shot_gee.sh b/scripts/train_one_shot_gee.sh
new file mode 100755
index 0000000..9a5fa85
--- /dev/null
+++ b/scripts/train_one_shot_gee.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+
+# One-shot GEE训练脚本
+# 使用方法: bash scripts/train_one_shot_gee.sh
+
+echo "开始One-shot GEE训练..."
+
+# 设置环境变量
+export CUDA_VISIBLE_DEVICES=0
+export NCCL_TIMEOUT=2700
+export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=2700
+
+# 训练参数
+MODEL_NAME="Qwen2.5-Math-7B"
+MODEL_PATH="/volume/pt-train/models/Qwen2.5-Math-7B" # 请根据实际路径修改
+TRAIN_DATA="dataset/1shot_rlvr/pi1_r1280.parquet"
+RUN_NAME="one_shot_gee_lambda3"
+WANDB_PROJECT="one-shot-gee"
+
+# 检查模型路径
+if [ ! -d "$MODEL_PATH" ]; then
+ echo "错误: 模型路径不存在: $MODEL_PATH"
+ echo "请修改脚本中的MODEL_PATH变量"
+ exit 1
+fi
+
+# 检查训练数据
+if [ ! -f "$TRAIN_DATA" ]; then
+ echo "错误: 训练数据文件不存在: $TRAIN_DATA"
+ echo "请检查数据文件路径"
+ exit 1
+fi
+
+echo "模型路径: $MODEL_PATH"
+echo "训练数据: $TRAIN_DATA"
+echo "运行名称: $RUN_NAME"
+
+# 开始训练
+accelerate launch train_gee.py \
+ --model_name $MODEL_NAME \
+ --model_path $MODEL_PATH \
+ --train_data $TRAIN_DATA \
+ --effective_batch 64 \
+ --micro_batch_size 2 \
+ --temperature 0.5 \
+ --learning_rate 2e-5 \
+ --max_steps 50 \
+ --lambda_weight 3.0 \
+ --auto_anneal \
+ --balance_dataset \
+ --log_steps 1 \
+ --save_steps 1 \
+ --run_name $RUN_NAME \
+ --wandb_project $WANDB_PROJECT
+
+echo "训练完成!" \ No newline at end of file