diff options
| -rw-r--r-- | blazing_env.py | 41 | ||||
| -rw-r--r-- | train.py | 58 |
2 files changed, 85 insertions, 14 deletions
diff --git a/blazing_env.py b/blazing_env.py index c440293..3f4b407 100644 --- a/blazing_env.py +++ b/blazing_env.py @@ -11,7 +11,7 @@ Special cards: K → All OTHER players draw 1 card from the deck Q → Reverse direction (no effect in 2-player games) J → Skip next player's turn - Swap → Swap entire hand with next player (playable anytime on your turn, no match needed) + Swap → Swap entire hand with next player (playable anytime; next card must match the card before the Swap) Rules: - Match top card by suit OR rank (unless playing 8 or Swap) @@ -182,18 +182,19 @@ class BlazingEightsEnv: for c in self.hands[player]: obs[c] = 1.0 - # Top card info + # Top card info (SWAP inherits previous card) top = self.discard[-1] + eff = self._effective_top() if is_swap(top) else top if self.active_suit is not None: suit = self.active_suit - elif not is_swap(top): - suit = card_suit(top) + elif not is_swap(eff): + suit = card_suit(eff) else: suit = 0 obs[56 + suit] = 1.0 - if not is_swap(top) and self.active_suit is None: - obs[60 + card_rank(top)] = 1.0 + if not is_swap(eff) and self.active_suit is None: + obs[60 + card_rank(eff)] = 1.0 # Direction obs[73] = 0.0 if self.direction == 1 else 1.0 @@ -279,6 +280,13 @@ class BlazingEightsEnv: actions.append(PASS_ACTION) return actions + def _effective_top(self) -> int: + """Find the last non-SWAP card in discard for matching purposes.""" + for c in reversed(self.discard): + if not is_swap(c): + return c + return self.discard[-1] # fallback (all swaps, shouldn't happen) + def _can_play(self, card: int, top: int) -> bool: # Swap cards: always playable if is_swap(card): @@ -289,10 +297,11 @@ class BlazingEightsEnv: # If active_suit is set (after a wild 8), must match that suit if self.active_suit is not None: return card_suit(card) == self.active_suit - # Normal: match suit or rank + # SWAP on top: inherit previous non-SWAP card's suit/rank if is_swap(top): - # Top is swap — shouldn't happen in normal flow, but match anything - return True + top = self._effective_top() + if is_swap(top): + return True # all swaps, match anything return card_suit(card) == card_suit(top) or card_rank(card) == card_rank(top) # ------------------------------------------------------------------ @@ -329,7 +338,6 @@ class BlazingEightsEnv: # --- Play phase --- if action == DRAW_ACTION: - self.consecutive_passes = 0 drawn = self._draw_card(player) self._record_event(player, 2) # drew a card # Card is added to hand; player keeps their turn to decide @@ -338,8 +346,12 @@ class BlazingEightsEnv: obs = self._get_obs(player) return obs, rewards, False, info elif action == PASS_ACTION: - # Skip turn (no cards anywhere) - self.consecutive_passes += 1 + if self.has_drawn_this_turn: + # Drew but chose not to play — game state changed, not stalemate + self.consecutive_passes = 0 + else: + # Hard pass: can't draw and can't play — real stalemate signal + self.consecutive_passes += 1 self._record_event(player, 3) # passed if self.consecutive_passes >= self.num_players: # Stalemate: all players passed in a row @@ -368,8 +380,9 @@ class BlazingEightsEnv: hand.remove(card) self.discard.append(card) - # Clear active suit (unless new card is 8) - self.active_suit = None + # Clear active suit (unless new card is 8 or SWAP — SWAP inherits) + if not is_swap(card): + self.active_suit = None # Clear swap knowledge for this player (cards change over time) # We keep it until they play; after playing, knowledge decays @@ -254,6 +254,57 @@ def ppo_update(model: PolicyValueNet, optimizer: torch.optim.Optimizer, # --------------------------------------------------------------------------- +# Greedy Warmup (Behavioral Cloning) +# --------------------------------------------------------------------------- +def greedy_warmup(model: PolicyValueNet, optimizer: torch.optim.Optimizer, + num_players: int, num_games: int = 2000, epochs: int = 5, + batch_size: int = 256, device: str = "cpu"): + """Pre-train the model to imitate greedy play (play if possible, else draw).""" + print(f"Greedy warmup: {num_games} games, {epochs} epochs...") + obs_list, action_list, mask_list = [], [], [] + + for _ in tqdm(range(num_games), desc="Collecting greedy data", unit="game"): + env = BlazingEightsEnv(num_players=num_players) + obs = env.reset() + for _ in range(500): + legal = env.legal_actions() + if not legal: + break + action = greedy_random_action(legal) + legal_mask = np.zeros(TOTAL_ACTIONS, dtype=np.float32) + for a in legal: + legal_mask[a] = 1.0 + obs_list.append(obs.copy()) + action_list.append(action) + mask_list.append(legal_mask) + obs, _, done, _ = env.step(action) + if done: + break + + obs_t = torch.tensor(np.array(obs_list), dtype=torch.float32, device=device) + act_t = torch.tensor(np.array(action_list), dtype=torch.long, device=device) + mask_t = torch.tensor(np.array(mask_list), dtype=torch.float32, device=device) + print(f" Collected {len(obs_list)} transitions") + + for epoch in range(epochs): + indices = np.arange(len(obs_list)) + np.random.shuffle(indices) + total_loss = 0 + n_batches = 0 + for start in range(0, len(indices), batch_size): + idx = indices[start:start + batch_size] + logits, _ = model(obs_t[idx], mask_t[idx]) + loss = F.cross_entropy(logits, act_t[idx]) + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optimizer.step() + total_loss += loss.item() + n_batches += 1 + print(f" Epoch {epoch+1}/{epochs}: loss={total_loss/n_batches:.4f}") + + +# --------------------------------------------------------------------------- # Training Loop # --------------------------------------------------------------------------- def train(args): @@ -264,6 +315,11 @@ def train(args): model = PolicyValueNet().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + # Greedy warmup: imitate greedy play before self-play + if args.greedy_warmup > 0: + greedy_warmup(model, optimizer, args.num_players, + num_games=args.greedy_warmup, device=device) + # Stats win_counts = defaultdict(int) game_lengths = [] @@ -421,6 +477,8 @@ if __name__ == "__main__": parser.add_argument("--log_every", type=int, default=1000) parser.add_argument("--save_every", type=int, default=10000) parser.add_argument("--save_path", type=str, default="blazing_ppo") + parser.add_argument("--greedy_warmup", type=int, default=2000, + help="Number of greedy games for behavioral cloning warmup (0 to skip)") args = parser.parse_args() model = train(args) |
