diff options
Diffstat (limited to 'smart_balanced_dataloader.py')
| -rw-r--r-- | smart_balanced_dataloader.py | 218 |
1 files changed, 218 insertions, 0 deletions
diff --git a/smart_balanced_dataloader.py b/smart_balanced_dataloader.py new file mode 100644 index 0000000..54eb630 --- /dev/null +++ b/smart_balanced_dataloader.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +""" +智能平衡数据加载器 - 自动检测并重新生成不平衡批次 +""" +import torch +from torch.utils.data import Dataset, DataLoader +import random +from typing import List, Dict +import warnings + +class SmartBalancedGEEDataset(Dataset): + def __init__(self, data: List[Dict]): + self.data = data + # 按性别分组 + self.male_data = [item for item in data if item['gender'] == 'male'] + self.female_data = [item for item in data if item['gender'] == 'female'] + + print(f"📊 数据分布: male={len(self.male_data)}, female={len(self.female_data)}") + + # 确保有足够的数据 + if len(self.male_data) == 0 or len(self.female_data) == 0: + raise ValueError("数据中必须包含男性和女性样本") + + # 确保有足够的样本进行重新生成 + min_samples = min(len(self.male_data), len(self.female_data)) + if min_samples < 10: + warnings.warn(f"样本数量较少 (min={min_samples}),可能影响批次生成质量") + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + def generate_balanced_batch(self, batch_size: int, max_retries: int = 10): + """生成一个平衡的批次,如果失败会重试""" + if batch_size < 2: + raise ValueError("batch_size必须>=2才能保证性别平衡") + + # 计算每个性别需要的样本数 + male_per_batch = batch_size // 2 + female_per_batch = batch_size - male_per_batch + + for attempt in range(max_retries): + try: + batch = [] + + # 随机选择男性样本 + if len(self.male_data) >= male_per_batch: + male_samples = random.sample(self.male_data, male_per_batch) + batch.extend(male_samples) + else: + # 如果男性样本不够,用替换的方式采样 + male_samples = random.choices(self.male_data, k=male_per_batch) + batch.extend(male_samples) + + # 随机选择女性样本 + if len(self.female_data) >= female_per_batch: + female_samples = random.sample(self.female_data, female_per_batch) + batch.extend(female_samples) + else: + # 如果女性样本不够,用替换的方式采样 + female_samples = random.choices(self.female_data, k=female_per_batch) + batch.extend(female_samples) + + # 打乱批次内的顺序 + random.shuffle(batch) + + # 验证批次平衡性 + male_count = sum(1 for item in batch if item['gender'] == 'male') + female_count = sum(1 for item in batch if item['gender'] == 'female') + + if male_count > 0 and female_count > 0: + return batch + else: + print(f"⚠️ 尝试 {attempt+1}: 批次不平衡 (male={male_count}, female={female_count}),重新生成...") + + except Exception as e: + print(f"⚠️ 尝试 {attempt+1} 失败: {e}") + continue + + # 如果所有尝试都失败,返回一个强制平衡的批次 + print(f"❌ {max_retries} 次尝试后仍然失败,强制生成平衡批次") + return self._force_balanced_batch(batch_size) + + def _force_balanced_batch(self, batch_size: int): + """强制生成一个平衡批次""" + batch = [] + male_per_batch = batch_size // 2 + female_per_batch = batch_size - male_per_batch + + # 强制添加男性样本(允许重复) + for _ in range(male_per_batch): + batch.append(random.choice(self.male_data)) + + # 强制添加女性样本(允许重复) + for _ in range(female_per_batch): + batch.append(random.choice(self.female_data)) + + random.shuffle(batch) + return batch + +class SmartBalancedDataLoader: + """智能平衡数据加载器""" + def __init__(self, dataset: SmartBalancedGEEDataset, batch_size: int, num_batches: int): + self.dataset = dataset + self.batch_size = batch_size + self.num_batches = num_batches + self.current_idx = 0 + + print(f"🧠 智能数据加载器初始化: batch_size={batch_size}, num_batches={num_batches}") + + def __iter__(self): + self.current_idx = 0 + return self + + def __next__(self): + if self.current_idx >= self.num_batches: + raise StopIteration + + # 动态生成平衡批次 + batch = self.dataset.generate_balanced_batch(self.batch_size) + self.current_idx += 1 + + # 应用collate函数并验证 + return self._smart_collate(batch) + + def _smart_collate(self, batch, max_regenerate: int = 3): + """智能collate函数,如果检测到不平衡会重新生成""" + inputs = [item["input"] for item in batch] + genders = [item["gender"] for item in batch] + + # 检查批次平衡性 + male_count = sum(1 for g in genders if g == 'male') + female_count = sum(1 for g in genders if g == 'female') + + # 如果不平衡,尝试重新生成 + regenerate_count = 0 + while (male_count == 0 or female_count == 0) and regenerate_count < max_regenerate: + print(f"🔄 检测到不平衡批次 (male={male_count}, female={female_count}),重新生成...") + + # 重新生成批次 + batch = self.dataset.generate_balanced_batch(self.batch_size) + inputs = [item["input"] for item in batch] + genders = [item["gender"] for item in batch] + + # 重新检查 + male_count = sum(1 for g in genders if g == 'male') + female_count = sum(1 for g in genders if g == 'female') + regenerate_count += 1 + + # 最终检查 + if male_count == 0: + print("❌ 最终警告: 批次中没有男性样本!") + if female_count == 0: + print("❌ 最终警告: 批次中没有女性样本!") + + if male_count > 0 and female_count > 0: + print(f"✅ 平衡批次: male={male_count}, female={female_count}") + + return { + "input": inputs, + "gender": genders + } + +def create_smart_balanced_dataloader(data: List[Dict], batch_size: int, num_batches: int = 10): + """创建智能平衡数据加载器""" + + if batch_size < 2: + print("⚠️ 警告: batch_size < 2,无法保证性别平衡") + # 回退到普通DataLoader + dataset = SmartBalancedGEEDataset(data) + return DataLoader(dataset, batch_size=batch_size, shuffle=True) + + dataset = SmartBalancedGEEDataset(data) + + print(f"🧠 创建智能平衡数据加载器") + print(f" 批次大小: {batch_size}") + print(f" 批次数量: {num_batches}") + print(f" 每批次配置: male={batch_size//2}, female={batch_size-batch_size//2}") + + return SmartBalancedDataLoader(dataset, batch_size, num_batches) + +# 测试函数 +if __name__ == "__main__": + # 测试智能平衡数据加载器 + import sys + sys.path.append('.') + from dataset.gee_processor import GEEProcessor + + class MockTokenizer: + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): + return messages[0]["content"] + + processor = GEEProcessor(MockTokenizer()) + test_data = processor.create_test_data(num_samples=20) + + print("🧪 测试智能平衡数据加载器") + dataloader = create_smart_balanced_dataloader(test_data, batch_size=4, num_batches=5) + + for i, batch in enumerate(dataloader): + print(f"\n=== 批次 {i+1} ===") + print(f"输入数量: {len(batch['input'])}") + print(f"性别分布: {batch['gender']}") + + # 验证平衡性 + male_count = sum(1 for g in batch['gender'] if g == 'male') + female_count = sum(1 for g in batch['gender'] if g == 'female') + + if male_count > 0 and female_count > 0: + print(f"✅ 批次完美平衡: male={male_count}, female={female_count}") + else: + print(f"❌ 批次仍然不平衡: male={male_count}, female={female_count}") + + if i >= 4: # 测试5个批次 + break + + print("\n�� 智能平衡数据加载器测试完成!")
\ No newline at end of file |
