summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorzitian-gao <zitian.gao@outlook.com>2025-05-27 19:22:44 +0800
committerzitian-gao <zitian.gao@outlook.com>2025-05-27 19:22:44 +0800
commit18c28c964e3e5f027d190aca3098a1c245f1f70b (patch)
tree0bde02fe30579890a7285ec4fa90a00694ef957a /train.py
parent95cec6f05bcb35fe1d368528337263d88f7f171f (diff)
update
Diffstat (limited to 'train.py')
-rw-r--r--train.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/train.py b/train.py
index 6413e4b..3092e10 100644
--- a/train.py
+++ b/train.py
@@ -14,8 +14,9 @@ from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AdamW
parser = argparse.ArgumentParser()
parser.add_argument("--temperature", type=float, default=0.5)
parser.add_argument("--lr", type=float, default=2e-5)
-parser.add_argument("--bsz", type=int)
+parser.add_argument("--bsz", type=int, default=64)
parser.add_argument("--seed", type=int, default=15)
+parser.add_argument("--data_path", type=str, default="dataset/1shot_rlvr/pi1_r1280.parquet")
args = parser.parse_args()
set_seed(args.seed)
@@ -24,19 +25,19 @@ torch.backends.cudnn.benchmark = False
temp_str = str(args.temperature).replace(".", "")
lr_str = f"{args.lr:.0e}"
-bsz_str = str(args.effective_batch)
-save_root = f"/volume/ailab4sci/ztgao/em/checkpoints/qwen25math7b/t{temp_str}_lr{lr_str}_bsz{bsz_str}_seed{args.seed}"
+bsz_str = str(args.bsz)
+save_root = f"checkpoints/qwen25math7b/t{temp_str}_lr{lr_str}_bsz{bsz_str}_seed{args.seed}"
temperature = args.temperature
learning_rate = args.lr
-effective_batch = args.effective_batch
+batch_size = args.bsz
micro_batch_size = 2
world_size = int(os.environ.get("WORLD_SIZE", 1))
-accum_steps = max(1, effective_batch // (micro_batch_size * world_size))
+accum_steps = max(1, batch_size // (micro_batch_size * world_size))
DEEPSPEED_CONFIG = {
"train_micro_batch_size_per_gpu": micro_batch_size,
- "train_batch_size": effective_batch,
+ "train_batch_size": batch_size,
"gradient_accumulation_steps": accum_steps,
"bf16": {"enabled": True},
"zero_optimization": {
@@ -80,7 +81,7 @@ def apply_chat_template(problem: str) -> str:
tokenize=False, add_generation_prompt=True
)
-df = pd.read_parquet("/volume/ailab4sci/ztgao/em/dataset/1shot_rlvr/pi1_r1280.parquet")
+df = pd.read_parquet(args.data_path)
data = [{"input": apply_chat_template(p)} for p in df["problem"].dropna().tolist()]
dataset = FTDataset(data)
data_loader = DataLoader(