From 800e1f1f33d93cb7a1812dff1dc0ef85289ef075 Mon Sep 17 00:00:00 2001 From: haoyuren <13851610112@163.com> Date: Sun, 22 Feb 2026 12:06:23 -0600 Subject: Auto-calibrate collect_batch when not specified Benchmarks batch sizes [64,128,256,512] and picks smallest within 10% of peak throughput. Smaller batches = more frequent PPO updates = better training quality at similar speed. Co-Authored-By: Claude Opus 4.6 --- train.py | 42 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) (limited to 'train.py') diff --git a/train.py b/train.py index f47572b..31ce926 100644 --- a/train.py +++ b/train.py @@ -276,6 +276,39 @@ def greedy_warmup(model: PolicyValueNet, optimizer: torch.optim.Optimizer, print(f" Epoch {epoch+1}/{epochs}: loss={total_loss/n_batches:.4f}") +# --------------------------------------------------------------------------- +# Auto-calibrate collect_batch +# --------------------------------------------------------------------------- +def calibrate_collect_batch(num_players: int, model: PolicyValueNet, device="cpu"): + """Find optimal parallel game batch size by benchmarking throughput. + + Strategy: test increasing batch sizes, pick the smallest one whose + throughput is within 10% of the peak. Smaller batches mean more frequent + PPO updates which is better for training quality. + """ + import time + candidates = [64, 128, 256, 512] + rates = [] + + print("Auto-calibrating collect_batch...") + for size in candidates: + t0 = time.time() + collect_games_batch(size, num_players, model, device) + elapsed = time.time() - t0 + rate = size / elapsed + rates.append(rate) + print(f" batch={size}: {rate:.0f} games/s") + + peak = max(rates) + threshold = peak * 0.9 # within 10% of peak + for size, rate in zip(candidates, rates): + if rate >= threshold: + print(f" Selected: {size} ({rate:.0f} games/s, peak={peak:.0f})") + return size + + return candidates[-1] + + # --------------------------------------------------------------------------- # Training Loop # --------------------------------------------------------------------------- @@ -283,9 +316,7 @@ def train(args): train_device = "cuda" if torch.cuda.is_available() else "cpu" collect_device = "cpu" # env simulation always on CPU print(f"Train device: {train_device}, Collect device: {collect_device}") - collect_batch = args.collect_batch if args.collect_batch is not None else args.update_every print(f"Training for {args.num_players} players, {args.episodes} episodes") - print(f"Batch collection: {collect_batch} games per batch") # Model lives on CPU for game collection; moves to GPU for PPO updates model = PolicyValueNet().to(collect_device) @@ -300,6 +331,13 @@ def train(args): if train_device != collect_device: model.to(collect_device) + # Determine collect_batch size + if args.collect_batch is not None: + collect_batch = args.collect_batch + print(f"Batch collection: {collect_batch} games per batch") + else: + collect_batch = calibrate_collect_batch(args.num_players, model, collect_device) + # Training log log_path = f"{args.save_path}_log.csv" with open(log_path, "w") as f: -- cgit v1.2.3