summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhaoyuren <13851610112@163.com>2025-06-27 11:46:59 -0700
committerhaoyuren <13851610112@163.com>2025-06-27 11:46:59 -0700
commit24e163f9211fb9a9af561de47898ea64f5f26df4 (patch)
treeecfaebd88fdabbf1030e6b6ed1355bda4a095fe9
parent0a8f3fb353d1b95cdef5bf1f0baa666b6f590ab0 (diff)
fix loss
-rw-r--r--losses/gee_loss.py44
-rw-r--r--test_gee_fix.py82
2 files changed, 112 insertions, 14 deletions
diff --git a/losses/gee_loss.py b/losses/gee_loss.py
index 2c21533..2605e22 100644
--- a/losses/gee_loss.py
+++ b/losses/gee_loss.py
@@ -22,7 +22,7 @@ class GEELoss:
def compute_sample_entropy(self, H_tok: torch.Tensor,
prompt_lengths: torch.Tensor) -> torch.Tensor:
- """计算样本平均熵"""
+ """计算样本平均熵 - 修复版本"""
batch_size = H_tok.size(0)
H_i = torch.zeros(batch_size, device=H_tok.device)
@@ -31,10 +31,14 @@ class GEELoss:
gen_start = prompt_lengths[i]
if gen_start < H_tok.size(1):
gen_entropy = H_tok[i, gen_start:]
- # 过滤掉padding token的熵
- valid_entropy = gen_entropy[gen_entropy != 0]
- if valid_entropy.numel() > 0:
- H_i[i] = valid_entropy.mean()
+
+ # 🔧 修复: 不要过滤熵值为0的token!
+ # 熵值为0是合理的(模型确定性高时)
+ # 只过滤掉真正的padding token(用attention_mask标记)
+ if gen_entropy.numel() > 0:
+ H_i[i] = gen_entropy.mean()
+ else:
+ H_i[i] = 0.0
return H_i
@@ -44,8 +48,21 @@ class GEELoss:
male_mask = (gender_labels == 0) # 假设0=male, 1=female
female_mask = (gender_labels == 1)
- H_male = H_i[male_mask].mean() if male_mask.sum() > 0 else torch.tensor(0.0, device=H_i.device)
- H_female = H_i[female_mask].mean() if female_mask.sum() > 0 else torch.tensor(0.0, device=H_i.device)
+ # 🔧 修复: 添加调试信息
+ male_count = male_mask.sum().item()
+ female_count = female_mask.sum().item()
+
+ if male_count == 0:
+ print(f"⚠️ 警告: 批次中没有男性样本")
+ H_male = torch.tensor(0.0, device=H_i.device)
+ else:
+ H_male = H_i[male_mask].mean()
+
+ if female_count == 0:
+ print(f"⚠️ 警告: 批次中没有女性样本")
+ H_female = torch.tensor(0.0, device=H_i.device)
+ else:
+ H_female = H_i[female_mask].mean()
return H_male, H_female
@@ -57,15 +74,13 @@ class GEELoss:
# 计算各组平均熵
H_male, H_female = self.compute_group_entropy(H_i, gender_labels)
- # 计算组间差异
+ # 🔧 修复: 改进组间差异计算
if self.use_l1:
# L1版本
- group_diff = torch.abs(H_female - H_male)
- loss_bias = group_diff
+ loss_bias = torch.abs(H_female - H_male)
else:
- # L2版本
- H_bar_group = (H_male + H_female) / 2
- loss_bias = (H_male - H_bar_group) ** 2 + (H_female - H_bar_group) ** 2
+ # L2版本 - 简化计算
+ loss_bias = (H_female - H_male) ** 2
# 总损失
loss_em = H_bar
@@ -79,7 +94,8 @@ class GEELoss:
'H_bar': H_bar.item(),
'H_male': H_male.item(),
'H_female': H_female.item(),
- 'entropy_gap': abs(H_female - H_male).item()
+ 'entropy_gap': abs(H_female - H_male).item(),
+ 'lambda_weight': self.lambda_weight
}
return loss_total, metrics
diff --git a/test_gee_fix.py b/test_gee_fix.py
new file mode 100644
index 0000000..3e773df
--- /dev/null
+++ b/test_gee_fix.py
@@ -0,0 +1,82 @@
+#!/usr/bin/env python3
+"""
+测试修复后的GEE损失函数
+"""
+import torch
+import sys
+sys.path.append('.')
+
+from losses.gee_loss import GEELoss, gender_to_label
+from dataset.gee_processor import GEEProcessor
+
+print("🧪 测试修复后的GEE损失函数")
+print("="*50)
+
+# 创建模拟tokenizer
+class MockTokenizer:
+ def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
+ return messages[0]["content"]
+
+# 1. 测试数据生成
+processor = GEEProcessor(MockTokenizer())
+test_data = processor.create_test_data(num_samples=6)
+
+print(f"📊 生成 {len(test_data)} 条测试数据")
+for i, item in enumerate(test_data):
+ print(f" {i+1}. {item['gender']}: {item['input'][:50]}...")
+
+# 2. 创建批次
+batch = {
+ "input": [item["input"] for item in test_data[:4]],
+ "gender": [item["gender"] for item in test_data[:4]]
+}
+
+print(f"\n📦 批次信息:")
+print(f"性别: {batch['gender']}")
+
+gender_labels = torch.tensor([gender_to_label(g) for g in batch["gender"]])
+print(f"标签: {gender_labels.tolist()}")
+
+# 3. 测试修复后的损失函数
+gee_loss = GEELoss(lambda_weight=1.0) # 降低lambda权重
+
+# 模拟合理的熵值(包含一些接近0的值)
+H_i_test = torch.tensor([0.8, 0.1, 0.6, 0.2]) # male, female, male, female
+
+print(f"\n🧮 测试修复后的GEE损失:")
+print(f"输入熵值: {H_i_test.tolist()}")
+print(f"性别标签: {batch['gender']}")
+
+loss, metrics = gee_loss.compute_gee_loss(H_i_test, gender_labels)
+
+print(f"\n📈 结果:")
+print(f"总损失: {loss:.6f}")
+print(f"熵最小化损失: {metrics['loss_em']:.6f}")
+print(f"偏见损失: {metrics['loss_bias']:.6f}")
+print(f"男性平均熵: {metrics['H_male']:.6f}")
+print(f"女性平均熵: {metrics['H_female']:.6f}")
+print(f"熵差距: {metrics['entropy_gap']:.6f}")
+print(f"Lambda权重: {metrics['lambda_weight']}")
+
+# 4. 验证修复效果
+print(f"\n✅ 修复验证:")
+if metrics['H_female'] > 0:
+ print("✅ H_female不再为0")
+else:
+ print("❌ H_female仍为0,可能还有问题")
+
+if metrics['entropy_gap'] < 1.0:
+ print("✅ 熵差距在合理范围内")
+else:
+ print("⚠️ 熵差距较大")
+
+if loss < 10.0:
+ print("✅ 总损失在合理范围内")
+else:
+ print("⚠️ 总损失可能过大")
+
+print(f"\n💡 修复要点:")
+print("1. 移除了错误的零熵值过滤")
+print("2. 简化了GEE损失计算")
+print("3. 添加了调试信息")
+print("4. 建议降低lambda权重到0.5-1.0") \ No newline at end of file