summaryrefslogtreecommitdiff
path: root/smart_balanced_dataloader.py
diff options
context:
space:
mode:
Diffstat (limited to 'smart_balanced_dataloader.py')
-rw-r--r--smart_balanced_dataloader.py65
1 files changed, 55 insertions, 10 deletions
diff --git a/smart_balanced_dataloader.py b/smart_balanced_dataloader.py
index 54eb630..925b323 100644
--- a/smart_balanced_dataloader.py
+++ b/smart_balanced_dataloader.py
@@ -45,35 +45,36 @@ class SmartBalancedGEEDataset(Dataset):
try:
batch = []
- # 随机选择男性样本
+ # 🔧 修复: 使用更好的随机策略确保平衡
+ # 强制选择所需数量的男性样本
if len(self.male_data) >= male_per_batch:
male_samples = random.sample(self.male_data, male_per_batch)
- batch.extend(male_samples)
else:
# 如果男性样本不够,用替换的方式采样
male_samples = random.choices(self.male_data, k=male_per_batch)
- batch.extend(male_samples)
- # 随机选择女性样本
+ batch.extend(male_samples)
+
+ # 强制选择所需数量的女性样本
if len(self.female_data) >= female_per_batch:
female_samples = random.sample(self.female_data, female_per_batch)
- batch.extend(female_samples)
else:
# 如果女性样本不够,用替换的方式采样
female_samples = random.choices(self.female_data, k=female_per_batch)
- batch.extend(female_samples)
+
+ batch.extend(female_samples)
# 打乱批次内的顺序
random.shuffle(batch)
- # 验证批次平衡性
+ # 最终验证批次平衡性
male_count = sum(1 for item in batch if item['gender'] == 'male')
female_count = sum(1 for item in batch if item['gender'] == 'female')
if male_count > 0 and female_count > 0:
return batch
else:
- print(f"⚠️ 尝试 {attempt+1}: 批次不平衡 (male={male_count}, female={female_count}),重新生成...")
+ print(f"❌ 尝试 {attempt+1}: 批次不平衡 (male={male_count}, female={female_count}),重新生成...")
except Exception as e:
print(f"⚠️ 尝试 {attempt+1} 失败: {e}")
@@ -118,13 +119,57 @@ class SmartBalancedDataLoader:
if self.current_idx >= self.num_batches:
raise StopIteration
- # 动态生成平衡批次
- batch = self.dataset.generate_balanced_batch(self.batch_size)
+ # 🔧 特殊处理第一个批次,确保绝对平衡
+ if self.current_idx == 0:
+ print(f"🎯 特殊处理第一个批次")
+ batch = self._generate_guaranteed_balanced_batch()
+ else:
+ # 其他批次使用标准方法
+ batch = self.dataset.generate_balanced_batch(self.batch_size)
+
self.current_idx += 1
# 应用collate函数并验证
return self._smart_collate(batch)
+ def _generate_guaranteed_balanced_batch(self):
+ """保证生成平衡的第一个批次"""
+ batch_size = self.batch_size
+ male_per_batch = batch_size // 2
+ female_per_batch = batch_size - male_per_batch
+
+ batch = []
+
+ # 强制选择男性样本(取前N个或轮换选择)
+ if len(self.dataset.male_data) >= male_per_batch:
+ # 不使用随机,而是轮换选择
+ male_samples = self.dataset.male_data[:male_per_batch]
+ else:
+ # 重复选择
+ male_samples = (self.dataset.male_data * ((male_per_batch // len(self.dataset.male_data)) + 1))[:male_per_batch]
+
+ batch.extend(male_samples)
+
+ # 强制选择女性样本(取前N个或轮换选择)
+ if len(self.dataset.female_data) >= female_per_batch:
+ # 不使用随机,而是轮换选择
+ female_samples = self.dataset.female_data[:female_per_batch]
+ else:
+ # 重复选择
+ female_samples = (self.dataset.female_data * ((female_per_batch // len(self.dataset.female_data)) + 1))[:female_per_batch]
+
+ batch.extend(female_samples)
+
+ # 验证
+ male_count = sum(1 for item in batch if item['gender'] == 'male')
+ female_count = sum(1 for item in batch if item['gender'] == 'female')
+ print(f"🎯 第一批次: male={male_count}, female={female_count} (强制平衡)")
+
+ # 打乱顺序
+ random.shuffle(batch)
+
+ return batch
+
def _smart_collate(self, batch, max_regenerate: int = 3):
"""智能collate函数,如果检测到不平衡会重新生成"""
inputs = [item["input"] for item in batch]