summaryrefslogtreecommitdiff
path: root/balanced_dataloader.py
diff options
context:
space:
mode:
authorhaoyuren <13851610112@163.com>2025-06-27 11:53:49 -0700
committerhaoyuren <13851610112@163.com>2025-06-27 11:53:49 -0700
commit3c1718c9d245b7e133da7632c06238166b480fa0 (patch)
treec63653778d5a98d8376caeae8f8d24801e90d1de /balanced_dataloader.py
parent24e163f9211fb9a9af561de47898ea64f5f26df4 (diff)
fix dataloader
Diffstat (limited to 'balanced_dataloader.py')
-rw-r--r--balanced_dataloader.py126
1 files changed, 126 insertions, 0 deletions
diff --git a/balanced_dataloader.py b/balanced_dataloader.py
new file mode 100644
index 0000000..a0b89f7
--- /dev/null
+++ b/balanced_dataloader.py
@@ -0,0 +1,126 @@
+#!/usr/bin/env python3
+"""
+平衡的数据加载器 - 确保每个批次包含男女样本
+"""
+import torch
+from torch.utils.data import Dataset, DataLoader
+import random
+from typing import List, Dict
+
+class BalancedGEEDataset(Dataset):
+ def __init__(self, data: List[Dict]):
+ self.data = data
+ # 按性别分组
+ self.male_data = [item for item in data if item['gender'] == 'male']
+ self.female_data = [item for item in data if item['gender'] == 'female']
+
+ print(f"📊 数据分布: male={len(self.male_data)}, female={len(self.female_data)}")
+
+ # 确保有足够的数据
+ if len(self.male_data) == 0 or len(self.female_data) == 0:
+ raise ValueError("数据中必须包含男性和女性样本")
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+ def create_balanced_batches(self, batch_size: int, num_batches: int = None):
+ """创建平衡的批次"""
+ if batch_size < 2:
+ raise ValueError("batch_size必须>=2才能保证性别平衡")
+
+ # 每个批次中男女样本的数量
+ male_per_batch = batch_size // 2
+ female_per_batch = batch_size - male_per_batch
+
+ batches = []
+ max_batches = num_batches or (len(self.data) // batch_size)
+
+ for i in range(max_batches):
+ batch = []
+
+ # 随机选择男性样本
+ male_samples = random.sample(self.male_data,
+ min(male_per_batch, len(self.male_data)))
+ batch.extend(male_samples)
+
+ # 随机选择女性样本
+ female_samples = random.sample(self.female_data,
+ min(female_per_batch, len(self.female_data)))
+ batch.extend(female_samples)
+
+ # 打乱批次内的顺序
+ random.shuffle(batch)
+ batches.append(batch)
+
+ print(f"批次 {i+1}: male={len(male_samples)}, female={len(female_samples)}")
+
+ return batches
+
+def balanced_collate(batch):
+ """平衡的collate函数"""
+ inputs = [item["input"] for item in batch]
+ genders = [item["gender"] for item in batch]
+
+ # 检查批次平衡性
+ male_count = sum(1 for g in genders if g == 'male')
+ female_count = sum(1 for g in genders if g == 'female')
+
+ print(f"🔍 批次检查: male={male_count}, female={female_count}")
+
+ if male_count == 0:
+ print("⚠️ 警告: 批次中没有男性样本!")
+ if female_count == 0:
+ print("⚠️ 警告: 批次中没有女性样本!")
+
+ return {
+ "input": inputs,
+ "gender": genders
+ }
+
+def create_balanced_dataloader(data: List[Dict], batch_size: int, num_batches: int = 10):
+ """创建平衡的数据加载器"""
+ dataset = BalancedGEEDataset(data)
+
+ if batch_size < 2:
+ print("⚠️ 警告: batch_size < 2,无法保证性别平衡")
+ return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=balanced_collate)
+
+ # 创建平衡批次
+ balanced_batches = dataset.create_balanced_batches(batch_size, num_batches)
+
+ # 展平批次为单个数据点
+ flat_data = []
+ for batch in balanced_batches:
+ flat_data.extend(batch)
+
+ # 创建新的数据集
+ balanced_dataset = BalancedGEEDataset(flat_data)
+
+ return DataLoader(balanced_dataset, batch_size=batch_size, shuffle=False, collate_fn=balanced_collate)
+
+# 测试函数
+if __name__ == "__main__":
+ # 测试平衡数据加载器
+ import sys
+ sys.path.append('.')
+ from dataset.gee_processor import GEEProcessor
+
+ class MockTokenizer:
+ def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
+ return messages[0]["content"]
+
+ processor = GEEProcessor(MockTokenizer())
+ test_data = processor.create_test_data(num_samples=20)
+
+ print("🧪 测试平衡数据加载器")
+ dataloader = create_balanced_dataloader(test_data, batch_size=4, num_batches=3)
+
+ for i, batch in enumerate(dataloader):
+ print(f"\n批次 {i+1}:")
+ print(f" 输入数量: {len(batch['input'])}")
+ print(f" 性别: {batch['gender']}")
+ if i >= 2: # 只测试前3个批次
+ break \ No newline at end of file