summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md36
-rw-r--r--train.py286
2 files changed, 189 insertions, 133 deletions
diff --git a/README.md b/README.md
index 2507b74..b5b6f62 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,23 @@
### Reproducing One-shot EM Training (SOTA)
```bash
-accelerate launch train.py --lr 2e-5 --temperature 0.5 --bsz 64
+accelerate launch train.py \
+ --model_name Qwen2.5-Math-7B \
+ --model_path /path/to/Qwen2.5-Math-7B \
+ --train_data dataset/1shot_rlvr/pi1_r1280.parquet \
+ --eval_data dataset/1shot_rlvr/pi1_r1280.parquet \
+ --effective_batch 64 \
+ --micro_batch_size auto \
+ --temperature 0.5 \
+ --learning_rate 2e-5 \
+ --eval_steps 5 \
+ --eval_batch_size 4 \
+ --eval_size 10 \
+ --max_steps 1000 \
+ --log_steps 1 \
+ --save_steps 1 \
+ --run_name my_experiment \
+ --wandb_project entropy-maximization-ft
```
---
@@ -15,7 +31,23 @@ accelerate launch train.py --lr 2e-5 --temperature 0.5 --bsz 64
### Reproducing Multi-shot EM Training
```bash
-accelerate launch train.py --lr 2e-5 --temperature 0.5 --bsz 64 --data_path "dataset/numina/numina_00.parquet"
+accelerate launch train.py \
+ --model_name Qwen2.5-Math-7B \
+ --model_path /path/to/Qwen2.5-Math-7B \
+ --train_data dataset/numina/numina_00.parquet \
+ --eval_data dataset/numina/numina_01.parquet \
+ --effective_batch 64 \
+ --micro_batch_size auto \
+ --temperature 0.5 \
+ --learning_rate 2e-5 \
+ --eval_steps 5 \
+ --eval_batch_size 4 \
+ --eval_size 10 \
+ --max_steps 1000 \
+ --log_steps 1 \
+ --save_steps 1 \
+ --run_name multi_shot_experiment \
+ --wandb_project entropy-maximization-ft
```
---
diff --git a/train.py b/train.py
index 3092e10..1033f7c 100644
--- a/train.py
+++ b/train.py
@@ -1,71 +1,54 @@
+#!/usr/bin/env python3
+"""
+train.py
+
+One-shot Entropy Minimization Fine-tuning Script
+"""
+
import argparse
+import os
import random
-import string
+import time
from pathlib import Path
+
+import psutil
+import torch
import torch.nn.functional as F
-import os, math, torch, pandas as pd
+import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
+
+import wandb
+
from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.utils import set_seed
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AdamW
-
-parser = argparse.ArgumentParser()
-parser.add_argument("--temperature", type=float, default=0.5)
-parser.add_argument("--lr", type=float, default=2e-5)
-parser.add_argument("--bsz", type=int, default=64)
-parser.add_argument("--seed", type=int, default=15)
-parser.add_argument("--data_path", type=str, default="dataset/1shot_rlvr/pi1_r1280.parquet")
-args = parser.parse_args()
-
-set_seed(args.seed)
-torch.backends.cudnn.deterministic = True
-torch.backends.cudnn.benchmark = False
-
-temp_str = str(args.temperature).replace(".", "")
-lr_str = f"{args.lr:.0e}"
-bsz_str = str(args.bsz)
-save_root = f"checkpoints/qwen25math7b/t{temp_str}_lr{lr_str}_bsz{bsz_str}_seed{args.seed}"
-
-temperature = args.temperature
-learning_rate = args.lr
-batch_size = args.bsz
-micro_batch_size = 2
-world_size = int(os.environ.get("WORLD_SIZE", 1))
-accum_steps = max(1, batch_size // (micro_batch_size * world_size))
-
-DEEPSPEED_CONFIG = {
- "train_micro_batch_size_per_gpu": micro_batch_size,
- "train_batch_size": batch_size,
- "gradient_accumulation_steps": accum_steps,
- "bf16": {"enabled": True},
- "zero_optimization": {
- "stage": 2,
- "offload_optimizer": {"device": "cpu"},
- "offload_param": {"device": "none"}
- },
- "gradient_clipping": 1.0,
-}
-
-ds_plugin = DeepSpeedPlugin(hf_ds_config=DEEPSPEED_CONFIG)
-accelerator = Accelerator(
- mixed_precision="bf16",
- gradient_accumulation_steps=accum_steps,
- deepspeed_plugin=ds_plugin,
-)
-print = accelerator.print
-
-model_name = "Qwen2.5-Math-7B"
-model_path = f"/volume/ailab4sci/models/{model_name}"
-config = AutoConfig.from_pretrained(model_path)
-config.use_cache = False
-model = AutoModelForCausalLM.from_pretrained(model_path, config=config)
-model.gradient_checkpointing_enable()
-
-tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
-if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
+# 设置NCCL超时时间,避免分布式训练超时(设置为45分钟)
+os.environ.setdefault("NCCL_TIMEOUT", "2700") # 45分钟
+os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700")
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Entropy Minimization Fine-tuning")
+ parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name')
+ parser.add_argument('--model_path', type=str, default=None, help='Local model path')
+ parser.add_argument('--train_data', type=str, default='', help='Training data file path')
+ parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory')
+ parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size')
+ parser.add_argument('--micro_batch_size', type=str, default='auto', help='Micro batch size or "auto"')
+ parser.add_argument('--temperature', type=float, default=0.9, help='Temperature coefficient')
+ parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
+ parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval')
+ parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval')
+ parser.add_argument('--max_steps', type=int, default=1000, help='Maximum training steps')
+ parser.add_argument('--sample_temp', type=float, default=None, help='Generation temperature parameter')
+ parser.add_argument('--run_name', type=str, default=None, help='Experiment run name')
+ parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name')
+ parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name')
+ parser.add_argument('--enable_self_distill', action='store_true', help='Enable self-distillation MSE constraint')
+ parser.add_argument('--self_distill_gamma', type=float, default=0.0, help='Self-distillation MSE weight γ')
+ parser.add_argument('--seed', type=int, default=42, help='Random seed')
+ return parser.parse_args()
class FTDataset(Dataset):
def __init__(self, rows): self.rows = rows
@@ -75,83 +58,124 @@ class FTDataset(Dataset):
def custom_collate(batch):
return {"input": [item["input"] for item in batch]}
-def apply_chat_template(problem: str) -> str:
+def get_optimal_micro_batch_size(model_name: str, world_size: int = 1) -> int:
+ model_configs = {
+ "1.5B": {"base_batch": 4, "keywords": ["1.5B", "1B"]},
+ "2B": {"base_batch": 4, "keywords": ["2B"]},
+ "3B": {"base_batch": 2, "keywords": ["3B"]},
+ "7B": {"base_batch": 2, "keywords": ["7B"]},
+ "8B+": {"base_batch": 1, "keywords": ["8B", "9B", "10B", "11B", "12B", "13B", "14B"]},
+ }
+ model_name_upper = model_name.upper()
+ detected = next((cfg for cfg in model_configs.values() if any(k in model_name_upper for k in cfg["keywords"])), None)
+ base_batch = detected["base_batch"] if detected else 2
+ if world_size > 1:
+ return min(base_batch + 1, int(base_batch * 1.5))
+ return base_batch
+
+def adjust_batch_config_for_deepspeed(effective_batch: int, micro_batch_size: int, world_size: int):
+ ideal = effective_batch / (micro_batch_size * world_size)
+ if ideal.is_integer():
+ return effective_batch, micro_batch_size, int(ideal)
+ down = max(1, int(ideal))
+ up = down + 1
+ down_eff = micro_batch_size * down * world_size
+ up_eff = micro_batch_size * up * world_size
+ use_up = abs(effective_batch - up_eff) < abs(effective_batch - down_eff)
+ steps = up if use_up else down
+ eff = up_eff if use_up else down_eff
+ return eff, micro_batch_size, steps
+
+def apply_chat_template(tokenizer, problem: str) -> str:
return tokenizer.apply_chat_template(
[{"role": "user", "content": problem}],
tokenize=False, add_generation_prompt=True
)
-df = pd.read_parquet(args.data_path)
-data = [{"input": apply_chat_template(p)} for p in df["problem"].dropna().tolist()]
-dataset = FTDataset(data)
-data_loader = DataLoader(
- dataset,
- batch_size=micro_batch_size,
- shuffle=True,
- collate_fn=custom_collate,
-)
-
-optimizer = AdamW(model.parameters(), lr=learning_rate)
-model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
-
-model.train()
-for step, batch in enumerate(data_loader, start=1):
- if step > 30:
- break
-
- with accelerator.accumulate(model):
- enc = tokenizer(
- batch["input"],
- return_tensors="pt",
- padding="longest",
- truncation=True,
- max_length=2048,
- ).to(accelerator.device)
-
- with torch.no_grad():
- gen_ids = accelerator.unwrap_model(model).generate(
- **enc,
- max_new_tokens=2048,
- do_sample=True,
- top_p=0.95,
- temperature=temperature,
- synced_gpus=True,
- repetition_penalty=1.15,
- pad_token_id=151643,
- use_cache=False
- )
-
- seq = torch.cat(
- [enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]],
- dim=1
- )[:, :4096]
-
- pad_mask = seq.ne(tokenizer.pad_token_id)
- prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1)
- token_idx = torch.arange(seq.size(1), device=seq.device)
- gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask
-
- logits = model(seq, attention_mask=pad_mask).logits
- probs = F.softmax(logits / temperature, dim=-1)
- H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1)
- loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1)
-
- accelerator.backward(loss)
- accelerator.clip_grad_norm_(model.parameters(), 1.0)
- optimizer.step(); optimizer.zero_grad()
- print(f"Step {step} | loss={loss.item():.8f}")
+def main():
+ args = parse_args()
+ set_seed(args.seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
+ micro_bs = int(args.micro_batch_size) if args.micro_batch_size != 'auto' else get_optimal_micro_batch_size(args.model_name, world_size)
+ eff_bs, micro_bs, accum_steps = adjust_batch_config_for_deepspeed(args.effective_batch, micro_bs, world_size)
+
+ temp = args.temperature
+ lr = args.learning_rate
+
+ save_root = args.save_root or (f"checkpoints/{args.model_name}/{args.run_name}" if args.run_name else f"checkpoints/{args.model_name}")
+
+ ds_config = {
+ "train_micro_batch_size_per_gpu": micro_bs,
+ "train_batch_size": eff_bs,
+ "gradient_accumulation_steps": accum_steps,
+ "bf16": {"enabled": True},
+ "zero_optimization": {"stage": 2, "offload_optimizer": {"device": "cpu"}, "offload_param": {"device": "none"}},
+ "gradient_clipping": 1.0,
+ }
+ accelerator = Accelerator(mixed_precision="bf16", gradient_accumulation_steps=accum_steps, deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config))
+ print = accelerator.print
+
+ model_path = args.model_path or f"/volume/ailab4sci/models/{args.model_name}"
+ config = AutoConfig.from_pretrained(model_path); config.use_cache = False
+ model = AutoModelForCausalLM.from_pretrained(model_path, config=config); model.gradient_checkpointing_enable()
+ tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
+ if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
+
+ if accelerator.is_main_process:
+ wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args))
+
+ df = pd.read_parquet(args.train_data)
+ train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()]
+ train_loader = DataLoader(FTDataset(train_data), batch_size=micro_bs, shuffle=True, collate_fn=custom_collate)
+
+ optimizer = AdamW(model.parameters(), lr=lr)
+ model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
+
+ prev_logits = None
+ model.train()
+ for step, batch in enumerate(train_loader, start=1):
+ if step > args.max_steps:
+ print(f"达到最大步数 {args.max_steps},停止训练"); break
+ with accelerator.accumulate(model):
+ enc = tokenizer(batch["input"], return_tensors="pt", padding="longest", truncation=True, max_length=2048).to(accelerator.device)
+ with torch.no_grad():
+ gen_ids = accelerator.unwrap_model(model).generate(**enc, max_new_tokens=512, do_sample=True, top_p=0.95, temperature=args.sample_temp if args.sample_temp is not None else temp, synced_gpus=True, pad_token_id=tokenizer.pad_token_id, use_cache=False)
+ seq = torch.cat([enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]], dim=1)[:, :4096]
+ pad_mask = seq.ne(tokenizer.pad_token_id); prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1)
+ token_idx = torch.arange(seq.size(1), device=seq.device)
+ gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask
+
+ logits = model(seq, attention_mask=pad_mask).logits
+ probs = F.softmax(logits / temp, dim=-1)
+ H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1)
+ loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1)
+
+ if args.enable_self_distill and prev_logits is not None:
+ loss = loss + args.self_distill_gamma * F.mse_loss(logits, prev_logits)
+
+ prev_logits = logits.detach()
+ accelerator.backward(loss); accelerator.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step(); optimizer.zero_grad()
+
+ if accelerator.is_main_process:
+ if step % args.log_steps == 0:
+ print(f"Step {step} | loss={loss.item():.6f}")
+ wandb.log({"step": step, "loss": loss.item()})
+ # 根据 save_steps 保存检查点
+ if step % args.save_steps == 0:
+ ckpt = Path(save_root) / f"step_{step}"
+ ckpt.mkdir(parents=True, exist_ok=True)
+ accelerator.unwrap_model(model).save_pretrained(ckpt, safe_serialization=True)
+ tokenizer.save_pretrained(ckpt)
+ print(f"Checkpoint saved to {ckpt}")
if accelerator.is_main_process:
- ckpt_dir = Path(save_root) / f"step_{step}"
- ckpt_dir.mkdir(parents=True, exist_ok=True)
- accelerator.unwrap_model(model).save_pretrained(ckpt_dir, safe_serialization=True)
- tokenizer.save_pretrained(ckpt_dir)
- accelerator.wait_for_everyone()
- print(f"Checkpoint saved to {ckpt_dir}")
-
-if accelerator.is_main_process:
- final_dir = Path(save_root) / "final"
- final_dir.mkdir(parents=True, exist_ok=True)
- accelerator.unwrap_model(model).save_pretrained(final_dir, safe_serialization=True)
- tokenizer.save_pretrained(final_dir)
- print(f"Final checkpoint saved to {final_dir}") \ No newline at end of file
+ final=Path(save_root)/"final"; final.mkdir(parents=True,exist_ok=True)
+ accelerator.unwrap_model(model).save_pretrained(final,safe_serialization=True); tokenizer.save_pretrained(final)
+ print(f"Final checkpoint saved to {final}"); wandb.finish()
+
+if __name__ == "__main__":
+ main() \ No newline at end of file