diff options
| -rw-r--r-- | README.md | 36 | ||||
| -rw-r--r-- | train.py | 286 |
2 files changed, 189 insertions, 133 deletions
@@ -7,7 +7,23 @@ ### Reproducing One-shot EM Training (SOTA) ```bash -accelerate launch train.py --lr 2e-5 --temperature 0.5 --bsz 64 +accelerate launch train.py \ + --model_name Qwen2.5-Math-7B \ + --model_path /path/to/Qwen2.5-Math-7B \ + --train_data dataset/1shot_rlvr/pi1_r1280.parquet \ + --eval_data dataset/1shot_rlvr/pi1_r1280.parquet \ + --effective_batch 64 \ + --micro_batch_size auto \ + --temperature 0.5 \ + --learning_rate 2e-5 \ + --eval_steps 5 \ + --eval_batch_size 4 \ + --eval_size 10 \ + --max_steps 1000 \ + --log_steps 1 \ + --save_steps 1 \ + --run_name my_experiment \ + --wandb_project entropy-maximization-ft ``` --- @@ -15,7 +31,23 @@ accelerate launch train.py --lr 2e-5 --temperature 0.5 --bsz 64 ### Reproducing Multi-shot EM Training ```bash -accelerate launch train.py --lr 2e-5 --temperature 0.5 --bsz 64 --data_path "dataset/numina/numina_00.parquet" +accelerate launch train.py \ + --model_name Qwen2.5-Math-7B \ + --model_path /path/to/Qwen2.5-Math-7B \ + --train_data dataset/numina/numina_00.parquet \ + --eval_data dataset/numina/numina_01.parquet \ + --effective_batch 64 \ + --micro_batch_size auto \ + --temperature 0.5 \ + --learning_rate 2e-5 \ + --eval_steps 5 \ + --eval_batch_size 4 \ + --eval_size 10 \ + --max_steps 1000 \ + --log_steps 1 \ + --save_steps 1 \ + --run_name multi_shot_experiment \ + --wandb_project entropy-maximization-ft ``` --- @@ -1,71 +1,54 @@ +#!/usr/bin/env python3 +""" +train.py + +One-shot Entropy Minimization Fine-tuning Script +""" + import argparse +import os import random -import string +import time from pathlib import Path + +import psutil +import torch import torch.nn.functional as F -import os, math, torch, pandas as pd +import pandas as pd import numpy as np from torch.utils.data import Dataset, DataLoader + +import wandb + from accelerate import Accelerator, DeepSpeedPlugin from accelerate.utils import set_seed from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AdamW - -parser = argparse.ArgumentParser() -parser.add_argument("--temperature", type=float, default=0.5) -parser.add_argument("--lr", type=float, default=2e-5) -parser.add_argument("--bsz", type=int, default=64) -parser.add_argument("--seed", type=int, default=15) -parser.add_argument("--data_path", type=str, default="dataset/1shot_rlvr/pi1_r1280.parquet") -args = parser.parse_args() - -set_seed(args.seed) -torch.backends.cudnn.deterministic = True -torch.backends.cudnn.benchmark = False - -temp_str = str(args.temperature).replace(".", "") -lr_str = f"{args.lr:.0e}" -bsz_str = str(args.bsz) -save_root = f"checkpoints/qwen25math7b/t{temp_str}_lr{lr_str}_bsz{bsz_str}_seed{args.seed}" - -temperature = args.temperature -learning_rate = args.lr -batch_size = args.bsz -micro_batch_size = 2 -world_size = int(os.environ.get("WORLD_SIZE", 1)) -accum_steps = max(1, batch_size // (micro_batch_size * world_size)) - -DEEPSPEED_CONFIG = { - "train_micro_batch_size_per_gpu": micro_batch_size, - "train_batch_size": batch_size, - "gradient_accumulation_steps": accum_steps, - "bf16": {"enabled": True}, - "zero_optimization": { - "stage": 2, - "offload_optimizer": {"device": "cpu"}, - "offload_param": {"device": "none"} - }, - "gradient_clipping": 1.0, -} - -ds_plugin = DeepSpeedPlugin(hf_ds_config=DEEPSPEED_CONFIG) -accelerator = Accelerator( - mixed_precision="bf16", - gradient_accumulation_steps=accum_steps, - deepspeed_plugin=ds_plugin, -) -print = accelerator.print - -model_name = "Qwen2.5-Math-7B" -model_path = f"/volume/ailab4sci/models/{model_name}" -config = AutoConfig.from_pretrained(model_path) -config.use_cache = False -model = AutoModelForCausalLM.from_pretrained(model_path, config=config) -model.gradient_checkpointing_enable() - -tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left") -if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token +# 设置NCCL超时时间,避免分布式训练超时(设置为45分钟) +os.environ.setdefault("NCCL_TIMEOUT", "2700") # 45分钟 +os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700") + +def parse_args(): + parser = argparse.ArgumentParser(description="Entropy Minimization Fine-tuning") + parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name') + parser.add_argument('--model_path', type=str, default=None, help='Local model path') + parser.add_argument('--train_data', type=str, default='', help='Training data file path') + parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory') + parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size') + parser.add_argument('--micro_batch_size', type=str, default='auto', help='Micro batch size or "auto"') + parser.add_argument('--temperature', type=float, default=0.9, help='Temperature coefficient') + parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate') + parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval') + parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval') + parser.add_argument('--max_steps', type=int, default=1000, help='Maximum training steps') + parser.add_argument('--sample_temp', type=float, default=None, help='Generation temperature parameter') + parser.add_argument('--run_name', type=str, default=None, help='Experiment run name') + parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name') + parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name') + parser.add_argument('--enable_self_distill', action='store_true', help='Enable self-distillation MSE constraint') + parser.add_argument('--self_distill_gamma', type=float, default=0.0, help='Self-distillation MSE weight γ') + parser.add_argument('--seed', type=int, default=42, help='Random seed') + return parser.parse_args() class FTDataset(Dataset): def __init__(self, rows): self.rows = rows @@ -75,83 +58,124 @@ class FTDataset(Dataset): def custom_collate(batch): return {"input": [item["input"] for item in batch]} -def apply_chat_template(problem: str) -> str: +def get_optimal_micro_batch_size(model_name: str, world_size: int = 1) -> int: + model_configs = { + "1.5B": {"base_batch": 4, "keywords": ["1.5B", "1B"]}, + "2B": {"base_batch": 4, "keywords": ["2B"]}, + "3B": {"base_batch": 2, "keywords": ["3B"]}, + "7B": {"base_batch": 2, "keywords": ["7B"]}, + "8B+": {"base_batch": 1, "keywords": ["8B", "9B", "10B", "11B", "12B", "13B", "14B"]}, + } + model_name_upper = model_name.upper() + detected = next((cfg for cfg in model_configs.values() if any(k in model_name_upper for k in cfg["keywords"])), None) + base_batch = detected["base_batch"] if detected else 2 + if world_size > 1: + return min(base_batch + 1, int(base_batch * 1.5)) + return base_batch + +def adjust_batch_config_for_deepspeed(effective_batch: int, micro_batch_size: int, world_size: int): + ideal = effective_batch / (micro_batch_size * world_size) + if ideal.is_integer(): + return effective_batch, micro_batch_size, int(ideal) + down = max(1, int(ideal)) + up = down + 1 + down_eff = micro_batch_size * down * world_size + up_eff = micro_batch_size * up * world_size + use_up = abs(effective_batch - up_eff) < abs(effective_batch - down_eff) + steps = up if use_up else down + eff = up_eff if use_up else down_eff + return eff, micro_batch_size, steps + +def apply_chat_template(tokenizer, problem: str) -> str: return tokenizer.apply_chat_template( [{"role": "user", "content": problem}], tokenize=False, add_generation_prompt=True ) -df = pd.read_parquet(args.data_path) -data = [{"input": apply_chat_template(p)} for p in df["problem"].dropna().tolist()] -dataset = FTDataset(data) -data_loader = DataLoader( - dataset, - batch_size=micro_batch_size, - shuffle=True, - collate_fn=custom_collate, -) - -optimizer = AdamW(model.parameters(), lr=learning_rate) -model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader) - -model.train() -for step, batch in enumerate(data_loader, start=1): - if step > 30: - break - - with accelerator.accumulate(model): - enc = tokenizer( - batch["input"], - return_tensors="pt", - padding="longest", - truncation=True, - max_length=2048, - ).to(accelerator.device) - - with torch.no_grad(): - gen_ids = accelerator.unwrap_model(model).generate( - **enc, - max_new_tokens=2048, - do_sample=True, - top_p=0.95, - temperature=temperature, - synced_gpus=True, - repetition_penalty=1.15, - pad_token_id=151643, - use_cache=False - ) - - seq = torch.cat( - [enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]], - dim=1 - )[:, :4096] - - pad_mask = seq.ne(tokenizer.pad_token_id) - prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1) - token_idx = torch.arange(seq.size(1), device=seq.device) - gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask - - logits = model(seq, attention_mask=pad_mask).logits - probs = F.softmax(logits / temperature, dim=-1) - H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1) - loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1) - - accelerator.backward(loss) - accelerator.clip_grad_norm_(model.parameters(), 1.0) - optimizer.step(); optimizer.zero_grad() - print(f"Step {step} | loss={loss.item():.8f}") +def main(): + args = parse_args() + set_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + world_size = int(os.getenv("WORLD_SIZE", "1")) + micro_bs = int(args.micro_batch_size) if args.micro_batch_size != 'auto' else get_optimal_micro_batch_size(args.model_name, world_size) + eff_bs, micro_bs, accum_steps = adjust_batch_config_for_deepspeed(args.effective_batch, micro_bs, world_size) + + temp = args.temperature + lr = args.learning_rate + + save_root = args.save_root or (f"checkpoints/{args.model_name}/{args.run_name}" if args.run_name else f"checkpoints/{args.model_name}") + + ds_config = { + "train_micro_batch_size_per_gpu": micro_bs, + "train_batch_size": eff_bs, + "gradient_accumulation_steps": accum_steps, + "bf16": {"enabled": True}, + "zero_optimization": {"stage": 2, "offload_optimizer": {"device": "cpu"}, "offload_param": {"device": "none"}}, + "gradient_clipping": 1.0, + } + accelerator = Accelerator(mixed_precision="bf16", gradient_accumulation_steps=accum_steps, deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config)) + print = accelerator.print + + model_path = args.model_path or f"/volume/ailab4sci/models/{args.model_name}" + config = AutoConfig.from_pretrained(model_path); config.use_cache = False + model = AutoModelForCausalLM.from_pretrained(model_path, config=config); model.gradient_checkpointing_enable() + tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left") + if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + + if accelerator.is_main_process: + wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args)) + + df = pd.read_parquet(args.train_data) + train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()] + train_loader = DataLoader(FTDataset(train_data), batch_size=micro_bs, shuffle=True, collate_fn=custom_collate) + + optimizer = AdamW(model.parameters(), lr=lr) + model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) + + prev_logits = None + model.train() + for step, batch in enumerate(train_loader, start=1): + if step > args.max_steps: + print(f"达到最大步数 {args.max_steps},停止训练"); break + with accelerator.accumulate(model): + enc = tokenizer(batch["input"], return_tensors="pt", padding="longest", truncation=True, max_length=2048).to(accelerator.device) + with torch.no_grad(): + gen_ids = accelerator.unwrap_model(model).generate(**enc, max_new_tokens=512, do_sample=True, top_p=0.95, temperature=args.sample_temp if args.sample_temp is not None else temp, synced_gpus=True, pad_token_id=tokenizer.pad_token_id, use_cache=False) + seq = torch.cat([enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]], dim=1)[:, :4096] + pad_mask = seq.ne(tokenizer.pad_token_id); prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1) + token_idx = torch.arange(seq.size(1), device=seq.device) + gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask + + logits = model(seq, attention_mask=pad_mask).logits + probs = F.softmax(logits / temp, dim=-1) + H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1) + loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1) + + if args.enable_self_distill and prev_logits is not None: + loss = loss + args.self_distill_gamma * F.mse_loss(logits, prev_logits) + + prev_logits = logits.detach() + accelerator.backward(loss); accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step(); optimizer.zero_grad() + + if accelerator.is_main_process: + if step % args.log_steps == 0: + print(f"Step {step} | loss={loss.item():.6f}") + wandb.log({"step": step, "loss": loss.item()}) + # 根据 save_steps 保存检查点 + if step % args.save_steps == 0: + ckpt = Path(save_root) / f"step_{step}" + ckpt.mkdir(parents=True, exist_ok=True) + accelerator.unwrap_model(model).save_pretrained(ckpt, safe_serialization=True) + tokenizer.save_pretrained(ckpt) + print(f"Checkpoint saved to {ckpt}") if accelerator.is_main_process: - ckpt_dir = Path(save_root) / f"step_{step}" - ckpt_dir.mkdir(parents=True, exist_ok=True) - accelerator.unwrap_model(model).save_pretrained(ckpt_dir, safe_serialization=True) - tokenizer.save_pretrained(ckpt_dir) - accelerator.wait_for_everyone() - print(f"Checkpoint saved to {ckpt_dir}") - -if accelerator.is_main_process: - final_dir = Path(save_root) / "final" - final_dir.mkdir(parents=True, exist_ok=True) - accelerator.unwrap_model(model).save_pretrained(final_dir, safe_serialization=True) - tokenizer.save_pretrained(final_dir) - print(f"Final checkpoint saved to {final_dir}")
\ No newline at end of file + final=Path(save_root)/"final"; final.mkdir(parents=True,exist_ok=True) + accelerator.unwrap_model(model).save_pretrained(final,safe_serialization=True); tokenizer.save_pretrained(final) + print(f"Final checkpoint saved to {final}"); wandb.finish() + +if __name__ == "__main__": + main()
\ No newline at end of file |
