diff options
Diffstat (limited to 'test_gee_training.py')
| -rw-r--r-- | test_gee_training.py | 231 |
1 files changed, 231 insertions, 0 deletions
diff --git a/test_gee_training.py b/test_gee_training.py new file mode 100644 index 0000000..82cce04 --- /dev/null +++ b/test_gee_training.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 +""" +GEE训练逻辑测试脚本 +模拟训练过程而不需要真实模型 +""" + +import sys +import os +import torch +import numpy as np +from pathlib import Path + +# 添加项目路径 +sys.path.append('.') + +from dataset.gee_processor import GEEProcessor +from losses.gee_loss import GEELoss, gender_to_label + +class MockTokenizer: + def __init__(self): + self.pad_token_id = 0 + self.eos_token = '<|endoftext|>' + self.pad_token = self.eos_token + + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): + return messages[0]["content"] + + def __call__(self, texts, return_tensors=None, padding=None, truncation=None, max_length=None): + # 模拟tokenization + batch_size = len(texts) + seq_len = 50 # 固定序列长度用于测试 + + return { + 'input_ids': torch.randint(1, 1000, (batch_size, seq_len)), + 'attention_mask': torch.ones(batch_size, seq_len) + } + +class MockModel: + def __init__(self): + self.device = 'cpu' + + def __call__(self, input_ids, attention_mask=None): + batch_size, seq_len = input_ids.shape + vocab_size = 1000 + + # 模拟logits输出 + logits = torch.randn(batch_size, seq_len, vocab_size) + + class MockOutput: + def __init__(self, logits): + self.logits = logits + + return MockOutput(logits) + + def generate(self, input_ids, attention_mask=None, max_new_tokens=50, **kwargs): + batch_size, prompt_len = input_ids.shape + # 模拟生成新的token + new_tokens = torch.randint(1, 1000, (batch_size, max_new_tokens)) + return torch.cat([input_ids, new_tokens], dim=1) + +def test_gee_training_logic(): + """测试GEE训练逻辑""" + print("="*60) + print("测试GEE训练逻辑") + print("="*60) + + # 初始化组件 + tokenizer = MockTokenizer() + model = MockModel() + gee_processor = GEEProcessor(tokenizer) + gee_loss_fn = GEELoss(lambda_weight=3.0, use_l1=False) + + # 生成测试数据 + train_data = gee_processor.create_test_data(num_samples=20) + print(f"生成训练数据: {len(train_data)} 条") + + # 模拟训练循环 + batch_size = 4 + num_steps = 5 + + print(f"\n开始模拟训练 ({num_steps} 步)...") + + for step in range(1, num_steps + 1): + # 创建batch + batch_data = train_data[(step-1)*batch_size:step*batch_size] + if len(batch_data) < batch_size: + # 循环使用数据 + batch_data = train_data[:batch_size] + + batch = { + "input": [item["input"] for item in batch_data], + "gender": [item["gender"] for item in batch_data] + } + + # 模拟tokenization + inputs = tokenizer(batch["input"]) + + # 模拟生成 + gen_ids = model.generate(**inputs, max_new_tokens=20) + + # 准备完整序列 + seq = gen_ids[:, :100] # 限制长度用于测试 + prompt_lengths = torch.tensor([inputs['input_ids'].shape[1]] * batch_size) + + # 计算logits和熵 + mock_output = model(seq) + logits = mock_output.logits + + # 计算GEE损失 + H_tok = gee_loss_fn.compute_token_entropy(logits) + H_i = gee_loss_fn.compute_sample_entropy(H_tok, prompt_lengths) + + # 准备性别标签 + gender_labels = torch.tensor([gender_to_label(g) for g in batch["gender"]]) + + # 计算损失 + loss, metrics = gee_loss_fn.compute_gee_loss(H_i, gender_labels) + + # 打印训练日志 + print(f"Step {step} | loss={loss.item():.6f} | " + f"entropy_gap={metrics['entropy_gap']:.6f} | " + f"H_male={metrics['H_male']:.6f} | " + f"H_female={metrics['H_female']:.6f}") + + # 验证损失计算 + assert not torch.isnan(loss), "损失为NaN" + assert loss.item() > 0, "损失应该为正值" + assert 'entropy_gap' in metrics, "缺少entropy_gap指标" + + print("✓ GEE训练逻辑测试通过") + +def test_different_lambdas(): + """测试不同lambda值的影响""" + print("\n" + "="*60) + print("测试不同lambda值的影响") + print("="*60) + + tokenizer = MockTokenizer() + model = MockModel() + gee_processor = GEEProcessor(tokenizer) + + # 测试不同的lambda值 + lambda_values = [0.0, 1.0, 3.0, 5.0] + + # 创建固定的测试数据 + batch_size = 4 + seq_len = 50 + vocab_size = 1000 + + logits = torch.randn(batch_size, seq_len, vocab_size) + prompt_lengths = torch.tensor([20, 20, 20, 20]) + gender_labels = torch.tensor([0, 1, 0, 1]) # male, female, male, female + + print("Lambda值对损失的影响:") + print("Lambda\tEM Loss\tBias Loss\tTotal Loss\tEntropy Gap") + print("-" * 60) + + for lambda_val in lambda_values: + gee_loss_fn = GEELoss(lambda_weight=lambda_val, use_l1=False) + + H_tok = gee_loss_fn.compute_token_entropy(logits) + H_i = gee_loss_fn.compute_sample_entropy(H_tok, prompt_lengths) + loss, metrics = gee_loss_fn.compute_gee_loss(H_i, gender_labels) + + print(f"{lambda_val:.1f}\t{metrics['loss_em']:.4f}\t" + f"{metrics['loss_bias']:.4f}\t{metrics['loss_total']:.4f}\t" + f"{metrics['entropy_gap']:.4f}") + + print("✓ Lambda值测试通过") + +def test_l1_vs_l2(): + """测试L1和L2损失的差异""" + print("\n" + "="*60) + print("测试L1和L2损失的差异") + print("="*60) + + # 创建固定的测试数据 + batch_size = 4 + seq_len = 50 + vocab_size = 1000 + + logits = torch.randn(batch_size, seq_len, vocab_size) + prompt_lengths = torch.tensor([20, 20, 20, 20]) + gender_labels = torch.tensor([0, 1, 0, 1]) + + # 测试L2版本 + gee_loss_l2 = GEELoss(lambda_weight=3.0, use_l1=False) + H_tok = gee_loss_l2.compute_token_entropy(logits) + H_i = gee_loss_l2.compute_sample_entropy(H_tok, prompt_lengths) + loss_l2, metrics_l2 = gee_loss_l2.compute_gee_loss(H_i, gender_labels) + + # 测试L1版本 + gee_loss_l1 = GEELoss(lambda_weight=3.0, use_l1=True) + loss_l1, metrics_l1 = gee_loss_l1.compute_gee_loss(H_i, gender_labels) + + print(f"L2损失: {metrics_l2['loss_total']:.6f} (bias: {metrics_l2['loss_bias']:.6f})") + print(f"L1损失: {metrics_l1['loss_total']:.6f} (bias: {metrics_l1['loss_bias']:.6f})") + print(f"熵差距: {metrics_l2['entropy_gap']:.6f}") + + print("✓ L1 vs L2测试通过") + +def main(): + """主测试函数""" + print("开始GEE训练逻辑测试...") + + try: + test_gee_training_logic() + test_different_lambdas() + test_l1_vs_l2() + + print("\n" + "="*60) + print("所有训练逻辑测试通过!✓") + print("="*60) + print("\n核心功能验证:") + print("✅ 数据处理流程正常") + print("✅ 损失函数计算正确") + print("✅ 训练循环逻辑正确") + print("✅ 不同参数配置有效") + print("\n🎯 准备就绪,可以进行真实模型训练!") + + except Exception as e: + print(f"\n测试失败: {e}") + import traceback + traceback.print_exc() + return False + + return True + +if __name__ == "__main__": + success = main() + sys.exit(0 if success else 1)
\ No newline at end of file |
