diff options
| author | haoyuren <13851610112@163.com> | 2025-06-27 14:09:07 -0700 |
|---|---|---|
| committer | haoyuren <13851610112@163.com> | 2025-06-27 14:09:07 -0700 |
| commit | 6b1180928562d1b407d5792ec20aae22cb3f70fd (patch) | |
| tree | e524f6cba37a8bc053b7b62fead738c2a612c948 /losses/debiasing_loss.py | |
| parent | a939274b08f377a2ed93f7234e11f5257ef7667a (diff) | |
remove EM
Diffstat (limited to 'losses/debiasing_loss.py')
| -rw-r--r-- | losses/debiasing_loss.py | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/losses/debiasing_loss.py b/losses/debiasing_loss.py new file mode 100644 index 0000000..b2fe99c --- /dev/null +++ b/losses/debiasing_loss.py @@ -0,0 +1,113 @@ +import torch +import torch.nn.functional as F +from typing import Dict, Tuple +import numpy as np + +class DebiasingLoss: + """ + 纯偏见减少损失函数 + 目标:最小化男女间的熵差,不包含整体熵最小化 + """ + def __init__(self, use_l1: bool = False, scale_factor: float = 1.0): + self.use_l1 = use_l1 + self.scale_factor = scale_factor # 可选的缩放因子 + + def compute_token_entropy(self, logits: torch.Tensor, + attention_mask: torch.Tensor = None) -> torch.Tensor: + """计算token级别的条件熵""" + probs = F.softmax(logits, dim=-1) + log_probs = F.log_softmax(logits, dim=-1) + H_tok = -(probs * log_probs).sum(-1) # (B, T) + + if attention_mask is not None: + H_tok = H_tok * attention_mask + + return H_tok + + def compute_sample_entropy(self, H_tok: torch.Tensor, + prompt_lengths: torch.Tensor) -> torch.Tensor: + """计算样本平均熵""" + batch_size = H_tok.size(0) + H_i = torch.zeros(batch_size, device=H_tok.device) + + for i in range(batch_size): + # 只计算生成部分的熵(排除prompt部分) + gen_start = prompt_lengths[i] + if gen_start < H_tok.size(1): + gen_entropy = H_tok[i, gen_start:] + + if gen_entropy.numel() > 0: + H_i[i] = gen_entropy.mean() + else: + H_i[i] = 0.0 + + return H_i + + def compute_group_entropy(self, H_i: torch.Tensor, + gender_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """计算各组平均熵""" + male_mask = (gender_labels == 0) # 假设0=male, 1=female + female_mask = (gender_labels == 1) + + male_count = male_mask.sum().item() + female_count = female_mask.sum().item() + + if male_count == 0: + print(f"⚠️ 警告: 批次中没有男性样本") + H_male = torch.tensor(0.0, device=H_i.device) + else: + H_male = H_i[male_mask].mean() + + if female_count == 0: + print(f"⚠️ 警告: 批次中没有女性样本") + H_female = torch.tensor(0.0, device=H_i.device) + else: + H_female = H_i[female_mask].mean() + + return H_male, H_female + + def compute_debiasing_loss(self, H_i: torch.Tensor, + gender_labels: torch.Tensor) -> Tuple[torch.Tensor, Dict]: + """ + 计算纯偏见减少损失 + 目标:最小化 |H_female - H_male| + """ + # 计算各组平均熵 + H_male, H_female = self.compute_group_entropy(H_i, gender_labels) + + # 计算熵差距(这是我们要最小化的目标) + entropy_gap = H_female - H_male + + if self.use_l1: + # L1损失:|H_female - H_male| + debiasing_loss = torch.abs(entropy_gap) * self.scale_factor + else: + # L2损失:(H_female - H_male)² + debiasing_loss = (entropy_gap ** 2) * self.scale_factor + + # 计算监控指标 + H_bar = H_i.mean() # 仅用于监控,不参与损失计算 + + metrics = { + 'loss_debiasing': debiasing_loss.item(), + 'entropy_gap': abs(entropy_gap.item()), + 'entropy_gap_signed': entropy_gap.item(), # 带符号的差距 + 'H_bar': H_bar.item(), # 整体平均熵(仅监控) + 'H_male': H_male.item(), + 'H_female': H_female.item(), + 'scale_factor': self.scale_factor + } + + return debiasing_loss, metrics + + def update_scale_factor(self, new_scale: float): + """更新缩放因子(用于调整损失大小)""" + self.scale_factor = new_scale + +def gender_to_label(gender_str: str) -> int: + """将性别字符串转换为标签""" + return 0 if gender_str == 'male' else 1 + +def label_to_gender(label: int) -> str: + """将标签转换为性别字符串""" + return 'male' if label == 0 else 'female'
\ No newline at end of file |
