summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--train.py22
1 files changed, 5 insertions, 17 deletions
diff --git a/train.py b/train.py
index 71da896..e23b23d 100644
--- a/train.py
+++ b/train.py
@@ -28,7 +28,7 @@ def parse_args():
parser.add_argument('--train_data', type=str, default='', help='Training data file path')
parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory')
parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size')
- parser.add_argument('--micro_batch_size', type=str, default='auto', help='Micro batch size or "auto"')
+ parser.add_argument('--micro_batch_size', type=str, default=2, help='Micro batch size or "auto"')
parser.add_argument('--temperature', type=float, default=0.9, help='Temperature coefficient')
parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval')
@@ -66,19 +66,6 @@ def get_optimal_micro_batch_size(model_name: str, world_size: int = 1) -> int:
return min(base_batch + 1, int(base_batch * 1.5))
return base_batch
-def adjust_batch_config_for_deepspeed(effective_batch: int, micro_batch_size: int, world_size: int):
- ideal = effective_batch / (micro_batch_size * world_size)
- if ideal.is_integer():
- return effective_batch, micro_batch_size, int(ideal)
- down = max(1, int(ideal))
- up = down + 1
- down_eff = micro_batch_size * down * world_size
- up_eff = micro_batch_size * up * world_size
- use_up = abs(effective_batch - up_eff) < abs(effective_batch - down_eff)
- steps = up if use_up else down
- eff = up_eff if use_up else down_eff
- return eff, micro_batch_size, steps
-
def apply_chat_template(tokenizer, problem: str) -> str:
return tokenizer.apply_chat_template(
[{"role": "user", "content": problem}],
@@ -92,8 +79,9 @@ def main():
torch.backends.cudnn.benchmark = False
world_size = int(os.getenv("WORLD_SIZE", "1"))
- micro_bs = int(args.micro_batch_size) if args.micro_batch_size != 'auto' else get_optimal_micro_batch_size(args.model_name, world_size)
- eff_bs, micro_bs, accum_steps = adjust_batch_config_for_deepspeed(args.effective_batch, micro_bs, world_size)
+ micro_bs = int(args.micro_batch_size)
+ eff_bs = args.effective_batch
+ accum_steps = max(1, eff_bs // (micro_bs * world_size))
temp = args.temperature
lr = args.learning_rate
@@ -182,4 +170,4 @@ def main():
wandb.finish()
if __name__ == "__main__":
- main() \ No newline at end of file
+ main()