summaryrefslogtreecommitdiff
path: root/test_first_batch_fix.py
diff options
context:
space:
mode:
authorhaoyuren <13851610112@163.com>2025-06-27 13:39:26 -0700
committerhaoyuren <13851610112@163.com>2025-06-27 13:39:26 -0700
commita939274b08f377a2ed93f7234e11f5257ef7667a (patch)
tree120763741912208aa5a66af7411419e12595c9b0 /test_first_batch_fix.py
parent9e45bd180d84e0d8e3b3962b16b0a437827af9f6 (diff)
fix 2
Diffstat (limited to 'test_first_batch_fix.py')
-rw-r--r--test_first_batch_fix.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/test_first_batch_fix.py b/test_first_batch_fix.py
new file mode 100644
index 0000000..be4e32f
--- /dev/null
+++ b/test_first_batch_fix.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python3
+"""
+测试第一个批次修复是否有效
+"""
+import sys
+sys.path.append('.')
+
+from dataset.gee_processor import GEEProcessor
+from smart_balanced_dataloader import create_smart_balanced_dataloader
+
+class MockTokenizer:
+ def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
+ return messages[0]["content"]
+
+print("🧪 测试第一个批次修复")
+print("="*50)
+
+# 创建测试数据
+processor = GEEProcessor(MockTokenizer())
+test_data = processor.create_test_data(num_samples=20)
+
+# 测试多次运行,确保第一个批次总是平衡
+print("🔄 测试5次运行,确保第一个批次总是平衡:")
+
+for test_run in range(5):
+ print(f"\n--- 测试运行 {test_run + 1} ---")
+
+ # 创建数据加载器
+ dataloader = create_smart_balanced_dataloader(test_data, batch_size=2, num_batches=3)
+
+ # 只关注第一个批次
+ first_batch = next(iter(dataloader))
+
+ male_count = sum(1 for g in first_batch['gender'] if g == 'male')
+ female_count = sum(1 for g in first_batch['gender'] if g == 'female')
+
+ print(f"第一批次: {first_batch['gender']}")
+ print(f"统计: male={male_count}, female={female_count}")
+
+ if male_count > 0 and female_count > 0:
+ print("✅ 第一批次平衡")
+ else:
+ print("❌ 第一批次仍然不平衡!")
+
+print(f"\n🎯 如果以上所有测试都显示'✅ 第一批次平衡',则修复成功!") \ No newline at end of file