summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorZitian Gao <zitian.gao@outlook.com>2025-05-31 14:59:54 +0800
committerGitHub <noreply@github.com>2025-05-31 14:59:54 +0800
commit2daf8089388a776c98caf3bed50e81c3c84a7c59 (patch)
treee48801127d17436702be0e6392289ebcab753699 /train.py
parenta79369af2a5d939892ace45bc70f1d6cb0b8b2f8 (diff)
fix bugs
Diffstat (limited to 'train.py')
-rw-r--r--train.py30
1 files changed, 14 insertions, 16 deletions
diff --git a/train.py b/train.py
index 54ad33c..9d5828f 100644
--- a/train.py
+++ b/train.py
@@ -7,15 +7,16 @@ from pathlib import Path
import psutil
import torch
import torch.nn.functional as F
+from torch.optim import AdamW
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
-import wandb
+# import wandb
from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.utils import set_seed
-from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AdamW
+from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
os.environ.setdefault("NCCL_TIMEOUT", "2700")
@@ -23,9 +24,9 @@ os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700")
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name')
+ parser.add_argument('--model_name', type=str, default='Qwen3-8B', help='Model name')
parser.add_argument('--model_path', type=str, default=None, help='Local model path')
- parser.add_argument('--train_data', type=str, default='', help='Training data file path')
+ parser.add_argument('--train_data', type=str, default='dataset/1shot_rlvr/pi1_r1280.parquet', help='Training data file path')
parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory')
parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size')
parser.add_argument('--micro_batch_size', type=str, default=2, help='Micro batch size or "auto"')
@@ -33,12 +34,12 @@ def parse_args():
parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval')
parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval')
- parser.add_argument('--max_steps', type=int, default=50, help='Maximum training steps')
+ parser.add_argument('--max_steps', type=int, default=1000, help='Maximum training steps')
parser.add_argument('--sample_temp', type=float, default=0.5, help='Generation temperature parameter')
parser.add_argument('--run_name', type=str, default=None, help='Experiment run name')
- parser.add_argument('--wandb_project', type=str, default='one-shot-em', help='W&B project name')
- parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name')
- parser.add_argument('--seed', type=int, default=15, help='Random seed')
+ # parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name')
+ # parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name')
+ parser.add_argument('--seed', type=int, default=42, help='Random seed')
return parser.parse_args()
class FTDataset(Dataset):
@@ -101,7 +102,7 @@ def main():
deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config))
print = accelerator.print
- model_path = args.model_path or f"/volume/ailab4sci/models/{args.model_name}"
+ model_path = args.model_path or f"/volume/pt-train/models/{args.model_name}"
config = AutoConfig.from_pretrained(model_path)
config.use_cache = False
model = AutoModelForCausalLM.from_pretrained(model_path, config=config)
@@ -109,8 +110,8 @@ def main():
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
- if accelerator.is_main_process:
- wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args))
+ # if accelerator.is_main_process:
+ # wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args))
df = pd.read_parquet(args.train_data)
train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()]
@@ -154,9 +155,6 @@ def main():
H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1)
loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1)
- if args.enable_self_distill and prev_logits is not None:
- loss = loss + args.self_distill_gamma * F.mse_loss(logits, prev_logits)
-
prev_logits = logits.detach()
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
@@ -166,7 +164,7 @@ def main():
if accelerator.is_main_process:
if step % args.log_steps == 0:
print(f"Step {step} | loss={loss.item():.6f}")
- wandb.log({"step": step, "loss": loss.item()})
+ # wandb.log({"step": step, "loss": loss.item()})
if step % args.save_steps == 0:
ckpt = Path(save_root) / f"step_{step}"
@@ -181,7 +179,7 @@ def main():
accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True)
tokenizer.save_pretrained(final)
print(f"Final checkpoint saved to {final}")
- wandb.finish()
+ # wandb.finish()
if __name__ == "__main__":
main()