summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--smart_balanced_dataloader.py218
-rw-r--r--train_gee_smart.py279
2 files changed, 497 insertions, 0 deletions
diff --git a/smart_balanced_dataloader.py b/smart_balanced_dataloader.py
new file mode 100644
index 0000000..54eb630
--- /dev/null
+++ b/smart_balanced_dataloader.py
@@ -0,0 +1,218 @@
+#!/usr/bin/env python3
+"""
+智能平衡数据加载器 - 自动检测并重新生成不平衡批次
+"""
+import torch
+from torch.utils.data import Dataset, DataLoader
+import random
+from typing import List, Dict
+import warnings
+
+class SmartBalancedGEEDataset(Dataset):
+ def __init__(self, data: List[Dict]):
+ self.data = data
+ # 按性别分组
+ self.male_data = [item for item in data if item['gender'] == 'male']
+ self.female_data = [item for item in data if item['gender'] == 'female']
+
+ print(f"📊 数据分布: male={len(self.male_data)}, female={len(self.female_data)}")
+
+ # 确保有足够的数据
+ if len(self.male_data) == 0 or len(self.female_data) == 0:
+ raise ValueError("数据中必须包含男性和女性样本")
+
+ # 确保有足够的样本进行重新生成
+ min_samples = min(len(self.male_data), len(self.female_data))
+ if min_samples < 10:
+ warnings.warn(f"样本数量较少 (min={min_samples}),可能影响批次生成质量")
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+ def generate_balanced_batch(self, batch_size: int, max_retries: int = 10):
+ """生成一个平衡的批次,如果失败会重试"""
+ if batch_size < 2:
+ raise ValueError("batch_size必须>=2才能保证性别平衡")
+
+ # 计算每个性别需要的样本数
+ male_per_batch = batch_size // 2
+ female_per_batch = batch_size - male_per_batch
+
+ for attempt in range(max_retries):
+ try:
+ batch = []
+
+ # 随机选择男性样本
+ if len(self.male_data) >= male_per_batch:
+ male_samples = random.sample(self.male_data, male_per_batch)
+ batch.extend(male_samples)
+ else:
+ # 如果男性样本不够,用替换的方式采样
+ male_samples = random.choices(self.male_data, k=male_per_batch)
+ batch.extend(male_samples)
+
+ # 随机选择女性样本
+ if len(self.female_data) >= female_per_batch:
+ female_samples = random.sample(self.female_data, female_per_batch)
+ batch.extend(female_samples)
+ else:
+ # 如果女性样本不够,用替换的方式采样
+ female_samples = random.choices(self.female_data, k=female_per_batch)
+ batch.extend(female_samples)
+
+ # 打乱批次内的顺序
+ random.shuffle(batch)
+
+ # 验证批次平衡性
+ male_count = sum(1 for item in batch if item['gender'] == 'male')
+ female_count = sum(1 for item in batch if item['gender'] == 'female')
+
+ if male_count > 0 and female_count > 0:
+ return batch
+ else:
+ print(f"⚠️ 尝试 {attempt+1}: 批次不平衡 (male={male_count}, female={female_count}),重新生成...")
+
+ except Exception as e:
+ print(f"⚠️ 尝试 {attempt+1} 失败: {e}")
+ continue
+
+ # 如果所有尝试都失败,返回一个强制平衡的批次
+ print(f"❌ {max_retries} 次尝试后仍然失败,强制生成平衡批次")
+ return self._force_balanced_batch(batch_size)
+
+ def _force_balanced_batch(self, batch_size: int):
+ """强制生成一个平衡批次"""
+ batch = []
+ male_per_batch = batch_size // 2
+ female_per_batch = batch_size - male_per_batch
+
+ # 强制添加男性样本(允许重复)
+ for _ in range(male_per_batch):
+ batch.append(random.choice(self.male_data))
+
+ # 强制添加女性样本(允许重复)
+ for _ in range(female_per_batch):
+ batch.append(random.choice(self.female_data))
+
+ random.shuffle(batch)
+ return batch
+
+class SmartBalancedDataLoader:
+ """智能平衡数据加载器"""
+ def __init__(self, dataset: SmartBalancedGEEDataset, batch_size: int, num_batches: int):
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.num_batches = num_batches
+ self.current_idx = 0
+
+ print(f"🧠 智能数据加载器初始化: batch_size={batch_size}, num_batches={num_batches}")
+
+ def __iter__(self):
+ self.current_idx = 0
+ return self
+
+ def __next__(self):
+ if self.current_idx >= self.num_batches:
+ raise StopIteration
+
+ # 动态生成平衡批次
+ batch = self.dataset.generate_balanced_batch(self.batch_size)
+ self.current_idx += 1
+
+ # 应用collate函数并验证
+ return self._smart_collate(batch)
+
+ def _smart_collate(self, batch, max_regenerate: int = 3):
+ """智能collate函数,如果检测到不平衡会重新生成"""
+ inputs = [item["input"] for item in batch]
+ genders = [item["gender"] for item in batch]
+
+ # 检查批次平衡性
+ male_count = sum(1 for g in genders if g == 'male')
+ female_count = sum(1 for g in genders if g == 'female')
+
+ # 如果不平衡,尝试重新生成
+ regenerate_count = 0
+ while (male_count == 0 or female_count == 0) and regenerate_count < max_regenerate:
+ print(f"🔄 检测到不平衡批次 (male={male_count}, female={female_count}),重新生成...")
+
+ # 重新生成批次
+ batch = self.dataset.generate_balanced_batch(self.batch_size)
+ inputs = [item["input"] for item in batch]
+ genders = [item["gender"] for item in batch]
+
+ # 重新检查
+ male_count = sum(1 for g in genders if g == 'male')
+ female_count = sum(1 for g in genders if g == 'female')
+ regenerate_count += 1
+
+ # 最终检查
+ if male_count == 0:
+ print("❌ 最终警告: 批次中没有男性样本!")
+ if female_count == 0:
+ print("❌ 最终警告: 批次中没有女性样本!")
+
+ if male_count > 0 and female_count > 0:
+ print(f"✅ 平衡批次: male={male_count}, female={female_count}")
+
+ return {
+ "input": inputs,
+ "gender": genders
+ }
+
+def create_smart_balanced_dataloader(data: List[Dict], batch_size: int, num_batches: int = 10):
+ """创建智能平衡数据加载器"""
+
+ if batch_size < 2:
+ print("⚠️ 警告: batch_size < 2,无法保证性别平衡")
+ # 回退到普通DataLoader
+ dataset = SmartBalancedGEEDataset(data)
+ return DataLoader(dataset, batch_size=batch_size, shuffle=True)
+
+ dataset = SmartBalancedGEEDataset(data)
+
+ print(f"🧠 创建智能平衡数据加载器")
+ print(f" 批次大小: {batch_size}")
+ print(f" 批次数量: {num_batches}")
+ print(f" 每批次配置: male={batch_size//2}, female={batch_size-batch_size//2}")
+
+ return SmartBalancedDataLoader(dataset, batch_size, num_batches)
+
+# 测试函数
+if __name__ == "__main__":
+ # 测试智能平衡数据加载器
+ import sys
+ sys.path.append('.')
+ from dataset.gee_processor import GEEProcessor
+
+ class MockTokenizer:
+ def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
+ return messages[0]["content"]
+
+ processor = GEEProcessor(MockTokenizer())
+ test_data = processor.create_test_data(num_samples=20)
+
+ print("🧪 测试智能平衡数据加载器")
+ dataloader = create_smart_balanced_dataloader(test_data, batch_size=4, num_batches=5)
+
+ for i, batch in enumerate(dataloader):
+ print(f"\n=== 批次 {i+1} ===")
+ print(f"输入数量: {len(batch['input'])}")
+ print(f"性别分布: {batch['gender']}")
+
+ # 验证平衡性
+ male_count = sum(1 for g in batch['gender'] if g == 'male')
+ female_count = sum(1 for g in batch['gender'] if g == 'female')
+
+ if male_count > 0 and female_count > 0:
+ print(f"✅ 批次完美平衡: male={male_count}, female={female_count}")
+ else:
+ print(f"❌ 批次仍然不平衡: male={male_count}, female={female_count}")
+
+ if i >= 4: # 测试5个批次
+ break
+
+ print("\n�� 智能平衡数据加载器测试完成!") \ No newline at end of file
diff --git a/train_gee_smart.py b/train_gee_smart.py
new file mode 100644
index 0000000..d368762
--- /dev/null
+++ b/train_gee_smart.py
@@ -0,0 +1,279 @@
+#!/usr/bin/env python3
+"""
+GEE训练脚本 - 使用智能平衡数据加载器
+绝对保证每个批次都包含男女样本
+"""
+import argparse
+import os
+import torch
+import torch.nn.functional as F
+from torch.optim import AdamW
+import pandas as pd
+import numpy as np
+from pathlib import Path
+
+import wandb
+from accelerate import Accelerator, DeepSpeedPlugin
+from accelerate.utils import set_seed
+from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
+
+# 导入自定义模块
+import sys
+sys.path.append('.')
+from dataset.gee_processor import GEEProcessor
+from losses.gee_loss import GEELoss, gender_to_label
+from smart_balanced_dataloader import create_smart_balanced_dataloader
+
+os.environ.setdefault("NCCL_TIMEOUT", "2700")
+os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700")
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ # GEE相关参数
+ parser.add_argument('--lambda_weight', type=float, default=0.5, help='GEE lambda weight')
+ parser.add_argument('--use_l1', action='store_true', help='Use L1 loss instead of L2')
+ parser.add_argument('--auto_anneal', action='store_true', help='Use automatic annealing')
+
+ # 模型参数
+ parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-1.5B-Instruct', help='Model name')
+ parser.add_argument('--model_path', type=str, required=True, help='Model path')
+ parser.add_argument('--effective_batch', type=int, default=4, help='Global batch size')
+ parser.add_argument('--micro_batch_size', type=int, default=2, help='Micro batch size (must >=2 for balance)')
+ parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
+ parser.add_argument('--max_steps', type=int, default=10, help='Maximum training steps')
+ parser.add_argument('--sample_temp', type=float, default=0.7, help='Generation temperature')
+
+ # 运行参数
+ parser.add_argument('--run_name', type=str, default='gee_smart_balanced', help='Run name')
+ parser.add_argument('--wandb_project', type=str, default='one-shot-gee', help='W&B project name')
+ parser.add_argument('--use_test_data', action='store_true', help='Use synthetic test data')
+ parser.add_argument('--seed', type=int, default=42, help='Random seed')
+ parser.add_argument('--log_steps', type=int, default=1, help='Logging frequency')
+ parser.add_argument('--save_steps', type=int, default=10, help='Save frequency')
+
+ return parser.parse_args()
+
+def main():
+ args = parse_args()
+ set_seed(args.seed)
+
+ # 强制检查batch_size
+ if args.micro_batch_size < 2:
+ print("❌ 错误: micro_batch_size必须>=2才能保证性别平衡!")
+ print("请使用: --micro_batch_size 2 或更大")
+ return
+
+ # DeepSpeed配置
+ ds_config = {
+ "train_micro_batch_size_per_gpu": args.micro_batch_size,
+ "train_batch_size": args.effective_batch,
+ "gradient_accumulation_steps": max(1, args.effective_batch // args.micro_batch_size),
+ "bf16": {"enabled": True},
+ "zero_optimization": {
+ "stage": 2,
+ "offload_optimizer": {"device": "cpu"},
+ "offload_param": {"device": "none"}
+ },
+ "gradient_clipping": 1.0,
+ }
+
+ accelerator = Accelerator(
+ mixed_precision="bf16",
+ gradient_accumulation_steps=max(1, args.effective_batch // args.micro_batch_size),
+ deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config)
+ )
+
+ print = accelerator.print
+ print(f"🧠 开始智能GEE训练 - {args.run_name}")
+ print(f"📊 配置信息:")
+ print(f" 批次大小: micro={args.micro_batch_size}, effective={args.effective_batch}")
+ print(f" Lambda权重: {args.lambda_weight}")
+ print(f" 最大步数: {args.max_steps}")
+ print(f" 智能平衡: ✅ 启用")
+
+ # 加载模型
+ model_path = args.model_path
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
+ config.use_cache = False
+ model = AutoModelForCausalLM.from_pretrained(model_path, config=config, trust_remote_code=True)
+ model.gradient_checkpointing_enable()
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ # 初始化GEE处理器和损失函数
+ gee_processor = GEEProcessor(tokenizer)
+ gee_loss_fn = GEELoss(lambda_weight=args.lambda_weight, use_l1=args.use_l1)
+
+ if accelerator.is_main_process:
+ wandb.init(project=args.wandb_project, name=args.run_name, config=vars(args))
+
+ # 准备数据 - 使用智能平衡数据加载器
+ if args.use_test_data:
+ print("📊 使用合成测试数据...")
+ train_data = gee_processor.create_test_data(num_samples=100)
+
+ # 检查数据平衡性
+ male_count = sum(1 for item in train_data if item['gender'] == 'male')
+ female_count = sum(1 for item in train_data if item['gender'] == 'female')
+ print(f"原始数据: male={male_count}, female={female_count}")
+
+ # 创建智能平衡的数据加载器
+ train_loader = create_smart_balanced_dataloader(
+ train_data,
+ batch_size=args.micro_batch_size,
+ num_batches=args.max_steps + 5 # 额外的批次确保有足够数据
+ )
+ else:
+ print("❌ 请使用 --use_test_data 进行测试")
+ return
+
+ optimizer = AdamW(model.parameters(), lr=args.learning_rate)
+
+ # 注意:智能数据加载器不需要accelerator.prepare处理
+ model, optimizer = accelerator.prepare(model, optimizer)
+
+ print(f"🎯 开始训练...")
+ print(f" 期望看到: 每个批次都显示 '✅ 平衡批次'")
+ print(f" 不应该看到: '❌ 警告: 批次中没有男性/女性样本'")
+
+ # 开始训练
+ model.train()
+ initial_entropy_gap = None
+ successful_steps = 0
+ failed_steps = 0
+
+ for step, batch in enumerate(train_loader, start=1):
+ if step > args.max_steps:
+ print(f"🛑 达到最大步数 {args.max_steps},训练结束")
+ break
+
+ with accelerator.accumulate(model):
+ try:
+ # 验证批次平衡性
+ male_count = sum(1 for g in batch["gender"] if g == 'male')
+ female_count = sum(1 for g in batch["gender"] if g == 'female')
+
+ if male_count == 0 or female_count == 0:
+ print(f"💥 Step {step}: 智能加载器失败!male={male_count}, female={female_count}")
+ failed_steps += 1
+ continue
+ else:
+ successful_steps += 1
+
+ # 准备输入
+ inputs = tokenizer(
+ batch["input"],
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ max_length=1024
+ ).to(accelerator.device)
+
+ # 生成回答
+ with torch.no_grad():
+ gen_ids = accelerator.unwrap_model(model).generate(
+ **inputs,
+ max_new_tokens=128,
+ do_sample=True,
+ top_p=0.95,
+ temperature=args.sample_temp,
+ synced_gpus=True,
+ repetition_penalty=1.15,
+ pad_token_id=tokenizer.pad_token_id,
+ use_cache=False
+ )
+
+ # 准备完整序列
+ seq = torch.cat([inputs.input_ids, gen_ids[:, inputs.input_ids.shape[1]:]], dim=1)
+ pad_mask = seq.ne(tokenizer.pad_token_id)
+ prompt_lengths = pad_mask[:, :inputs.input_ids.shape[1]].sum(-1)
+
+ # 计算logits和熵
+ logits = model(seq, attention_mask=pad_mask).logits
+ H_tok = gee_loss_fn.compute_token_entropy(logits, pad_mask)
+ H_i = gee_loss_fn.compute_sample_entropy(H_tok, prompt_lengths)
+
+ # 准备性别标签
+ gender_labels = torch.tensor([
+ gender_to_label(g) for g in batch["gender"]
+ ], device=accelerator.device)
+
+ # 计算GEE损失
+ loss, metrics = gee_loss_fn.compute_gee_loss(H_i, gender_labels)
+
+ # 自动退火(可选)
+ if args.auto_anneal and initial_entropy_gap is None:
+ initial_entropy_gap = metrics['entropy_gap']
+
+ if args.auto_anneal and initial_entropy_gap > 0:
+ current_gap = metrics['entropy_gap']
+ anneal_factor = current_gap / initial_entropy_gap
+ new_lambda = args.lambda_weight * anneal_factor
+ gee_loss_fn.update_lambda(new_lambda)
+ metrics['lambda_weight'] = new_lambda
+
+ # 反向传播
+ accelerator.backward(loss)
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+ optimizer.zero_grad()
+
+ # 日志记录
+ if accelerator.is_main_process:
+ if step % args.log_steps == 0:
+ print(f"🎯 Step {step} | loss={loss.item():.6f} | "
+ f"gap={metrics['entropy_gap']:.6f} | "
+ f"H_male={metrics['H_male']:.6f} | "
+ f"H_female={metrics['H_female']:.6f} | "
+ f"批次[{male_count}M,{female_count}F]")
+
+ # 添加成功率统计
+ success_rate = successful_steps / (successful_steps + failed_steps) * 100
+ metrics['success_rate'] = success_rate
+
+ wandb.log({"step": step, **metrics})
+
+ # 保存检查点
+ if accelerator.is_main_process and step % args.save_steps == 0:
+ ckpt = Path(f"checkpoints/{args.model_name}/{args.run_name}") / f"step_{step}"
+ ckpt.mkdir(parents=True, exist_ok=True)
+ accelerator.unwrap_model(model).save_pretrained(ckpt, safe_serialization=True)
+ tokenizer.save_pretrained(ckpt)
+ print(f"💾 检查点已保存: {ckpt}")
+
+ except Exception as e:
+ print(f"❌ 训练步骤 {step} 出错: {e}")
+ failed_steps += 1
+ continue
+
+ if accelerator.is_main_process:
+ # 最终统计
+ total_steps = successful_steps + failed_steps
+ success_rate = successful_steps / total_steps * 100 if total_steps > 0 else 0
+
+ print(f"\n🎉 智能GEE训练完成!")
+ print(f"📊 最终统计:")
+ print(f" 成功步数: {successful_steps}")
+ print(f" 失败步数: {failed_steps}")
+ print(f" 成功率: {success_rate:.1f}%")
+
+ if success_rate >= 95:
+ print("✅ 智能平衡数据加载器工作完美!")
+ elif success_rate >= 80:
+ print("⚠️ 智能平衡数据加载器基本正常,偶有问题")
+ else:
+ print("❌ 智能平衡数据加载器需要进一步优化")
+
+ # 保存最终模型
+ final = Path(f"checkpoints/{args.model_name}/{args.run_name}") / "final"
+ final.mkdir(parents=True, exist_ok=True)
+ accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True)
+ tokenizer.save_pretrained(final)
+ print(f"💾 最终模型已保存: {final}")
+
+ wandb.finish()
+
+if __name__ == "__main__":
+ main() \ No newline at end of file