diff options
Diffstat (limited to 'train_gee.py')
| -rw-r--r-- | train_gee.py | 242 |
1 files changed, 242 insertions, 0 deletions
diff --git a/train_gee.py b/train_gee.py new file mode 100644 index 0000000..2e62fee --- /dev/null +++ b/train_gee.py @@ -0,0 +1,242 @@ +import argparse +import os +import random +import time +from pathlib import Path + +import psutil +import torch +import torch.nn.functional as F +from torch.optim import AdamW +import pandas as pd +import numpy as np +from torch.utils.data import Dataset, DataLoader + +import wandb + +from accelerate import Accelerator, DeepSpeedPlugin +from accelerate.utils import set_seed +from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM + +# 导入自定义模块 +import sys +sys.path.append('.') +from dataset.gee_processor import GEEProcessor +from losses.gee_loss import GEELoss, gender_to_label + +os.environ.setdefault("NCCL_TIMEOUT", "2700") +os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700") + +class GEEDataset(Dataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + +def custom_collate(batch): + return { + "input": [item["input"] for item in batch], + "gender": [item["gender"] for item in batch] + } + +def parse_args(): + parser = argparse.ArgumentParser() + # GEE相关参数 + parser.add_argument('--lambda_weight', type=float, default=3.0, help='GEE lambda weight') + parser.add_argument('--use_l1', action='store_true', help='Use L1 loss instead of L2') + parser.add_argument('--auto_anneal', action='store_true', help='Use automatic annealing') + parser.add_argument('--bias_eval_steps', type=int, default=10, help='Bias evaluation frequency') + parser.add_argument('--balance_dataset', action='store_true', default=True, help='Balance dataset by gender') + parser.add_argument('--target_size', type=int, default=None, help='Target dataset size for balancing') + + # 保留原有参数 + parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name') + parser.add_argument('--model_path', type=str, default=None, help='Local model path') + parser.add_argument('--train_data', type=str, default='dataset/1shot_rlvr/pi1_r1280.parquet', help='Training data file path') + parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory') + parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size') + parser.add_argument('--micro_batch_size', type=int, default=2, help='Micro batch size') + parser.add_argument('--temperature', type=float, default=0.5, help='Temperature coefficient') + parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate') + parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval') + parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval') + parser.add_argument('--max_steps', type=int, default=50, help='Maximum training steps') + parser.add_argument('--sample_temp', type=float, default=0.5, help='Generation temperature parameter') + parser.add_argument('--run_name', type=str, default='one_shot_gee', help='Experiment run name') + parser.add_argument('--wandb_project', type=str, default='one-shot-gee', help='W&B project name') + parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name') + parser.add_argument('--seed', type=int, default=15, help='Random seed') + parser.add_argument('--use_test_data', action='store_true', help='Use synthetic test data instead of real data') + return parser.parse_args() + +def apply_chat_template(tokenizer, problem: str) -> str: + return tokenizer.apply_chat_template( + [{"role": "user", "content": problem}], + tokenize=False, add_generation_prompt=True + ) + +def main(): + args = parse_args() + set_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + world_size = int(os.getenv("WORLD_SIZE", "1")) + micro_bs = args.micro_batch_size + eff_bs = args.effective_batch + accum_steps = max(1, eff_bs // (micro_bs * world_size)) + temp = args.temperature + lr = args.learning_rate + + save_root = args.save_root or (f"checkpoints/{args.model_name}/{args.run_name}" if args.run_name else f"checkpoints/{args.model_name}") + ds_config = { + "train_micro_batch_size_per_gpu": micro_bs, + "train_batch_size": eff_bs, + "gradient_accumulation_steps": accum_steps, + "bf16": {"enabled": True}, + "zero_optimization": { + "stage": 2, + "offload_optimizer": {"device": "cpu"}, + "offload_param": {"device": "none"} + }, + "gradient_clipping": 1.0, + } + accelerator = Accelerator(mixed_precision="bf16", + gradient_accumulation_steps=accum_steps, + deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config)) + print = accelerator.print + + model_path = args.model_path or f"/volume/pt-train/models/{args.model_name}" + config = AutoConfig.from_pretrained(model_path) + config.use_cache = False + model = AutoModelForCausalLM.from_pretrained(model_path, config=config) + model.gradient_checkpointing_enable() + tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left") + tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token + + # 初始化GEE处理器和损失函数 + gee_processor = GEEProcessor(tokenizer) + gee_loss_fn = GEELoss(lambda_weight=args.lambda_weight, use_l1=args.use_l1) + + if accelerator.is_main_process: + wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args)) + + # 准备数据 + if args.use_test_data: + print("使用合成测试数据...") + train_data = gee_processor.create_test_data(num_samples=200) + else: + print("使用真实数据...") + train_data = gee_processor.prepare_gee_data( + args.train_data, + balance=args.balance_dataset, + target_size=args.target_size + ) + + train_loader = DataLoader( + GEEDataset(train_data), + batch_size=micro_bs, + shuffle=True, + collate_fn=custom_collate + ) + + optimizer = AdamW(model.parameters(), lr=lr) + model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) + + initial_entropy_gap = None + model.train() + + for step, batch in enumerate(train_loader, start=1): + if step > args.max_steps: + print(f"Exceed max step {args.max_steps}, training stopped.") + break + + with accelerator.accumulate(model): + # 准备输入 + inputs = tokenizer( + batch["input"], + return_tensors="pt", + padding="longest", + truncation=True, + max_length=2048 + ).to(accelerator.device) + + # 生成回答 + with torch.no_grad(): + gen_ids = accelerator.unwrap_model(model).generate( + **inputs, + max_new_tokens=512, + do_sample=True, + top_p=0.95, + temperature=args.sample_temp, + synced_gpus=True, + repetition_penalty=1.15, + pad_token_id=tokenizer.pad_token_id, + use_cache=False + ) + + # 准备完整序列 + seq = torch.cat([inputs.input_ids, gen_ids[:, inputs.input_ids.shape[1]:]], dim=1)[:, :4096] + pad_mask = seq.ne(tokenizer.pad_token_id) + prompt_lengths = pad_mask[:, :inputs.input_ids.shape[1]].sum(-1) + + # 计算logits和熵 + logits = model(seq, attention_mask=pad_mask).logits + H_tok = gee_loss_fn.compute_token_entropy(logits, pad_mask) + H_i = gee_loss_fn.compute_sample_entropy(H_tok, prompt_lengths) + + # 准备性别标签 + gender_labels = torch.tensor([ + gender_to_label(g) for g in batch["gender"] + ], device=accelerator.device) + + # 计算GEE损失 + loss, metrics = gee_loss_fn.compute_gee_loss(H_i, gender_labels) + + # 自动退火(可选) + if args.auto_anneal and initial_entropy_gap is None: + initial_entropy_gap = metrics['entropy_gap'] + + if args.auto_anneal and initial_entropy_gap > 0: + current_gap = metrics['entropy_gap'] + anneal_factor = current_gap / initial_entropy_gap + new_lambda = args.lambda_weight * anneal_factor + gee_loss_fn.update_lambda(new_lambda) + metrics['lambda_weight'] = new_lambda + + # 反向传播 + accelerator.backward(loss) + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + optimizer.zero_grad() + + # 日志记录 + if accelerator.is_main_process: + if step % args.log_steps == 0: + print(f"Step {step} | loss={loss.item():.6f} | " + f"entropy_gap={metrics['entropy_gap']:.6f} | " + f"H_male={metrics['H_male']:.6f} | " + f"H_female={metrics['H_female']:.6f}") + wandb.log({"step": step, **metrics}) + + if step % args.save_steps == 0: + ckpt = Path(save_root) / f"step_{step}" + ckpt.mkdir(parents=True, exist_ok=True) + accelerator.unwrap_model(model).save_pretrained(ckpt, safe_serialization=True) + tokenizer.save_pretrained(ckpt) + print(f"Checkpoint saved to {ckpt}") + + if accelerator.is_main_process: + final = Path(save_root) / "final" + final.mkdir(parents=True, exist_ok=True) + accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True) + tokenizer.save_pretrained(final) + print(f"Final checkpoint saved to {final}") + wandb.finish() + +if __name__ == "__main__": + main()
\ No newline at end of file |
