summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/train.py b/train.py
index d9ab348..01f3ba6 100644
--- a/train.py
+++ b/train.py
@@ -24,13 +24,13 @@ os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700")
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument('--model_name', type=str, default='Qwen3-8B', help='Model name')
+ parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name')
parser.add_argument('--model_path', type=str, default=None, help='Local model path')
parser.add_argument('--train_data', type=str, default='dataset/1shot_rlvr/pi1_r1280.parquet', help='Training data file path')
parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory')
parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size')
parser.add_argument('--micro_batch_size', type=str, default=2, help='Micro batch size or "auto"')
- parser.add_argument('--temperature', type=float, default=0.9, help='Temperature coefficient')
+ parser.add_argument('--temperature', type=float, default=0.5, help='Temperature coefficient')
parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval')
parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval')