summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py364
1 files changed, 194 insertions, 170 deletions
diff --git a/train.py b/train.py
index 65fb633..f47572b 100644
--- a/train.py
+++ b/train.py
@@ -4,7 +4,7 @@ PPO Self-Play Training for Blazing Eights.
Architecture:
- Single policy network shared across all seats
- Self-play: all players use the same (latest) policy
- - Collect trajectories by running full games
+ - Batched game collection: many games run in parallel with batched inference
- Standard PPO update with masked invalid actions
Usage:
@@ -85,7 +85,7 @@ class PolicyValueNet(nn.Module):
# ---------------------------------------------------------------------------
-# Trajectory Collection
+# Trajectory Storage
# ---------------------------------------------------------------------------
class Transition:
__slots__ = ["obs", "action", "log_prob", "value", "reward", "done", "legal_mask"]
@@ -108,56 +108,97 @@ def greedy_random_action(legal: list[int]) -> int:
return int(np.random.choice(legal))
-def collect_game(env: BlazingEightsEnv, model: PolicyValueNet, device="cpu"):
- """
- Play one full game, return per-player trajectories.
- All players use the same model (self-play).
+# ---------------------------------------------------------------------------
+# Batched Game Collection
+# ---------------------------------------------------------------------------
+def collect_games_batch(num_games: int, num_players: int, model: PolicyValueNet,
+ device="cpu", max_steps=500):
+ """Run multiple games simultaneously with batched model inference.
+
+ Instead of running games one-by-one (each step = batch_size=1 forward pass),
+ this runs all games in lockstep: at each step, all active games' observations
+ are batched into a single forward pass.
+
+ Returns:
+ envs: list of completed environments (for reading winner/done)
+ trajectories: list of per-player trajectory dicts
"""
- obs = env.reset()
- trajectories: dict[int, list[Transition]] = defaultdict(list)
- max_steps = 500
+ envs = [BlazingEightsEnv(num_players=num_players) for _ in range(num_games)]
+ obs_list = [env.reset() for env in envs]
+ trajectories = [defaultdict(list) for _ in range(num_games)]
+ active = set(range(num_games))
for _ in range(max_steps):
- player = env.current_player
- legal = env.legal_actions()
- if not legal:
+ if not active:
break
- action, log_prob, value = model.get_action(obs, legal, device)
-
- # Build legal mask
- legal_mask = np.zeros(TOTAL_ACTIONS, dtype=np.float32)
- for a in legal:
- legal_mask[a] = 1.0
-
- obs_next, rewards, done, info = env.step(action)
-
- # Store transition for the acting player
- trajectories[player].append(Transition(
- obs=obs.copy(),
- action=action,
- log_prob=log_prob,
- value=value,
- reward=rewards[player],
- done=done,
- legal_mask=legal_mask,
- ))
-
- # If done, also assign terminal rewards to other players' last transitions
- if done:
- for p in range(env.num_players):
- if p != player and trajectories[p]:
- trajectories[p][-1].reward = rewards[p]
- trajectories[p][-1].done = True
+ # Gather observations and legal masks for all active games
+ indices = []
+ batch_obs = []
+ batch_masks = []
+ batch_players = []
+
+ for i in sorted(active):
+ legal = envs[i].legal_actions()
+ if not legal:
+ active.discard(i)
+ continue
+ mask = np.zeros(TOTAL_ACTIONS, dtype=np.float32)
+ for a in legal:
+ mask[a] = 1.0
+ indices.append(i)
+ batch_obs.append(obs_list[i])
+ batch_masks.append(mask)
+ batch_players.append(envs[i].current_player)
+
+ if not indices:
break
- obs = obs_next
+ # Single batched forward pass for all active games
+ obs_t = torch.tensor(np.array(batch_obs), dtype=torch.float32, device=device)
+ mask_t = torch.tensor(np.array(batch_masks), dtype=torch.float32, device=device)
+
+ with torch.inference_mode():
+ logits, values = model(obs_t, mask_t)
+ probs = F.softmax(logits, dim=-1)
+ dist = Categorical(probs)
+ actions = dist.sample()
+ log_probs = dist.log_prob(actions)
+
+ actions_np = actions.cpu().numpy()
+ log_probs_np = log_probs.cpu().numpy()
+ values_np = values.cpu().numpy()
+
+ # Step each environment
+ for j, i in enumerate(indices):
+ player = batch_players[j]
+ action = int(actions_np[j])
+ obs_next, rewards, done, info = envs[i].step(action)
+
+ trajectories[i][player].append(Transition(
+ obs=batch_obs[j],
+ action=action,
+ log_prob=float(log_probs_np[j]),
+ value=float(values_np[j]),
+ reward=rewards[player],
+ done=done,
+ legal_mask=batch_masks[j],
+ ))
+
+ if done:
+ for p in range(envs[i].num_players):
+ if p != player and trajectories[i][p]:
+ trajectories[i][p][-1].reward = rewards[p]
+ trajectories[i][p][-1].done = True
+ active.discard(i)
+ else:
+ obs_list[i] = obs_next
- return trajectories
+ return envs, trajectories
# ---------------------------------------------------------------------------
-# PPO Update
+# PPO Utilities
# ---------------------------------------------------------------------------
def compute_gae(transitions: list[Transition], gamma=0.99, lam=0.95):
"""Compute GAE returns and advantages."""
@@ -184,75 +225,6 @@ def compute_gae(transitions: list[Transition], gamma=0.99, lam=0.95):
return returns.tolist(), advantages.tolist()
-def ppo_update(model: PolicyValueNet, optimizer: torch.optim.Optimizer,
- all_transitions: list[Transition], device="cpu",
- epochs=4, batch_size=256, clip_eps=0.2, vf_coef=0.5, ent_coef=0.01):
- """PPO clipped surrogate update."""
- if not all_transitions:
- return {}
-
- # Prepare tensors
- obs_arr = np.array([t.obs for t in all_transitions])
- actions_arr = np.array([t.action for t in all_transitions])
- old_log_probs_arr = np.array([t.log_prob for t in all_transitions])
- masks_arr = np.array([t.legal_mask for t in all_transitions])
-
- # Compute GAE (treat all transitions as one sequence — not ideal, but we
- # already computed per-game, so we just concatenate pre-computed values)
- returns_arr = np.array([t.reward for t in all_transitions]) # placeholder
- advantages_arr = np.array([t.reward for t in all_transitions]) # placeholder
-
- obs_t = torch.tensor(obs_arr, dtype=torch.float32, device=device)
- actions_t = torch.tensor(actions_arr, dtype=torch.long, device=device)
- old_log_probs_t = torch.tensor(old_log_probs_arr, dtype=torch.float32, device=device)
- masks_t = torch.tensor(masks_arr, dtype=torch.float32, device=device)
- returns_t = torch.tensor(returns_arr, dtype=torch.float32, device=device)
- advantages_t = torch.tensor(advantages_arr, dtype=torch.float32, device=device)
-
- # Normalize advantages
- if len(advantages_t) > 1:
- advantages_t = (advantages_t - advantages_t.mean()) / (advantages_t.std() + 1e-8)
-
- total_loss_sum = 0
- n_updates = 0
-
- for _ in range(epochs):
- indices = np.arange(len(all_transitions))
- np.random.shuffle(indices)
-
- for start in range(0, len(indices), batch_size):
- end = min(start + batch_size, len(indices))
- idx = indices[start:end]
-
- b_obs = obs_t[idx]
- b_actions = actions_t[idx]
- b_old_lp = old_log_probs_t[idx]
- b_masks = masks_t[idx]
- b_returns = returns_t[idx]
- b_advantages = advantages_t[idx]
-
- new_log_probs, values, entropy = model.evaluate(b_obs, b_masks, b_actions)
-
- ratio = torch.exp(new_log_probs - b_old_lp)
- surr1 = ratio * b_advantages
- surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * b_advantages
- policy_loss = -torch.min(surr1, surr2).mean()
-
- value_loss = F.mse_loss(values, b_returns)
-
- loss = policy_loss + vf_coef * value_loss - ent_coef * entropy.mean()
-
- optimizer.zero_grad()
- loss.backward()
- nn.utils.clip_grad_norm_(model.parameters(), 0.5)
- optimizer.step()
-
- total_loss_sum += loss.item()
- n_updates += 1
-
- return {"loss": total_loss_sum / max(n_updates, 1)}
-
-
# ---------------------------------------------------------------------------
# Greedy Warmup (Behavioral Cloning)
# ---------------------------------------------------------------------------
@@ -311,7 +283,9 @@ def train(args):
train_device = "cuda" if torch.cuda.is_available() else "cpu"
collect_device = "cpu" # env simulation always on CPU
print(f"Train device: {train_device}, Collect device: {collect_device}")
+ collect_batch = args.collect_batch if args.collect_batch is not None else args.update_every
print(f"Training for {args.num_players} players, {args.episodes} episodes")
+ print(f"Batch collection: {collect_batch} games per batch")
# Model lives on CPU for game collection; moves to GPU for PPO updates
model = PolicyValueNet().to(collect_device)
@@ -333,30 +307,44 @@ def train(args):
# Stats
win_counts = defaultdict(int)
- game_lengths = []
- batch_transitions = []
+ all_game_lengths = []
recent_loss = 0.0
recent_loss_count = 0
- pbar = tqdm(range(1, args.episodes + 1), desc="Training", unit="ep")
- for ep in pbar:
- env = BlazingEightsEnv(num_players=args.num_players)
- trajectories = collect_game(env, model, collect_device)
-
- # Record stats
- if env.done:
- win_counts[env.winner] += 1
- game_lengths.append(sum(len(v) for v in trajectories.values()))
-
- # Compute GAE per player and collect
- for player, trans_list in trajectories.items():
- returns, advantages = compute_gae(trans_list, gamma=args.gamma, lam=args.lam)
- for i, t in enumerate(trans_list):
- t.reward = returns[i] if i < len(returns) else t.reward
- batch_transitions.extend(trans_list)
-
- # Update every `update_every` episodes
- if ep % args.update_every == 0 and batch_transitions:
+ ep = 0
+ next_log = args.log_every
+ next_eval = args.eval_every
+ next_save = args.save_every
+ pbar = tqdm(total=args.episodes, desc="Training", unit="ep")
+
+ while ep < args.episodes:
+ games_this_batch = min(collect_batch, args.episodes - ep)
+
+ # Collect games in parallel with batched inference
+ envs, batch_trajectories = collect_games_batch(
+ games_this_batch, args.num_players, model, collect_device
+ )
+
+ # Process trajectories: compute GAE and collect all transitions
+ batch_transitions = []
+ for i in range(games_this_batch):
+ env = envs[i]
+ traj = batch_trajectories[i]
+ if env.done:
+ win_counts[env.winner] += 1
+ all_game_lengths.append(sum(len(v) for v in traj.values()))
+
+ for player, trans_list in traj.items():
+ returns, advantages = compute_gae(trans_list, gamma=args.gamma, lam=args.lam)
+ for k, t in enumerate(trans_list):
+ t.reward = returns[k] if k < len(returns) else t.reward
+ batch_transitions.extend(trans_list)
+
+ ep += games_this_batch
+ pbar.update(games_this_batch)
+
+ # PPO update
+ if batch_transitions:
returns_for_update = np.array([t.reward for t in batch_transitions])
values_for_update = np.array([t.value for t in batch_transitions])
advs = returns_for_update - values_for_update
@@ -380,15 +368,15 @@ def train(args):
if len(advs_t) > 1:
advs_t = (advs_t - advs_t.mean()) / (advs_t.std() + 1e-8)
- # PPO update
+ # PPO clipped surrogate update
batch_loss = 0.0
n_updates = 0
for _ in range(args.ppo_epochs):
- indices = np.arange(len(batch_transitions))
- np.random.shuffle(indices)
- for start in range(0, len(indices), args.batch_size):
- end = min(start + args.batch_size, len(indices))
- idx = indices[start:end]
+ perm = np.arange(len(batch_transitions))
+ np.random.shuffle(perm)
+ for start in range(0, len(perm), args.batch_size):
+ end = min(start + args.batch_size, len(perm))
+ idx = perm[start:end]
b_obs = obs_t[idx]
b_actions = actions_t[idx]
@@ -420,11 +408,9 @@ def train(args):
if train_device != collect_device:
model.to(collect_device)
- batch_transitions = []
-
# Logging
- if ep % args.log_every == 0:
- avg_len = np.mean(game_lengths[-args.log_every:]) if game_lengths else 0
+ if ep >= next_log:
+ avg_len = np.mean(all_game_lengths[-args.log_every:]) if all_game_lengths else 0
avg_loss = recent_loss / max(recent_loss_count, 1)
total_games = sum(win_counts.values())
wr0 = win_counts[0] / max(total_games, 1)
@@ -432,19 +418,21 @@ def train(args):
wr0=f"{wr0:.1%}", games=total_games)
recent_loss = 0.0
recent_loss_count = 0
+ next_log += args.log_every
- # Evaluate vs greedy + write log every eval_every episodes
- if ep % args.eval_every == 0:
- avg_len = np.mean(game_lengths[-args.eval_every:]) if game_lengths else 0
+ # Evaluate vs greedy + write log
+ if ep >= next_eval:
+ avg_len = np.mean(all_game_lengths[-args.eval_every:]) if all_game_lengths else 0
avg_loss_log = recent_loss / max(recent_loss_count, 1) if recent_loss_count > 0 else 0
- vs_wr = evaluate_vs_random(model, num_players=args.num_players,
- num_games=500, device=collect_device)
+ vs_wr = evaluate_vs_greedy_batch(model, num_players=args.num_players,
+ num_games=500, device=collect_device)
with open(log_path, "a") as f:
f.write(f"{ep},{avg_len:.1f},{avg_loss_log:.4f},{vs_wr:.4f}\n")
tqdm.write(f" [Eval ep{ep}] avg_len={avg_len:.1f} vs_greedy={vs_wr:.1%}")
+ next_eval += args.eval_every
# Save checkpoint
- if ep % args.save_every == 0:
+ if ep >= next_save:
path = f"{args.save_path}_ep{ep}.pt"
torch.save({
"model": model.state_dict(),
@@ -453,6 +441,7 @@ def train(args):
"num_players": args.num_players,
}, path)
tqdm.write(f" Saved checkpoint: {path}")
+ next_save += args.save_every
# Final save
torch.save({
@@ -468,29 +457,62 @@ def train(args):
# ---------------------------------------------------------------------------
-# Evaluation: play against random
+# Evaluation: play against greedy (batched)
# ---------------------------------------------------------------------------
-def evaluate_vs_random(model: PolicyValueNet, num_players=2, num_games=1000, device="cpu"):
- """Player 0 = model, others = random. Returns player 0 win rate."""
- wins = 0
- for _ in range(num_games):
- env = BlazingEightsEnv(num_players=num_players)
- obs = env.reset()
- for _ in range(500):
- player = env.current_player
- legal = env.legal_actions()
+def evaluate_vs_greedy_batch(model: PolicyValueNet, num_players=2, num_games=500, device="cpu"):
+ """Batched evaluation: player 0 = model, others = greedy random."""
+ envs = [BlazingEightsEnv(num_players=num_players) for _ in range(num_games)]
+ obs_list = [env.reset() for env in envs]
+ active = set(range(num_games))
+
+ for _ in range(500):
+ if not active:
+ break
+
+ # Separate model-controlled (player 0) and greedy-controlled turns
+ model_idx = []
+ model_obs = []
+ model_masks = []
+ greedy_pairs = []
+
+ for i in sorted(active):
+ legal = envs[i].legal_actions()
if not legal:
- break
- if player == 0:
- action, _, _ = model.get_action(obs, legal, device)
+ active.discard(i)
+ continue
+ if envs[i].current_player == 0:
+ mask = np.zeros(TOTAL_ACTIONS, dtype=np.float32)
+ for a in legal:
+ mask[a] = 1.0
+ model_idx.append(i)
+ model_obs.append(obs_list[i])
+ model_masks.append(mask)
else:
- action = greedy_random_action(legal)
- obs, rewards, done, info = env.step(action)
+ greedy_pairs.append((i, greedy_random_action(legal)))
+
+ # Batched model inference for player 0 turns
+ if model_obs:
+ obs_t = torch.tensor(np.array(model_obs), dtype=torch.float32, device=device)
+ mask_t = torch.tensor(np.array(model_masks), dtype=torch.float32, device=device)
+ with torch.inference_mode():
+ logits, _ = model(obs_t, mask_t)
+ actions = Categorical(F.softmax(logits, dim=-1)).sample().cpu().numpy()
+ for j, i in enumerate(model_idx):
+ obs_next, _, done, _ = envs[i].step(int(actions[j]))
+ if done:
+ active.discard(i)
+ else:
+ obs_list[i] = obs_next
+
+ # Greedy actions for other players
+ for i, action in greedy_pairs:
+ obs_next, _, done, _ = envs[i].step(action)
if done:
- if env.winner == 0:
- wins += 1
- break
- return wins / num_games
+ active.discard(i)
+ else:
+ obs_list[i] = obs_next
+
+ return sum(1 for e in envs if e.done and e.winner == 0) / num_games
if __name__ == "__main__":
@@ -509,15 +531,17 @@ if __name__ == "__main__":
parser.add_argument("--eval_every", type=int, default=10000)
parser.add_argument("--save_every", type=int, default=10000)
parser.add_argument("--save_path", type=str, default="blazing_ppo")
+ parser.add_argument("--collect_batch", type=int, default=None,
+ help="Parallel game collection batch size (default: same as update_every)")
parser.add_argument("--greedy_warmup", type=int, default=2000,
help="Number of greedy games for behavioral cloning warmup (0 to skip)")
args = parser.parse_args()
model = train(args)
- # Eval vs random
- print("\nEvaluating vs random opponents...")
+ # Eval vs greedy
+ print("\nEvaluating vs greedy opponents...")
for n in [2, 3, 4, 5]:
- if n <= args.num_players + 1: # only eval for trained player count
- wr = evaluate_vs_random(model, num_players=n, num_games=1000)
+ if n <= args.num_players + 1:
+ wr = evaluate_vs_greedy_batch(model, num_players=n, num_games=1000)
print(f" {n} players: win rate = {wr:.1%} (random baseline: {1/n:.1%})")