summaryrefslogtreecommitdiff
path: root/scripts/evaluate_gee.sh
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/evaluate_gee.sh')
-rwxr-xr-xscripts/evaluate_gee.sh58
1 files changed, 58 insertions, 0 deletions
diff --git a/scripts/evaluate_gee.sh b/scripts/evaluate_gee.sh
new file mode 100755
index 0000000..46f6055
--- /dev/null
+++ b/scripts/evaluate_gee.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+
+# GEE评估脚本
+# 使用方法: bash scripts/evaluate_gee.sh
+
+echo "开始GEE评估..."
+
+# 设置环境变量
+export CUDA_VISIBLE_DEVICES=0
+
+# 模型路径(请根据实际情况修改)
+BASE_MODEL_PATH="/volume/pt-train/models/Qwen2.5-Math-7B"
+GEE_MODEL_PATH="checkpoints/Qwen2.5-Math-7B/one_shot_gee/final"
+
+# 检查模型路径
+if [ ! -d "$BASE_MODEL_PATH" ]; then
+ echo "错误: 基础模型路径不存在: $BASE_MODEL_PATH"
+ exit 1
+fi
+
+if [ ! -d "$GEE_MODEL_PATH" ]; then
+ echo "错误: GEE模型路径不存在: $GEE_MODEL_PATH"
+ echo "请先运行训练脚本"
+ exit 1
+fi
+
+echo "基础模型: $BASE_MODEL_PATH"
+echo "GEE模型: $GEE_MODEL_PATH"
+
+# 运行评估
+python -c "
+import sys
+sys.path.append('.')
+from evaluation.gee_evaluator import GEEEvaluator
+
+# 创建评估器
+evaluator = GEEEvaluator('$BASE_MODEL_PATH')
+
+# 生成测试数据
+test_data = evaluator.create_winogender_style_data(num_samples=100)
+
+# 定义要比较的模型
+model_paths = {
+ 'Base': '$BASE_MODEL_PATH',
+ 'GEE': '$GEE_MODEL_PATH'
+}
+
+# 比较模型
+results = evaluator.compare_models(model_paths, test_data)
+
+# 打印摘要
+evaluator.print_summary(results)
+
+# 绘制结果
+evaluator.plot_results(results, 'gee_evaluation_results.png')
+"
+
+echo "评估完成!结果已保存到 gee_evaluation_results.png" \ No newline at end of file