summaryrefslogtreecommitdiff
path: root/balanced_dataloader.py
blob: a7991780c4e85d5d0ba332dab3c4d148e8e1168a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python3
"""
平衡的数据加载器 - 确保每个批次包含男女样本
"""
import torch
from torch.utils.data import Dataset, DataLoader
import random
from typing import List, Dict

class BalancedGEEDataset(Dataset):
    def __init__(self, data: List[Dict]):
        self.data = data
        # 按性别分组
        self.male_data = [item for item in data if item['gender'] == 'male']
        self.female_data = [item for item in data if item['gender'] == 'female']
        
        print(f"📊 数据分布: male={len(self.male_data)}, female={len(self.female_data)}")
        
        # 确保有足够的数据
        if len(self.male_data) == 0 or len(self.female_data) == 0:
            raise ValueError("数据中必须包含男性和女性样本")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    def create_balanced_batches(self, batch_size: int, num_batches: int = None):
        """创建平衡的批次"""
        if batch_size < 2:
            raise ValueError("batch_size必须>=2才能保证性别平衡")
        
        # 每个批次中男女样本的数量
        male_per_batch = batch_size // 2
        female_per_batch = batch_size - male_per_batch
        
        batches = []
        max_batches = num_batches or (len(self.data) // batch_size)
        
        for i in range(max_batches):
            batch = []
            
            # 随机选择男性样本
            male_samples = random.sample(self.male_data, 
                                       min(male_per_batch, len(self.male_data)))
            batch.extend(male_samples)
            
            # 随机选择女性样本
            female_samples = random.sample(self.female_data, 
                                         min(female_per_batch, len(self.female_data)))
            batch.extend(female_samples)
            
            # 打乱批次内的顺序
            random.shuffle(batch)
            batches.append(batch)
            
            print(f"批次 {i+1}: male={len(male_samples)}, female={len(female_samples)}")
        
        return batches

def balanced_collate(batch, verbose=False):
    """平衡的collate函数"""
    inputs = [item["input"] for item in batch]
    genders = [item["gender"] for item in batch]
    
    # 检查批次平衡性
    male_count = sum(1 for g in genders if g == 'male')
    female_count = sum(1 for g in genders if g == 'female')
    
    if verbose:
        print(f"🔍 批次检查: male={male_count}, female={female_count}")
    
    # 只在不平衡时打印警告
    if male_count == 0:
        print("⚠️ 警告: 批次中没有男性样本!")
    if female_count == 0:
        print("⚠️ 警告: 批次中没有女性样本!")
    
    return {
        "input": inputs,
        "gender": genders
    }

class BalancedDataLoader:
    """自定义平衡数据加载器"""
    def __init__(self, balanced_batches):
        self.batches = balanced_batches
        self.current_idx = 0
    
    def __iter__(self):
        self.current_idx = 0
        return self
    
    def __next__(self):
        if self.current_idx >= len(self.batches):
            raise StopIteration
        
        batch = self.batches[self.current_idx]
        self.current_idx += 1
        
        # 应用collate函数 (不显示详细信息,只显示警告)
        return balanced_collate(batch, verbose=False)

def create_balanced_dataloader(data: List[Dict], batch_size: int, num_batches: int = 10):
    """创建平衡的数据加载器 - 修复版本"""
    dataset = BalancedGEEDataset(data)
    
    if batch_size < 2:
        print("⚠️ 警告: batch_size < 2,无法保证性别平衡")
        return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda batch: balanced_collate(batch, verbose=True))
    
    # 🔧 修复: 直接返回预构造的平衡批次
    balanced_batches = dataset.create_balanced_batches(batch_size, num_batches)
    
    print(f"✅ 创建了 {len(balanced_batches)} 个平衡批次")
    
    return BalancedDataLoader(balanced_batches)

# 测试函数
if __name__ == "__main__":
    # 测试平衡数据加载器
    import sys
    sys.path.append('.')
    from dataset.gee_processor import GEEProcessor
    
    class MockTokenizer:
        def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
            return messages[0]["content"]
    
    processor = GEEProcessor(MockTokenizer())
    test_data = processor.create_test_data(num_samples=20)
    
    print("🧪 测试修复后的平衡数据加载器")
    dataloader = create_balanced_dataloader(test_data, batch_size=4, num_batches=3)
    
    for i, batch in enumerate(dataloader):
        print(f"\n批次 {i+1}:")
        print(f"  输入数量: {len(batch['input'])}")
        print(f"  性别: {batch['gender']}")
        
        # 验证平衡性
        male_count = sum(1 for g in batch['gender'] if g == 'male')
        female_count = sum(1 for g in batch['gender'] if g == 'female')
        
        if male_count > 0 and female_count > 0:
            print(f"  ✅ 批次平衡: male={male_count}, female={female_count}")
        else:
            print(f"  ❌ 批次不平衡: male={male_count}, female={female_count}")
        
        if i >= 2:  # 只测试前3个批次
            break
    
    print("\n✅ 平衡数据加载器测试完成!")