diff options
| -rw-r--r-- | train.py | 12 |
1 files changed, 6 insertions, 6 deletions
@@ -37,8 +37,8 @@ def parse_args(): parser.add_argument('--max_steps', type=int, default=1000, help='Maximum training steps') parser.add_argument('--sample_temp', type=float, default=0.5, help='Generation temperature parameter') parser.add_argument('--run_name', type=str, default=None, help='Experiment run name') - # parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name') - # parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name') + parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name') + parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name') parser.add_argument('--seed', type=int, default=42, help='Random seed') return parser.parse_args() @@ -110,8 +110,8 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left") tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token - # if accelerator.is_main_process: - # wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args)) + if accelerator.is_main_process: + wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args)) df = pd.read_parquet(args.train_data) train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()] @@ -164,7 +164,7 @@ def main(): if accelerator.is_main_process: if step % args.log_steps == 0: print(f"Step {step} | loss={loss.item():.6f}") - # wandb.log({"step": step, "loss": loss.item()}) + wandb.log({"step": step, "loss": loss.item()}) if step % args.save_steps == 0: ckpt = Path(save_root) / f"step_{step}" @@ -179,7 +179,7 @@ def main(): accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True) tokenizer.save_pretrained(final) print(f"Final checkpoint saved to {final}") - # wandb.finish() + wandb.finish() if __name__ == "__main__": main() |
