summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhaoyuren <13851610112@163.com>2025-06-27 13:39:26 -0700
committerhaoyuren <13851610112@163.com>2025-06-27 13:39:26 -0700
commita939274b08f377a2ed93f7234e11f5257ef7667a (patch)
tree120763741912208aa5a66af7411419e12595c9b0
parent9e45bd180d84e0d8e3b3962b16b0a437827af9f6 (diff)
fix 2
-rw-r--r--smart_balanced_dataloader.py65
-rw-r--r--test_first_batch_fix.py45
2 files changed, 100 insertions, 10 deletions
diff --git a/smart_balanced_dataloader.py b/smart_balanced_dataloader.py
index 54eb630..925b323 100644
--- a/smart_balanced_dataloader.py
+++ b/smart_balanced_dataloader.py
@@ -45,35 +45,36 @@ class SmartBalancedGEEDataset(Dataset):
try:
batch = []
- # 随机选择男性样本
+ # 🔧 修复: 使用更好的随机策略确保平衡
+ # 强制选择所需数量的男性样本
if len(self.male_data) >= male_per_batch:
male_samples = random.sample(self.male_data, male_per_batch)
- batch.extend(male_samples)
else:
# 如果男性样本不够,用替换的方式采样
male_samples = random.choices(self.male_data, k=male_per_batch)
- batch.extend(male_samples)
- # 随机选择女性样本
+ batch.extend(male_samples)
+
+ # 强制选择所需数量的女性样本
if len(self.female_data) >= female_per_batch:
female_samples = random.sample(self.female_data, female_per_batch)
- batch.extend(female_samples)
else:
# 如果女性样本不够,用替换的方式采样
female_samples = random.choices(self.female_data, k=female_per_batch)
- batch.extend(female_samples)
+
+ batch.extend(female_samples)
# 打乱批次内的顺序
random.shuffle(batch)
- # 验证批次平衡性
+ # 最终验证批次平衡性
male_count = sum(1 for item in batch if item['gender'] == 'male')
female_count = sum(1 for item in batch if item['gender'] == 'female')
if male_count > 0 and female_count > 0:
return batch
else:
- print(f"⚠️ 尝试 {attempt+1}: 批次不平衡 (male={male_count}, female={female_count}),重新生成...")
+ print(f"❌ 尝试 {attempt+1}: 批次不平衡 (male={male_count}, female={female_count}),重新生成...")
except Exception as e:
print(f"⚠️ 尝试 {attempt+1} 失败: {e}")
@@ -118,13 +119,57 @@ class SmartBalancedDataLoader:
if self.current_idx >= self.num_batches:
raise StopIteration
- # 动态生成平衡批次
- batch = self.dataset.generate_balanced_batch(self.batch_size)
+ # 🔧 特殊处理第一个批次,确保绝对平衡
+ if self.current_idx == 0:
+ print(f"🎯 特殊处理第一个批次")
+ batch = self._generate_guaranteed_balanced_batch()
+ else:
+ # 其他批次使用标准方法
+ batch = self.dataset.generate_balanced_batch(self.batch_size)
+
self.current_idx += 1
# 应用collate函数并验证
return self._smart_collate(batch)
+ def _generate_guaranteed_balanced_batch(self):
+ """保证生成平衡的第一个批次"""
+ batch_size = self.batch_size
+ male_per_batch = batch_size // 2
+ female_per_batch = batch_size - male_per_batch
+
+ batch = []
+
+ # 强制选择男性样本(取前N个或轮换选择)
+ if len(self.dataset.male_data) >= male_per_batch:
+ # 不使用随机,而是轮换选择
+ male_samples = self.dataset.male_data[:male_per_batch]
+ else:
+ # 重复选择
+ male_samples = (self.dataset.male_data * ((male_per_batch // len(self.dataset.male_data)) + 1))[:male_per_batch]
+
+ batch.extend(male_samples)
+
+ # 强制选择女性样本(取前N个或轮换选择)
+ if len(self.dataset.female_data) >= female_per_batch:
+ # 不使用随机,而是轮换选择
+ female_samples = self.dataset.female_data[:female_per_batch]
+ else:
+ # 重复选择
+ female_samples = (self.dataset.female_data * ((female_per_batch // len(self.dataset.female_data)) + 1))[:female_per_batch]
+
+ batch.extend(female_samples)
+
+ # 验证
+ male_count = sum(1 for item in batch if item['gender'] == 'male')
+ female_count = sum(1 for item in batch if item['gender'] == 'female')
+ print(f"🎯 第一批次: male={male_count}, female={female_count} (强制平衡)")
+
+ # 打乱顺序
+ random.shuffle(batch)
+
+ return batch
+
def _smart_collate(self, batch, max_regenerate: int = 3):
"""智能collate函数,如果检测到不平衡会重新生成"""
inputs = [item["input"] for item in batch]
diff --git a/test_first_batch_fix.py b/test_first_batch_fix.py
new file mode 100644
index 0000000..be4e32f
--- /dev/null
+++ b/test_first_batch_fix.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python3
+"""
+测试第一个批次修复是否有效
+"""
+import sys
+sys.path.append('.')
+
+from dataset.gee_processor import GEEProcessor
+from smart_balanced_dataloader import create_smart_balanced_dataloader
+
+class MockTokenizer:
+ def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
+ return messages[0]["content"]
+
+print("🧪 测试第一个批次修复")
+print("="*50)
+
+# 创建测试数据
+processor = GEEProcessor(MockTokenizer())
+test_data = processor.create_test_data(num_samples=20)
+
+# 测试多次运行,确保第一个批次总是平衡
+print("🔄 测试5次运行,确保第一个批次总是平衡:")
+
+for test_run in range(5):
+ print(f"\n--- 测试运行 {test_run + 1} ---")
+
+ # 创建数据加载器
+ dataloader = create_smart_balanced_dataloader(test_data, batch_size=2, num_batches=3)
+
+ # 只关注第一个批次
+ first_batch = next(iter(dataloader))
+
+ male_count = sum(1 for g in first_batch['gender'] if g == 'male')
+ female_count = sum(1 for g in first_batch['gender'] if g == 'female')
+
+ print(f"第一批次: {first_batch['gender']}")
+ print(f"统计: male={male_count}, female={female_count}")
+
+ if male_count > 0 and female_count > 0:
+ print("✅ 第一批次平衡")
+ else:
+ print("❌ 第一批次仍然不平衡!")
+
+print(f"\n🎯 如果以上所有测试都显示'✅ 第一批次平衡',则修复成功!") \ No newline at end of file