diff options
| -rw-r--r-- | .gitignore | 4 | ||||
| -rw-r--r-- | README.md | 67 | ||||
| -rw-r--r-- | blazing_env.py | 506 | ||||
| -rw-r--r-- | play.py | 270 | ||||
| -rw-r--r-- | train.py | 425 | ||||
| -rw-r--r-- | versus.py | 341 |
6 files changed, 1613 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8947638 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.pt +__pycache__/ +*.pyc +.DS_Store diff --git a/README.md b/README.md new file mode 100644 index 0000000..711f793 --- /dev/null +++ b/README.md @@ -0,0 +1,67 @@ +# Blazing Eights — RL Agent + +Self-play PPO agent for the Blazing Eights card game. + +## Setup + +```bash +pip install torch numpy +``` + +## Files + +- `blazing_env.py` — Game environment (2-5 players) +- `train.py` — PPO self-play training +- `play.py` — Real-time play assistant (input game state, get best move) + +## Training + +```bash +# Train a 2-player agent (~10min on CPU for 100k episodes) +python train.py --num_players 2 --episodes 100000 + +# Train for 3 players (may need more episodes) +python train.py --num_players 3 --episodes 200000 + +# Custom hyperparams +python train.py --num_players 4 --episodes 300000 --lr 1e-4 --ent_coef 0.02 +``` + +Training saves checkpoints every 10k episodes and a final model. + +## Real-time Play Assistant + +After training, use the assistant during a real game: + +```bash +python play.py --model blazing_ppo_final.pt --num_players 3 +``` + +It will prompt you for: +1. Your hand (e.g., `8h,Ks,3d,SWAP`) +2. Top discard card (e.g., `6d`) +3. Active suit if an 8 was played +4. Direction (cw/ccw) +5. Other players' hand sizes +6. Approximate deck size + +Then shows ranked action recommendations with probabilities. + +## Game Rules + +- **56 cards**: standard 52 + 4 Swap cards +- **Match**: suit or rank of top card +- **8**: Wild — choose a suit for next player +- **K**: All other players draw 1 +- **Q**: Reverse direction (no effect in 2-player) +- **J**: Skip next player +- **Swap**: Swap your entire hand with next player (playable anytime, no match needed) +- **Can't play**: Draw 1, play it if legal +- **Win**: First to empty hand + +## Tips for Better Training + +1. **Train per player count** — the optimal policy differs significantly for 2 vs 5 players. +2. **Increase episodes for more players** — larger games have more variance, need more samples. +3. **Opponent modeling** — after self-play, you can fine-tune against specific opponent behaviors by replacing some players with heuristic bots that mimic your friends' tendencies. +4. **Curriculum** — start training with 2 players, then use the trained model to initialize training for 3+ players. diff --git a/blazing_env.py b/blazing_env.py new file mode 100644 index 0000000..c3d97ae --- /dev/null +++ b/blazing_env.py @@ -0,0 +1,506 @@ +""" +Blazing Eights — multi-agent card game environment. + +Cards: + 52 standard cards (4 suits × 13 ranks: A,2..10,J,Q,K) + + 4 Swap cards (no suit, index 52-55) + Total: 56 cards + +Special cards: + 8 → Wild: player chooses a suit, next player must match that suit + K → All OTHER players draw 1 card from the deck + Q → Reverse direction (no effect in 2-player games) + J → Skip next player's turn + Swap → Swap entire hand with next player (playable anytime on your turn, no match needed) + +Rules: + - Match top card by suit OR rank (unless playing 8 or Swap) + - Can't play → draw 1; if drawn card is playable, play it immediately + - First player to empty hand wins + - Initial hand: 5 cards each +""" + +import numpy as np +from typing import Optional + +# --------------------------------------------------------------------------- +# Card encoding +# --------------------------------------------------------------------------- +# Standard cards: index 0-51 +# suit = index // 13 (0=♠, 1=♥, 2=♦, 3=♣) +# rank = index % 13 (0=A, 1=2, 2=3, ..., 9=10, 10=J, 11=Q, 12=K) +# Swap cards: index 52, 53, 54, 55 + +NUM_STANDARD = 52 +NUM_SWAP = 4 +NUM_CARDS = NUM_STANDARD + NUM_SWAP + +RANK_A, RANK_J, RANK_Q, RANK_K = 0, 10, 11, 12 +RANK_8 = 7 # rank index for 8 (0=A,1=2,...,7=8) + +def card_suit(c: int) -> int: + """Return suit of a standard card (0-3). Swap cards return -1.""" + return c // 13 if c < NUM_STANDARD else -1 + +def card_rank(c: int) -> int: + """Return rank of a standard card (0-12). Swap cards return -1.""" + return c % 13 if c < NUM_STANDARD else -1 + +def is_swap(c: int) -> bool: + return c >= NUM_STANDARD + +def card_name(c: int) -> str: + if is_swap(c): + return f"SWAP-{c - NUM_STANDARD}" + suits = "♠♥♦♣" + ranks = ["A"] + [str(i) for i in range(2, 11)] + ["J", "Q", "K"] + return f"{ranks[card_rank(c)]}{suits[card_suit(c)]}" + + +# --------------------------------------------------------------------------- +# Action space +# --------------------------------------------------------------------------- +# Actions 0-55: play card with that index +# Actions 56-59: choose suit after playing an 8 (♠,♥,♦,♣) +# Action 60: draw a card +NUM_PLAY_ACTIONS = NUM_CARDS # 0..55 +NUM_SUIT_ACTIONS = 4 # 56..59 +DRAW_ACTION = 60 +PASS_ACTION = 61 # skip turn (when deck empty & no playable card) +TOTAL_ACTIONS = 62 + + +class BlazingEightsEnv: + """ + Multi-agent environment for Blazing Eights. + + Designed for self-play RL. Call step() with the current player's action. + The env tracks whose turn it is. + """ + + def __init__(self, num_players: int = 2, seed: Optional[int] = None): + assert 2 <= num_players <= 5 + self.num_players = num_players + self.rng = np.random.default_rng(seed) + self.reset() + + # ------------------------------------------------------------------ + # Reset + # ------------------------------------------------------------------ + def reset(self, seed: Optional[int] = None): + if seed is not None: + self.rng = np.random.default_rng(seed) + + # Build & shuffle deck + deck = list(range(NUM_CARDS)) + self.rng.shuffle(deck) + + # Deal 5 cards each + self.hands: list[list[int]] = [] + idx = 0 + for _ in range(self.num_players): + self.hands.append(sorted(deck[idx:idx + 5])) + idx += 5 + + # Find a non-special starting card for the discard pile + # (avoid starting with 8, J, Q, K, or Swap) + self.discard: list[int] = [] + start_card = None + remaining = deck[idx:] + for i, c in enumerate(remaining): + if not is_swap(c) and card_rank(c) not in (RANK_8, RANK_J, RANK_Q, RANK_K): + start_card = c + remaining.pop(i) + break + if start_card is None: + # Extremely unlikely; just use first card + start_card = remaining.pop(0) + self.discard.append(start_card) + + self.deck: list[int] = remaining + self.current_player = int(self.rng.integers(0, self.num_players)) + self.direction = 1 # 1=clockwise, -1=counter-clockwise + self.done = False + self.winner = -1 + + # State for wild-8: the chosen suit (None if top card is not a wild) + self.active_suit: Optional[int] = None + + # Phase: "play" or "choose_suit" + self.phase = "play" + # Temp storage for the card that triggered choose_suit + self._pending_8_player: Optional[int] = None + + # For K resolution + self._pending_k = False + + # Track consecutive passes for stalemate detection + self.consecutive_passes = 0 + + # Track whether current player has already drawn this turn + self.has_drawn_this_turn = False + + # Action history: records recent events visible to all players + # Each entry: (player, event_type) + # event_type: 0=played_card, 1=(unused), 2=drew_card, 3=passed + self.action_history: list[tuple[int, int]] = [] + self.max_history = 20 # keep last 20 events + + # Track swap visibility: after a swap, the swapper sees the received cards + # This is informational; the obs encodes it + self.swap_known_cards: dict[int, list[int]] = {} # player -> known opponent cards + + return self._get_obs(self.current_player) + + # ------------------------------------------------------------------ + # Observation + # ------------------------------------------------------------------ + def _get_obs(self, player: int) -> np.ndarray: + """ + Observation vector for `player`: + [0:56] one-hot of cards in hand + [56:60] top card suit one-hot (or active_suit if wild) + [60:73] top card rank one-hot + [73] direction (0=cw, 1=ccw) + [74:74+N-1] other players' hand sizes (normalized /20) + [74+N-1] deck size (normalized /56) + [75+N-1] phase: 0=play, 1=choose_suit + [76+N-1 : 132+N-1] known cards of next player (from swap), 56 one-hot + [132+N-1 : 132+N-1+(N-1)*5] per other player draw info: + 4 floats: last event one-hot (played/drew_played/drew_skipped/passed) + 1 float: consecutive draw-and-skip streak (/10) + Padded to fixed 180. + """ + obs = np.zeros(180, dtype=np.float32) + + # Hand + for c in self.hands[player]: + obs[c] = 1.0 + + # Top card info + top = self.discard[-1] + if self.active_suit is not None: + suit = self.active_suit + elif not is_swap(top): + suit = card_suit(top) + else: + suit = 0 + obs[56 + suit] = 1.0 + + if not is_swap(top) and self.active_suit is None: + obs[60 + card_rank(top)] = 1.0 + + # Direction + obs[73] = 0.0 if self.direction == 1 else 1.0 + + # Other players' hand sizes + for i in range(1, self.num_players): + other = (player + i) % self.num_players + obs[74 + i - 1] = len(self.hands[other]) / 20.0 + + # Deck size + obs[74 + self.num_players - 1] = len(self.deck) / 56.0 + + # Phase + obs[75 + self.num_players - 1] = 1.0 if self.phase == "choose_suit" else 0.0 + + # Known cards of next player (from swap) + offset = 76 + self.num_players - 1 + if player in self.swap_known_cards: + for c in self.swap_known_cards[player]: + obs[offset + c] = 1.0 + + # Per other player: last event type + consecutive draw-skip streak + # This encodes the "timer tells" — draw then skip means drawn card unplayable + draw_info_offset = 132 + self.num_players - 1 + for i in range(1, self.num_players): + other = (player + i) % self.num_players + base = draw_info_offset + (i - 1) * 5 + + # Scan history backwards for this player's events + last_event = None + consec_draw_skip = 0 + for p, evt in reversed(self.action_history): + if p == other: + if last_event is None: + last_event = evt + if evt == 2: # drew_and_skipped + consec_draw_skip += 1 + else: + break + + if last_event is not None: + obs[base + last_event] = 1.0 + obs[base + 4] = consec_draw_skip / 10.0 + + return obs + + @staticmethod + def obs_size() -> int: + return 180 + + # ------------------------------------------------------------------ + # Legal actions + # ------------------------------------------------------------------ + def legal_actions(self, player: Optional[int] = None) -> list[int]: + if self.done: + return [] + if player is None: + player = self.current_player + + if self.phase == "choose_suit": + if player == self._pending_8_player: + return [56, 57, 58, 59] + else: + return [] + + actions = [] + hand = self.hands[player] + top = self.discard[-1] + + for c in hand: + if self._can_play(c, top): + actions.append(c) + + if self.has_drawn_this_turn: + # Already drew this turn: can play a card or pass (end turn) + actions.append(PASS_ACTION) + else: + # Can always choose to draw instead of playing + if self.deck or len(self.discard) > 1: + actions.append(DRAW_ACTION) + if not actions: + # No playable cards and no deck: must pass + actions.append(PASS_ACTION) + return actions + + def _can_play(self, card: int, top: int) -> bool: + # Swap cards: always playable + if is_swap(card): + return True + # 8s: always playable (wild) + if card_rank(card) == RANK_8: + return True + # If active_suit is set (after a wild 8), must match that suit + if self.active_suit is not None: + return card_suit(card) == self.active_suit + # Normal: match suit or rank + if is_swap(top): + # Top is swap — shouldn't happen in normal flow, but match anything + return True + return card_suit(card) == card_suit(top) or card_rank(card) == card_rank(top) + + # ------------------------------------------------------------------ + # Step + # ------------------------------------------------------------------ + def step(self, action: int): + """ + Returns (obs_next_player, reward_dict, done, info) + reward_dict: {player_id: reward} + """ + assert not self.done, "Game is over" + + player = self.current_player + legal = self.legal_actions(player) + assert action in legal, f"Illegal action {action}. Legal: {legal}" + + info = {} + rewards = {i: 0.0 for i in range(self.num_players)} + + # --- Choose suit phase --- + if self.phase == "choose_suit": + self.active_suit = action - 56 + self.phase = "play" + + # Now resolve K if pending + if self._pending_k: + self._resolve_k(player) + self._pending_k = False + + # Advance to next player + self._advance_turn() + obs = self._get_obs(self.current_player) + return obs, rewards, False, info + + # --- Play phase --- + if action == DRAW_ACTION: + self.consecutive_passes = 0 + drawn = self._draw_card(player) + self._record_event(player, 2) # drew a card + # Card is added to hand; player keeps their turn to decide + # whether to play it (or any other card) or pass + self.has_drawn_this_turn = True + obs = self._get_obs(player) + return obs, rewards, False, info + elif action == PASS_ACTION: + # Skip turn (no cards anywhere) + self.consecutive_passes += 1 + self._record_event(player, 3) # passed + if self.consecutive_passes >= self.num_players: + # Stalemate: all players passed in a row + self.done = True + self.winner = -1 # no winner + # Player with fewest cards gets partial reward + min_cards = min(len(h) for h in self.hands) + for i in range(self.num_players): + if len(self.hands[i]) == min_cards: + rewards[i] = 0.5 + else: + rewards[i] = -1.0 + obs = self._get_obs(player) + return obs, rewards, True, {"stalemate": True} + self._advance_turn() + obs = self._get_obs(self.current_player) + return obs, rewards, False, info + else: + return self._play_card(player, action, rewards, info) + + def _play_card(self, player: int, card: int, rewards: dict, info: dict): + self.consecutive_passes = 0 + self._record_event(player, 0) # played_card + hand = self.hands[player] + assert card in hand, f"Card {card} not in hand of player {player}" + hand.remove(card) + self.discard.append(card) + + # Clear active suit (unless new card is 8) + self.active_suit = None + + # Clear swap knowledge for this player (cards change over time) + # We keep it until they play; after playing, knowledge decays + # Actually let's just keep swap_known_cards until overwritten + + # Check win + if len(hand) == 0: + self.done = True + self.winner = player + rewards[player] = 1.0 + for i in range(self.num_players): + if i != player: + rewards[i] = -1.0 + obs = self._get_obs(player) + return obs, rewards, True, {"winner": player} + + # Handle special cards + if is_swap(card): + self._resolve_swap(player) + self._advance_turn() + elif card_rank(card) == RANK_8: + # Need to choose suit + self.phase = "choose_suit" + self._pending_8_player = player + # Check if K also (8 is rank 7, K is rank 12 — not the same, so no overlap) + # 8 is not K, so no K effect here + obs = self._get_obs(player) + return obs, rewards, False, info + elif card_rank(card) == RANK_K: + # All other players draw 1 + self._resolve_k(player) + self._advance_turn() + elif card_rank(card) == RANK_Q: + # Reverse direction (no effect in 2-player) + if self.num_players > 2: + self.direction *= -1 + self._advance_turn() + elif card_rank(card) == RANK_J: + # Skip next player + self._advance_turn() # skip + self._advance_turn() # to the one after + else: + self._advance_turn() + + obs = self._get_obs(self.current_player) + return obs, rewards, False, info + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _advance_turn(self): + self.has_drawn_this_turn = False + self.current_player = (self.current_player + self.direction) % self.num_players + + def _draw_card(self, player: int) -> Optional[int]: + if not self.deck: + self._reshuffle_discard() + if not self.deck: + return None # No cards left anywhere + card = self.deck.pop() + self.hands[player].append(card) + return card + + def _reshuffle_discard(self): + """Reshuffle all but the top card of the discard pile into the deck.""" + if len(self.discard) <= 1: + return + top = self.discard[-1] + self.deck = self.discard[:-1] + self.discard = [top] + self.rng.shuffle(self.deck) + + def _resolve_k(self, player: int): + """All players except `player` draw 1 card.""" + for i in range(self.num_players): + if i != player: + self._draw_card(i) + + def _record_event(self, player: int, event_type: int): + """Record a visible game event.""" + self.action_history.append((player, event_type)) + if len(self.action_history) > self.max_history: + self.action_history.pop(0) + + def _resolve_swap(self, player: int): + """Swap hands with the next player.""" + next_player = (player + self.direction) % self.num_players + self.hands[player], self.hands[next_player] = self.hands[next_player], self.hands[player] + # After swap, `player` now has what `next_player` had → player knows these cards + # And `next_player` now has what `player` had → next_player knows these cards + self.swap_known_cards[player] = list(self.hands[player]) + self.swap_known_cards[next_player] = list(self.hands[next_player]) + + # ------------------------------------------------------------------ + # Utilities + # ------------------------------------------------------------------ + def render(self): + print(f"--- Turn: Player {self.current_player} | Direction: {'→' if self.direction == 1 else '←'} ---") + top = self.discard[-1] + suit_names = ["♠", "♥", "♦", "♣"] + top_str = card_name(top) + if self.active_suit is not None: + top_str += f" (active suit: {suit_names[self.active_suit]})" + print(f"Top card: {top_str}") + for i in range(self.num_players): + hand_str = ", ".join(card_name(c) for c in sorted(self.hands[i])) + marker = " ◀" if i == self.current_player else "" + print(f" Player {i}: [{len(self.hands[i])}] {hand_str}{marker}") + print(f"Deck: {len(self.deck)} cards") + if self.done: + print(f"🏆 Player {self.winner} wins!") + + def copy(self): + """Return a deep copy of the environment state.""" + import copy + return copy.deepcopy(self) + + +# --------------------------------------------------------------------------- +# Quick test +# --------------------------------------------------------------------------- +if __name__ == "__main__": + env = BlazingEightsEnv(num_players=3, seed=42) + env.render() + print() + + for step_i in range(200): + player = env.current_player + actions = env.legal_actions() + if not actions: + break + action = env.rng.choice(actions) + print(f"Player {player} plays: {card_name(action) if action < NUM_CARDS else ('suit ' + '♠♥♦♣'[action-56] if action < DRAW_ACTION else ('DRAW' if action == DRAW_ACTION else 'PASS'))}") + obs, rewards, done, info = env.step(action) + if done: + env.render() + break + else: + print("Game didn't finish in 200 steps") @@ -0,0 +1,270 @@ +""" +Blazing Eights — Real-time Play Assistant. + +Load a trained model and get recommended actions during a real game. +You input the game state, it tells you the best move. + +Usage: + python play.py --model blazing_ppo_final.pt --num_players 3 +""" + +import argparse +import torch +import torch.nn.functional as F +import numpy as np +from blazing_env import ( + BlazingEightsEnv, PolicyValueNet, card_name, card_suit, card_rank, + is_swap, RANK_8, RANK_J, RANK_Q, RANK_K, NUM_STANDARD, NUM_CARDS, + TOTAL_ACTIONS, DRAW_ACTION, PASS_ACTION +) + +# Import network +import sys +sys.path.insert(0, ".") +from train import PolicyValueNet + + +SUIT_NAMES = ["♠ spades", "♥ hearts", "♦ diamonds", "♣ clubs"] +SUIT_SHORT = ["s", "h", "d", "c"] +RANK_NAMES = ["A", "2", "3", "4", "5", "6", "7", "8", "9", "10", "J", "Q", "K"] + + +def parse_card(s: str) -> int: + """Parse a card string like '8h', 'Ks', 'SWAP', '10d' into card index.""" + s = s.strip().upper() + if s.startswith("SWAP"): + # We don't distinguish between swap cards; just return first available + return 52 # caller should handle + if s.startswith("SW"): + return 52 + + # Parse rank + if s.startswith("10"): + rank_str = "10" + suit_str = s[2:].lower() + else: + rank_str = s[0] + suit_str = s[1:].lower() + + rank_map = {r: i for i, r in enumerate(RANK_NAMES)} + suit_map = {"s": 0, "h": 1, "d": 2, "c": 3, + "♠": 0, "♥": 1, "♦": 2, "♣": 3} + + if rank_str not in rank_map or suit_str not in suit_map: + raise ValueError(f"Cannot parse card: {s}") + + return suit_map[suit_str] * 13 + rank_map[rank_str] + + +def build_obs_from_input(hand: list[int], top_card: int, active_suit: int | None, + direction: int, other_hand_sizes: list[int], + deck_size: int, num_players: int, + known_opponent_cards: list[int] | None = None, + other_last_events: list[int] | None = None, + other_draw_streaks: list[int] | None = None) -> np.ndarray: + """Build observation vector from manual game state input.""" + obs = np.zeros(180, dtype=np.float32) + + # Hand + for c in hand: + obs[c] = 1.0 + + # Top card suit + if active_suit is not None: + suit = active_suit + elif not is_swap(top_card): + suit = card_suit(top_card) + else: + suit = 0 + obs[56 + suit] = 1.0 + + # Top card rank + if not is_swap(top_card) and active_suit is None: + obs[60 + card_rank(top_card)] = 1.0 + + # Direction + obs[73] = 0.0 if direction == 1 else 1.0 + + # Other players' hand sizes + for i, sz in enumerate(other_hand_sizes): + obs[74 + i] = sz / 20.0 + + # Deck size + obs[74 + num_players - 1] = deck_size / 56.0 + + # Phase (always play in interactive mode) + obs[75 + num_players - 1] = 0.0 + + # Known opponent cards + if known_opponent_cards: + offset = 76 + num_players - 1 + for c in known_opponent_cards: + obs[offset + c] = 1.0 + + # Per other player draw info + draw_info_offset = 132 + num_players - 1 + if other_last_events: + for i, evt in enumerate(other_last_events): + if evt >= 0: + obs[draw_info_offset + i * 5 + evt] = 1.0 + if other_draw_streaks: + for i, streak in enumerate(other_draw_streaks): + obs[draw_info_offset + i * 5 + 4] = streak / 10.0 + + return obs + + +def get_recommendations(model: PolicyValueNet, obs: np.ndarray, hand: list[int], + top_card: int, active_suit: int | None, device="cpu"): + """Get action probabilities and recommendations.""" + # Determine legal actions + legal = [] + for c in hand: + if is_swap(c): + legal.append(c) + elif card_rank(c) == RANK_8: + legal.append(c) + elif active_suit is not None: + if card_suit(c) == active_suit: + legal.append(c) + elif not is_swap(top_card): + if card_suit(c) == card_suit(top_card) or card_rank(c) == card_rank(top_card): + legal.append(c) + + if not legal: + legal = [DRAW_ACTION] # Caller should check deck; in practice use env's legal_actions + + # Get model probabilities + obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0) + mask = torch.zeros(1, TOTAL_ACTIONS, device=device) + for a in legal: + mask[0, a] = 1.0 + + with torch.no_grad(): + logits, value = model.forward(obs_t, mask) + probs = F.softmax(logits, dim=-1).squeeze(0).cpu().numpy() + + # Sort by probability + ranked = [] + for a in legal: + if a == DRAW_ACTION: + name = "DRAW" + elif a >= NUM_CARDS: + name = f"Choose {SUIT_NAMES[a - 56]}" + else: + name = card_name(a) + ranked.append((a, name, probs[a])) + + ranked.sort(key=lambda x: -x[2]) + return ranked, value.item() + + +def interactive_loop(model_path: str, num_players: int): + device = "cpu" + + # Load model + model = PolicyValueNet() + checkpoint = torch.load(model_path, map_location=device, weights_only=True) + model.load_state_dict(checkpoint["model"]) + model.eval() + print(f"Loaded model from {model_path}") + print(f"Trained for {checkpoint.get('episode', '?')} episodes, " + f"{checkpoint.get('num_players', '?')} players") + print() + + print("=" * 60) + print(" Blazing Eights — Play Assistant") + print(" Card format: rank+suit (e.g., 8h, Ks, 10d, Ac, SWAP)") + print(" Type 'quit' to exit") + print("=" * 60) + + while True: + print("\n--- New Turn ---") + try: + # Hand + hand_str = input("Your hand (comma-separated, e.g., 8h,Ks,3d,SWAP): ").strip() + if hand_str.lower() == "quit": + break + hand = [parse_card(c) for c in hand_str.split(",")] + + # Top card + top_str = input("Top card on discard pile: ").strip() + top_card = parse_card(top_str) + + # Active suit (if top is 8) + active_suit = None + if card_rank(top_card) == RANK_8: + suit_str = input("Active suit (s/h/d/c): ").strip().lower() + suit_map = {"s": 0, "h": 1, "d": 2, "c": 3} + active_suit = suit_map.get(suit_str) + + # Direction + dir_str = input("Direction (cw/ccw) [cw]: ").strip().lower() + direction = -1 if dir_str == "ccw" else 1 + + # Other players' hand sizes + sizes_str = input(f"Other players' hand sizes (comma-sep, {num_players-1} values): ").strip() + other_sizes = [int(x) for x in sizes_str.split(",")] + + # Deck size estimate + deck_str = input("Approximate deck size [20]: ").strip() + deck_size = int(deck_str) if deck_str else 20 + + # Draw info for other players + # p=played from hand, d=drew and played, s=drew and skipped, ?=unknown + event_str = input(f"Other players' last action ({num_players-1} values, p/d/s/?): ").strip().lower() + event_map = {"p": 0, "d": 1, "s": 2, "?": -1, "": -1} + other_events = None + if event_str: + other_events = [event_map.get(x.strip(), -1) for x in event_str.split(",")] + + streak_str = input(f"Other players' consecutive draw-skip count ({num_players-1} values) [0s]: ").strip() + other_streaks = None + if streak_str: + other_streaks = [int(x) for x in streak_str.split(",")] + + # Build obs and get recommendation + obs = build_obs_from_input( + hand, top_card, active_suit, direction, + other_sizes, deck_size, num_players, + other_last_events=other_events, + other_draw_streaks=other_streaks, + ) + ranked, value = get_recommendations(model, obs, hand, top_card, active_suit, device) + + print(f"\n Win probability estimate: {(value + 1) / 2:.1%}") + print(" Recommended actions:") + for i, (action, name, prob) in enumerate(ranked): + bar = "█" * int(prob * 30) + print(f" {'→' if i == 0 else ' '} {name:<12s} {prob:6.1%} {bar}") + + # If best action is an 8, also show suit recommendation + if ranked and ranked[0][0] < NUM_CARDS and card_rank(ranked[0][0]) == RANK_8: + print("\n If you play 8, recommended suit:") + # Quick eval for each suit + for suit_idx in range(4): + temp_obs = obs.copy() + # Set active suit + temp_obs[56:60] = 0 + temp_obs[56 + suit_idx] = 1.0 + temp_obs[60:73] = 0 # clear rank (wild) + obs_t = torch.tensor(temp_obs, dtype=torch.float32).unsqueeze(0) + mask = torch.ones(1, TOTAL_ACTIONS) # dummy + with torch.no_grad(): + _, v = model.forward(obs_t, mask) + print(f" {SUIT_NAMES[suit_idx]}: estimated value {v.item():.3f}") + + except (ValueError, IndexError) as e: + print(f" Error: {e}. Try again.") + except KeyboardInterrupt: + break + + print("Goodbye!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=True, help="Path to trained model .pt file") + parser.add_argument("--num_players", type=int, default=2) + args = parser.parse_args() + interactive_loop(args.model, args.num_players) diff --git a/train.py b/train.py new file mode 100644 index 0000000..96affb2 --- /dev/null +++ b/train.py @@ -0,0 +1,425 @@ +""" +PPO Self-Play Training for Blazing Eights. + +Architecture: + - Single policy network shared across all seats + - Self-play: all players use the same (latest) policy + - Collect trajectories by running full games + - Standard PPO update with masked invalid actions + +Usage: + python train.py --num_players 2 --episodes 100000 --save_path model.pt + python train.py --num_players 3 --episodes 200000 +""" + +import argparse +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions import Categorical +from collections import defaultdict + +from blazing_env import BlazingEightsEnv, TOTAL_ACTIONS, NUM_CARDS, DRAW_ACTION, PASS_ACTION + + +# --------------------------------------------------------------------------- +# Policy + Value Network +# --------------------------------------------------------------------------- +class PolicyValueNet(nn.Module): + def __init__(self, obs_size: int = 180, action_size: int = TOTAL_ACTIONS, hidden: int = 256): + super().__init__() + self.shared = nn.Sequential( + nn.Linear(obs_size, hidden), + nn.ReLU(), + nn.Linear(hidden, hidden), + nn.ReLU(), + ) + self.policy_head = nn.Sequential( + nn.Linear(hidden, hidden // 2), + nn.ReLU(), + nn.Linear(hidden // 2, action_size), + ) + self.value_head = nn.Sequential( + nn.Linear(hidden, hidden // 2), + nn.ReLU(), + nn.Linear(hidden // 2, 1), + ) + + def forward(self, obs: torch.Tensor, legal_mask: torch.Tensor): + """ + obs: (B, obs_size) + legal_mask: (B, action_size) — 1 for legal, 0 for illegal + Returns: logits (masked), value + """ + h = self.shared(obs) + logits = self.policy_head(h) + # Mask illegal actions with large negative + logits = logits + (1 - legal_mask) * (-1e9) + value = self.value_head(h).squeeze(-1) + return logits, value + + def get_action(self, obs: np.ndarray, legal_actions: list[int], device="cpu"): + """Sample an action from the policy.""" + obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0) + mask = torch.zeros(1, TOTAL_ACTIONS, device=device) + for a in legal_actions: + mask[0, a] = 1.0 + with torch.no_grad(): + logits, value = self.forward(obs_t, mask) + probs = F.softmax(logits, dim=-1) + dist = Categorical(probs) + action = dist.sample() + return action.item(), dist.log_prob(action).item(), value.item() + + def evaluate(self, obs_t: torch.Tensor, mask_t: torch.Tensor, actions_t: torch.Tensor): + """Evaluate actions for PPO update.""" + logits, values = self.forward(obs_t, mask_t) + probs = F.softmax(logits, dim=-1) + dist = Categorical(probs) + log_probs = dist.log_prob(actions_t) + entropy = dist.entropy() + return log_probs, values, entropy + + +# --------------------------------------------------------------------------- +# Trajectory Collection +# --------------------------------------------------------------------------- +class Transition: + __slots__ = ["obs", "action", "log_prob", "value", "reward", "done", "legal_mask"] + + def __init__(self, obs, action, log_prob, value, reward, done, legal_mask): + self.obs = obs + self.action = action + self.log_prob = log_prob + self.value = value + self.reward = reward + self.done = done + self.legal_mask = legal_mask + + +def collect_game(env: BlazingEightsEnv, model: PolicyValueNet, device="cpu"): + """ + Play one full game, return per-player trajectories. + All players use the same model (self-play). + """ + obs = env.reset() + trajectories: dict[int, list[Transition]] = defaultdict(list) + max_steps = 500 + + for _ in range(max_steps): + player = env.current_player + legal = env.legal_actions() + if not legal: + break + + action, log_prob, value = model.get_action(obs, legal, device) + + # Build legal mask + legal_mask = np.zeros(TOTAL_ACTIONS, dtype=np.float32) + for a in legal: + legal_mask[a] = 1.0 + + obs_next, rewards, done, info = env.step(action) + + # Store transition for the acting player + trajectories[player].append(Transition( + obs=obs.copy(), + action=action, + log_prob=log_prob, + value=value, + reward=rewards[player], + done=done, + legal_mask=legal_mask, + )) + + # If done, also assign terminal rewards to other players' last transitions + if done: + for p in range(env.num_players): + if p != player and trajectories[p]: + trajectories[p][-1].reward = rewards[p] + trajectories[p][-1].done = True + break + + obs = obs_next + + return trajectories + + +# --------------------------------------------------------------------------- +# PPO Update +# --------------------------------------------------------------------------- +def compute_gae(transitions: list[Transition], gamma=0.99, lam=0.95): + """Compute GAE returns and advantages.""" + T = len(transitions) + if T == 0: + return [], [] + + rewards = [t.reward for t in transitions] + values = [t.value for t in transitions] + dones = [t.done for t in transitions] + + advantages = np.zeros(T, dtype=np.float32) + last_gae = 0.0 + + for t in reversed(range(T)): + if t == T - 1 or dones[t]: + next_value = 0.0 + else: + next_value = values[t + 1] + delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t] + advantages[t] = last_gae = delta + gamma * lam * (1 - dones[t]) * last_gae + + returns = advantages + np.array(values) + return returns.tolist(), advantages.tolist() + + +def ppo_update(model: PolicyValueNet, optimizer: torch.optim.Optimizer, + all_transitions: list[Transition], device="cpu", + epochs=4, batch_size=256, clip_eps=0.2, vf_coef=0.5, ent_coef=0.01): + """PPO clipped surrogate update.""" + if not all_transitions: + return {} + + # Prepare tensors + obs_arr = np.array([t.obs for t in all_transitions]) + actions_arr = np.array([t.action for t in all_transitions]) + old_log_probs_arr = np.array([t.log_prob for t in all_transitions]) + masks_arr = np.array([t.legal_mask for t in all_transitions]) + + # Compute GAE (treat all transitions as one sequence — not ideal, but we + # already computed per-game, so we just concatenate pre-computed values) + returns_arr = np.array([t.reward for t in all_transitions]) # placeholder + advantages_arr = np.array([t.reward for t in all_transitions]) # placeholder + + obs_t = torch.tensor(obs_arr, dtype=torch.float32, device=device) + actions_t = torch.tensor(actions_arr, dtype=torch.long, device=device) + old_log_probs_t = torch.tensor(old_log_probs_arr, dtype=torch.float32, device=device) + masks_t = torch.tensor(masks_arr, dtype=torch.float32, device=device) + returns_t = torch.tensor(returns_arr, dtype=torch.float32, device=device) + advantages_t = torch.tensor(advantages_arr, dtype=torch.float32, device=device) + + # Normalize advantages + if len(advantages_t) > 1: + advantages_t = (advantages_t - advantages_t.mean()) / (advantages_t.std() + 1e-8) + + total_loss_sum = 0 + n_updates = 0 + + for _ in range(epochs): + indices = np.arange(len(all_transitions)) + np.random.shuffle(indices) + + for start in range(0, len(indices), batch_size): + end = min(start + batch_size, len(indices)) + idx = indices[start:end] + + b_obs = obs_t[idx] + b_actions = actions_t[idx] + b_old_lp = old_log_probs_t[idx] + b_masks = masks_t[idx] + b_returns = returns_t[idx] + b_advantages = advantages_t[idx] + + new_log_probs, values, entropy = model.evaluate(b_obs, b_masks, b_actions) + + ratio = torch.exp(new_log_probs - b_old_lp) + surr1 = ratio * b_advantages + surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * b_advantages + policy_loss = -torch.min(surr1, surr2).mean() + + value_loss = F.mse_loss(values, b_returns) + + loss = policy_loss + vf_coef * value_loss - ent_coef * entropy.mean() + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optimizer.step() + + total_loss_sum += loss.item() + n_updates += 1 + + return {"loss": total_loss_sum / max(n_updates, 1)} + + +# --------------------------------------------------------------------------- +# Training Loop +# --------------------------------------------------------------------------- +def train(args): + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Device: {device}") + print(f"Training for {args.num_players} players, {args.episodes} episodes") + + model = PolicyValueNet().to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + # Stats + win_counts = defaultdict(int) + game_lengths = [] + batch_transitions = [] + + for ep in range(1, args.episodes + 1): + env = BlazingEightsEnv(num_players=args.num_players) + trajectories = collect_game(env, model, device) + + # Record stats + if env.done: + win_counts[env.winner] += 1 + game_lengths.append(sum(len(v) for v in trajectories.values())) + + # Compute GAE per player and collect + for player, trans_list in trajectories.items(): + returns, advantages = compute_gae(trans_list, gamma=args.gamma, lam=args.lam) + for i, t in enumerate(trans_list): + t.reward = returns[i] if i < len(returns) else t.reward + # Store advantage in a hacky way: overwrite reward with return, + # and we'll use (return - value) as advantage in update + batch_transitions.extend(trans_list) + + # Update every `update_every` episodes + if ep % args.update_every == 0: + # Recompute advantages from stored returns and values + for t in batch_transitions: + pass # returns already in t.reward + + # Build proper advantages + for t in batch_transitions: + # t.reward is now the GAE return; advantage = return - value + t.reward = t.reward # this is the return + # We'll set the advantage in the update + # Actually, let's just pass returns and let update compute + returns_for_update = np.array([t.reward for t in batch_transitions]) + values_for_update = np.array([t.value for t in batch_transitions]) + advs = returns_for_update - values_for_update + + # Overwrite for the update function + obs_arr = np.array([t.obs for t in batch_transitions]) + actions_arr = np.array([t.action for t in batch_transitions]) + old_lp_arr = np.array([t.log_prob for t in batch_transitions]) + masks_arr = np.array([t.legal_mask for t in batch_transitions]) + + obs_t = torch.tensor(obs_arr, dtype=torch.float32, device=device) + actions_t = torch.tensor(actions_arr, dtype=torch.long, device=device) + old_lp_t = torch.tensor(old_lp_arr, dtype=torch.float32, device=device) + masks_t = torch.tensor(masks_arr, dtype=torch.float32, device=device) + returns_t = torch.tensor(returns_for_update, dtype=torch.float32, device=device) + advs_t = torch.tensor(advs, dtype=torch.float32, device=device) + + if len(advs_t) > 1: + advs_t = (advs_t - advs_t.mean()) / (advs_t.std() + 1e-8) + + # Manual PPO update + for _ in range(args.ppo_epochs): + indices = np.arange(len(batch_transitions)) + np.random.shuffle(indices) + for start in range(0, len(indices), args.batch_size): + end = min(start + args.batch_size, len(indices)) + idx = indices[start:end] + + b_obs = obs_t[idx] + b_actions = actions_t[idx] + b_old_lp = old_lp_t[idx] + b_masks = masks_t[idx] + b_returns = returns_t[idx] + b_advs = advs_t[idx] + + new_lp, values, entropy = model.evaluate(b_obs, b_masks, b_actions) + ratio = torch.exp(new_lp - b_old_lp) + surr1 = ratio * b_advs + surr2 = torch.clamp(ratio, 1 - args.clip_eps, 1 + args.clip_eps) * b_advs + policy_loss = -torch.min(surr1, surr2).mean() + value_loss = F.mse_loss(values, b_returns) + loss = policy_loss + 0.5 * value_loss - args.ent_coef * entropy.mean() + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optimizer.step() + + batch_transitions = [] + + # Logging + if ep % args.log_every == 0: + avg_len = np.mean(game_lengths[-args.log_every:]) if game_lengths else 0 + total_games = sum(win_counts.values()) + win_rates = {p: win_counts[p] / max(total_games, 1) for p in range(args.num_players)} + print(f"Episode {ep:>7d} | Avg game length: {avg_len:.1f} | " + f"Win rates: {win_rates} | " + f"Total games: {total_games}") + + # Save checkpoint + if ep % args.save_every == 0: + path = f"{args.save_path}_ep{ep}.pt" + torch.save({ + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "episode": ep, + "num_players": args.num_players, + }, path) + print(f" Saved checkpoint: {path}") + + # Final save + torch.save({ + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "episode": args.episodes, + "num_players": args.num_players, + }, f"{args.save_path}_final.pt") + print(f"Training complete. Final model saved to {args.save_path}_final.pt") + + return model + + +# --------------------------------------------------------------------------- +# Evaluation: play against random +# --------------------------------------------------------------------------- +def evaluate_vs_random(model: PolicyValueNet, num_players=2, num_games=1000, device="cpu"): + """Player 0 = model, others = random. Returns player 0 win rate.""" + wins = 0 + for _ in range(num_games): + env = BlazingEightsEnv(num_players=num_players) + obs = env.reset() + for _ in range(500): + player = env.current_player + legal = env.legal_actions() + if not legal: + break + if player == 0: + action, _, _ = model.get_action(obs, legal, device) + else: + action = np.random.choice(legal) + obs, rewards, done, info = env.step(action) + if done: + if env.winner == 0: + wins += 1 + break + return wins / num_games + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train Blazing Eights PPO agent") + parser.add_argument("--num_players", type=int, default=2) + parser.add_argument("--episodes", type=int, default=100000) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--lam", type=float, default=0.95) + parser.add_argument("--clip_eps", type=float, default=0.2) + parser.add_argument("--ent_coef", type=float, default=0.01) + parser.add_argument("--ppo_epochs", type=int, default=4) + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--update_every", type=int, default=64) + parser.add_argument("--log_every", type=int, default=1000) + parser.add_argument("--save_every", type=int, default=10000) + parser.add_argument("--save_path", type=str, default="blazing_ppo") + args = parser.parse_args() + + model = train(args) + + # Eval vs random + print("\nEvaluating vs random opponents...") + for n in [2, 3, 4, 5]: + if n <= args.num_players + 1: # only eval for trained player count + wr = evaluate_vs_random(model, num_players=n, num_games=1000) + print(f" {n} players: win rate = {wr:.1%} (random baseline: {1/n:.1%})") diff --git a/versus.py b/versus.py new file mode 100644 index 0000000..7225e29 --- /dev/null +++ b/versus.py @@ -0,0 +1,341 @@ +""" +Blazing Eights — Human vs AI interactive game. + +Play against the trained PPO agent in your terminal. + +Usage: + python versus.py --model blazing_ppo_final.pt + python versus.py --model blazing_ppo_final.pt --num_players 3 # you + 2 AI +""" + +import argparse +import torch +import torch.nn.functional as F +import numpy as np +from blazing_env import ( + BlazingEightsEnv, card_name, card_suit, card_rank, + is_swap, RANK_8, RANK_J, RANK_Q, RANK_K, + NUM_CARDS, TOTAL_ACTIONS, DRAW_ACTION, PASS_ACTION, +) +from train import PolicyValueNet + +SUIT_SYMBOLS = ["♠", "♥", "♦", "♣"] +SUIT_LETTERS = {"s": 0, "h": 1, "d": 2, "c": 3, + "♠": 0, "♥": 1, "♦": 2, "♣": 3} +RANK_NAMES = ["A", "2", "3", "4", "5", "6", "7", "8", "9", "10", "J", "Q", "K"] + + +def card_effect(c: int, num_players: int = 2) -> str: + """Return a short effect tag for special cards.""" + if is_swap(c): + return "\033[93m换牌\033[0m" + r = card_rank(c) + if r == RANK_8: + return "\033[93m万能\033[0m" + if r == RANK_K: + return "\033[93m全摸\033[0m" + if r == RANK_Q: + return "\033[93m反转\033[0m" if num_players > 2 else "" + if r == RANK_J: + return "\033[93m跳过\033[0m" + return "" + + +def pretty_card(c: int) -> str: + if is_swap(c): + return "\033[95mSWAP\033[0m" + suit = card_suit(c) + rank = RANK_NAMES[card_rank(c)] + colors = ["\033[37m", "\033[91m", "\033[91m", "\033[37m"] # ♠white ♥red ♦red ♣white + return f"{colors[suit]}{rank}{SUIT_SYMBOLS[suit]}\033[0m" + + +def pretty_hand(hand: list[int], num_players: int = 2) -> str: + sorted_hand = sorted(hand, key=lambda c: (card_suit(c) if not is_swap(c) else 99, c)) + parts = [] + for i, c in enumerate(sorted_hand): + effect = card_effect(c, num_players) + tag = f"[{i}] {pretty_card(c)}" + if effect: + tag += f"({effect})" + parts.append(tag) + return " ".join(parts) + + +def print_game_state(env: BlazingEightsEnv, human_player: int, show_ai_hand: bool = False): + print() + print("=" * 55) + top = env.discard[-1] + top_str = pretty_card(top) + if env.active_suit is not None: + top_str += f" (指定花色: {SUIT_SYMBOLS[env.active_suit]})" + dir_str = "顺时针 →" if env.direction == 1 else "逆时针 ←" + print(f" 弃牌堆顶: {top_str} 方向: {dir_str} 牌堆剩余: {len(env.deck)}") + print("-" * 55) + for i in range(env.num_players): + if i == human_player: + tag = "你" + hand_str = f"{len(env.hands[i])} 张牌" + else: + tag = f"AI-{i}" + if show_ai_hand: + hand_str = ", ".join(pretty_card(c) for c in sorted(env.hands[i])) + else: + hand_str = f"{len(env.hands[i])} 张牌" + arrow = " ◀" if i == env.current_player else "" + print(f" {tag}: {hand_str}{arrow}") + print("=" * 55) + + +def parse_card_input(s: str) -> int: + s = s.strip().upper() + if s.startswith("SWAP") or s.startswith("SW"): + return 52 + if s.startswith("10"): + rank_str, suit_str = "10", s[2:].lower() + else: + rank_str, suit_str = s[0], s[1:].lower() + rank_map = {r: i for i, r in enumerate(RANK_NAMES)} + if rank_str not in rank_map or suit_str not in SUIT_LETTERS: + raise ValueError(f"无法识别: {s} (格式例: 8h, Ks, 10d, Ac, SWAP)") + return SUIT_LETTERS[suit_str] * 13 + rank_map[rank_str] + + +def human_choose_action(env: BlazingEightsEnv, player: int) -> int: + hand = sorted(env.hands[player], key=lambda c: (card_suit(c) if not is_swap(c) else 99, c)) + legal = env.legal_actions(player) + + if env.phase == "choose_suit": + print("\n 你打出了 8!选择指定花色:") + for i, s in enumerate(SUIT_SYMBOLS): + print(f" [{i}] {s}") + while True: + try: + choice = input(" 选择 (0-3): ").strip() + idx = int(choice) + action = 56 + idx + if action in legal: + return action + print(" 无效选择,请重试") + except (ValueError, IndexError): + print(" 请输入 0-3") + return action + + print(f"\n 你的手牌: {pretty_hand(hand, env.num_players)}") + + # Build playable cards display + playable = [a for a in legal if a < NUM_CARDS] + can_draw = DRAW_ACTION in legal + can_pass = PASS_ACTION in legal + + print(" 可出的牌:", end="") + if playable: + playable_names = [] + for a in playable: + idx_in_hand = hand.index(a) + effect = card_effect(a, env.num_players) + tag = f"[{idx_in_hand}]{pretty_card(a)}" + if effect: + tag += f"({effect})" + playable_names.append(tag) + print(" " + " ".join(playable_names)) + else: + print(" 无") + + if can_draw: + print(" [d] 摸牌") + if can_pass: + print(" [p] 跳过 (牌堆与弃牌堆均已空)") + + while True: + choice = input(" 你的选择: ").strip().lower() + if choice == "d" and can_draw: + return DRAW_ACTION + if choice == "p" and can_pass: + return PASS_ACTION + if choice == "d" and not can_draw: + print(" 牌堆已空,无法摸牌") + continue + if choice == "p" and not can_pass: + print(" 还没摸牌,不能直接跳过") + continue + if choice == "q": + raise KeyboardInterrupt + try: + idx = int(choice) + if 0 <= idx < len(hand): + card = hand[idx] + if card in playable: + return card + # Handle swap cards (might have multiple) + if is_swap(card): + for a in playable: + if is_swap(a): + return a + print(f" {pretty_card(card)} 不能出,请选其他牌") + else: + print(f" 序号超出范围 (0-{len(hand)-1})") + except ValueError: + print(" 输入序号、d(摸牌) 或 q(退出)") + + +def ai_choose_action(env: BlazingEightsEnv, model: PolicyValueNet, player: int, device="cpu") -> int: + obs = env._get_obs(player) + legal = env.legal_actions(player) + action, _, value = model.get_action(obs, legal, device) + return action + + +def describe_action(player_name: str, action: int, env: BlazingEightsEnv, drawn_card: int = None): + if action == DRAW_ACTION: + return f" {player_name} 摸了一张牌" + if action == PASS_ACTION: + return f" {player_name} 跳过" + if action >= 56: + suit = action - 56 + return f" {player_name} 指定花色: {SUIT_SYMBOLS[suit]}" + desc = f" {player_name} 打出 {pretty_card(action)}" + rank = card_rank(action) + if is_swap(action): + desc += " → 与下家交换手牌!" + elif rank == RANK_8: + desc += " → 万能牌!选择花色..." + elif rank == RANK_K: + desc += " → 其他所有人各摸 1 张!" + elif rank == RANK_Q and env.num_players > 2: + desc += " → 反转方向!" + elif rank == RANK_J: + desc += " → 跳过下一位!" + return desc + + +def play_game(model_path: str, num_players: int, human_player: int = 0, show_ai: bool = False): + device = "cpu" + model = PolicyValueNet() + checkpoint = torch.load(model_path, map_location=device, weights_only=True) + model.load_state_dict(checkpoint["model"]) + model.eval() + + print() + print("╔══════════════════════════════════════╗") + print("║ Blazing Eights - 人机对战 ║") + print("╠══════════════════════════════════════╣") + print(f"║ 玩家数: {num_players} 你是: Player {human_player} ║") + print("║ 输入序号出牌, d摸牌, p跳过, q退出 ║") + print("╚══════════════════════════════════════╝") + + env = BlazingEightsEnv(num_players=num_players) + turn = 0 + + while not env.done: + player = env.current_player + turn += 1 + + if player == human_player: + print_game_state(env, human_player, show_ai_hand=show_ai) + try: + action = human_choose_action(env, player) + except KeyboardInterrupt: + print("\n\n 你退出了游戏。再见!") + return + + # Describe human action + name = "你" + if action == DRAW_ACTION: + # Remember hand before draw to find the new card + hand_before = set(env.hands[player]) + obs, rewards, done, info = env.step(action) + hand_after = set(env.hands[player]) + new_cards = hand_after - hand_before + if new_cards: + drawn = next(iter(new_cards)) + print(f" 你摸到了 {pretty_card(drawn)}") + else: + print(f" 牌堆已空,没摸到牌") + # Turn stays with human — loop back to let them decide + continue + elif action == PASS_ACTION: + print(f" 你选择不出牌,结束回合") + obs, rewards, done, info = env.step(action) + continue + else: + print(describe_action(name, action, env)) + obs, rewards, done, info = env.step(action) + # If played an 8, need to choose suit + if env.phase == "choose_suit" and env._pending_8_player == human_player: + suit_action = human_choose_action(env, human_player) + print(f" 你指定花色: {SUIT_SYMBOLS[suit_action - 56]}") + obs, rewards, done, info = env.step(suit_action) + continue + else: + # AI turn + ai_name = f"AI-{player}" + + if env.phase == "choose_suit": + action = ai_choose_action(env, model, player, device) + print(f" {ai_name} 指定花色: {SUIT_SYMBOLS[action - 56]}") + obs, rewards, done, info = env.step(action) + continue + + action = ai_choose_action(env, model, player, device) + + if action == DRAW_ACTION: + print(f" {ai_name} 摸了一张牌") + obs, rewards, done, info = env.step(action) + # AI still has their turn — now decide to play or pass + action2 = ai_choose_action(env, model, player, device) + if action2 == PASS_ACTION: + print(f" {ai_name} 选择不出牌") + obs, rewards, done, info = env.step(action2) + else: + print(describe_action(ai_name, action2, env)) + obs, rewards, done, info = env.step(action2) + if env.phase == "choose_suit" and env._pending_8_player == player: + suit_action = ai_choose_action(env, model, player, device) + print(f" {ai_name} 指定花色: {SUIT_SYMBOLS[suit_action - 56]}") + obs, rewards, done, info = env.step(suit_action) + elif action == PASS_ACTION: + print(f" {ai_name} 跳过") + obs, rewards, done, info = env.step(action) + else: + print(describe_action(ai_name, action, env)) + obs, rewards, done, info = env.step(action) + if env.phase == "choose_suit" and env._pending_8_player == player: + suit_action = ai_choose_action(env, model, player, device) + print(f" {ai_name} 指定花色: {SUIT_SYMBOLS[suit_action - 56]}") + obs, rewards, done, info = env.step(suit_action) + + # Game over + print_game_state(env, human_player, show_ai_hand=True) + print() + if env.winner == human_player: + print(" 🎉 你赢了!!!") + elif env.winner >= 0: + print(f" 💀 AI-{env.winner} 赢了...") + else: + print(" 平局(僵局)") + + # Show hand sizes + for i in range(env.num_players): + name = "你" if i == human_player else f"AI-{i}" + print(f" {name}: {len(env.hands[i])} 张剩余") + print() + + +def main(): + parser = argparse.ArgumentParser(description="Blazing Eights 人机对战") + parser.add_argument("--model", type=str, default="blazing_ppo_final.pt", help="模型路径") + parser.add_argument("--num_players", type=int, default=2, help="玩家总数 (2-5)") + parser.add_argument("--show_ai", action="store_true", help="显示 AI 手牌 (调试用)") + args = parser.parse_args() + + while True: + play_game(args.model, args.num_players, human_player=0, show_ai=args.show_ai) + again = input(" 再来一局? (y/n): ").strip().lower() + if again != "y": + print(" 下次再见!") + break + + +if __name__ == "__main__": + main() |
