summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--train.py42
1 files changed, 40 insertions, 2 deletions
diff --git a/train.py b/train.py
index f47572b..31ce926 100644
--- a/train.py
+++ b/train.py
@@ -277,15 +277,46 @@ def greedy_warmup(model: PolicyValueNet, optimizer: torch.optim.Optimizer,
# ---------------------------------------------------------------------------
+# Auto-calibrate collect_batch
+# ---------------------------------------------------------------------------
+def calibrate_collect_batch(num_players: int, model: PolicyValueNet, device="cpu"):
+ """Find optimal parallel game batch size by benchmarking throughput.
+
+ Strategy: test increasing batch sizes, pick the smallest one whose
+ throughput is within 10% of the peak. Smaller batches mean more frequent
+ PPO updates which is better for training quality.
+ """
+ import time
+ candidates = [64, 128, 256, 512]
+ rates = []
+
+ print("Auto-calibrating collect_batch...")
+ for size in candidates:
+ t0 = time.time()
+ collect_games_batch(size, num_players, model, device)
+ elapsed = time.time() - t0
+ rate = size / elapsed
+ rates.append(rate)
+ print(f" batch={size}: {rate:.0f} games/s")
+
+ peak = max(rates)
+ threshold = peak * 0.9 # within 10% of peak
+ for size, rate in zip(candidates, rates):
+ if rate >= threshold:
+ print(f" Selected: {size} ({rate:.0f} games/s, peak={peak:.0f})")
+ return size
+
+ return candidates[-1]
+
+
+# ---------------------------------------------------------------------------
# Training Loop
# ---------------------------------------------------------------------------
def train(args):
train_device = "cuda" if torch.cuda.is_available() else "cpu"
collect_device = "cpu" # env simulation always on CPU
print(f"Train device: {train_device}, Collect device: {collect_device}")
- collect_batch = args.collect_batch if args.collect_batch is not None else args.update_every
print(f"Training for {args.num_players} players, {args.episodes} episodes")
- print(f"Batch collection: {collect_batch} games per batch")
# Model lives on CPU for game collection; moves to GPU for PPO updates
model = PolicyValueNet().to(collect_device)
@@ -300,6 +331,13 @@ def train(args):
if train_device != collect_device:
model.to(collect_device)
+ # Determine collect_batch size
+ if args.collect_batch is not None:
+ collect_batch = args.collect_batch
+ print(f"Batch collection: {collect_batch} games per batch")
+ else:
+ collect_batch = calibrate_collect_batch(args.num_players, model, collect_device)
+
# Training log
log_path = f"{args.save_path}_log.csv"
with open(log_path, "w") as f: