summaryrefslogtreecommitdiff
path: root/smart_balanced_dataloader.py
diff options
context:
space:
mode:
authorhaoyuren <13851610112@163.com>2025-06-27 13:16:51 -0700
committerhaoyuren <13851610112@163.com>2025-06-27 13:16:51 -0700
commit9e45bd180d84e0d8e3b3962b16b0a437827af9f6 (patch)
tree3f85c62039384274f4fb9c02a6a51ed2e4c4dd29 /smart_balanced_dataloader.py
parent59ab57da5a5d222d7fa0c1862a1182c7e5059d72 (diff)
fix generater
Diffstat (limited to 'smart_balanced_dataloader.py')
-rw-r--r--smart_balanced_dataloader.py218
1 files changed, 218 insertions, 0 deletions
diff --git a/smart_balanced_dataloader.py b/smart_balanced_dataloader.py
new file mode 100644
index 0000000..54eb630
--- /dev/null
+++ b/smart_balanced_dataloader.py
@@ -0,0 +1,218 @@
+#!/usr/bin/env python3
+"""
+智能平衡数据加载器 - 自动检测并重新生成不平衡批次
+"""
+import torch
+from torch.utils.data import Dataset, DataLoader
+import random
+from typing import List, Dict
+import warnings
+
+class SmartBalancedGEEDataset(Dataset):
+ def __init__(self, data: List[Dict]):
+ self.data = data
+ # 按性别分组
+ self.male_data = [item for item in data if item['gender'] == 'male']
+ self.female_data = [item for item in data if item['gender'] == 'female']
+
+ print(f"📊 数据分布: male={len(self.male_data)}, female={len(self.female_data)}")
+
+ # 确保有足够的数据
+ if len(self.male_data) == 0 or len(self.female_data) == 0:
+ raise ValueError("数据中必须包含男性和女性样本")
+
+ # 确保有足够的样本进行重新生成
+ min_samples = min(len(self.male_data), len(self.female_data))
+ if min_samples < 10:
+ warnings.warn(f"样本数量较少 (min={min_samples}),可能影响批次生成质量")
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+ def generate_balanced_batch(self, batch_size: int, max_retries: int = 10):
+ """生成一个平衡的批次,如果失败会重试"""
+ if batch_size < 2:
+ raise ValueError("batch_size必须>=2才能保证性别平衡")
+
+ # 计算每个性别需要的样本数
+ male_per_batch = batch_size // 2
+ female_per_batch = batch_size - male_per_batch
+
+ for attempt in range(max_retries):
+ try:
+ batch = []
+
+ # 随机选择男性样本
+ if len(self.male_data) >= male_per_batch:
+ male_samples = random.sample(self.male_data, male_per_batch)
+ batch.extend(male_samples)
+ else:
+ # 如果男性样本不够,用替换的方式采样
+ male_samples = random.choices(self.male_data, k=male_per_batch)
+ batch.extend(male_samples)
+
+ # 随机选择女性样本
+ if len(self.female_data) >= female_per_batch:
+ female_samples = random.sample(self.female_data, female_per_batch)
+ batch.extend(female_samples)
+ else:
+ # 如果女性样本不够,用替换的方式采样
+ female_samples = random.choices(self.female_data, k=female_per_batch)
+ batch.extend(female_samples)
+
+ # 打乱批次内的顺序
+ random.shuffle(batch)
+
+ # 验证批次平衡性
+ male_count = sum(1 for item in batch if item['gender'] == 'male')
+ female_count = sum(1 for item in batch if item['gender'] == 'female')
+
+ if male_count > 0 and female_count > 0:
+ return batch
+ else:
+ print(f"⚠️ 尝试 {attempt+1}: 批次不平衡 (male={male_count}, female={female_count}),重新生成...")
+
+ except Exception as e:
+ print(f"⚠️ 尝试 {attempt+1} 失败: {e}")
+ continue
+
+ # 如果所有尝试都失败,返回一个强制平衡的批次
+ print(f"❌ {max_retries} 次尝试后仍然失败,强制生成平衡批次")
+ return self._force_balanced_batch(batch_size)
+
+ def _force_balanced_batch(self, batch_size: int):
+ """强制生成一个平衡批次"""
+ batch = []
+ male_per_batch = batch_size // 2
+ female_per_batch = batch_size - male_per_batch
+
+ # 强制添加男性样本(允许重复)
+ for _ in range(male_per_batch):
+ batch.append(random.choice(self.male_data))
+
+ # 强制添加女性样本(允许重复)
+ for _ in range(female_per_batch):
+ batch.append(random.choice(self.female_data))
+
+ random.shuffle(batch)
+ return batch
+
+class SmartBalancedDataLoader:
+ """智能平衡数据加载器"""
+ def __init__(self, dataset: SmartBalancedGEEDataset, batch_size: int, num_batches: int):
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.num_batches = num_batches
+ self.current_idx = 0
+
+ print(f"🧠 智能数据加载器初始化: batch_size={batch_size}, num_batches={num_batches}")
+
+ def __iter__(self):
+ self.current_idx = 0
+ return self
+
+ def __next__(self):
+ if self.current_idx >= self.num_batches:
+ raise StopIteration
+
+ # 动态生成平衡批次
+ batch = self.dataset.generate_balanced_batch(self.batch_size)
+ self.current_idx += 1
+
+ # 应用collate函数并验证
+ return self._smart_collate(batch)
+
+ def _smart_collate(self, batch, max_regenerate: int = 3):
+ """智能collate函数,如果检测到不平衡会重新生成"""
+ inputs = [item["input"] for item in batch]
+ genders = [item["gender"] for item in batch]
+
+ # 检查批次平衡性
+ male_count = sum(1 for g in genders if g == 'male')
+ female_count = sum(1 for g in genders if g == 'female')
+
+ # 如果不平衡,尝试重新生成
+ regenerate_count = 0
+ while (male_count == 0 or female_count == 0) and regenerate_count < max_regenerate:
+ print(f"🔄 检测到不平衡批次 (male={male_count}, female={female_count}),重新生成...")
+
+ # 重新生成批次
+ batch = self.dataset.generate_balanced_batch(self.batch_size)
+ inputs = [item["input"] for item in batch]
+ genders = [item["gender"] for item in batch]
+
+ # 重新检查
+ male_count = sum(1 for g in genders if g == 'male')
+ female_count = sum(1 for g in genders if g == 'female')
+ regenerate_count += 1
+
+ # 最终检查
+ if male_count == 0:
+ print("❌ 最终警告: 批次中没有男性样本!")
+ if female_count == 0:
+ print("❌ 最终警告: 批次中没有女性样本!")
+
+ if male_count > 0 and female_count > 0:
+ print(f"✅ 平衡批次: male={male_count}, female={female_count}")
+
+ return {
+ "input": inputs,
+ "gender": genders
+ }
+
+def create_smart_balanced_dataloader(data: List[Dict], batch_size: int, num_batches: int = 10):
+ """创建智能平衡数据加载器"""
+
+ if batch_size < 2:
+ print("⚠️ 警告: batch_size < 2,无法保证性别平衡")
+ # 回退到普通DataLoader
+ dataset = SmartBalancedGEEDataset(data)
+ return DataLoader(dataset, batch_size=batch_size, shuffle=True)
+
+ dataset = SmartBalancedGEEDataset(data)
+
+ print(f"🧠 创建智能平衡数据加载器")
+ print(f" 批次大小: {batch_size}")
+ print(f" 批次数量: {num_batches}")
+ print(f" 每批次配置: male={batch_size//2}, female={batch_size-batch_size//2}")
+
+ return SmartBalancedDataLoader(dataset, batch_size, num_batches)
+
+# 测试函数
+if __name__ == "__main__":
+ # 测试智能平衡数据加载器
+ import sys
+ sys.path.append('.')
+ from dataset.gee_processor import GEEProcessor
+
+ class MockTokenizer:
+ def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
+ return messages[0]["content"]
+
+ processor = GEEProcessor(MockTokenizer())
+ test_data = processor.create_test_data(num_samples=20)
+
+ print("🧪 测试智能平衡数据加载器")
+ dataloader = create_smart_balanced_dataloader(test_data, batch_size=4, num_batches=5)
+
+ for i, batch in enumerate(dataloader):
+ print(f"\n=== 批次 {i+1} ===")
+ print(f"输入数量: {len(batch['input'])}")
+ print(f"性别分布: {batch['gender']}")
+
+ # 验证平衡性
+ male_count = sum(1 for g in batch['gender'] if g == 'male')
+ female_count = sum(1 for g in batch['gender'] if g == 'female')
+
+ if male_count > 0 and female_count > 0:
+ print(f"✅ 批次完美平衡: male={male_count}, female={female_count}")
+ else:
+ print(f"❌ 批次仍然不平衡: male={male_count}, female={female_count}")
+
+ if i >= 4: # 测试5个批次
+ break
+
+ print("\n�� 智能平衡数据加载器测试完成!") \ No newline at end of file