summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorzitian-gao <zitian.gao@outlook.com>2025-05-27 16:41:24 +0800
committerzitian-gao <zitian.gao@outlook.com>2025-05-27 16:41:24 +0800
commit16815c8c5ec263c4bd1a0af60030c1c0efa1421e (patch)
tree42bcfd79a2fde61e848b841d3f51b275f4de46d6 /train.py
parent0f949c08df00e50893cd17416db62673041b5b31 (diff)
update
Diffstat (limited to 'train.py')
-rw-r--r--train.py17
1 files changed, 9 insertions, 8 deletions
diff --git a/train.py b/train.py
index 557b209..6413e4b 100644
--- a/train.py
+++ b/train.py
@@ -11,11 +11,11 @@ from accelerate.utils import set_seed
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AdamW
-parser = argparse.ArgumentParser(description="Train model with configurable hyperparameters")
-parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature")
-parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate")
-parser.add_argument("--effective_batch", type=int, default=64, help="Effective batch size")
-parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
+parser = argparse.ArgumentParser()
+parser.add_argument("--temperature", type=float, default=0.5)
+parser.add_argument("--lr", type=float, default=2e-5)
+parser.add_argument("--bsz", type=int)
+parser.add_argument("--seed", type=int, default=15)
args = parser.parse_args()
set_seed(args.seed)
@@ -25,7 +25,7 @@ torch.backends.cudnn.benchmark = False
temp_str = str(args.temperature).replace(".", "")
lr_str = f"{args.lr:.0e}"
bsz_str = str(args.effective_batch)
-save_root = f"/volume/ailab4sci/ztgao/em/checkpoints_32b/t{temp_str}_lr{lr_str}_bsz{bsz_str}_seed{args.seed}"
+save_root = f"/volume/ailab4sci/ztgao/em/checkpoints/qwen25math7b/t{temp_str}_lr{lr_str}_bsz{bsz_str}_seed{args.seed}"
temperature = args.temperature
learning_rate = args.lr
@@ -55,7 +55,7 @@ accelerator = Accelerator(
)
print = accelerator.print
-model_name = "Qwen2.5-32B-Instruct"
+model_name = "Qwen2.5-Math-7B"
model_path = f"/volume/ailab4sci/models/{model_name}"
config = AutoConfig.from_pretrained(model_path)
config.use_cache = False
@@ -116,7 +116,8 @@ for step, batch in enumerate(data_loader, start=1):
temperature=temperature,
synced_gpus=True,
repetition_penalty=1.15,
- pad_token_id=151643
+ pad_token_id=151643,
+ use_cache=False
)
seq = torch.cat(