summaryrefslogtreecommitdiff
path: root/test_gee_training.py
diff options
context:
space:
mode:
Diffstat (limited to 'test_gee_training.py')
-rw-r--r--test_gee_training.py231
1 files changed, 231 insertions, 0 deletions
diff --git a/test_gee_training.py b/test_gee_training.py
new file mode 100644
index 0000000..82cce04
--- /dev/null
+++ b/test_gee_training.py
@@ -0,0 +1,231 @@
+#!/usr/bin/env python3
+"""
+GEE训练逻辑测试脚本
+模拟训练过程而不需要真实模型
+"""
+
+import sys
+import os
+import torch
+import numpy as np
+from pathlib import Path
+
+# 添加项目路径
+sys.path.append('.')
+
+from dataset.gee_processor import GEEProcessor
+from losses.gee_loss import GEELoss, gender_to_label
+
+class MockTokenizer:
+ def __init__(self):
+ self.pad_token_id = 0
+ self.eos_token = '<|endoftext|>'
+ self.pad_token = self.eos_token
+
+ def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
+ return messages[0]["content"]
+
+ def __call__(self, texts, return_tensors=None, padding=None, truncation=None, max_length=None):
+ # 模拟tokenization
+ batch_size = len(texts)
+ seq_len = 50 # 固定序列长度用于测试
+
+ return {
+ 'input_ids': torch.randint(1, 1000, (batch_size, seq_len)),
+ 'attention_mask': torch.ones(batch_size, seq_len)
+ }
+
+class MockModel:
+ def __init__(self):
+ self.device = 'cpu'
+
+ def __call__(self, input_ids, attention_mask=None):
+ batch_size, seq_len = input_ids.shape
+ vocab_size = 1000
+
+ # 模拟logits输出
+ logits = torch.randn(batch_size, seq_len, vocab_size)
+
+ class MockOutput:
+ def __init__(self, logits):
+ self.logits = logits
+
+ return MockOutput(logits)
+
+ def generate(self, input_ids, attention_mask=None, max_new_tokens=50, **kwargs):
+ batch_size, prompt_len = input_ids.shape
+ # 模拟生成新的token
+ new_tokens = torch.randint(1, 1000, (batch_size, max_new_tokens))
+ return torch.cat([input_ids, new_tokens], dim=1)
+
+def test_gee_training_logic():
+ """测试GEE训练逻辑"""
+ print("="*60)
+ print("测试GEE训练逻辑")
+ print("="*60)
+
+ # 初始化组件
+ tokenizer = MockTokenizer()
+ model = MockModel()
+ gee_processor = GEEProcessor(tokenizer)
+ gee_loss_fn = GEELoss(lambda_weight=3.0, use_l1=False)
+
+ # 生成测试数据
+ train_data = gee_processor.create_test_data(num_samples=20)
+ print(f"生成训练数据: {len(train_data)} 条")
+
+ # 模拟训练循环
+ batch_size = 4
+ num_steps = 5
+
+ print(f"\n开始模拟训练 ({num_steps} 步)...")
+
+ for step in range(1, num_steps + 1):
+ # 创建batch
+ batch_data = train_data[(step-1)*batch_size:step*batch_size]
+ if len(batch_data) < batch_size:
+ # 循环使用数据
+ batch_data = train_data[:batch_size]
+
+ batch = {
+ "input": [item["input"] for item in batch_data],
+ "gender": [item["gender"] for item in batch_data]
+ }
+
+ # 模拟tokenization
+ inputs = tokenizer(batch["input"])
+
+ # 模拟生成
+ gen_ids = model.generate(**inputs, max_new_tokens=20)
+
+ # 准备完整序列
+ seq = gen_ids[:, :100] # 限制长度用于测试
+ prompt_lengths = torch.tensor([inputs['input_ids'].shape[1]] * batch_size)
+
+ # 计算logits和熵
+ mock_output = model(seq)
+ logits = mock_output.logits
+
+ # 计算GEE损失
+ H_tok = gee_loss_fn.compute_token_entropy(logits)
+ H_i = gee_loss_fn.compute_sample_entropy(H_tok, prompt_lengths)
+
+ # 准备性别标签
+ gender_labels = torch.tensor([gender_to_label(g) for g in batch["gender"]])
+
+ # 计算损失
+ loss, metrics = gee_loss_fn.compute_gee_loss(H_i, gender_labels)
+
+ # 打印训练日志
+ print(f"Step {step} | loss={loss.item():.6f} | "
+ f"entropy_gap={metrics['entropy_gap']:.6f} | "
+ f"H_male={metrics['H_male']:.6f} | "
+ f"H_female={metrics['H_female']:.6f}")
+
+ # 验证损失计算
+ assert not torch.isnan(loss), "损失为NaN"
+ assert loss.item() > 0, "损失应该为正值"
+ assert 'entropy_gap' in metrics, "缺少entropy_gap指标"
+
+ print("✓ GEE训练逻辑测试通过")
+
+def test_different_lambdas():
+ """测试不同lambda值的影响"""
+ print("\n" + "="*60)
+ print("测试不同lambda值的影响")
+ print("="*60)
+
+ tokenizer = MockTokenizer()
+ model = MockModel()
+ gee_processor = GEEProcessor(tokenizer)
+
+ # 测试不同的lambda值
+ lambda_values = [0.0, 1.0, 3.0, 5.0]
+
+ # 创建固定的测试数据
+ batch_size = 4
+ seq_len = 50
+ vocab_size = 1000
+
+ logits = torch.randn(batch_size, seq_len, vocab_size)
+ prompt_lengths = torch.tensor([20, 20, 20, 20])
+ gender_labels = torch.tensor([0, 1, 0, 1]) # male, female, male, female
+
+ print("Lambda值对损失的影响:")
+ print("Lambda\tEM Loss\tBias Loss\tTotal Loss\tEntropy Gap")
+ print("-" * 60)
+
+ for lambda_val in lambda_values:
+ gee_loss_fn = GEELoss(lambda_weight=lambda_val, use_l1=False)
+
+ H_tok = gee_loss_fn.compute_token_entropy(logits)
+ H_i = gee_loss_fn.compute_sample_entropy(H_tok, prompt_lengths)
+ loss, metrics = gee_loss_fn.compute_gee_loss(H_i, gender_labels)
+
+ print(f"{lambda_val:.1f}\t{metrics['loss_em']:.4f}\t"
+ f"{metrics['loss_bias']:.4f}\t{metrics['loss_total']:.4f}\t"
+ f"{metrics['entropy_gap']:.4f}")
+
+ print("✓ Lambda值测试通过")
+
+def test_l1_vs_l2():
+ """测试L1和L2损失的差异"""
+ print("\n" + "="*60)
+ print("测试L1和L2损失的差异")
+ print("="*60)
+
+ # 创建固定的测试数据
+ batch_size = 4
+ seq_len = 50
+ vocab_size = 1000
+
+ logits = torch.randn(batch_size, seq_len, vocab_size)
+ prompt_lengths = torch.tensor([20, 20, 20, 20])
+ gender_labels = torch.tensor([0, 1, 0, 1])
+
+ # 测试L2版本
+ gee_loss_l2 = GEELoss(lambda_weight=3.0, use_l1=False)
+ H_tok = gee_loss_l2.compute_token_entropy(logits)
+ H_i = gee_loss_l2.compute_sample_entropy(H_tok, prompt_lengths)
+ loss_l2, metrics_l2 = gee_loss_l2.compute_gee_loss(H_i, gender_labels)
+
+ # 测试L1版本
+ gee_loss_l1 = GEELoss(lambda_weight=3.0, use_l1=True)
+ loss_l1, metrics_l1 = gee_loss_l1.compute_gee_loss(H_i, gender_labels)
+
+ print(f"L2损失: {metrics_l2['loss_total']:.6f} (bias: {metrics_l2['loss_bias']:.6f})")
+ print(f"L1损失: {metrics_l1['loss_total']:.6f} (bias: {metrics_l1['loss_bias']:.6f})")
+ print(f"熵差距: {metrics_l2['entropy_gap']:.6f}")
+
+ print("✓ L1 vs L2测试通过")
+
+def main():
+ """主测试函数"""
+ print("开始GEE训练逻辑测试...")
+
+ try:
+ test_gee_training_logic()
+ test_different_lambdas()
+ test_l1_vs_l2()
+
+ print("\n" + "="*60)
+ print("所有训练逻辑测试通过!✓")
+ print("="*60)
+ print("\n核心功能验证:")
+ print("✅ 数据处理流程正常")
+ print("✅ 损失函数计算正确")
+ print("✅ 训练循环逻辑正确")
+ print("✅ 不同参数配置有效")
+ print("\n🎯 准备就绪,可以进行真实模型训练!")
+
+ except Exception as e:
+ print(f"\n测试失败: {e}")
+ import traceback
+ traceback.print_exc()
+ return False
+
+ return True
+
+if __name__ == "__main__":
+ success = main()
+ sys.exit(0 if success else 1) \ No newline at end of file