summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhaoyuren <13851610112@163.com>2026-02-22 01:48:03 -0600
committerhaoyuren <13851610112@163.com>2026-02-22 01:48:03 -0600
commit72cf72d704ca1a3bf4e2a5e04dcbbad99dc0f98e (patch)
tree55cb96c17a0a71bc3c7155d65fd19cc185bf495c
Initial commit: Blazing Eights RL agent
- Game environment with draw-then-decide rule (no auto-play on draw) - PPO self-play training script - Interactive human vs AI game (versus.py) - Real-time play assistant (play.py) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
-rw-r--r--.gitignore4
-rw-r--r--README.md67
-rw-r--r--blazing_env.py506
-rw-r--r--play.py270
-rw-r--r--train.py425
-rw-r--r--versus.py341
6 files changed, 1613 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..8947638
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+*.pt
+__pycache__/
+*.pyc
+.DS_Store
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..711f793
--- /dev/null
+++ b/README.md
@@ -0,0 +1,67 @@
+# Blazing Eights — RL Agent
+
+Self-play PPO agent for the Blazing Eights card game.
+
+## Setup
+
+```bash
+pip install torch numpy
+```
+
+## Files
+
+- `blazing_env.py` — Game environment (2-5 players)
+- `train.py` — PPO self-play training
+- `play.py` — Real-time play assistant (input game state, get best move)
+
+## Training
+
+```bash
+# Train a 2-player agent (~10min on CPU for 100k episodes)
+python train.py --num_players 2 --episodes 100000
+
+# Train for 3 players (may need more episodes)
+python train.py --num_players 3 --episodes 200000
+
+# Custom hyperparams
+python train.py --num_players 4 --episodes 300000 --lr 1e-4 --ent_coef 0.02
+```
+
+Training saves checkpoints every 10k episodes and a final model.
+
+## Real-time Play Assistant
+
+After training, use the assistant during a real game:
+
+```bash
+python play.py --model blazing_ppo_final.pt --num_players 3
+```
+
+It will prompt you for:
+1. Your hand (e.g., `8h,Ks,3d,SWAP`)
+2. Top discard card (e.g., `6d`)
+3. Active suit if an 8 was played
+4. Direction (cw/ccw)
+5. Other players' hand sizes
+6. Approximate deck size
+
+Then shows ranked action recommendations with probabilities.
+
+## Game Rules
+
+- **56 cards**: standard 52 + 4 Swap cards
+- **Match**: suit or rank of top card
+- **8**: Wild — choose a suit for next player
+- **K**: All other players draw 1
+- **Q**: Reverse direction (no effect in 2-player)
+- **J**: Skip next player
+- **Swap**: Swap your entire hand with next player (playable anytime, no match needed)
+- **Can't play**: Draw 1, play it if legal
+- **Win**: First to empty hand
+
+## Tips for Better Training
+
+1. **Train per player count** — the optimal policy differs significantly for 2 vs 5 players.
+2. **Increase episodes for more players** — larger games have more variance, need more samples.
+3. **Opponent modeling** — after self-play, you can fine-tune against specific opponent behaviors by replacing some players with heuristic bots that mimic your friends' tendencies.
+4. **Curriculum** — start training with 2 players, then use the trained model to initialize training for 3+ players.
diff --git a/blazing_env.py b/blazing_env.py
new file mode 100644
index 0000000..c3d97ae
--- /dev/null
+++ b/blazing_env.py
@@ -0,0 +1,506 @@
+"""
+Blazing Eights — multi-agent card game environment.
+
+Cards:
+ 52 standard cards (4 suits × 13 ranks: A,2..10,J,Q,K)
+ + 4 Swap cards (no suit, index 52-55)
+ Total: 56 cards
+
+Special cards:
+ 8 → Wild: player chooses a suit, next player must match that suit
+ K → All OTHER players draw 1 card from the deck
+ Q → Reverse direction (no effect in 2-player games)
+ J → Skip next player's turn
+ Swap → Swap entire hand with next player (playable anytime on your turn, no match needed)
+
+Rules:
+ - Match top card by suit OR rank (unless playing 8 or Swap)
+ - Can't play → draw 1; if drawn card is playable, play it immediately
+ - First player to empty hand wins
+ - Initial hand: 5 cards each
+"""
+
+import numpy as np
+from typing import Optional
+
+# ---------------------------------------------------------------------------
+# Card encoding
+# ---------------------------------------------------------------------------
+# Standard cards: index 0-51
+# suit = index // 13 (0=♠, 1=♥, 2=♦, 3=♣)
+# rank = index % 13 (0=A, 1=2, 2=3, ..., 9=10, 10=J, 11=Q, 12=K)
+# Swap cards: index 52, 53, 54, 55
+
+NUM_STANDARD = 52
+NUM_SWAP = 4
+NUM_CARDS = NUM_STANDARD + NUM_SWAP
+
+RANK_A, RANK_J, RANK_Q, RANK_K = 0, 10, 11, 12
+RANK_8 = 7 # rank index for 8 (0=A,1=2,...,7=8)
+
+def card_suit(c: int) -> int:
+ """Return suit of a standard card (0-3). Swap cards return -1."""
+ return c // 13 if c < NUM_STANDARD else -1
+
+def card_rank(c: int) -> int:
+ """Return rank of a standard card (0-12). Swap cards return -1."""
+ return c % 13 if c < NUM_STANDARD else -1
+
+def is_swap(c: int) -> bool:
+ return c >= NUM_STANDARD
+
+def card_name(c: int) -> str:
+ if is_swap(c):
+ return f"SWAP-{c - NUM_STANDARD}"
+ suits = "♠♥♦♣"
+ ranks = ["A"] + [str(i) for i in range(2, 11)] + ["J", "Q", "K"]
+ return f"{ranks[card_rank(c)]}{suits[card_suit(c)]}"
+
+
+# ---------------------------------------------------------------------------
+# Action space
+# ---------------------------------------------------------------------------
+# Actions 0-55: play card with that index
+# Actions 56-59: choose suit after playing an 8 (♠,♥,♦,♣)
+# Action 60: draw a card
+NUM_PLAY_ACTIONS = NUM_CARDS # 0..55
+NUM_SUIT_ACTIONS = 4 # 56..59
+DRAW_ACTION = 60
+PASS_ACTION = 61 # skip turn (when deck empty & no playable card)
+TOTAL_ACTIONS = 62
+
+
+class BlazingEightsEnv:
+ """
+ Multi-agent environment for Blazing Eights.
+
+ Designed for self-play RL. Call step() with the current player's action.
+ The env tracks whose turn it is.
+ """
+
+ def __init__(self, num_players: int = 2, seed: Optional[int] = None):
+ assert 2 <= num_players <= 5
+ self.num_players = num_players
+ self.rng = np.random.default_rng(seed)
+ self.reset()
+
+ # ------------------------------------------------------------------
+ # Reset
+ # ------------------------------------------------------------------
+ def reset(self, seed: Optional[int] = None):
+ if seed is not None:
+ self.rng = np.random.default_rng(seed)
+
+ # Build & shuffle deck
+ deck = list(range(NUM_CARDS))
+ self.rng.shuffle(deck)
+
+ # Deal 5 cards each
+ self.hands: list[list[int]] = []
+ idx = 0
+ for _ in range(self.num_players):
+ self.hands.append(sorted(deck[idx:idx + 5]))
+ idx += 5
+
+ # Find a non-special starting card for the discard pile
+ # (avoid starting with 8, J, Q, K, or Swap)
+ self.discard: list[int] = []
+ start_card = None
+ remaining = deck[idx:]
+ for i, c in enumerate(remaining):
+ if not is_swap(c) and card_rank(c) not in (RANK_8, RANK_J, RANK_Q, RANK_K):
+ start_card = c
+ remaining.pop(i)
+ break
+ if start_card is None:
+ # Extremely unlikely; just use first card
+ start_card = remaining.pop(0)
+ self.discard.append(start_card)
+
+ self.deck: list[int] = remaining
+ self.current_player = int(self.rng.integers(0, self.num_players))
+ self.direction = 1 # 1=clockwise, -1=counter-clockwise
+ self.done = False
+ self.winner = -1
+
+ # State for wild-8: the chosen suit (None if top card is not a wild)
+ self.active_suit: Optional[int] = None
+
+ # Phase: "play" or "choose_suit"
+ self.phase = "play"
+ # Temp storage for the card that triggered choose_suit
+ self._pending_8_player: Optional[int] = None
+
+ # For K resolution
+ self._pending_k = False
+
+ # Track consecutive passes for stalemate detection
+ self.consecutive_passes = 0
+
+ # Track whether current player has already drawn this turn
+ self.has_drawn_this_turn = False
+
+ # Action history: records recent events visible to all players
+ # Each entry: (player, event_type)
+ # event_type: 0=played_card, 1=(unused), 2=drew_card, 3=passed
+ self.action_history: list[tuple[int, int]] = []
+ self.max_history = 20 # keep last 20 events
+
+ # Track swap visibility: after a swap, the swapper sees the received cards
+ # This is informational; the obs encodes it
+ self.swap_known_cards: dict[int, list[int]] = {} # player -> known opponent cards
+
+ return self._get_obs(self.current_player)
+
+ # ------------------------------------------------------------------
+ # Observation
+ # ------------------------------------------------------------------
+ def _get_obs(self, player: int) -> np.ndarray:
+ """
+ Observation vector for `player`:
+ [0:56] one-hot of cards in hand
+ [56:60] top card suit one-hot (or active_suit if wild)
+ [60:73] top card rank one-hot
+ [73] direction (0=cw, 1=ccw)
+ [74:74+N-1] other players' hand sizes (normalized /20)
+ [74+N-1] deck size (normalized /56)
+ [75+N-1] phase: 0=play, 1=choose_suit
+ [76+N-1 : 132+N-1] known cards of next player (from swap), 56 one-hot
+ [132+N-1 : 132+N-1+(N-1)*5] per other player draw info:
+ 4 floats: last event one-hot (played/drew_played/drew_skipped/passed)
+ 1 float: consecutive draw-and-skip streak (/10)
+ Padded to fixed 180.
+ """
+ obs = np.zeros(180, dtype=np.float32)
+
+ # Hand
+ for c in self.hands[player]:
+ obs[c] = 1.0
+
+ # Top card info
+ top = self.discard[-1]
+ if self.active_suit is not None:
+ suit = self.active_suit
+ elif not is_swap(top):
+ suit = card_suit(top)
+ else:
+ suit = 0
+ obs[56 + suit] = 1.0
+
+ if not is_swap(top) and self.active_suit is None:
+ obs[60 + card_rank(top)] = 1.0
+
+ # Direction
+ obs[73] = 0.0 if self.direction == 1 else 1.0
+
+ # Other players' hand sizes
+ for i in range(1, self.num_players):
+ other = (player + i) % self.num_players
+ obs[74 + i - 1] = len(self.hands[other]) / 20.0
+
+ # Deck size
+ obs[74 + self.num_players - 1] = len(self.deck) / 56.0
+
+ # Phase
+ obs[75 + self.num_players - 1] = 1.0 if self.phase == "choose_suit" else 0.0
+
+ # Known cards of next player (from swap)
+ offset = 76 + self.num_players - 1
+ if player in self.swap_known_cards:
+ for c in self.swap_known_cards[player]:
+ obs[offset + c] = 1.0
+
+ # Per other player: last event type + consecutive draw-skip streak
+ # This encodes the "timer tells" — draw then skip means drawn card unplayable
+ draw_info_offset = 132 + self.num_players - 1
+ for i in range(1, self.num_players):
+ other = (player + i) % self.num_players
+ base = draw_info_offset + (i - 1) * 5
+
+ # Scan history backwards for this player's events
+ last_event = None
+ consec_draw_skip = 0
+ for p, evt in reversed(self.action_history):
+ if p == other:
+ if last_event is None:
+ last_event = evt
+ if evt == 2: # drew_and_skipped
+ consec_draw_skip += 1
+ else:
+ break
+
+ if last_event is not None:
+ obs[base + last_event] = 1.0
+ obs[base + 4] = consec_draw_skip / 10.0
+
+ return obs
+
+ @staticmethod
+ def obs_size() -> int:
+ return 180
+
+ # ------------------------------------------------------------------
+ # Legal actions
+ # ------------------------------------------------------------------
+ def legal_actions(self, player: Optional[int] = None) -> list[int]:
+ if self.done:
+ return []
+ if player is None:
+ player = self.current_player
+
+ if self.phase == "choose_suit":
+ if player == self._pending_8_player:
+ return [56, 57, 58, 59]
+ else:
+ return []
+
+ actions = []
+ hand = self.hands[player]
+ top = self.discard[-1]
+
+ for c in hand:
+ if self._can_play(c, top):
+ actions.append(c)
+
+ if self.has_drawn_this_turn:
+ # Already drew this turn: can play a card or pass (end turn)
+ actions.append(PASS_ACTION)
+ else:
+ # Can always choose to draw instead of playing
+ if self.deck or len(self.discard) > 1:
+ actions.append(DRAW_ACTION)
+ if not actions:
+ # No playable cards and no deck: must pass
+ actions.append(PASS_ACTION)
+ return actions
+
+ def _can_play(self, card: int, top: int) -> bool:
+ # Swap cards: always playable
+ if is_swap(card):
+ return True
+ # 8s: always playable (wild)
+ if card_rank(card) == RANK_8:
+ return True
+ # If active_suit is set (after a wild 8), must match that suit
+ if self.active_suit is not None:
+ return card_suit(card) == self.active_suit
+ # Normal: match suit or rank
+ if is_swap(top):
+ # Top is swap — shouldn't happen in normal flow, but match anything
+ return True
+ return card_suit(card) == card_suit(top) or card_rank(card) == card_rank(top)
+
+ # ------------------------------------------------------------------
+ # Step
+ # ------------------------------------------------------------------
+ def step(self, action: int):
+ """
+ Returns (obs_next_player, reward_dict, done, info)
+ reward_dict: {player_id: reward}
+ """
+ assert not self.done, "Game is over"
+
+ player = self.current_player
+ legal = self.legal_actions(player)
+ assert action in legal, f"Illegal action {action}. Legal: {legal}"
+
+ info = {}
+ rewards = {i: 0.0 for i in range(self.num_players)}
+
+ # --- Choose suit phase ---
+ if self.phase == "choose_suit":
+ self.active_suit = action - 56
+ self.phase = "play"
+
+ # Now resolve K if pending
+ if self._pending_k:
+ self._resolve_k(player)
+ self._pending_k = False
+
+ # Advance to next player
+ self._advance_turn()
+ obs = self._get_obs(self.current_player)
+ return obs, rewards, False, info
+
+ # --- Play phase ---
+ if action == DRAW_ACTION:
+ self.consecutive_passes = 0
+ drawn = self._draw_card(player)
+ self._record_event(player, 2) # drew a card
+ # Card is added to hand; player keeps their turn to decide
+ # whether to play it (or any other card) or pass
+ self.has_drawn_this_turn = True
+ obs = self._get_obs(player)
+ return obs, rewards, False, info
+ elif action == PASS_ACTION:
+ # Skip turn (no cards anywhere)
+ self.consecutive_passes += 1
+ self._record_event(player, 3) # passed
+ if self.consecutive_passes >= self.num_players:
+ # Stalemate: all players passed in a row
+ self.done = True
+ self.winner = -1 # no winner
+ # Player with fewest cards gets partial reward
+ min_cards = min(len(h) for h in self.hands)
+ for i in range(self.num_players):
+ if len(self.hands[i]) == min_cards:
+ rewards[i] = 0.5
+ else:
+ rewards[i] = -1.0
+ obs = self._get_obs(player)
+ return obs, rewards, True, {"stalemate": True}
+ self._advance_turn()
+ obs = self._get_obs(self.current_player)
+ return obs, rewards, False, info
+ else:
+ return self._play_card(player, action, rewards, info)
+
+ def _play_card(self, player: int, card: int, rewards: dict, info: dict):
+ self.consecutive_passes = 0
+ self._record_event(player, 0) # played_card
+ hand = self.hands[player]
+ assert card in hand, f"Card {card} not in hand of player {player}"
+ hand.remove(card)
+ self.discard.append(card)
+
+ # Clear active suit (unless new card is 8)
+ self.active_suit = None
+
+ # Clear swap knowledge for this player (cards change over time)
+ # We keep it until they play; after playing, knowledge decays
+ # Actually let's just keep swap_known_cards until overwritten
+
+ # Check win
+ if len(hand) == 0:
+ self.done = True
+ self.winner = player
+ rewards[player] = 1.0
+ for i in range(self.num_players):
+ if i != player:
+ rewards[i] = -1.0
+ obs = self._get_obs(player)
+ return obs, rewards, True, {"winner": player}
+
+ # Handle special cards
+ if is_swap(card):
+ self._resolve_swap(player)
+ self._advance_turn()
+ elif card_rank(card) == RANK_8:
+ # Need to choose suit
+ self.phase = "choose_suit"
+ self._pending_8_player = player
+ # Check if K also (8 is rank 7, K is rank 12 — not the same, so no overlap)
+ # 8 is not K, so no K effect here
+ obs = self._get_obs(player)
+ return obs, rewards, False, info
+ elif card_rank(card) == RANK_K:
+ # All other players draw 1
+ self._resolve_k(player)
+ self._advance_turn()
+ elif card_rank(card) == RANK_Q:
+ # Reverse direction (no effect in 2-player)
+ if self.num_players > 2:
+ self.direction *= -1
+ self._advance_turn()
+ elif card_rank(card) == RANK_J:
+ # Skip next player
+ self._advance_turn() # skip
+ self._advance_turn() # to the one after
+ else:
+ self._advance_turn()
+
+ obs = self._get_obs(self.current_player)
+ return obs, rewards, False, info
+
+ # ------------------------------------------------------------------
+ # Helpers
+ # ------------------------------------------------------------------
+ def _advance_turn(self):
+ self.has_drawn_this_turn = False
+ self.current_player = (self.current_player + self.direction) % self.num_players
+
+ def _draw_card(self, player: int) -> Optional[int]:
+ if not self.deck:
+ self._reshuffle_discard()
+ if not self.deck:
+ return None # No cards left anywhere
+ card = self.deck.pop()
+ self.hands[player].append(card)
+ return card
+
+ def _reshuffle_discard(self):
+ """Reshuffle all but the top card of the discard pile into the deck."""
+ if len(self.discard) <= 1:
+ return
+ top = self.discard[-1]
+ self.deck = self.discard[:-1]
+ self.discard = [top]
+ self.rng.shuffle(self.deck)
+
+ def _resolve_k(self, player: int):
+ """All players except `player` draw 1 card."""
+ for i in range(self.num_players):
+ if i != player:
+ self._draw_card(i)
+
+ def _record_event(self, player: int, event_type: int):
+ """Record a visible game event."""
+ self.action_history.append((player, event_type))
+ if len(self.action_history) > self.max_history:
+ self.action_history.pop(0)
+
+ def _resolve_swap(self, player: int):
+ """Swap hands with the next player."""
+ next_player = (player + self.direction) % self.num_players
+ self.hands[player], self.hands[next_player] = self.hands[next_player], self.hands[player]
+ # After swap, `player` now has what `next_player` had → player knows these cards
+ # And `next_player` now has what `player` had → next_player knows these cards
+ self.swap_known_cards[player] = list(self.hands[player])
+ self.swap_known_cards[next_player] = list(self.hands[next_player])
+
+ # ------------------------------------------------------------------
+ # Utilities
+ # ------------------------------------------------------------------
+ def render(self):
+ print(f"--- Turn: Player {self.current_player} | Direction: {'→' if self.direction == 1 else '←'} ---")
+ top = self.discard[-1]
+ suit_names = ["♠", "♥", "♦", "♣"]
+ top_str = card_name(top)
+ if self.active_suit is not None:
+ top_str += f" (active suit: {suit_names[self.active_suit]})"
+ print(f"Top card: {top_str}")
+ for i in range(self.num_players):
+ hand_str = ", ".join(card_name(c) for c in sorted(self.hands[i]))
+ marker = " ◀" if i == self.current_player else ""
+ print(f" Player {i}: [{len(self.hands[i])}] {hand_str}{marker}")
+ print(f"Deck: {len(self.deck)} cards")
+ if self.done:
+ print(f"🏆 Player {self.winner} wins!")
+
+ def copy(self):
+ """Return a deep copy of the environment state."""
+ import copy
+ return copy.deepcopy(self)
+
+
+# ---------------------------------------------------------------------------
+# Quick test
+# ---------------------------------------------------------------------------
+if __name__ == "__main__":
+ env = BlazingEightsEnv(num_players=3, seed=42)
+ env.render()
+ print()
+
+ for step_i in range(200):
+ player = env.current_player
+ actions = env.legal_actions()
+ if not actions:
+ break
+ action = env.rng.choice(actions)
+ print(f"Player {player} plays: {card_name(action) if action < NUM_CARDS else ('suit ' + '♠♥♦♣'[action-56] if action < DRAW_ACTION else ('DRAW' if action == DRAW_ACTION else 'PASS'))}")
+ obs, rewards, done, info = env.step(action)
+ if done:
+ env.render()
+ break
+ else:
+ print("Game didn't finish in 200 steps")
diff --git a/play.py b/play.py
new file mode 100644
index 0000000..51d5864
--- /dev/null
+++ b/play.py
@@ -0,0 +1,270 @@
+"""
+Blazing Eights — Real-time Play Assistant.
+
+Load a trained model and get recommended actions during a real game.
+You input the game state, it tells you the best move.
+
+Usage:
+ python play.py --model blazing_ppo_final.pt --num_players 3
+"""
+
+import argparse
+import torch
+import torch.nn.functional as F
+import numpy as np
+from blazing_env import (
+ BlazingEightsEnv, PolicyValueNet, card_name, card_suit, card_rank,
+ is_swap, RANK_8, RANK_J, RANK_Q, RANK_K, NUM_STANDARD, NUM_CARDS,
+ TOTAL_ACTIONS, DRAW_ACTION, PASS_ACTION
+)
+
+# Import network
+import sys
+sys.path.insert(0, ".")
+from train import PolicyValueNet
+
+
+SUIT_NAMES = ["♠ spades", "♥ hearts", "♦ diamonds", "♣ clubs"]
+SUIT_SHORT = ["s", "h", "d", "c"]
+RANK_NAMES = ["A", "2", "3", "4", "5", "6", "7", "8", "9", "10", "J", "Q", "K"]
+
+
+def parse_card(s: str) -> int:
+ """Parse a card string like '8h', 'Ks', 'SWAP', '10d' into card index."""
+ s = s.strip().upper()
+ if s.startswith("SWAP"):
+ # We don't distinguish between swap cards; just return first available
+ return 52 # caller should handle
+ if s.startswith("SW"):
+ return 52
+
+ # Parse rank
+ if s.startswith("10"):
+ rank_str = "10"
+ suit_str = s[2:].lower()
+ else:
+ rank_str = s[0]
+ suit_str = s[1:].lower()
+
+ rank_map = {r: i for i, r in enumerate(RANK_NAMES)}
+ suit_map = {"s": 0, "h": 1, "d": 2, "c": 3,
+ "♠": 0, "♥": 1, "♦": 2, "♣": 3}
+
+ if rank_str not in rank_map or suit_str not in suit_map:
+ raise ValueError(f"Cannot parse card: {s}")
+
+ return suit_map[suit_str] * 13 + rank_map[rank_str]
+
+
+def build_obs_from_input(hand: list[int], top_card: int, active_suit: int | None,
+ direction: int, other_hand_sizes: list[int],
+ deck_size: int, num_players: int,
+ known_opponent_cards: list[int] | None = None,
+ other_last_events: list[int] | None = None,
+ other_draw_streaks: list[int] | None = None) -> np.ndarray:
+ """Build observation vector from manual game state input."""
+ obs = np.zeros(180, dtype=np.float32)
+
+ # Hand
+ for c in hand:
+ obs[c] = 1.0
+
+ # Top card suit
+ if active_suit is not None:
+ suit = active_suit
+ elif not is_swap(top_card):
+ suit = card_suit(top_card)
+ else:
+ suit = 0
+ obs[56 + suit] = 1.0
+
+ # Top card rank
+ if not is_swap(top_card) and active_suit is None:
+ obs[60 + card_rank(top_card)] = 1.0
+
+ # Direction
+ obs[73] = 0.0 if direction == 1 else 1.0
+
+ # Other players' hand sizes
+ for i, sz in enumerate(other_hand_sizes):
+ obs[74 + i] = sz / 20.0
+
+ # Deck size
+ obs[74 + num_players - 1] = deck_size / 56.0
+
+ # Phase (always play in interactive mode)
+ obs[75 + num_players - 1] = 0.0
+
+ # Known opponent cards
+ if known_opponent_cards:
+ offset = 76 + num_players - 1
+ for c in known_opponent_cards:
+ obs[offset + c] = 1.0
+
+ # Per other player draw info
+ draw_info_offset = 132 + num_players - 1
+ if other_last_events:
+ for i, evt in enumerate(other_last_events):
+ if evt >= 0:
+ obs[draw_info_offset + i * 5 + evt] = 1.0
+ if other_draw_streaks:
+ for i, streak in enumerate(other_draw_streaks):
+ obs[draw_info_offset + i * 5 + 4] = streak / 10.0
+
+ return obs
+
+
+def get_recommendations(model: PolicyValueNet, obs: np.ndarray, hand: list[int],
+ top_card: int, active_suit: int | None, device="cpu"):
+ """Get action probabilities and recommendations."""
+ # Determine legal actions
+ legal = []
+ for c in hand:
+ if is_swap(c):
+ legal.append(c)
+ elif card_rank(c) == RANK_8:
+ legal.append(c)
+ elif active_suit is not None:
+ if card_suit(c) == active_suit:
+ legal.append(c)
+ elif not is_swap(top_card):
+ if card_suit(c) == card_suit(top_card) or card_rank(c) == card_rank(top_card):
+ legal.append(c)
+
+ if not legal:
+ legal = [DRAW_ACTION] # Caller should check deck; in practice use env's legal_actions
+
+ # Get model probabilities
+ obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
+ mask = torch.zeros(1, TOTAL_ACTIONS, device=device)
+ for a in legal:
+ mask[0, a] = 1.0
+
+ with torch.no_grad():
+ logits, value = model.forward(obs_t, mask)
+ probs = F.softmax(logits, dim=-1).squeeze(0).cpu().numpy()
+
+ # Sort by probability
+ ranked = []
+ for a in legal:
+ if a == DRAW_ACTION:
+ name = "DRAW"
+ elif a >= NUM_CARDS:
+ name = f"Choose {SUIT_NAMES[a - 56]}"
+ else:
+ name = card_name(a)
+ ranked.append((a, name, probs[a]))
+
+ ranked.sort(key=lambda x: -x[2])
+ return ranked, value.item()
+
+
+def interactive_loop(model_path: str, num_players: int):
+ device = "cpu"
+
+ # Load model
+ model = PolicyValueNet()
+ checkpoint = torch.load(model_path, map_location=device, weights_only=True)
+ model.load_state_dict(checkpoint["model"])
+ model.eval()
+ print(f"Loaded model from {model_path}")
+ print(f"Trained for {checkpoint.get('episode', '?')} episodes, "
+ f"{checkpoint.get('num_players', '?')} players")
+ print()
+
+ print("=" * 60)
+ print(" Blazing Eights — Play Assistant")
+ print(" Card format: rank+suit (e.g., 8h, Ks, 10d, Ac, SWAP)")
+ print(" Type 'quit' to exit")
+ print("=" * 60)
+
+ while True:
+ print("\n--- New Turn ---")
+ try:
+ # Hand
+ hand_str = input("Your hand (comma-separated, e.g., 8h,Ks,3d,SWAP): ").strip()
+ if hand_str.lower() == "quit":
+ break
+ hand = [parse_card(c) for c in hand_str.split(",")]
+
+ # Top card
+ top_str = input("Top card on discard pile: ").strip()
+ top_card = parse_card(top_str)
+
+ # Active suit (if top is 8)
+ active_suit = None
+ if card_rank(top_card) == RANK_8:
+ suit_str = input("Active suit (s/h/d/c): ").strip().lower()
+ suit_map = {"s": 0, "h": 1, "d": 2, "c": 3}
+ active_suit = suit_map.get(suit_str)
+
+ # Direction
+ dir_str = input("Direction (cw/ccw) [cw]: ").strip().lower()
+ direction = -1 if dir_str == "ccw" else 1
+
+ # Other players' hand sizes
+ sizes_str = input(f"Other players' hand sizes (comma-sep, {num_players-1} values): ").strip()
+ other_sizes = [int(x) for x in sizes_str.split(",")]
+
+ # Deck size estimate
+ deck_str = input("Approximate deck size [20]: ").strip()
+ deck_size = int(deck_str) if deck_str else 20
+
+ # Draw info for other players
+ # p=played from hand, d=drew and played, s=drew and skipped, ?=unknown
+ event_str = input(f"Other players' last action ({num_players-1} values, p/d/s/?): ").strip().lower()
+ event_map = {"p": 0, "d": 1, "s": 2, "?": -1, "": -1}
+ other_events = None
+ if event_str:
+ other_events = [event_map.get(x.strip(), -1) for x in event_str.split(",")]
+
+ streak_str = input(f"Other players' consecutive draw-skip count ({num_players-1} values) [0s]: ").strip()
+ other_streaks = None
+ if streak_str:
+ other_streaks = [int(x) for x in streak_str.split(",")]
+
+ # Build obs and get recommendation
+ obs = build_obs_from_input(
+ hand, top_card, active_suit, direction,
+ other_sizes, deck_size, num_players,
+ other_last_events=other_events,
+ other_draw_streaks=other_streaks,
+ )
+ ranked, value = get_recommendations(model, obs, hand, top_card, active_suit, device)
+
+ print(f"\n Win probability estimate: {(value + 1) / 2:.1%}")
+ print(" Recommended actions:")
+ for i, (action, name, prob) in enumerate(ranked):
+ bar = "█" * int(prob * 30)
+ print(f" {'→' if i == 0 else ' '} {name:<12s} {prob:6.1%} {bar}")
+
+ # If best action is an 8, also show suit recommendation
+ if ranked and ranked[0][0] < NUM_CARDS and card_rank(ranked[0][0]) == RANK_8:
+ print("\n If you play 8, recommended suit:")
+ # Quick eval for each suit
+ for suit_idx in range(4):
+ temp_obs = obs.copy()
+ # Set active suit
+ temp_obs[56:60] = 0
+ temp_obs[56 + suit_idx] = 1.0
+ temp_obs[60:73] = 0 # clear rank (wild)
+ obs_t = torch.tensor(temp_obs, dtype=torch.float32).unsqueeze(0)
+ mask = torch.ones(1, TOTAL_ACTIONS) # dummy
+ with torch.no_grad():
+ _, v = model.forward(obs_t, mask)
+ print(f" {SUIT_NAMES[suit_idx]}: estimated value {v.item():.3f}")
+
+ except (ValueError, IndexError) as e:
+ print(f" Error: {e}. Try again.")
+ except KeyboardInterrupt:
+ break
+
+ print("Goodbye!")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", type=str, required=True, help="Path to trained model .pt file")
+ parser.add_argument("--num_players", type=int, default=2)
+ args = parser.parse_args()
+ interactive_loop(args.model, args.num_players)
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..96affb2
--- /dev/null
+++ b/train.py
@@ -0,0 +1,425 @@
+"""
+PPO Self-Play Training for Blazing Eights.
+
+Architecture:
+ - Single policy network shared across all seats
+ - Self-play: all players use the same (latest) policy
+ - Collect trajectories by running full games
+ - Standard PPO update with masked invalid actions
+
+Usage:
+ python train.py --num_players 2 --episodes 100000 --save_path model.pt
+ python train.py --num_players 3 --episodes 200000
+"""
+
+import argparse
+import os
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributions import Categorical
+from collections import defaultdict
+
+from blazing_env import BlazingEightsEnv, TOTAL_ACTIONS, NUM_CARDS, DRAW_ACTION, PASS_ACTION
+
+
+# ---------------------------------------------------------------------------
+# Policy + Value Network
+# ---------------------------------------------------------------------------
+class PolicyValueNet(nn.Module):
+ def __init__(self, obs_size: int = 180, action_size: int = TOTAL_ACTIONS, hidden: int = 256):
+ super().__init__()
+ self.shared = nn.Sequential(
+ nn.Linear(obs_size, hidden),
+ nn.ReLU(),
+ nn.Linear(hidden, hidden),
+ nn.ReLU(),
+ )
+ self.policy_head = nn.Sequential(
+ nn.Linear(hidden, hidden // 2),
+ nn.ReLU(),
+ nn.Linear(hidden // 2, action_size),
+ )
+ self.value_head = nn.Sequential(
+ nn.Linear(hidden, hidden // 2),
+ nn.ReLU(),
+ nn.Linear(hidden // 2, 1),
+ )
+
+ def forward(self, obs: torch.Tensor, legal_mask: torch.Tensor):
+ """
+ obs: (B, obs_size)
+ legal_mask: (B, action_size) — 1 for legal, 0 for illegal
+ Returns: logits (masked), value
+ """
+ h = self.shared(obs)
+ logits = self.policy_head(h)
+ # Mask illegal actions with large negative
+ logits = logits + (1 - legal_mask) * (-1e9)
+ value = self.value_head(h).squeeze(-1)
+ return logits, value
+
+ def get_action(self, obs: np.ndarray, legal_actions: list[int], device="cpu"):
+ """Sample an action from the policy."""
+ obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
+ mask = torch.zeros(1, TOTAL_ACTIONS, device=device)
+ for a in legal_actions:
+ mask[0, a] = 1.0
+ with torch.no_grad():
+ logits, value = self.forward(obs_t, mask)
+ probs = F.softmax(logits, dim=-1)
+ dist = Categorical(probs)
+ action = dist.sample()
+ return action.item(), dist.log_prob(action).item(), value.item()
+
+ def evaluate(self, obs_t: torch.Tensor, mask_t: torch.Tensor, actions_t: torch.Tensor):
+ """Evaluate actions for PPO update."""
+ logits, values = self.forward(obs_t, mask_t)
+ probs = F.softmax(logits, dim=-1)
+ dist = Categorical(probs)
+ log_probs = dist.log_prob(actions_t)
+ entropy = dist.entropy()
+ return log_probs, values, entropy
+
+
+# ---------------------------------------------------------------------------
+# Trajectory Collection
+# ---------------------------------------------------------------------------
+class Transition:
+ __slots__ = ["obs", "action", "log_prob", "value", "reward", "done", "legal_mask"]
+
+ def __init__(self, obs, action, log_prob, value, reward, done, legal_mask):
+ self.obs = obs
+ self.action = action
+ self.log_prob = log_prob
+ self.value = value
+ self.reward = reward
+ self.done = done
+ self.legal_mask = legal_mask
+
+
+def collect_game(env: BlazingEightsEnv, model: PolicyValueNet, device="cpu"):
+ """
+ Play one full game, return per-player trajectories.
+ All players use the same model (self-play).
+ """
+ obs = env.reset()
+ trajectories: dict[int, list[Transition]] = defaultdict(list)
+ max_steps = 500
+
+ for _ in range(max_steps):
+ player = env.current_player
+ legal = env.legal_actions()
+ if not legal:
+ break
+
+ action, log_prob, value = model.get_action(obs, legal, device)
+
+ # Build legal mask
+ legal_mask = np.zeros(TOTAL_ACTIONS, dtype=np.float32)
+ for a in legal:
+ legal_mask[a] = 1.0
+
+ obs_next, rewards, done, info = env.step(action)
+
+ # Store transition for the acting player
+ trajectories[player].append(Transition(
+ obs=obs.copy(),
+ action=action,
+ log_prob=log_prob,
+ value=value,
+ reward=rewards[player],
+ done=done,
+ legal_mask=legal_mask,
+ ))
+
+ # If done, also assign terminal rewards to other players' last transitions
+ if done:
+ for p in range(env.num_players):
+ if p != player and trajectories[p]:
+ trajectories[p][-1].reward = rewards[p]
+ trajectories[p][-1].done = True
+ break
+
+ obs = obs_next
+
+ return trajectories
+
+
+# ---------------------------------------------------------------------------
+# PPO Update
+# ---------------------------------------------------------------------------
+def compute_gae(transitions: list[Transition], gamma=0.99, lam=0.95):
+ """Compute GAE returns and advantages."""
+ T = len(transitions)
+ if T == 0:
+ return [], []
+
+ rewards = [t.reward for t in transitions]
+ values = [t.value for t in transitions]
+ dones = [t.done for t in transitions]
+
+ advantages = np.zeros(T, dtype=np.float32)
+ last_gae = 0.0
+
+ for t in reversed(range(T)):
+ if t == T - 1 or dones[t]:
+ next_value = 0.0
+ else:
+ next_value = values[t + 1]
+ delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]
+ advantages[t] = last_gae = delta + gamma * lam * (1 - dones[t]) * last_gae
+
+ returns = advantages + np.array(values)
+ return returns.tolist(), advantages.tolist()
+
+
+def ppo_update(model: PolicyValueNet, optimizer: torch.optim.Optimizer,
+ all_transitions: list[Transition], device="cpu",
+ epochs=4, batch_size=256, clip_eps=0.2, vf_coef=0.5, ent_coef=0.01):
+ """PPO clipped surrogate update."""
+ if not all_transitions:
+ return {}
+
+ # Prepare tensors
+ obs_arr = np.array([t.obs for t in all_transitions])
+ actions_arr = np.array([t.action for t in all_transitions])
+ old_log_probs_arr = np.array([t.log_prob for t in all_transitions])
+ masks_arr = np.array([t.legal_mask for t in all_transitions])
+
+ # Compute GAE (treat all transitions as one sequence — not ideal, but we
+ # already computed per-game, so we just concatenate pre-computed values)
+ returns_arr = np.array([t.reward for t in all_transitions]) # placeholder
+ advantages_arr = np.array([t.reward for t in all_transitions]) # placeholder
+
+ obs_t = torch.tensor(obs_arr, dtype=torch.float32, device=device)
+ actions_t = torch.tensor(actions_arr, dtype=torch.long, device=device)
+ old_log_probs_t = torch.tensor(old_log_probs_arr, dtype=torch.float32, device=device)
+ masks_t = torch.tensor(masks_arr, dtype=torch.float32, device=device)
+ returns_t = torch.tensor(returns_arr, dtype=torch.float32, device=device)
+ advantages_t = torch.tensor(advantages_arr, dtype=torch.float32, device=device)
+
+ # Normalize advantages
+ if len(advantages_t) > 1:
+ advantages_t = (advantages_t - advantages_t.mean()) / (advantages_t.std() + 1e-8)
+
+ total_loss_sum = 0
+ n_updates = 0
+
+ for _ in range(epochs):
+ indices = np.arange(len(all_transitions))
+ np.random.shuffle(indices)
+
+ for start in range(0, len(indices), batch_size):
+ end = min(start + batch_size, len(indices))
+ idx = indices[start:end]
+
+ b_obs = obs_t[idx]
+ b_actions = actions_t[idx]
+ b_old_lp = old_log_probs_t[idx]
+ b_masks = masks_t[idx]
+ b_returns = returns_t[idx]
+ b_advantages = advantages_t[idx]
+
+ new_log_probs, values, entropy = model.evaluate(b_obs, b_masks, b_actions)
+
+ ratio = torch.exp(new_log_probs - b_old_lp)
+ surr1 = ratio * b_advantages
+ surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * b_advantages
+ policy_loss = -torch.min(surr1, surr2).mean()
+
+ value_loss = F.mse_loss(values, b_returns)
+
+ loss = policy_loss + vf_coef * value_loss - ent_coef * entropy.mean()
+
+ optimizer.zero_grad()
+ loss.backward()
+ nn.utils.clip_grad_norm_(model.parameters(), 0.5)
+ optimizer.step()
+
+ total_loss_sum += loss.item()
+ n_updates += 1
+
+ return {"loss": total_loss_sum / max(n_updates, 1)}
+
+
+# ---------------------------------------------------------------------------
+# Training Loop
+# ---------------------------------------------------------------------------
+def train(args):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ print(f"Device: {device}")
+ print(f"Training for {args.num_players} players, {args.episodes} episodes")
+
+ model = PolicyValueNet().to(device)
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+
+ # Stats
+ win_counts = defaultdict(int)
+ game_lengths = []
+ batch_transitions = []
+
+ for ep in range(1, args.episodes + 1):
+ env = BlazingEightsEnv(num_players=args.num_players)
+ trajectories = collect_game(env, model, device)
+
+ # Record stats
+ if env.done:
+ win_counts[env.winner] += 1
+ game_lengths.append(sum(len(v) for v in trajectories.values()))
+
+ # Compute GAE per player and collect
+ for player, trans_list in trajectories.items():
+ returns, advantages = compute_gae(trans_list, gamma=args.gamma, lam=args.lam)
+ for i, t in enumerate(trans_list):
+ t.reward = returns[i] if i < len(returns) else t.reward
+ # Store advantage in a hacky way: overwrite reward with return,
+ # and we'll use (return - value) as advantage in update
+ batch_transitions.extend(trans_list)
+
+ # Update every `update_every` episodes
+ if ep % args.update_every == 0:
+ # Recompute advantages from stored returns and values
+ for t in batch_transitions:
+ pass # returns already in t.reward
+
+ # Build proper advantages
+ for t in batch_transitions:
+ # t.reward is now the GAE return; advantage = return - value
+ t.reward = t.reward # this is the return
+ # We'll set the advantage in the update
+ # Actually, let's just pass returns and let update compute
+ returns_for_update = np.array([t.reward for t in batch_transitions])
+ values_for_update = np.array([t.value for t in batch_transitions])
+ advs = returns_for_update - values_for_update
+
+ # Overwrite for the update function
+ obs_arr = np.array([t.obs for t in batch_transitions])
+ actions_arr = np.array([t.action for t in batch_transitions])
+ old_lp_arr = np.array([t.log_prob for t in batch_transitions])
+ masks_arr = np.array([t.legal_mask for t in batch_transitions])
+
+ obs_t = torch.tensor(obs_arr, dtype=torch.float32, device=device)
+ actions_t = torch.tensor(actions_arr, dtype=torch.long, device=device)
+ old_lp_t = torch.tensor(old_lp_arr, dtype=torch.float32, device=device)
+ masks_t = torch.tensor(masks_arr, dtype=torch.float32, device=device)
+ returns_t = torch.tensor(returns_for_update, dtype=torch.float32, device=device)
+ advs_t = torch.tensor(advs, dtype=torch.float32, device=device)
+
+ if len(advs_t) > 1:
+ advs_t = (advs_t - advs_t.mean()) / (advs_t.std() + 1e-8)
+
+ # Manual PPO update
+ for _ in range(args.ppo_epochs):
+ indices = np.arange(len(batch_transitions))
+ np.random.shuffle(indices)
+ for start in range(0, len(indices), args.batch_size):
+ end = min(start + args.batch_size, len(indices))
+ idx = indices[start:end]
+
+ b_obs = obs_t[idx]
+ b_actions = actions_t[idx]
+ b_old_lp = old_lp_t[idx]
+ b_masks = masks_t[idx]
+ b_returns = returns_t[idx]
+ b_advs = advs_t[idx]
+
+ new_lp, values, entropy = model.evaluate(b_obs, b_masks, b_actions)
+ ratio = torch.exp(new_lp - b_old_lp)
+ surr1 = ratio * b_advs
+ surr2 = torch.clamp(ratio, 1 - args.clip_eps, 1 + args.clip_eps) * b_advs
+ policy_loss = -torch.min(surr1, surr2).mean()
+ value_loss = F.mse_loss(values, b_returns)
+ loss = policy_loss + 0.5 * value_loss - args.ent_coef * entropy.mean()
+
+ optimizer.zero_grad()
+ loss.backward()
+ nn.utils.clip_grad_norm_(model.parameters(), 0.5)
+ optimizer.step()
+
+ batch_transitions = []
+
+ # Logging
+ if ep % args.log_every == 0:
+ avg_len = np.mean(game_lengths[-args.log_every:]) if game_lengths else 0
+ total_games = sum(win_counts.values())
+ win_rates = {p: win_counts[p] / max(total_games, 1) for p in range(args.num_players)}
+ print(f"Episode {ep:>7d} | Avg game length: {avg_len:.1f} | "
+ f"Win rates: {win_rates} | "
+ f"Total games: {total_games}")
+
+ # Save checkpoint
+ if ep % args.save_every == 0:
+ path = f"{args.save_path}_ep{ep}.pt"
+ torch.save({
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "episode": ep,
+ "num_players": args.num_players,
+ }, path)
+ print(f" Saved checkpoint: {path}")
+
+ # Final save
+ torch.save({
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "episode": args.episodes,
+ "num_players": args.num_players,
+ }, f"{args.save_path}_final.pt")
+ print(f"Training complete. Final model saved to {args.save_path}_final.pt")
+
+ return model
+
+
+# ---------------------------------------------------------------------------
+# Evaluation: play against random
+# ---------------------------------------------------------------------------
+def evaluate_vs_random(model: PolicyValueNet, num_players=2, num_games=1000, device="cpu"):
+ """Player 0 = model, others = random. Returns player 0 win rate."""
+ wins = 0
+ for _ in range(num_games):
+ env = BlazingEightsEnv(num_players=num_players)
+ obs = env.reset()
+ for _ in range(500):
+ player = env.current_player
+ legal = env.legal_actions()
+ if not legal:
+ break
+ if player == 0:
+ action, _, _ = model.get_action(obs, legal, device)
+ else:
+ action = np.random.choice(legal)
+ obs, rewards, done, info = env.step(action)
+ if done:
+ if env.winner == 0:
+ wins += 1
+ break
+ return wins / num_games
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Train Blazing Eights PPO agent")
+ parser.add_argument("--num_players", type=int, default=2)
+ parser.add_argument("--episodes", type=int, default=100000)
+ parser.add_argument("--lr", type=float, default=3e-4)
+ parser.add_argument("--gamma", type=float, default=0.99)
+ parser.add_argument("--lam", type=float, default=0.95)
+ parser.add_argument("--clip_eps", type=float, default=0.2)
+ parser.add_argument("--ent_coef", type=float, default=0.01)
+ parser.add_argument("--ppo_epochs", type=int, default=4)
+ parser.add_argument("--batch_size", type=int, default=256)
+ parser.add_argument("--update_every", type=int, default=64)
+ parser.add_argument("--log_every", type=int, default=1000)
+ parser.add_argument("--save_every", type=int, default=10000)
+ parser.add_argument("--save_path", type=str, default="blazing_ppo")
+ args = parser.parse_args()
+
+ model = train(args)
+
+ # Eval vs random
+ print("\nEvaluating vs random opponents...")
+ for n in [2, 3, 4, 5]:
+ if n <= args.num_players + 1: # only eval for trained player count
+ wr = evaluate_vs_random(model, num_players=n, num_games=1000)
+ print(f" {n} players: win rate = {wr:.1%} (random baseline: {1/n:.1%})")
diff --git a/versus.py b/versus.py
new file mode 100644
index 0000000..7225e29
--- /dev/null
+++ b/versus.py
@@ -0,0 +1,341 @@
+"""
+Blazing Eights — Human vs AI interactive game.
+
+Play against the trained PPO agent in your terminal.
+
+Usage:
+ python versus.py --model blazing_ppo_final.pt
+ python versus.py --model blazing_ppo_final.pt --num_players 3 # you + 2 AI
+"""
+
+import argparse
+import torch
+import torch.nn.functional as F
+import numpy as np
+from blazing_env import (
+ BlazingEightsEnv, card_name, card_suit, card_rank,
+ is_swap, RANK_8, RANK_J, RANK_Q, RANK_K,
+ NUM_CARDS, TOTAL_ACTIONS, DRAW_ACTION, PASS_ACTION,
+)
+from train import PolicyValueNet
+
+SUIT_SYMBOLS = ["♠", "♥", "♦", "♣"]
+SUIT_LETTERS = {"s": 0, "h": 1, "d": 2, "c": 3,
+ "♠": 0, "♥": 1, "♦": 2, "♣": 3}
+RANK_NAMES = ["A", "2", "3", "4", "5", "6", "7", "8", "9", "10", "J", "Q", "K"]
+
+
+def card_effect(c: int, num_players: int = 2) -> str:
+ """Return a short effect tag for special cards."""
+ if is_swap(c):
+ return "\033[93m换牌\033[0m"
+ r = card_rank(c)
+ if r == RANK_8:
+ return "\033[93m万能\033[0m"
+ if r == RANK_K:
+ return "\033[93m全摸\033[0m"
+ if r == RANK_Q:
+ return "\033[93m反转\033[0m" if num_players > 2 else ""
+ if r == RANK_J:
+ return "\033[93m跳过\033[0m"
+ return ""
+
+
+def pretty_card(c: int) -> str:
+ if is_swap(c):
+ return "\033[95mSWAP\033[0m"
+ suit = card_suit(c)
+ rank = RANK_NAMES[card_rank(c)]
+ colors = ["\033[37m", "\033[91m", "\033[91m", "\033[37m"] # ♠white ♥red ♦red ♣white
+ return f"{colors[suit]}{rank}{SUIT_SYMBOLS[suit]}\033[0m"
+
+
+def pretty_hand(hand: list[int], num_players: int = 2) -> str:
+ sorted_hand = sorted(hand, key=lambda c: (card_suit(c) if not is_swap(c) else 99, c))
+ parts = []
+ for i, c in enumerate(sorted_hand):
+ effect = card_effect(c, num_players)
+ tag = f"[{i}] {pretty_card(c)}"
+ if effect:
+ tag += f"({effect})"
+ parts.append(tag)
+ return " ".join(parts)
+
+
+def print_game_state(env: BlazingEightsEnv, human_player: int, show_ai_hand: bool = False):
+ print()
+ print("=" * 55)
+ top = env.discard[-1]
+ top_str = pretty_card(top)
+ if env.active_suit is not None:
+ top_str += f" (指定花色: {SUIT_SYMBOLS[env.active_suit]})"
+ dir_str = "顺时针 →" if env.direction == 1 else "逆时针 ←"
+ print(f" 弃牌堆顶: {top_str} 方向: {dir_str} 牌堆剩余: {len(env.deck)}")
+ print("-" * 55)
+ for i in range(env.num_players):
+ if i == human_player:
+ tag = "你"
+ hand_str = f"{len(env.hands[i])} 张牌"
+ else:
+ tag = f"AI-{i}"
+ if show_ai_hand:
+ hand_str = ", ".join(pretty_card(c) for c in sorted(env.hands[i]))
+ else:
+ hand_str = f"{len(env.hands[i])} 张牌"
+ arrow = " ◀" if i == env.current_player else ""
+ print(f" {tag}: {hand_str}{arrow}")
+ print("=" * 55)
+
+
+def parse_card_input(s: str) -> int:
+ s = s.strip().upper()
+ if s.startswith("SWAP") or s.startswith("SW"):
+ return 52
+ if s.startswith("10"):
+ rank_str, suit_str = "10", s[2:].lower()
+ else:
+ rank_str, suit_str = s[0], s[1:].lower()
+ rank_map = {r: i for i, r in enumerate(RANK_NAMES)}
+ if rank_str not in rank_map or suit_str not in SUIT_LETTERS:
+ raise ValueError(f"无法识别: {s} (格式例: 8h, Ks, 10d, Ac, SWAP)")
+ return SUIT_LETTERS[suit_str] * 13 + rank_map[rank_str]
+
+
+def human_choose_action(env: BlazingEightsEnv, player: int) -> int:
+ hand = sorted(env.hands[player], key=lambda c: (card_suit(c) if not is_swap(c) else 99, c))
+ legal = env.legal_actions(player)
+
+ if env.phase == "choose_suit":
+ print("\n 你打出了 8!选择指定花色:")
+ for i, s in enumerate(SUIT_SYMBOLS):
+ print(f" [{i}] {s}")
+ while True:
+ try:
+ choice = input(" 选择 (0-3): ").strip()
+ idx = int(choice)
+ action = 56 + idx
+ if action in legal:
+ return action
+ print(" 无效选择,请重试")
+ except (ValueError, IndexError):
+ print(" 请输入 0-3")
+ return action
+
+ print(f"\n 你的手牌: {pretty_hand(hand, env.num_players)}")
+
+ # Build playable cards display
+ playable = [a for a in legal if a < NUM_CARDS]
+ can_draw = DRAW_ACTION in legal
+ can_pass = PASS_ACTION in legal
+
+ print(" 可出的牌:", end="")
+ if playable:
+ playable_names = []
+ for a in playable:
+ idx_in_hand = hand.index(a)
+ effect = card_effect(a, env.num_players)
+ tag = f"[{idx_in_hand}]{pretty_card(a)}"
+ if effect:
+ tag += f"({effect})"
+ playable_names.append(tag)
+ print(" " + " ".join(playable_names))
+ else:
+ print(" 无")
+
+ if can_draw:
+ print(" [d] 摸牌")
+ if can_pass:
+ print(" [p] 跳过 (牌堆与弃牌堆均已空)")
+
+ while True:
+ choice = input(" 你的选择: ").strip().lower()
+ if choice == "d" and can_draw:
+ return DRAW_ACTION
+ if choice == "p" and can_pass:
+ return PASS_ACTION
+ if choice == "d" and not can_draw:
+ print(" 牌堆已空,无法摸牌")
+ continue
+ if choice == "p" and not can_pass:
+ print(" 还没摸牌,不能直接跳过")
+ continue
+ if choice == "q":
+ raise KeyboardInterrupt
+ try:
+ idx = int(choice)
+ if 0 <= idx < len(hand):
+ card = hand[idx]
+ if card in playable:
+ return card
+ # Handle swap cards (might have multiple)
+ if is_swap(card):
+ for a in playable:
+ if is_swap(a):
+ return a
+ print(f" {pretty_card(card)} 不能出,请选其他牌")
+ else:
+ print(f" 序号超出范围 (0-{len(hand)-1})")
+ except ValueError:
+ print(" 输入序号、d(摸牌) 或 q(退出)")
+
+
+def ai_choose_action(env: BlazingEightsEnv, model: PolicyValueNet, player: int, device="cpu") -> int:
+ obs = env._get_obs(player)
+ legal = env.legal_actions(player)
+ action, _, value = model.get_action(obs, legal, device)
+ return action
+
+
+def describe_action(player_name: str, action: int, env: BlazingEightsEnv, drawn_card: int = None):
+ if action == DRAW_ACTION:
+ return f" {player_name} 摸了一张牌"
+ if action == PASS_ACTION:
+ return f" {player_name} 跳过"
+ if action >= 56:
+ suit = action - 56
+ return f" {player_name} 指定花色: {SUIT_SYMBOLS[suit]}"
+ desc = f" {player_name} 打出 {pretty_card(action)}"
+ rank = card_rank(action)
+ if is_swap(action):
+ desc += " → 与下家交换手牌!"
+ elif rank == RANK_8:
+ desc += " → 万能牌!选择花色..."
+ elif rank == RANK_K:
+ desc += " → 其他所有人各摸 1 张!"
+ elif rank == RANK_Q and env.num_players > 2:
+ desc += " → 反转方向!"
+ elif rank == RANK_J:
+ desc += " → 跳过下一位!"
+ return desc
+
+
+def play_game(model_path: str, num_players: int, human_player: int = 0, show_ai: bool = False):
+ device = "cpu"
+ model = PolicyValueNet()
+ checkpoint = torch.load(model_path, map_location=device, weights_only=True)
+ model.load_state_dict(checkpoint["model"])
+ model.eval()
+
+ print()
+ print("╔══════════════════════════════════════╗")
+ print("║ Blazing Eights - 人机对战 ║")
+ print("╠══════════════════════════════════════╣")
+ print(f"║ 玩家数: {num_players} 你是: Player {human_player} ║")
+ print("║ 输入序号出牌, d摸牌, p跳过, q退出 ║")
+ print("╚══════════════════════════════════════╝")
+
+ env = BlazingEightsEnv(num_players=num_players)
+ turn = 0
+
+ while not env.done:
+ player = env.current_player
+ turn += 1
+
+ if player == human_player:
+ print_game_state(env, human_player, show_ai_hand=show_ai)
+ try:
+ action = human_choose_action(env, player)
+ except KeyboardInterrupt:
+ print("\n\n 你退出了游戏。再见!")
+ return
+
+ # Describe human action
+ name = "你"
+ if action == DRAW_ACTION:
+ # Remember hand before draw to find the new card
+ hand_before = set(env.hands[player])
+ obs, rewards, done, info = env.step(action)
+ hand_after = set(env.hands[player])
+ new_cards = hand_after - hand_before
+ if new_cards:
+ drawn = next(iter(new_cards))
+ print(f" 你摸到了 {pretty_card(drawn)}")
+ else:
+ print(f" 牌堆已空,没摸到牌")
+ # Turn stays with human — loop back to let them decide
+ continue
+ elif action == PASS_ACTION:
+ print(f" 你选择不出牌,结束回合")
+ obs, rewards, done, info = env.step(action)
+ continue
+ else:
+ print(describe_action(name, action, env))
+ obs, rewards, done, info = env.step(action)
+ # If played an 8, need to choose suit
+ if env.phase == "choose_suit" and env._pending_8_player == human_player:
+ suit_action = human_choose_action(env, human_player)
+ print(f" 你指定花色: {SUIT_SYMBOLS[suit_action - 56]}")
+ obs, rewards, done, info = env.step(suit_action)
+ continue
+ else:
+ # AI turn
+ ai_name = f"AI-{player}"
+
+ if env.phase == "choose_suit":
+ action = ai_choose_action(env, model, player, device)
+ print(f" {ai_name} 指定花色: {SUIT_SYMBOLS[action - 56]}")
+ obs, rewards, done, info = env.step(action)
+ continue
+
+ action = ai_choose_action(env, model, player, device)
+
+ if action == DRAW_ACTION:
+ print(f" {ai_name} 摸了一张牌")
+ obs, rewards, done, info = env.step(action)
+ # AI still has their turn — now decide to play or pass
+ action2 = ai_choose_action(env, model, player, device)
+ if action2 == PASS_ACTION:
+ print(f" {ai_name} 选择不出牌")
+ obs, rewards, done, info = env.step(action2)
+ else:
+ print(describe_action(ai_name, action2, env))
+ obs, rewards, done, info = env.step(action2)
+ if env.phase == "choose_suit" and env._pending_8_player == player:
+ suit_action = ai_choose_action(env, model, player, device)
+ print(f" {ai_name} 指定花色: {SUIT_SYMBOLS[suit_action - 56]}")
+ obs, rewards, done, info = env.step(suit_action)
+ elif action == PASS_ACTION:
+ print(f" {ai_name} 跳过")
+ obs, rewards, done, info = env.step(action)
+ else:
+ print(describe_action(ai_name, action, env))
+ obs, rewards, done, info = env.step(action)
+ if env.phase == "choose_suit" and env._pending_8_player == player:
+ suit_action = ai_choose_action(env, model, player, device)
+ print(f" {ai_name} 指定花色: {SUIT_SYMBOLS[suit_action - 56]}")
+ obs, rewards, done, info = env.step(suit_action)
+
+ # Game over
+ print_game_state(env, human_player, show_ai_hand=True)
+ print()
+ if env.winner == human_player:
+ print(" 🎉 你赢了!!!")
+ elif env.winner >= 0:
+ print(f" 💀 AI-{env.winner} 赢了...")
+ else:
+ print(" 平局(僵局)")
+
+ # Show hand sizes
+ for i in range(env.num_players):
+ name = "你" if i == human_player else f"AI-{i}"
+ print(f" {name}: {len(env.hands[i])} 张剩余")
+ print()
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Blazing Eights 人机对战")
+ parser.add_argument("--model", type=str, default="blazing_ppo_final.pt", help="模型路径")
+ parser.add_argument("--num_players", type=int, default=2, help="玩家总数 (2-5)")
+ parser.add_argument("--show_ai", action="store_true", help="显示 AI 手牌 (调试用)")
+ args = parser.parse_args()
+
+ while True:
+ play_game(args.model, args.num_players, human_player=0, show_ai=args.show_ai)
+ again = input(" 再来一局? (y/n): ").strip().lower()
+ if again != "y":
+ print(" 下次再见!")
+ break
+
+
+if __name__ == "__main__":
+ main()