summaryrefslogtreecommitdiff
path: root/balanced_dataloader.py
blob: a0b89f732b814118155063addfad21b9b2ad7c6b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python3
"""
平衡的数据加载器 - 确保每个批次包含男女样本
"""
import torch
from torch.utils.data import Dataset, DataLoader
import random
from typing import List, Dict

class BalancedGEEDataset(Dataset):
    def __init__(self, data: List[Dict]):
        self.data = data
        # 按性别分组
        self.male_data = [item for item in data if item['gender'] == 'male']
        self.female_data = [item for item in data if item['gender'] == 'female']
        
        print(f"📊 数据分布: male={len(self.male_data)}, female={len(self.female_data)}")
        
        # 确保有足够的数据
        if len(self.male_data) == 0 or len(self.female_data) == 0:
            raise ValueError("数据中必须包含男性和女性样本")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    def create_balanced_batches(self, batch_size: int, num_batches: int = None):
        """创建平衡的批次"""
        if batch_size < 2:
            raise ValueError("batch_size必须>=2才能保证性别平衡")
        
        # 每个批次中男女样本的数量
        male_per_batch = batch_size // 2
        female_per_batch = batch_size - male_per_batch
        
        batches = []
        max_batches = num_batches or (len(self.data) // batch_size)
        
        for i in range(max_batches):
            batch = []
            
            # 随机选择男性样本
            male_samples = random.sample(self.male_data, 
                                       min(male_per_batch, len(self.male_data)))
            batch.extend(male_samples)
            
            # 随机选择女性样本
            female_samples = random.sample(self.female_data, 
                                         min(female_per_batch, len(self.female_data)))
            batch.extend(female_samples)
            
            # 打乱批次内的顺序
            random.shuffle(batch)
            batches.append(batch)
            
            print(f"批次 {i+1}: male={len(male_samples)}, female={len(female_samples)}")
        
        return batches

def balanced_collate(batch):
    """平衡的collate函数"""
    inputs = [item["input"] for item in batch]
    genders = [item["gender"] for item in batch]
    
    # 检查批次平衡性
    male_count = sum(1 for g in genders if g == 'male')
    female_count = sum(1 for g in genders if g == 'female')
    
    print(f"🔍 批次检查: male={male_count}, female={female_count}")
    
    if male_count == 0:
        print("⚠️ 警告: 批次中没有男性样本!")
    if female_count == 0:
        print("⚠️ 警告: 批次中没有女性样本!")
    
    return {
        "input": inputs,
        "gender": genders
    }

def create_balanced_dataloader(data: List[Dict], batch_size: int, num_batches: int = 10):
    """创建平衡的数据加载器"""
    dataset = BalancedGEEDataset(data)
    
    if batch_size < 2:
        print("⚠️ 警告: batch_size < 2,无法保证性别平衡")
        return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=balanced_collate)
    
    # 创建平衡批次
    balanced_batches = dataset.create_balanced_batches(batch_size, num_batches)
    
    # 展平批次为单个数据点
    flat_data = []
    for batch in balanced_batches:
        flat_data.extend(batch)
    
    # 创建新的数据集
    balanced_dataset = BalancedGEEDataset(flat_data)
    
    return DataLoader(balanced_dataset, batch_size=batch_size, shuffle=False, collate_fn=balanced_collate)

# 测试函数
if __name__ == "__main__":
    # 测试平衡数据加载器
    import sys
    sys.path.append('.')
    from dataset.gee_processor import GEEProcessor
    
    class MockTokenizer:
        def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
            return messages[0]["content"]
    
    processor = GEEProcessor(MockTokenizer())
    test_data = processor.create_test_data(num_samples=20)
    
    print("🧪 测试平衡数据加载器")
    dataloader = create_balanced_dataloader(test_data, batch_size=4, num_batches=3)
    
    for i, batch in enumerate(dataloader):
        print(f"\n批次 {i+1}:")
        print(f"  输入数量: {len(batch['input'])}")
        print(f"  性别: {batch['gender']}")
        if i >= 2:  # 只测试前3个批次
            break