summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorZitian Gao <zitian.gao@outlook.com>2025-06-01 20:21:26 +0800
committerGitHub <noreply@github.com>2025-06-01 20:21:26 +0800
commitf368e1296a535cb7505e6e4a1406677d926300b7 (patch)
treeeed5b60c39ae1bd2396636c9ed7f8fac674dbf0e /train.py
parent2daf8089388a776c98caf3bed50e81c3c84a7c59 (diff)
update wandb
Diffstat (limited to 'train.py')
-rw-r--r--train.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/train.py b/train.py
index 9d5828f..438c582 100644
--- a/train.py
+++ b/train.py
@@ -37,8 +37,8 @@ def parse_args():
parser.add_argument('--max_steps', type=int, default=1000, help='Maximum training steps')
parser.add_argument('--sample_temp', type=float, default=0.5, help='Generation temperature parameter')
parser.add_argument('--run_name', type=str, default=None, help='Experiment run name')
- # parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name')
- # parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name')
+ parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name')
+ parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name')
parser.add_argument('--seed', type=int, default=42, help='Random seed')
return parser.parse_args()
@@ -110,8 +110,8 @@ def main():
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
- # if accelerator.is_main_process:
- # wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args))
+ if accelerator.is_main_process:
+ wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args))
df = pd.read_parquet(args.train_data)
train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()]
@@ -164,7 +164,7 @@ def main():
if accelerator.is_main_process:
if step % args.log_steps == 0:
print(f"Step {step} | loss={loss.item():.6f}")
- # wandb.log({"step": step, "loss": loss.item()})
+ wandb.log({"step": step, "loss": loss.item()})
if step % args.save_steps == 0:
ckpt = Path(save_root) / f"step_{step}"
@@ -179,7 +179,7 @@ def main():
accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True)
tokenizer.save_pretrained(final)
print(f"Final checkpoint saved to {final}")
- # wandb.finish()
+ wandb.finish()
if __name__ == "__main__":
main()