diff options
| author | haoyuren <13851610112@163.com> | 2026-02-22 12:06:23 -0600 |
|---|---|---|
| committer | haoyuren <13851610112@163.com> | 2026-02-22 12:06:23 -0600 |
| commit | 800e1f1f33d93cb7a1812dff1dc0ef85289ef075 (patch) | |
| tree | 7a72228f3cb639046269a89931f8c8cebf33ee84 /train.py | |
| parent | dda6db0777620f8139bd476e27e6b275c0679358 (diff) | |
Auto-calibrate collect_batch when not specified
Benchmarks batch sizes [64,128,256,512] and picks smallest
within 10% of peak throughput. Smaller batches = more frequent
PPO updates = better training quality at similar speed.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'train.py')
| -rw-r--r-- | train.py | 42 |
1 files changed, 40 insertions, 2 deletions
@@ -277,15 +277,46 @@ def greedy_warmup(model: PolicyValueNet, optimizer: torch.optim.Optimizer, # --------------------------------------------------------------------------- +# Auto-calibrate collect_batch +# --------------------------------------------------------------------------- +def calibrate_collect_batch(num_players: int, model: PolicyValueNet, device="cpu"): + """Find optimal parallel game batch size by benchmarking throughput. + + Strategy: test increasing batch sizes, pick the smallest one whose + throughput is within 10% of the peak. Smaller batches mean more frequent + PPO updates which is better for training quality. + """ + import time + candidates = [64, 128, 256, 512] + rates = [] + + print("Auto-calibrating collect_batch...") + for size in candidates: + t0 = time.time() + collect_games_batch(size, num_players, model, device) + elapsed = time.time() - t0 + rate = size / elapsed + rates.append(rate) + print(f" batch={size}: {rate:.0f} games/s") + + peak = max(rates) + threshold = peak * 0.9 # within 10% of peak + for size, rate in zip(candidates, rates): + if rate >= threshold: + print(f" Selected: {size} ({rate:.0f} games/s, peak={peak:.0f})") + return size + + return candidates[-1] + + +# --------------------------------------------------------------------------- # Training Loop # --------------------------------------------------------------------------- def train(args): train_device = "cuda" if torch.cuda.is_available() else "cpu" collect_device = "cpu" # env simulation always on CPU print(f"Train device: {train_device}, Collect device: {collect_device}") - collect_batch = args.collect_batch if args.collect_batch is not None else args.update_every print(f"Training for {args.num_players} players, {args.episodes} episodes") - print(f"Batch collection: {collect_batch} games per batch") # Model lives on CPU for game collection; moves to GPU for PPO updates model = PolicyValueNet().to(collect_device) @@ -300,6 +331,13 @@ def train(args): if train_device != collect_device: model.to(collect_device) + # Determine collect_batch size + if args.collect_batch is not None: + collect_batch = args.collect_batch + print(f"Batch collection: {collect_batch} games per batch") + else: + collect_batch = calibrate_collect_batch(args.num_players, model, collect_device) + # Training log log_path = f"{args.save_path}_log.csv" with open(log_path, "w") as f: |
