From 8392bdfc10f92e61303e39bb356522ee491ce97c Mon Sep 17 00:00:00 2001 From: haoyuren <13851610112@163.com> Date: Sun, 22 Feb 2026 11:56:48 -0600 Subject: Batched game collection for ~7x training speedup - collect_games_batch(): run N games in parallel with single batched forward pass per step - evaluate_vs_greedy_batch(): batched evaluation replacing sequential eval - Add --collect_batch CLI arg for configurable parallel game count - Use torch.inference_mode() for faster collection - Update Colab notebook: GPU info, --collect_batch, log download cell Co-Authored-By: Claude Opus 4.6 --- train.py | 364 +++++++++++++++++++++++++++++------------------------- train_colab.ipynb | 54 +++----- 2 files changed, 213 insertions(+), 205 deletions(-) diff --git a/train.py b/train.py index 65fb633..f47572b 100644 --- a/train.py +++ b/train.py @@ -4,7 +4,7 @@ PPO Self-Play Training for Blazing Eights. Architecture: - Single policy network shared across all seats - Self-play: all players use the same (latest) policy - - Collect trajectories by running full games + - Batched game collection: many games run in parallel with batched inference - Standard PPO update with masked invalid actions Usage: @@ -85,7 +85,7 @@ class PolicyValueNet(nn.Module): # --------------------------------------------------------------------------- -# Trajectory Collection +# Trajectory Storage # --------------------------------------------------------------------------- class Transition: __slots__ = ["obs", "action", "log_prob", "value", "reward", "done", "legal_mask"] @@ -108,56 +108,97 @@ def greedy_random_action(legal: list[int]) -> int: return int(np.random.choice(legal)) -def collect_game(env: BlazingEightsEnv, model: PolicyValueNet, device="cpu"): - """ - Play one full game, return per-player trajectories. - All players use the same model (self-play). +# --------------------------------------------------------------------------- +# Batched Game Collection +# --------------------------------------------------------------------------- +def collect_games_batch(num_games: int, num_players: int, model: PolicyValueNet, + device="cpu", max_steps=500): + """Run multiple games simultaneously with batched model inference. + + Instead of running games one-by-one (each step = batch_size=1 forward pass), + this runs all games in lockstep: at each step, all active games' observations + are batched into a single forward pass. + + Returns: + envs: list of completed environments (for reading winner/done) + trajectories: list of per-player trajectory dicts """ - obs = env.reset() - trajectories: dict[int, list[Transition]] = defaultdict(list) - max_steps = 500 + envs = [BlazingEightsEnv(num_players=num_players) for _ in range(num_games)] + obs_list = [env.reset() for env in envs] + trajectories = [defaultdict(list) for _ in range(num_games)] + active = set(range(num_games)) for _ in range(max_steps): - player = env.current_player - legal = env.legal_actions() - if not legal: + if not active: break - action, log_prob, value = model.get_action(obs, legal, device) - - # Build legal mask - legal_mask = np.zeros(TOTAL_ACTIONS, dtype=np.float32) - for a in legal: - legal_mask[a] = 1.0 - - obs_next, rewards, done, info = env.step(action) - - # Store transition for the acting player - trajectories[player].append(Transition( - obs=obs.copy(), - action=action, - log_prob=log_prob, - value=value, - reward=rewards[player], - done=done, - legal_mask=legal_mask, - )) - - # If done, also assign terminal rewards to other players' last transitions - if done: - for p in range(env.num_players): - if p != player and trajectories[p]: - trajectories[p][-1].reward = rewards[p] - trajectories[p][-1].done = True + # Gather observations and legal masks for all active games + indices = [] + batch_obs = [] + batch_masks = [] + batch_players = [] + + for i in sorted(active): + legal = envs[i].legal_actions() + if not legal: + active.discard(i) + continue + mask = np.zeros(TOTAL_ACTIONS, dtype=np.float32) + for a in legal: + mask[a] = 1.0 + indices.append(i) + batch_obs.append(obs_list[i]) + batch_masks.append(mask) + batch_players.append(envs[i].current_player) + + if not indices: break - obs = obs_next + # Single batched forward pass for all active games + obs_t = torch.tensor(np.array(batch_obs), dtype=torch.float32, device=device) + mask_t = torch.tensor(np.array(batch_masks), dtype=torch.float32, device=device) + + with torch.inference_mode(): + logits, values = model(obs_t, mask_t) + probs = F.softmax(logits, dim=-1) + dist = Categorical(probs) + actions = dist.sample() + log_probs = dist.log_prob(actions) + + actions_np = actions.cpu().numpy() + log_probs_np = log_probs.cpu().numpy() + values_np = values.cpu().numpy() + + # Step each environment + for j, i in enumerate(indices): + player = batch_players[j] + action = int(actions_np[j]) + obs_next, rewards, done, info = envs[i].step(action) + + trajectories[i][player].append(Transition( + obs=batch_obs[j], + action=action, + log_prob=float(log_probs_np[j]), + value=float(values_np[j]), + reward=rewards[player], + done=done, + legal_mask=batch_masks[j], + )) + + if done: + for p in range(envs[i].num_players): + if p != player and trajectories[i][p]: + trajectories[i][p][-1].reward = rewards[p] + trajectories[i][p][-1].done = True + active.discard(i) + else: + obs_list[i] = obs_next - return trajectories + return envs, trajectories # --------------------------------------------------------------------------- -# PPO Update +# PPO Utilities # --------------------------------------------------------------------------- def compute_gae(transitions: list[Transition], gamma=0.99, lam=0.95): """Compute GAE returns and advantages.""" @@ -184,75 +225,6 @@ def compute_gae(transitions: list[Transition], gamma=0.99, lam=0.95): return returns.tolist(), advantages.tolist() -def ppo_update(model: PolicyValueNet, optimizer: torch.optim.Optimizer, - all_transitions: list[Transition], device="cpu", - epochs=4, batch_size=256, clip_eps=0.2, vf_coef=0.5, ent_coef=0.01): - """PPO clipped surrogate update.""" - if not all_transitions: - return {} - - # Prepare tensors - obs_arr = np.array([t.obs for t in all_transitions]) - actions_arr = np.array([t.action for t in all_transitions]) - old_log_probs_arr = np.array([t.log_prob for t in all_transitions]) - masks_arr = np.array([t.legal_mask for t in all_transitions]) - - # Compute GAE (treat all transitions as one sequence — not ideal, but we - # already computed per-game, so we just concatenate pre-computed values) - returns_arr = np.array([t.reward for t in all_transitions]) # placeholder - advantages_arr = np.array([t.reward for t in all_transitions]) # placeholder - - obs_t = torch.tensor(obs_arr, dtype=torch.float32, device=device) - actions_t = torch.tensor(actions_arr, dtype=torch.long, device=device) - old_log_probs_t = torch.tensor(old_log_probs_arr, dtype=torch.float32, device=device) - masks_t = torch.tensor(masks_arr, dtype=torch.float32, device=device) - returns_t = torch.tensor(returns_arr, dtype=torch.float32, device=device) - advantages_t = torch.tensor(advantages_arr, dtype=torch.float32, device=device) - - # Normalize advantages - if len(advantages_t) > 1: - advantages_t = (advantages_t - advantages_t.mean()) / (advantages_t.std() + 1e-8) - - total_loss_sum = 0 - n_updates = 0 - - for _ in range(epochs): - indices = np.arange(len(all_transitions)) - np.random.shuffle(indices) - - for start in range(0, len(indices), batch_size): - end = min(start + batch_size, len(indices)) - idx = indices[start:end] - - b_obs = obs_t[idx] - b_actions = actions_t[idx] - b_old_lp = old_log_probs_t[idx] - b_masks = masks_t[idx] - b_returns = returns_t[idx] - b_advantages = advantages_t[idx] - - new_log_probs, values, entropy = model.evaluate(b_obs, b_masks, b_actions) - - ratio = torch.exp(new_log_probs - b_old_lp) - surr1 = ratio * b_advantages - surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * b_advantages - policy_loss = -torch.min(surr1, surr2).mean() - - value_loss = F.mse_loss(values, b_returns) - - loss = policy_loss + vf_coef * value_loss - ent_coef * entropy.mean() - - optimizer.zero_grad() - loss.backward() - nn.utils.clip_grad_norm_(model.parameters(), 0.5) - optimizer.step() - - total_loss_sum += loss.item() - n_updates += 1 - - return {"loss": total_loss_sum / max(n_updates, 1)} - - # --------------------------------------------------------------------------- # Greedy Warmup (Behavioral Cloning) # --------------------------------------------------------------------------- @@ -311,7 +283,9 @@ def train(args): train_device = "cuda" if torch.cuda.is_available() else "cpu" collect_device = "cpu" # env simulation always on CPU print(f"Train device: {train_device}, Collect device: {collect_device}") + collect_batch = args.collect_batch if args.collect_batch is not None else args.update_every print(f"Training for {args.num_players} players, {args.episodes} episodes") + print(f"Batch collection: {collect_batch} games per batch") # Model lives on CPU for game collection; moves to GPU for PPO updates model = PolicyValueNet().to(collect_device) @@ -333,30 +307,44 @@ def train(args): # Stats win_counts = defaultdict(int) - game_lengths = [] - batch_transitions = [] + all_game_lengths = [] recent_loss = 0.0 recent_loss_count = 0 - pbar = tqdm(range(1, args.episodes + 1), desc="Training", unit="ep") - for ep in pbar: - env = BlazingEightsEnv(num_players=args.num_players) - trajectories = collect_game(env, model, collect_device) - - # Record stats - if env.done: - win_counts[env.winner] += 1 - game_lengths.append(sum(len(v) for v in trajectories.values())) - - # Compute GAE per player and collect - for player, trans_list in trajectories.items(): - returns, advantages = compute_gae(trans_list, gamma=args.gamma, lam=args.lam) - for i, t in enumerate(trans_list): - t.reward = returns[i] if i < len(returns) else t.reward - batch_transitions.extend(trans_list) - - # Update every `update_every` episodes - if ep % args.update_every == 0 and batch_transitions: + ep = 0 + next_log = args.log_every + next_eval = args.eval_every + next_save = args.save_every + pbar = tqdm(total=args.episodes, desc="Training", unit="ep") + + while ep < args.episodes: + games_this_batch = min(collect_batch, args.episodes - ep) + + # Collect games in parallel with batched inference + envs, batch_trajectories = collect_games_batch( + games_this_batch, args.num_players, model, collect_device + ) + + # Process trajectories: compute GAE and collect all transitions + batch_transitions = [] + for i in range(games_this_batch): + env = envs[i] + traj = batch_trajectories[i] + if env.done: + win_counts[env.winner] += 1 + all_game_lengths.append(sum(len(v) for v in traj.values())) + + for player, trans_list in traj.items(): + returns, advantages = compute_gae(trans_list, gamma=args.gamma, lam=args.lam) + for k, t in enumerate(trans_list): + t.reward = returns[k] if k < len(returns) else t.reward + batch_transitions.extend(trans_list) + + ep += games_this_batch + pbar.update(games_this_batch) + + # PPO update + if batch_transitions: returns_for_update = np.array([t.reward for t in batch_transitions]) values_for_update = np.array([t.value for t in batch_transitions]) advs = returns_for_update - values_for_update @@ -380,15 +368,15 @@ def train(args): if len(advs_t) > 1: advs_t = (advs_t - advs_t.mean()) / (advs_t.std() + 1e-8) - # PPO update + # PPO clipped surrogate update batch_loss = 0.0 n_updates = 0 for _ in range(args.ppo_epochs): - indices = np.arange(len(batch_transitions)) - np.random.shuffle(indices) - for start in range(0, len(indices), args.batch_size): - end = min(start + args.batch_size, len(indices)) - idx = indices[start:end] + perm = np.arange(len(batch_transitions)) + np.random.shuffle(perm) + for start in range(0, len(perm), args.batch_size): + end = min(start + args.batch_size, len(perm)) + idx = perm[start:end] b_obs = obs_t[idx] b_actions = actions_t[idx] @@ -420,11 +408,9 @@ def train(args): if train_device != collect_device: model.to(collect_device) - batch_transitions = [] - # Logging - if ep % args.log_every == 0: - avg_len = np.mean(game_lengths[-args.log_every:]) if game_lengths else 0 + if ep >= next_log: + avg_len = np.mean(all_game_lengths[-args.log_every:]) if all_game_lengths else 0 avg_loss = recent_loss / max(recent_loss_count, 1) total_games = sum(win_counts.values()) wr0 = win_counts[0] / max(total_games, 1) @@ -432,19 +418,21 @@ def train(args): wr0=f"{wr0:.1%}", games=total_games) recent_loss = 0.0 recent_loss_count = 0 + next_log += args.log_every - # Evaluate vs greedy + write log every eval_every episodes - if ep % args.eval_every == 0: - avg_len = np.mean(game_lengths[-args.eval_every:]) if game_lengths else 0 + # Evaluate vs greedy + write log + if ep >= next_eval: + avg_len = np.mean(all_game_lengths[-args.eval_every:]) if all_game_lengths else 0 avg_loss_log = recent_loss / max(recent_loss_count, 1) if recent_loss_count > 0 else 0 - vs_wr = evaluate_vs_random(model, num_players=args.num_players, - num_games=500, device=collect_device) + vs_wr = evaluate_vs_greedy_batch(model, num_players=args.num_players, + num_games=500, device=collect_device) with open(log_path, "a") as f: f.write(f"{ep},{avg_len:.1f},{avg_loss_log:.4f},{vs_wr:.4f}\n") tqdm.write(f" [Eval ep{ep}] avg_len={avg_len:.1f} vs_greedy={vs_wr:.1%}") + next_eval += args.eval_every # Save checkpoint - if ep % args.save_every == 0: + if ep >= next_save: path = f"{args.save_path}_ep{ep}.pt" torch.save({ "model": model.state_dict(), @@ -453,6 +441,7 @@ def train(args): "num_players": args.num_players, }, path) tqdm.write(f" Saved checkpoint: {path}") + next_save += args.save_every # Final save torch.save({ @@ -468,29 +457,62 @@ def train(args): # --------------------------------------------------------------------------- -# Evaluation: play against random +# Evaluation: play against greedy (batched) # --------------------------------------------------------------------------- -def evaluate_vs_random(model: PolicyValueNet, num_players=2, num_games=1000, device="cpu"): - """Player 0 = model, others = random. Returns player 0 win rate.""" - wins = 0 - for _ in range(num_games): - env = BlazingEightsEnv(num_players=num_players) - obs = env.reset() - for _ in range(500): - player = env.current_player - legal = env.legal_actions() +def evaluate_vs_greedy_batch(model: PolicyValueNet, num_players=2, num_games=500, device="cpu"): + """Batched evaluation: player 0 = model, others = greedy random.""" + envs = [BlazingEightsEnv(num_players=num_players) for _ in range(num_games)] + obs_list = [env.reset() for env in envs] + active = set(range(num_games)) + + for _ in range(500): + if not active: + break + + # Separate model-controlled (player 0) and greedy-controlled turns + model_idx = [] + model_obs = [] + model_masks = [] + greedy_pairs = [] + + for i in sorted(active): + legal = envs[i].legal_actions() if not legal: - break - if player == 0: - action, _, _ = model.get_action(obs, legal, device) + active.discard(i) + continue + if envs[i].current_player == 0: + mask = np.zeros(TOTAL_ACTIONS, dtype=np.float32) + for a in legal: + mask[a] = 1.0 + model_idx.append(i) + model_obs.append(obs_list[i]) + model_masks.append(mask) else: - action = greedy_random_action(legal) - obs, rewards, done, info = env.step(action) + greedy_pairs.append((i, greedy_random_action(legal))) + + # Batched model inference for player 0 turns + if model_obs: + obs_t = torch.tensor(np.array(model_obs), dtype=torch.float32, device=device) + mask_t = torch.tensor(np.array(model_masks), dtype=torch.float32, device=device) + with torch.inference_mode(): + logits, _ = model(obs_t, mask_t) + actions = Categorical(F.softmax(logits, dim=-1)).sample().cpu().numpy() + for j, i in enumerate(model_idx): + obs_next, _, done, _ = envs[i].step(int(actions[j])) + if done: + active.discard(i) + else: + obs_list[i] = obs_next + + # Greedy actions for other players + for i, action in greedy_pairs: + obs_next, _, done, _ = envs[i].step(action) if done: - if env.winner == 0: - wins += 1 - break - return wins / num_games + active.discard(i) + else: + obs_list[i] = obs_next + + return sum(1 for e in envs if e.done and e.winner == 0) / num_games if __name__ == "__main__": @@ -509,15 +531,17 @@ if __name__ == "__main__": parser.add_argument("--eval_every", type=int, default=10000) parser.add_argument("--save_every", type=int, default=10000) parser.add_argument("--save_path", type=str, default="blazing_ppo") + parser.add_argument("--collect_batch", type=int, default=None, + help="Parallel game collection batch size (default: same as update_every)") parser.add_argument("--greedy_warmup", type=int, default=2000, help="Number of greedy games for behavioral cloning warmup (0 to skip)") args = parser.parse_args() model = train(args) - # Eval vs random - print("\nEvaluating vs random opponents...") + # Eval vs greedy + print("\nEvaluating vs greedy opponents...") for n in [2, 3, 4, 5]: - if n <= args.num_players + 1: # only eval for trained player count - wr = evaluate_vs_random(model, num_players=n, num_games=1000) + if n <= args.num_players + 1: + wr = evaluate_vs_greedy_batch(model, num_players=n, num_games=1000) print(f" {n} players: win rate = {wr:.1%} (random baseline: {1/n:.1%})") diff --git a/train_colab.ipynb b/train_colab.ipynb index afae94d..ac64e50 100644 --- a/train_colab.ipynb +++ b/train_colab.ipynb @@ -15,7 +15,7 @@ { "cell_type": "markdown", "metadata": {}, - "source": "# Blazing Eights - Colab GPU Training\n\nClone repo → Train PPO agent (CPU collect, GPU update) → Push trained model back to GitHub\n\n**Game**: UNO variant with custom special cards (8=Wild, K=All draw, J=Skip, Swap=Swap hands)." + "source": "# Blazing Eights - Colab GPU Training\n\nClone repo → Train PPO agent (batched collection on CPU, PPO on GPU) → Download model & logs\n\n**Game**: UNO variant with custom special cards (8=Wild, K=All draw, J=Skip, Swap=Swap hands)." }, { "cell_type": "markdown", @@ -34,60 +34,48 @@ { "cell_type": "code", "metadata": {}, - "source": [ - "import torch\n", - "print(f\"PyTorch: {torch.__version__}\")\n", - "print(f\"CUDA available: {torch.cuda.is_available()}\")\n", - "if torch.cuda.is_available():\n", - " print(f\"GPU: {torch.cuda.get_device_name(0)}\")" - ], + "source": "import torch\nprint(f\"PyTorch: {torch.__version__}\")\nprint(f\"CUDA available: {torch.cuda.is_available()}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n print(f\"Memory: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Go to Runtime → Change runtime type → GPU\")", "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, - "source": [ - "## 2. Train" - ] + "source": "## 2. Train\n\nBatched collection: runs many games in parallel with a single forward pass per step.\n- `--collect_batch`: number of parallel games (higher = faster, more VRAM). Default = 64.\n- Game simulation on CPU, PPO gradient updates on GPU (auto-detected)." }, { "cell_type": "code", "metadata": {}, - "source": "# 2-player training with greedy warmup + CSV logging\n# Game simulation on CPU, PPO updates on GPU automatically\n!python train.py --num_players 2 --episodes 200000 --save_path blazing_ppo_2p\n\n# Show training log\nimport pandas as pd\ndf = pd.read_csv(\"blazing_ppo_2p_log.csv\")\nprint(df.to_string(index=False))", + "source": "# 2-player training (GPU PPO + batched collection)\n!python train.py \\\n --num_players 2 \\\n --episodes 200000 \\\n --collect_batch 128 \\\n --save_path blazing_ppo_2p", "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": {}, - "source": "# (Optional) 3-player training\n# !python train.py --num_players 3 --episodes 300000 --save_path blazing_ppo_3p\n\n# (Optional) Skip greedy warmup\n# !python train.py --num_players 2 --episodes 200000 --greedy_warmup 0 --save_path blazing_ppo_2p_no_warmup", + "source": "# (Optional) 3-player training\n# !python train.py --num_players 3 --episodes 300000 --collect_batch 128 --save_path blazing_ppo_3p\n\n# (Optional) Larger batch for faster throughput\n# !python train.py --num_players 2 --episodes 200000 --collect_batch 256 --save_path blazing_ppo_2p\n\n# (Optional) Skip greedy warmup\n# !python train.py --num_players 2 --episodes 200000 --greedy_warmup 0 --save_path blazing_ppo_2p_no_warmup", "execution_count": null, "outputs": [] }, { - "cell_type": "markdown", + "cell_type": "code", + "source": "# Show training log\nimport pandas as pd\ndf = pd.read_csv(\"blazing_ppo_2p_log.csv\")\nprint(df.to_string(index=False))", "metadata": {}, - "source": [ - "## 3. Download model locally (Option A)\n", - "Download .pt files directly from Colab to your machine." - ] + "execution_count": null, + "outputs": [] }, { - "cell_type": "code", + "cell_type": "markdown", + "source": "## 3. Download model & logs", "metadata": {}, - "source": [ - "from google.colab import files\n", - "import glob\n", - "\n", - "# Download the final model\n", - "for f in glob.glob(\"*_final.pt\"):\n", - " print(f\"Downloading {f}...\")\n", - " files.download(f)" - ], "execution_count": null, "outputs": [] }, + { + "cell_type": "code", + "metadata": {}, + "source": "from google.colab import files\nimport glob\n\n# Download final model(s)\nfor f in glob.glob(\"*_final.pt\"):\n print(f\"Downloading {f}...\")\n files.download(f)\n\n# Download training log(s)\nfor f in glob.glob(\"*_log.csv\"):\n print(f\"Downloading {f}...\")\n files.download(f)" + }, { "cell_type": "markdown", "metadata": {}, @@ -135,17 +123,13 @@ }, { "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Quick evaluation" - ] + "source": "## 5. Quick evaluation", + "metadata": {} }, { "cell_type": "code", "metadata": {}, - "source": "import sys\nsys.path.insert(0, \".\")\nfrom train import PolicyValueNet, evaluate_vs_random\n\ndevice = \"cpu\" # eval on CPU (single-sample inference)\nmodel = PolicyValueNet().to(device)\n\nimport glob\nfinal_models = glob.glob(\"*_final.pt\") + glob.glob(\"models/*_final.pt\")\nif final_models:\n ckpt = torch.load(final_models[0], map_location=device, weights_only=True)\n model.load_state_dict(ckpt[\"model\"])\n model.eval()\n print(f\"Loaded: {final_models[0]}\")\n print(f\"Trained for {ckpt.get('episode', '?')} episodes\")\n print()\n\n for n in [2, 3, 4]:\n wr = evaluate_vs_random(model, num_players=n, num_games=2000, device=device)\n print(f\" {n} players: win rate = {wr:.1%} (random baseline: {1/n:.1%})\")\nelse:\n print(\"No model found. Train first!\")", - "execution_count": null, - "outputs": [] + "source": "import sys\nsys.path.insert(0, \".\")\nfrom train import PolicyValueNet, evaluate_vs_greedy_batch\n\ndevice = \"cpu\"\nmodel = PolicyValueNet().to(device)\n\nimport glob\nfinal_models = glob.glob(\"*_final.pt\") + glob.glob(\"models/*_final.pt\")\nif final_models:\n ckpt = torch.load(final_models[0], map_location=device, weights_only=True)\n model.load_state_dict(ckpt[\"model\"])\n model.eval()\n print(f\"Loaded: {final_models[0]}\")\n print(f\"Trained for {ckpt.get('episode', '?')} episodes\")\n print()\n\n for n in [2, 3, 4]:\n wr = evaluate_vs_greedy_batch(model, num_players=n, num_games=2000, device=device)\n print(f\" {n} players: win rate = {wr:.1%} (random baseline: {1/n:.1%})\")\nelse:\n print(\"No model found. Train first!\")" } ] } \ No newline at end of file -- cgit v1.2.3