summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--train.py46
1 files changed, 25 insertions, 21 deletions
diff --git a/train.py b/train.py
index 1033f7c..71da896 100644
--- a/train.py
+++ b/train.py
@@ -1,10 +1,3 @@
-#!/usr/bin/env python3
-"""
-train.py
-
-One-shot Entropy Minimization Fine-tuning Script
-"""
-
import argparse
import os
import random
@@ -24,12 +17,12 @@ from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.utils import set_seed
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AdamW
-# 设置NCCL超时时间,避免分布式训练超时(设置为45分钟)
-os.environ.setdefault("NCCL_TIMEOUT", "2700") # 45分钟
+
+os.environ.setdefault("NCCL_TIMEOUT", "2700")
os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700")
def parse_args():
- parser = argparse.ArgumentParser(description="Entropy Minimization Fine-tuning")
+ parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name')
parser.add_argument('--model_path', type=str, default=None, help='Local model path')
parser.add_argument('--train_data', type=str, default='', help='Training data file path')
@@ -41,7 +34,7 @@ def parse_args():
parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval')
parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval')
parser.add_argument('--max_steps', type=int, default=1000, help='Maximum training steps')
- parser.add_argument('--sample_temp', type=float, default=None, help='Generation temperature parameter')
+ parser.add_argument('--sample_temp', type=float, default=0.5, help='Generation temperature parameter')
parser.add_argument('--run_name', type=str, default=None, help='Experiment run name')
parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name')
parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name')
@@ -101,12 +94,10 @@ def main():
world_size = int(os.getenv("WORLD_SIZE", "1"))
micro_bs = int(args.micro_batch_size) if args.micro_batch_size != 'auto' else get_optimal_micro_batch_size(args.model_name, world_size)
eff_bs, micro_bs, accum_steps = adjust_batch_config_for_deepspeed(args.effective_batch, micro_bs, world_size)
-
temp = args.temperature
lr = args.learning_rate
save_root = args.save_root or (f"checkpoints/{args.model_name}/{args.run_name}" if args.run_name else f"checkpoints/{args.model_name}")
-
ds_config = {
"train_micro_batch_size_per_gpu": micro_bs,
"train_batch_size": eff_bs,
@@ -122,7 +113,7 @@ def main():
config = AutoConfig.from_pretrained(model_path); config.use_cache = False
model = AutoModelForCausalLM.from_pretrained(model_path, config=config); model.gradient_checkpointing_enable()
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
- if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
if accelerator.is_main_process:
wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args))
@@ -133,16 +124,26 @@ def main():
optimizer = AdamW(model.parameters(), lr=lr)
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
-
prev_logits = None
model.train()
+
for step, batch in enumerate(train_loader, start=1):
if step > args.max_steps:
- print(f"达到最大步数 {args.max_steps},停止训练"); break
+ print(f"Exceed max step {args.max_steps}, training stopped.")
+ break
+
with accelerator.accumulate(model):
enc = tokenizer(batch["input"], return_tensors="pt", padding="longest", truncation=True, max_length=2048).to(accelerator.device)
+
with torch.no_grad():
- gen_ids = accelerator.unwrap_model(model).generate(**enc, max_new_tokens=512, do_sample=True, top_p=0.95, temperature=args.sample_temp if args.sample_temp is not None else temp, synced_gpus=True, pad_token_id=tokenizer.pad_token_id, use_cache=False)
+ gen_ids = accelerator.unwrap_model(model).generate(**enc,
+ max_new_tokens=512, do_sample=True,
+ top_p=0.95,
+ temperature=args.sample_temp,
+ synced_gpus=True,
+ pad_token_id=tokenizer.pad_token_id,
+ use_cache=False)
+
seq = torch.cat([enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]], dim=1)[:, :4096]
pad_mask = seq.ne(tokenizer.pad_token_id); prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1)
token_idx = torch.arange(seq.size(1), device=seq.device)
@@ -164,7 +165,7 @@ def main():
if step % args.log_steps == 0:
print(f"Step {step} | loss={loss.item():.6f}")
wandb.log({"step": step, "loss": loss.item()})
- # 根据 save_steps 保存检查点
+
if step % args.save_steps == 0:
ckpt = Path(save_root) / f"step_{step}"
ckpt.mkdir(parents=True, exist_ok=True)
@@ -173,9 +174,12 @@ def main():
print(f"Checkpoint saved to {ckpt}")
if accelerator.is_main_process:
- final=Path(save_root)/"final"; final.mkdir(parents=True,exist_ok=True)
- accelerator.unwrap_model(model).save_pretrained(final,safe_serialization=True); tokenizer.save_pretrained(final)
- print(f"Final checkpoint saved to {final}"); wandb.finish()
+ final = Path(save_root) / "final"
+ final.mkdir(parents=True, exist_ok=True)
+ accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True)
+ tokenizer.save_pretrained(final)
+ print(f"Final checkpoint saved to {final}")
+ wandb.finish()
if __name__ == "__main__":
main() \ No newline at end of file