diff options
| -rw-r--r-- | blazing_env.py | 9 | ||||
| -rw-r--r-- | train.py | 10 |
2 files changed, 16 insertions, 3 deletions
diff --git a/blazing_env.py b/blazing_env.py index c3d97ae..c440293 100644 --- a/blazing_env.py +++ b/blazing_env.py @@ -15,7 +15,10 @@ Special cards: Rules: - Match top card by suit OR rank (unless playing 8 or Swap) - - Can't play → draw 1; if drawn card is playable, play it immediately + - Player may freely choose to draw even if they have playable cards + - After drawing, player may play any playable card OR pass (end turn) + - Each turn allows at most one draw + - If no playable cards and deck is empty, player must pass - First player to empty hand wins - Initial hand: 5 cards each """ @@ -92,7 +95,9 @@ class BlazingEightsEnv: self.rng = np.random.default_rng(seed) # Build & shuffle deck - deck = list(range(NUM_CARDS)) + # In 2-player games, remove Q cards (reverse has no effect) + deck = [c for c in range(NUM_CARDS) + if not (self.num_players == 2 and card_rank(c) == RANK_Q)] self.rng.shuffle(deck) # Deal 5 cards each @@ -100,6 +100,14 @@ class Transition: self.legal_mask = legal_mask +def greedy_random_action(legal: list[int]) -> int: + """Pick a random playable card; only draw/pass if no card to play.""" + play_actions = [a for a in legal if a < NUM_CARDS or (56 <= a <= 59)] + if play_actions: + return int(np.random.choice(play_actions)) + return int(np.random.choice(legal)) + + def collect_game(env: BlazingEightsEnv, model: PolicyValueNet, device="cpu"): """ Play one full game, return per-player trajectories. @@ -389,7 +397,7 @@ def evaluate_vs_random(model: PolicyValueNet, num_players=2, num_games=1000, dev if player == 0: action, _, _ = model.get_action(obs, legal, device) else: - action = np.random.choice(legal) + action = greedy_random_action(legal) obs, rewards, done, info = env.step(action) if done: if env.winner == 0: |
