summaryrefslogtreecommitdiff
path: root/losses/debiasing_loss.py
blob: b2fe99cc89003b11d07264a5d928b887f67f5148 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.nn.functional as F
from typing import Dict, Tuple
import numpy as np

class DebiasingLoss:
    """
    纯偏见减少损失函数
    目标:最小化男女间的熵差,不包含整体熵最小化
    """
    def __init__(self, use_l1: bool = False, scale_factor: float = 1.0):
        self.use_l1 = use_l1
        self.scale_factor = scale_factor  # 可选的缩放因子
    
    def compute_token_entropy(self, logits: torch.Tensor, 
                            attention_mask: torch.Tensor = None) -> torch.Tensor:
        """计算token级别的条件熵"""
        probs = F.softmax(logits, dim=-1)
        log_probs = F.log_softmax(logits, dim=-1)
        H_tok = -(probs * log_probs).sum(-1)  # (B, T)
        
        if attention_mask is not None:
            H_tok = H_tok * attention_mask
        
        return H_tok
    
    def compute_sample_entropy(self, H_tok: torch.Tensor, 
                             prompt_lengths: torch.Tensor) -> torch.Tensor:
        """计算样本平均熵"""
        batch_size = H_tok.size(0)
        H_i = torch.zeros(batch_size, device=H_tok.device)
        
        for i in range(batch_size):
            # 只计算生成部分的熵(排除prompt部分)
            gen_start = prompt_lengths[i]
            if gen_start < H_tok.size(1):
                gen_entropy = H_tok[i, gen_start:]
                
                if gen_entropy.numel() > 0:
                    H_i[i] = gen_entropy.mean()
                else:
                    H_i[i] = 0.0
        
        return H_i
    
    def compute_group_entropy(self, H_i: torch.Tensor, 
                            gender_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """计算各组平均熵"""
        male_mask = (gender_labels == 0)  # 假设0=male, 1=female
        female_mask = (gender_labels == 1)
        
        male_count = male_mask.sum().item()
        female_count = female_mask.sum().item()
        
        if male_count == 0:
            print(f"⚠️ 警告: 批次中没有男性样本")
            H_male = torch.tensor(0.0, device=H_i.device)
        else:
            H_male = H_i[male_mask].mean()
            
        if female_count == 0:
            print(f"⚠️ 警告: 批次中没有女性样本")
            H_female = torch.tensor(0.0, device=H_i.device)
        else:
            H_female = H_i[female_mask].mean()
        
        return H_male, H_female
    
    def compute_debiasing_loss(self, H_i: torch.Tensor, 
                             gender_labels: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
        """
        计算纯偏见减少损失
        目标:最小化 |H_female - H_male|
        """
        # 计算各组平均熵
        H_male, H_female = self.compute_group_entropy(H_i, gender_labels)
        
        # 计算熵差距(这是我们要最小化的目标)
        entropy_gap = H_female - H_male
        
        if self.use_l1:
            # L1损失:|H_female - H_male|
            debiasing_loss = torch.abs(entropy_gap) * self.scale_factor
        else:
            # L2损失:(H_female - H_male)²
            debiasing_loss = (entropy_gap ** 2) * self.scale_factor
        
        # 计算监控指标
        H_bar = H_i.mean()  # 仅用于监控,不参与损失计算
        
        metrics = {
            'loss_debiasing': debiasing_loss.item(),
            'entropy_gap': abs(entropy_gap.item()),
            'entropy_gap_signed': entropy_gap.item(),  # 带符号的差距
            'H_bar': H_bar.item(),  # 整体平均熵(仅监控)
            'H_male': H_male.item(),
            'H_female': H_female.item(),
            'scale_factor': self.scale_factor
        }
        
        return debiasing_loss, metrics
    
    def update_scale_factor(self, new_scale: float):
        """更新缩放因子(用于调整损失大小)"""
        self.scale_factor = new_scale

def gender_to_label(gender_str: str) -> int:
    """将性别字符串转换为标签"""
    return 0 if gender_str == 'male' else 1

def label_to_gender(label: int) -> str:
    """将标签转换为性别字符串"""
    return 'male' if label == 0 else 'female'