summaryrefslogtreecommitdiff
path: root/balanced_dataloader.py
diff options
context:
space:
mode:
Diffstat (limited to 'balanced_dataloader.py')
-rw-r--r--balanced_dataloader.py58
1 files changed, 43 insertions, 15 deletions
diff --git a/balanced_dataloader.py b/balanced_dataloader.py
index a0b89f7..a799178 100644
--- a/balanced_dataloader.py
+++ b/balanced_dataloader.py
@@ -59,7 +59,7 @@ class BalancedGEEDataset(Dataset):
return batches
-def balanced_collate(batch):
+def balanced_collate(batch, verbose=False):
"""平衡的collate函数"""
inputs = [item["input"] for item in batch]
genders = [item["gender"] for item in batch]
@@ -68,8 +68,10 @@ def balanced_collate(batch):
male_count = sum(1 for g in genders if g == 'male')
female_count = sum(1 for g in genders if g == 'female')
- print(f"🔍 批次检查: male={male_count}, female={female_count}")
+ if verbose:
+ print(f"🔍 批次检查: male={male_count}, female={female_count}")
+ # 只在不平衡时打印警告
if male_count == 0:
print("⚠️ 警告: 批次中没有男性样本!")
if female_count == 0:
@@ -80,26 +82,40 @@ def balanced_collate(batch):
"gender": genders
}
+class BalancedDataLoader:
+ """自定义平衡数据加载器"""
+ def __init__(self, balanced_batches):
+ self.batches = balanced_batches
+ self.current_idx = 0
+
+ def __iter__(self):
+ self.current_idx = 0
+ return self
+
+ def __next__(self):
+ if self.current_idx >= len(self.batches):
+ raise StopIteration
+
+ batch = self.batches[self.current_idx]
+ self.current_idx += 1
+
+ # 应用collate函数 (不显示详细信息,只显示警告)
+ return balanced_collate(batch, verbose=False)
+
def create_balanced_dataloader(data: List[Dict], batch_size: int, num_batches: int = 10):
- """创建平衡的数据加载器"""
+ """创建平衡的数据加载器 - 修复版本"""
dataset = BalancedGEEDataset(data)
if batch_size < 2:
print("⚠️ 警告: batch_size < 2,无法保证性别平衡")
- return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=balanced_collate)
+ return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda batch: balanced_collate(batch, verbose=True))
- # 创建平衡批次
+ # 🔧 修复: 直接返回预构造的平衡批次
balanced_batches = dataset.create_balanced_batches(batch_size, num_batches)
- # 展平批次为单个数据点
- flat_data = []
- for batch in balanced_batches:
- flat_data.extend(batch)
-
- # 创建新的数据集
- balanced_dataset = BalancedGEEDataset(flat_data)
+ print(f"✅ 创建了 {len(balanced_batches)} 个平衡批次")
- return DataLoader(balanced_dataset, batch_size=batch_size, shuffle=False, collate_fn=balanced_collate)
+ return BalancedDataLoader(balanced_batches)
# 测试函数
if __name__ == "__main__":
@@ -115,12 +131,24 @@ if __name__ == "__main__":
processor = GEEProcessor(MockTokenizer())
test_data = processor.create_test_data(num_samples=20)
- print("🧪 测试平衡数据加载器")
+ print("🧪 测试修复后的平衡数据加载器")
dataloader = create_balanced_dataloader(test_data, batch_size=4, num_batches=3)
for i, batch in enumerate(dataloader):
print(f"\n批次 {i+1}:")
print(f" 输入数量: {len(batch['input'])}")
print(f" 性别: {batch['gender']}")
+
+ # 验证平衡性
+ male_count = sum(1 for g in batch['gender'] if g == 'male')
+ female_count = sum(1 for g in batch['gender'] if g == 'female')
+
+ if male_count > 0 and female_count > 0:
+ print(f" ✅ 批次平衡: male={male_count}, female={female_count}")
+ else:
+ print(f" ❌ 批次不平衡: male={male_count}, female={female_count}")
+
if i >= 2: # 只测试前3个批次
- break \ No newline at end of file
+ break
+
+ print("\n✅ 平衡数据加载器测试完成!") \ No newline at end of file