summaryrefslogtreecommitdiff
path: root/losses
diff options
context:
space:
mode:
Diffstat (limited to 'losses')
-rw-r--r--losses/debiasing_loss.py113
1 files changed, 113 insertions, 0 deletions
diff --git a/losses/debiasing_loss.py b/losses/debiasing_loss.py
new file mode 100644
index 0000000..b2fe99c
--- /dev/null
+++ b/losses/debiasing_loss.py
@@ -0,0 +1,113 @@
+import torch
+import torch.nn.functional as F
+from typing import Dict, Tuple
+import numpy as np
+
+class DebiasingLoss:
+ """
+ 纯偏见减少损失函数
+ 目标:最小化男女间的熵差,不包含整体熵最小化
+ """
+ def __init__(self, use_l1: bool = False, scale_factor: float = 1.0):
+ self.use_l1 = use_l1
+ self.scale_factor = scale_factor # 可选的缩放因子
+
+ def compute_token_entropy(self, logits: torch.Tensor,
+ attention_mask: torch.Tensor = None) -> torch.Tensor:
+ """计算token级别的条件熵"""
+ probs = F.softmax(logits, dim=-1)
+ log_probs = F.log_softmax(logits, dim=-1)
+ H_tok = -(probs * log_probs).sum(-1) # (B, T)
+
+ if attention_mask is not None:
+ H_tok = H_tok * attention_mask
+
+ return H_tok
+
+ def compute_sample_entropy(self, H_tok: torch.Tensor,
+ prompt_lengths: torch.Tensor) -> torch.Tensor:
+ """计算样本平均熵"""
+ batch_size = H_tok.size(0)
+ H_i = torch.zeros(batch_size, device=H_tok.device)
+
+ for i in range(batch_size):
+ # 只计算生成部分的熵(排除prompt部分)
+ gen_start = prompt_lengths[i]
+ if gen_start < H_tok.size(1):
+ gen_entropy = H_tok[i, gen_start:]
+
+ if gen_entropy.numel() > 0:
+ H_i[i] = gen_entropy.mean()
+ else:
+ H_i[i] = 0.0
+
+ return H_i
+
+ def compute_group_entropy(self, H_i: torch.Tensor,
+ gender_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """计算各组平均熵"""
+ male_mask = (gender_labels == 0) # 假设0=male, 1=female
+ female_mask = (gender_labels == 1)
+
+ male_count = male_mask.sum().item()
+ female_count = female_mask.sum().item()
+
+ if male_count == 0:
+ print(f"⚠️ 警告: 批次中没有男性样本")
+ H_male = torch.tensor(0.0, device=H_i.device)
+ else:
+ H_male = H_i[male_mask].mean()
+
+ if female_count == 0:
+ print(f"⚠️ 警告: 批次中没有女性样本")
+ H_female = torch.tensor(0.0, device=H_i.device)
+ else:
+ H_female = H_i[female_mask].mean()
+
+ return H_male, H_female
+
+ def compute_debiasing_loss(self, H_i: torch.Tensor,
+ gender_labels: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
+ """
+ 计算纯偏见减少损失
+ 目标:最小化 |H_female - H_male|
+ """
+ # 计算各组平均熵
+ H_male, H_female = self.compute_group_entropy(H_i, gender_labels)
+
+ # 计算熵差距(这是我们要最小化的目标)
+ entropy_gap = H_female - H_male
+
+ if self.use_l1:
+ # L1损失:|H_female - H_male|
+ debiasing_loss = torch.abs(entropy_gap) * self.scale_factor
+ else:
+ # L2损失:(H_female - H_male)²
+ debiasing_loss = (entropy_gap ** 2) * self.scale_factor
+
+ # 计算监控指标
+ H_bar = H_i.mean() # 仅用于监控,不参与损失计算
+
+ metrics = {
+ 'loss_debiasing': debiasing_loss.item(),
+ 'entropy_gap': abs(entropy_gap.item()),
+ 'entropy_gap_signed': entropy_gap.item(), # 带符号的差距
+ 'H_bar': H_bar.item(), # 整体平均熵(仅监控)
+ 'H_male': H_male.item(),
+ 'H_female': H_female.item(),
+ 'scale_factor': self.scale_factor
+ }
+
+ return debiasing_loss, metrics
+
+ def update_scale_factor(self, new_scale: float):
+ """更新缩放因子(用于调整损失大小)"""
+ self.scale_factor = new_scale
+
+def gender_to_label(gender_str: str) -> int:
+ """将性别字符串转换为标签"""
+ return 0 if gender_str == 'male' else 1
+
+def label_to_gender(label: int) -> str:
+ """将标签转换为性别字符串"""
+ return 'male' if label == 0 else 'female' \ No newline at end of file