diff options
| -rw-r--r-- | losses/gee_loss.py | 44 | ||||
| -rw-r--r-- | test_gee_fix.py | 82 |
2 files changed, 112 insertions, 14 deletions
diff --git a/losses/gee_loss.py b/losses/gee_loss.py index 2c21533..2605e22 100644 --- a/losses/gee_loss.py +++ b/losses/gee_loss.py @@ -22,7 +22,7 @@ class GEELoss: def compute_sample_entropy(self, H_tok: torch.Tensor, prompt_lengths: torch.Tensor) -> torch.Tensor: - """计算样本平均熵""" + """计算样本平均熵 - 修复版本""" batch_size = H_tok.size(0) H_i = torch.zeros(batch_size, device=H_tok.device) @@ -31,10 +31,14 @@ class GEELoss: gen_start = prompt_lengths[i] if gen_start < H_tok.size(1): gen_entropy = H_tok[i, gen_start:] - # 过滤掉padding token的熵 - valid_entropy = gen_entropy[gen_entropy != 0] - if valid_entropy.numel() > 0: - H_i[i] = valid_entropy.mean() + + # 🔧 修复: 不要过滤熵值为0的token! + # 熵值为0是合理的(模型确定性高时) + # 只过滤掉真正的padding token(用attention_mask标记) + if gen_entropy.numel() > 0: + H_i[i] = gen_entropy.mean() + else: + H_i[i] = 0.0 return H_i @@ -44,8 +48,21 @@ class GEELoss: male_mask = (gender_labels == 0) # 假设0=male, 1=female female_mask = (gender_labels == 1) - H_male = H_i[male_mask].mean() if male_mask.sum() > 0 else torch.tensor(0.0, device=H_i.device) - H_female = H_i[female_mask].mean() if female_mask.sum() > 0 else torch.tensor(0.0, device=H_i.device) + # 🔧 修复: 添加调试信息 + male_count = male_mask.sum().item() + female_count = female_mask.sum().item() + + if male_count == 0: + print(f"⚠️ 警告: 批次中没有男性样本") + H_male = torch.tensor(0.0, device=H_i.device) + else: + H_male = H_i[male_mask].mean() + + if female_count == 0: + print(f"⚠️ 警告: 批次中没有女性样本") + H_female = torch.tensor(0.0, device=H_i.device) + else: + H_female = H_i[female_mask].mean() return H_male, H_female @@ -57,15 +74,13 @@ class GEELoss: # 计算各组平均熵 H_male, H_female = self.compute_group_entropy(H_i, gender_labels) - # 计算组间差异 + # 🔧 修复: 改进组间差异计算 if self.use_l1: # L1版本 - group_diff = torch.abs(H_female - H_male) - loss_bias = group_diff + loss_bias = torch.abs(H_female - H_male) else: - # L2版本 - H_bar_group = (H_male + H_female) / 2 - loss_bias = (H_male - H_bar_group) ** 2 + (H_female - H_bar_group) ** 2 + # L2版本 - 简化计算 + loss_bias = (H_female - H_male) ** 2 # 总损失 loss_em = H_bar @@ -79,7 +94,8 @@ class GEELoss: 'H_bar': H_bar.item(), 'H_male': H_male.item(), 'H_female': H_female.item(), - 'entropy_gap': abs(H_female - H_male).item() + 'entropy_gap': abs(H_female - H_male).item(), + 'lambda_weight': self.lambda_weight } return loss_total, metrics diff --git a/test_gee_fix.py b/test_gee_fix.py new file mode 100644 index 0000000..3e773df --- /dev/null +++ b/test_gee_fix.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +""" +测试修复后的GEE损失函数 +""" +import torch +import sys +sys.path.append('.') + +from losses.gee_loss import GEELoss, gender_to_label +from dataset.gee_processor import GEEProcessor + +print("🧪 测试修复后的GEE损失函数") +print("="*50) + +# 创建模拟tokenizer +class MockTokenizer: + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): + return messages[0]["content"] + +# 1. 测试数据生成 +processor = GEEProcessor(MockTokenizer()) +test_data = processor.create_test_data(num_samples=6) + +print(f"📊 生成 {len(test_data)} 条测试数据") +for i, item in enumerate(test_data): + print(f" {i+1}. {item['gender']}: {item['input'][:50]}...") + +# 2. 创建批次 +batch = { + "input": [item["input"] for item in test_data[:4]], + "gender": [item["gender"] for item in test_data[:4]] +} + +print(f"\n📦 批次信息:") +print(f"性别: {batch['gender']}") + +gender_labels = torch.tensor([gender_to_label(g) for g in batch["gender"]]) +print(f"标签: {gender_labels.tolist()}") + +# 3. 测试修复后的损失函数 +gee_loss = GEELoss(lambda_weight=1.0) # 降低lambda权重 + +# 模拟合理的熵值(包含一些接近0的值) +H_i_test = torch.tensor([0.8, 0.1, 0.6, 0.2]) # male, female, male, female + +print(f"\n🧮 测试修复后的GEE损失:") +print(f"输入熵值: {H_i_test.tolist()}") +print(f"性别标签: {batch['gender']}") + +loss, metrics = gee_loss.compute_gee_loss(H_i_test, gender_labels) + +print(f"\n📈 结果:") +print(f"总损失: {loss:.6f}") +print(f"熵最小化损失: {metrics['loss_em']:.6f}") +print(f"偏见损失: {metrics['loss_bias']:.6f}") +print(f"男性平均熵: {metrics['H_male']:.6f}") +print(f"女性平均熵: {metrics['H_female']:.6f}") +print(f"熵差距: {metrics['entropy_gap']:.6f}") +print(f"Lambda权重: {metrics['lambda_weight']}") + +# 4. 验证修复效果 +print(f"\n✅ 修复验证:") +if metrics['H_female'] > 0: + print("✅ H_female不再为0") +else: + print("❌ H_female仍为0,可能还有问题") + +if metrics['entropy_gap'] < 1.0: + print("✅ 熵差距在合理范围内") +else: + print("⚠️ 熵差距较大") + +if loss < 10.0: + print("✅ 总损失在合理范围内") +else: + print("⚠️ 总损失可能过大") + +print(f"\n💡 修复要点:") +print("1. 移除了错误的零熵值过滤") +print("2. 简化了GEE损失计算") +print("3. 添加了调试信息") +print("4. 建议降低lambda权重到0.5-1.0")
\ No newline at end of file |
