summaryrefslogtreecommitdiff
path: root/train_gee.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_gee.py')
-rw-r--r--train_gee.py242
1 files changed, 242 insertions, 0 deletions
diff --git a/train_gee.py b/train_gee.py
new file mode 100644
index 0000000..2e62fee
--- /dev/null
+++ b/train_gee.py
@@ -0,0 +1,242 @@
+import argparse
+import os
+import random
+import time
+from pathlib import Path
+
+import psutil
+import torch
+import torch.nn.functional as F
+from torch.optim import AdamW
+import pandas as pd
+import numpy as np
+from torch.utils.data import Dataset, DataLoader
+
+import wandb
+
+from accelerate import Accelerator, DeepSpeedPlugin
+from accelerate.utils import set_seed
+from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
+
+# 导入自定义模块
+import sys
+sys.path.append('.')
+from dataset.gee_processor import GEEProcessor
+from losses.gee_loss import GEELoss, gender_to_label
+
+os.environ.setdefault("NCCL_TIMEOUT", "2700")
+os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700")
+
+class GEEDataset(Dataset):
+ def __init__(self, data):
+ self.data = data
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+def custom_collate(batch):
+ return {
+ "input": [item["input"] for item in batch],
+ "gender": [item["gender"] for item in batch]
+ }
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ # GEE相关参数
+ parser.add_argument('--lambda_weight', type=float, default=3.0, help='GEE lambda weight')
+ parser.add_argument('--use_l1', action='store_true', help='Use L1 loss instead of L2')
+ parser.add_argument('--auto_anneal', action='store_true', help='Use automatic annealing')
+ parser.add_argument('--bias_eval_steps', type=int, default=10, help='Bias evaluation frequency')
+ parser.add_argument('--balance_dataset', action='store_true', default=True, help='Balance dataset by gender')
+ parser.add_argument('--target_size', type=int, default=None, help='Target dataset size for balancing')
+
+ # 保留原有参数
+ parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name')
+ parser.add_argument('--model_path', type=str, default=None, help='Local model path')
+ parser.add_argument('--train_data', type=str, default='dataset/1shot_rlvr/pi1_r1280.parquet', help='Training data file path')
+ parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory')
+ parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size')
+ parser.add_argument('--micro_batch_size', type=int, default=2, help='Micro batch size')
+ parser.add_argument('--temperature', type=float, default=0.5, help='Temperature coefficient')
+ parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
+ parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval')
+ parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval')
+ parser.add_argument('--max_steps', type=int, default=50, help='Maximum training steps')
+ parser.add_argument('--sample_temp', type=float, default=0.5, help='Generation temperature parameter')
+ parser.add_argument('--run_name', type=str, default='one_shot_gee', help='Experiment run name')
+ parser.add_argument('--wandb_project', type=str, default='one-shot-gee', help='W&B project name')
+ parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name')
+ parser.add_argument('--seed', type=int, default=15, help='Random seed')
+ parser.add_argument('--use_test_data', action='store_true', help='Use synthetic test data instead of real data')
+ return parser.parse_args()
+
+def apply_chat_template(tokenizer, problem: str) -> str:
+ return tokenizer.apply_chat_template(
+ [{"role": "user", "content": problem}],
+ tokenize=False, add_generation_prompt=True
+ )
+
+def main():
+ args = parse_args()
+ set_seed(args.seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
+ micro_bs = args.micro_batch_size
+ eff_bs = args.effective_batch
+ accum_steps = max(1, eff_bs // (micro_bs * world_size))
+ temp = args.temperature
+ lr = args.learning_rate
+
+ save_root = args.save_root or (f"checkpoints/{args.model_name}/{args.run_name}" if args.run_name else f"checkpoints/{args.model_name}")
+ ds_config = {
+ "train_micro_batch_size_per_gpu": micro_bs,
+ "train_batch_size": eff_bs,
+ "gradient_accumulation_steps": accum_steps,
+ "bf16": {"enabled": True},
+ "zero_optimization": {
+ "stage": 2,
+ "offload_optimizer": {"device": "cpu"},
+ "offload_param": {"device": "none"}
+ },
+ "gradient_clipping": 1.0,
+ }
+ accelerator = Accelerator(mixed_precision="bf16",
+ gradient_accumulation_steps=accum_steps,
+ deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config))
+ print = accelerator.print
+
+ model_path = args.model_path or f"/volume/pt-train/models/{args.model_name}"
+ config = AutoConfig.from_pretrained(model_path)
+ config.use_cache = False
+ model = AutoModelForCausalLM.from_pretrained(model_path, config=config)
+ model.gradient_checkpointing_enable()
+ tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
+ tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
+
+ # 初始化GEE处理器和损失函数
+ gee_processor = GEEProcessor(tokenizer)
+ gee_loss_fn = GEELoss(lambda_weight=args.lambda_weight, use_l1=args.use_l1)
+
+ if accelerator.is_main_process:
+ wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args))
+
+ # 准备数据
+ if args.use_test_data:
+ print("使用合成测试数据...")
+ train_data = gee_processor.create_test_data(num_samples=200)
+ else:
+ print("使用真实数据...")
+ train_data = gee_processor.prepare_gee_data(
+ args.train_data,
+ balance=args.balance_dataset,
+ target_size=args.target_size
+ )
+
+ train_loader = DataLoader(
+ GEEDataset(train_data),
+ batch_size=micro_bs,
+ shuffle=True,
+ collate_fn=custom_collate
+ )
+
+ optimizer = AdamW(model.parameters(), lr=lr)
+ model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
+
+ initial_entropy_gap = None
+ model.train()
+
+ for step, batch in enumerate(train_loader, start=1):
+ if step > args.max_steps:
+ print(f"Exceed max step {args.max_steps}, training stopped.")
+ break
+
+ with accelerator.accumulate(model):
+ # 准备输入
+ inputs = tokenizer(
+ batch["input"],
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ max_length=2048
+ ).to(accelerator.device)
+
+ # 生成回答
+ with torch.no_grad():
+ gen_ids = accelerator.unwrap_model(model).generate(
+ **inputs,
+ max_new_tokens=512,
+ do_sample=True,
+ top_p=0.95,
+ temperature=args.sample_temp,
+ synced_gpus=True,
+ repetition_penalty=1.15,
+ pad_token_id=tokenizer.pad_token_id,
+ use_cache=False
+ )
+
+ # 准备完整序列
+ seq = torch.cat([inputs.input_ids, gen_ids[:, inputs.input_ids.shape[1]:]], dim=1)[:, :4096]
+ pad_mask = seq.ne(tokenizer.pad_token_id)
+ prompt_lengths = pad_mask[:, :inputs.input_ids.shape[1]].sum(-1)
+
+ # 计算logits和熵
+ logits = model(seq, attention_mask=pad_mask).logits
+ H_tok = gee_loss_fn.compute_token_entropy(logits, pad_mask)
+ H_i = gee_loss_fn.compute_sample_entropy(H_tok, prompt_lengths)
+
+ # 准备性别标签
+ gender_labels = torch.tensor([
+ gender_to_label(g) for g in batch["gender"]
+ ], device=accelerator.device)
+
+ # 计算GEE损失
+ loss, metrics = gee_loss_fn.compute_gee_loss(H_i, gender_labels)
+
+ # 自动退火(可选)
+ if args.auto_anneal and initial_entropy_gap is None:
+ initial_entropy_gap = metrics['entropy_gap']
+
+ if args.auto_anneal and initial_entropy_gap > 0:
+ current_gap = metrics['entropy_gap']
+ anneal_factor = current_gap / initial_entropy_gap
+ new_lambda = args.lambda_weight * anneal_factor
+ gee_loss_fn.update_lambda(new_lambda)
+ metrics['lambda_weight'] = new_lambda
+
+ # 反向传播
+ accelerator.backward(loss)
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+ optimizer.zero_grad()
+
+ # 日志记录
+ if accelerator.is_main_process:
+ if step % args.log_steps == 0:
+ print(f"Step {step} | loss={loss.item():.6f} | "
+ f"entropy_gap={metrics['entropy_gap']:.6f} | "
+ f"H_male={metrics['H_male']:.6f} | "
+ f"H_female={metrics['H_female']:.6f}")
+ wandb.log({"step": step, **metrics})
+
+ if step % args.save_steps == 0:
+ ckpt = Path(save_root) / f"step_{step}"
+ ckpt.mkdir(parents=True, exist_ok=True)
+ accelerator.unwrap_model(model).save_pretrained(ckpt, safe_serialization=True)
+ tokenizer.save_pretrained(ckpt)
+ print(f"Checkpoint saved to {ckpt}")
+
+ if accelerator.is_main_process:
+ final = Path(save_root) / "final"
+ final.mkdir(parents=True, exist_ok=True)
+ accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True)
+ tokenizer.save_pretrained(final)
+ print(f"Final checkpoint saved to {final}")
+ wandb.finish()
+
+if __name__ == "__main__":
+ main() \ No newline at end of file