From 496a954c3aae0357f24d76f88383b2c36bd49277 Mon Sep 17 00:00:00 2001 From: Zitian Gao Date: Thu, 29 May 2025 15:24:14 +0800 Subject: update params --- train.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) (limited to 'train.py') diff --git a/train.py b/train.py index 71da896..e23b23d 100644 --- a/train.py +++ b/train.py @@ -28,7 +28,7 @@ def parse_args(): parser.add_argument('--train_data', type=str, default='', help='Training data file path') parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory') parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size') - parser.add_argument('--micro_batch_size', type=str, default='auto', help='Micro batch size or "auto"') + parser.add_argument('--micro_batch_size', type=str, default=2, help='Micro batch size or "auto"') parser.add_argument('--temperature', type=float, default=0.9, help='Temperature coefficient') parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate') parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval') @@ -66,19 +66,6 @@ def get_optimal_micro_batch_size(model_name: str, world_size: int = 1) -> int: return min(base_batch + 1, int(base_batch * 1.5)) return base_batch -def adjust_batch_config_for_deepspeed(effective_batch: int, micro_batch_size: int, world_size: int): - ideal = effective_batch / (micro_batch_size * world_size) - if ideal.is_integer(): - return effective_batch, micro_batch_size, int(ideal) - down = max(1, int(ideal)) - up = down + 1 - down_eff = micro_batch_size * down * world_size - up_eff = micro_batch_size * up * world_size - use_up = abs(effective_batch - up_eff) < abs(effective_batch - down_eff) - steps = up if use_up else down - eff = up_eff if use_up else down_eff - return eff, micro_batch_size, steps - def apply_chat_template(tokenizer, problem: str) -> str: return tokenizer.apply_chat_template( [{"role": "user", "content": problem}], @@ -92,8 +79,9 @@ def main(): torch.backends.cudnn.benchmark = False world_size = int(os.getenv("WORLD_SIZE", "1")) - micro_bs = int(args.micro_batch_size) if args.micro_batch_size != 'auto' else get_optimal_micro_batch_size(args.model_name, world_size) - eff_bs, micro_bs, accum_steps = adjust_batch_config_for_deepspeed(args.effective_batch, micro_bs, world_size) + micro_bs = int(args.micro_batch_size) + eff_bs = args.effective_batch + accum_steps = max(1, eff_bs // (micro_bs * world_size)) temp = args.temperature lr = args.learning_rate @@ -182,4 +170,4 @@ def main(): wandb.finish() if __name__ == "__main__": - main() \ No newline at end of file + main() -- cgit v1.2.3