From f368e1296a535cb7505e6e4a1406677d926300b7 Mon Sep 17 00:00:00 2001 From: Zitian Gao Date: Sun, 1 Jun 2025 20:21:26 +0800 Subject: update wandb --- train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 9d5828f..438c582 100644 --- a/train.py +++ b/train.py @@ -37,8 +37,8 @@ def parse_args(): parser.add_argument('--max_steps', type=int, default=1000, help='Maximum training steps') parser.add_argument('--sample_temp', type=float, default=0.5, help='Generation temperature parameter') parser.add_argument('--run_name', type=str, default=None, help='Experiment run name') - # parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name') - # parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name') + parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name') + parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name') parser.add_argument('--seed', type=int, default=42, help='Random seed') return parser.parse_args() @@ -110,8 +110,8 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left") tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token - # if accelerator.is_main_process: - # wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args)) + if accelerator.is_main_process: + wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args)) df = pd.read_parquet(args.train_data) train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()] @@ -164,7 +164,7 @@ def main(): if accelerator.is_main_process: if step % args.log_steps == 0: print(f"Step {step} | loss={loss.item():.6f}") - # wandb.log({"step": step, "loss": loss.item()}) + wandb.log({"step": step, "loss": loss.item()}) if step % args.save_steps == 0: ckpt = Path(save_root) / f"step_{step}" @@ -179,7 +179,7 @@ def main(): accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True) tokenizer.save_pretrained(final) print(f"Final checkpoint saved to {final}") - # wandb.finish() + wandb.finish() if __name__ == "__main__": main() -- cgit v1.2.3