From 16815c8c5ec263c4bd1a0af60030c1c0efa1421e Mon Sep 17 00:00:00 2001 From: zitian-gao Date: Tue, 27 May 2025 16:41:24 +0800 Subject: update --- train.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) (limited to 'train.py') diff --git a/train.py b/train.py index 557b209..6413e4b 100644 --- a/train.py +++ b/train.py @@ -11,11 +11,11 @@ from accelerate.utils import set_seed from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AdamW -parser = argparse.ArgumentParser(description="Train model with configurable hyperparameters") -parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature") -parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate") -parser.add_argument("--effective_batch", type=int, default=64, help="Effective batch size") -parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") +parser = argparse.ArgumentParser() +parser.add_argument("--temperature", type=float, default=0.5) +parser.add_argument("--lr", type=float, default=2e-5) +parser.add_argument("--bsz", type=int) +parser.add_argument("--seed", type=int, default=15) args = parser.parse_args() set_seed(args.seed) @@ -25,7 +25,7 @@ torch.backends.cudnn.benchmark = False temp_str = str(args.temperature).replace(".", "") lr_str = f"{args.lr:.0e}" bsz_str = str(args.effective_batch) -save_root = f"/volume/ailab4sci/ztgao/em/checkpoints_32b/t{temp_str}_lr{lr_str}_bsz{bsz_str}_seed{args.seed}" +save_root = f"/volume/ailab4sci/ztgao/em/checkpoints/qwen25math7b/t{temp_str}_lr{lr_str}_bsz{bsz_str}_seed{args.seed}" temperature = args.temperature learning_rate = args.lr @@ -55,7 +55,7 @@ accelerator = Accelerator( ) print = accelerator.print -model_name = "Qwen2.5-32B-Instruct" +model_name = "Qwen2.5-Math-7B" model_path = f"/volume/ailab4sci/models/{model_name}" config = AutoConfig.from_pretrained(model_path) config.use_cache = False @@ -116,7 +116,8 @@ for step, batch in enumerate(data_loader, start=1): temperature=temperature, synced_gpus=True, repetition_penalty=1.15, - pad_token_id=151643 + pad_token_id=151643, + use_cache=False ) seq = torch.cat( -- cgit v1.2.3