1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
|
import torch
import torch.nn.functional as F
from typing import Dict, Tuple
import numpy as np
class GEELoss:
def __init__(self, lambda_weight: float = 3.0, use_l1: bool = False):
self.lambda_weight = lambda_weight
self.use_l1 = use_l1
def compute_token_entropy(self, logits: torch.Tensor,
attention_mask: torch.Tensor = None) -> torch.Tensor:
"""计算token级别的条件熵"""
probs = F.softmax(logits, dim=-1)
log_probs = F.log_softmax(logits, dim=-1)
H_tok = -(probs * log_probs).sum(-1) # (B, T)
if attention_mask is not None:
H_tok = H_tok * attention_mask
return H_tok
def compute_sample_entropy(self, H_tok: torch.Tensor,
prompt_lengths: torch.Tensor) -> torch.Tensor:
"""计算样本平均熵 - 修复版本"""
batch_size = H_tok.size(0)
H_i = torch.zeros(batch_size, device=H_tok.device)
for i in range(batch_size):
# 只计算生成部分的熵(排除prompt部分)
gen_start = prompt_lengths[i]
if gen_start < H_tok.size(1):
gen_entropy = H_tok[i, gen_start:]
# 🔧 修复: 不要过滤熵值为0的token!
# 熵值为0是合理的(模型确定性高时)
# 只过滤掉真正的padding token(用attention_mask标记)
if gen_entropy.numel() > 0:
H_i[i] = gen_entropy.mean()
else:
H_i[i] = 0.0
return H_i
def compute_group_entropy(self, H_i: torch.Tensor,
gender_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""计算各组平均熵"""
male_mask = (gender_labels == 0) # 假设0=male, 1=female
female_mask = (gender_labels == 1)
# 🔧 修复: 添加调试信息
male_count = male_mask.sum().item()
female_count = female_mask.sum().item()
if male_count == 0:
print(f"⚠️ 警告: 批次中没有男性样本")
H_male = torch.tensor(0.0, device=H_i.device)
else:
H_male = H_i[male_mask].mean()
if female_count == 0:
print(f"⚠️ 警告: 批次中没有女性样本")
H_female = torch.tensor(0.0, device=H_i.device)
else:
H_female = H_i[female_mask].mean()
return H_male, H_female
def compute_gee_loss(self, H_i: torch.Tensor,
gender_labels: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
"""计算GEE损失"""
H_bar = H_i.mean() # 全批平均熵
# 计算各组平均熵
H_male, H_female = self.compute_group_entropy(H_i, gender_labels)
# 🔧 修复: 改进组间差异计算
if self.use_l1:
# L1版本
loss_bias = torch.abs(H_female - H_male)
else:
# L2版本 - 简化计算
loss_bias = (H_female - H_male) ** 2
# 总损失
loss_em = H_bar
loss_total = loss_em + self.lambda_weight * loss_bias
# 返回损失和监控指标
metrics = {
'loss_em': loss_em.item(),
'loss_bias': loss_bias.item(),
'loss_total': loss_total.item(),
'H_bar': H_bar.item(),
'H_male': H_male.item(),
'H_female': H_female.item(),
'entropy_gap': abs(H_female - H_male).item(),
'lambda_weight': self.lambda_weight
}
return loss_total, metrics
def update_lambda(self, new_lambda: float):
"""更新lambda权重(用于自动退火)"""
self.lambda_weight = new_lambda
def gender_to_label(gender_str: str) -> int:
"""将性别字符串转换为标签"""
return 0 if gender_str == 'male' else 1
def label_to_gender(label: int) -> str:
"""将标签转换为性别字符串"""
return 'male' if label == 0 else 'female'
|