From 59ab57da5a5d222d7fa0c1862a1182c7e5059d72 Mon Sep 17 00:00:00 2001 From: haoyuren <13851610112@163.com> Date: Fri, 27 Jun 2025 12:07:28 -0700 Subject: fix dataloader 2 --- balanced_dataloader.py | 58 +++++++++++++++++++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 15 deletions(-) (limited to 'balanced_dataloader.py') diff --git a/balanced_dataloader.py b/balanced_dataloader.py index a0b89f7..a799178 100644 --- a/balanced_dataloader.py +++ b/balanced_dataloader.py @@ -59,7 +59,7 @@ class BalancedGEEDataset(Dataset): return batches -def balanced_collate(batch): +def balanced_collate(batch, verbose=False): """平衡的collate函数""" inputs = [item["input"] for item in batch] genders = [item["gender"] for item in batch] @@ -68,8 +68,10 @@ def balanced_collate(batch): male_count = sum(1 for g in genders if g == 'male') female_count = sum(1 for g in genders if g == 'female') - print(f"🔍 批次检查: male={male_count}, female={female_count}") + if verbose: + print(f"🔍 批次检查: male={male_count}, female={female_count}") + # 只在不平衡时打印警告 if male_count == 0: print("⚠️ 警告: 批次中没有男性样本!") if female_count == 0: @@ -80,26 +82,40 @@ def balanced_collate(batch): "gender": genders } +class BalancedDataLoader: + """自定义平衡数据加载器""" + def __init__(self, balanced_batches): + self.batches = balanced_batches + self.current_idx = 0 + + def __iter__(self): + self.current_idx = 0 + return self + + def __next__(self): + if self.current_idx >= len(self.batches): + raise StopIteration + + batch = self.batches[self.current_idx] + self.current_idx += 1 + + # 应用collate函数 (不显示详细信息,只显示警告) + return balanced_collate(batch, verbose=False) + def create_balanced_dataloader(data: List[Dict], batch_size: int, num_batches: int = 10): - """创建平衡的数据加载器""" + """创建平衡的数据加载器 - 修复版本""" dataset = BalancedGEEDataset(data) if batch_size < 2: print("⚠️ 警告: batch_size < 2,无法保证性别平衡") - return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=balanced_collate) + return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda batch: balanced_collate(batch, verbose=True)) - # 创建平衡批次 + # 🔧 修复: 直接返回预构造的平衡批次 balanced_batches = dataset.create_balanced_batches(batch_size, num_batches) - # 展平批次为单个数据点 - flat_data = [] - for batch in balanced_batches: - flat_data.extend(batch) - - # 创建新的数据集 - balanced_dataset = BalancedGEEDataset(flat_data) + print(f"✅ 创建了 {len(balanced_batches)} 个平衡批次") - return DataLoader(balanced_dataset, batch_size=batch_size, shuffle=False, collate_fn=balanced_collate) + return BalancedDataLoader(balanced_batches) # 测试函数 if __name__ == "__main__": @@ -115,12 +131,24 @@ if __name__ == "__main__": processor = GEEProcessor(MockTokenizer()) test_data = processor.create_test_data(num_samples=20) - print("🧪 测试平衡数据加载器") + print("🧪 测试修复后的平衡数据加载器") dataloader = create_balanced_dataloader(test_data, batch_size=4, num_batches=3) for i, batch in enumerate(dataloader): print(f"\n批次 {i+1}:") print(f" 输入数量: {len(batch['input'])}") print(f" 性别: {batch['gender']}") + + # 验证平衡性 + male_count = sum(1 for g in batch['gender'] if g == 'male') + female_count = sum(1 for g in batch['gender'] if g == 'female') + + if male_count > 0 and female_count > 0: + print(f" ✅ 批次平衡: male={male_count}, female={female_count}") + else: + print(f" ❌ 批次不平衡: male={male_count}, female={female_count}") + if i >= 2: # 只测试前3个批次 - break \ No newline at end of file + break + + print("\n✅ 平衡数据加载器测试完成!") \ No newline at end of file -- cgit v1.2.3