diff options
| -rw-r--r-- | train.py | 4 |
1 files changed, 2 insertions, 2 deletions
@@ -24,13 +24,13 @@ os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700") def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--model_name', type=str, default='Qwen3-8B', help='Model name') + parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name') parser.add_argument('--model_path', type=str, default=None, help='Local model path') parser.add_argument('--train_data', type=str, default='dataset/1shot_rlvr/pi1_r1280.parquet', help='Training data file path') parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory') parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size') parser.add_argument('--micro_batch_size', type=str, default=2, help='Micro batch size or "auto"') - parser.add_argument('--temperature', type=float, default=0.9, help='Temperature coefficient') + parser.add_argument('--temperature', type=float, default=0.5, help='Temperature coefficient') parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate') parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval') parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval') |
