summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--blazing_env.py41
-rw-r--r--train.py58
2 files changed, 85 insertions, 14 deletions
diff --git a/blazing_env.py b/blazing_env.py
index c440293..3f4b407 100644
--- a/blazing_env.py
+++ b/blazing_env.py
@@ -11,7 +11,7 @@ Special cards:
K → All OTHER players draw 1 card from the deck
Q → Reverse direction (no effect in 2-player games)
J → Skip next player's turn
- Swap → Swap entire hand with next player (playable anytime on your turn, no match needed)
+ Swap → Swap entire hand with next player (playable anytime; next card must match the card before the Swap)
Rules:
- Match top card by suit OR rank (unless playing 8 or Swap)
@@ -182,18 +182,19 @@ class BlazingEightsEnv:
for c in self.hands[player]:
obs[c] = 1.0
- # Top card info
+ # Top card info (SWAP inherits previous card)
top = self.discard[-1]
+ eff = self._effective_top() if is_swap(top) else top
if self.active_suit is not None:
suit = self.active_suit
- elif not is_swap(top):
- suit = card_suit(top)
+ elif not is_swap(eff):
+ suit = card_suit(eff)
else:
suit = 0
obs[56 + suit] = 1.0
- if not is_swap(top) and self.active_suit is None:
- obs[60 + card_rank(top)] = 1.0
+ if not is_swap(eff) and self.active_suit is None:
+ obs[60 + card_rank(eff)] = 1.0
# Direction
obs[73] = 0.0 if self.direction == 1 else 1.0
@@ -279,6 +280,13 @@ class BlazingEightsEnv:
actions.append(PASS_ACTION)
return actions
+ def _effective_top(self) -> int:
+ """Find the last non-SWAP card in discard for matching purposes."""
+ for c in reversed(self.discard):
+ if not is_swap(c):
+ return c
+ return self.discard[-1] # fallback (all swaps, shouldn't happen)
+
def _can_play(self, card: int, top: int) -> bool:
# Swap cards: always playable
if is_swap(card):
@@ -289,10 +297,11 @@ class BlazingEightsEnv:
# If active_suit is set (after a wild 8), must match that suit
if self.active_suit is not None:
return card_suit(card) == self.active_suit
- # Normal: match suit or rank
+ # SWAP on top: inherit previous non-SWAP card's suit/rank
if is_swap(top):
- # Top is swap — shouldn't happen in normal flow, but match anything
- return True
+ top = self._effective_top()
+ if is_swap(top):
+ return True # all swaps, match anything
return card_suit(card) == card_suit(top) or card_rank(card) == card_rank(top)
# ------------------------------------------------------------------
@@ -329,7 +338,6 @@ class BlazingEightsEnv:
# --- Play phase ---
if action == DRAW_ACTION:
- self.consecutive_passes = 0
drawn = self._draw_card(player)
self._record_event(player, 2) # drew a card
# Card is added to hand; player keeps their turn to decide
@@ -338,8 +346,12 @@ class BlazingEightsEnv:
obs = self._get_obs(player)
return obs, rewards, False, info
elif action == PASS_ACTION:
- # Skip turn (no cards anywhere)
- self.consecutive_passes += 1
+ if self.has_drawn_this_turn:
+ # Drew but chose not to play — game state changed, not stalemate
+ self.consecutive_passes = 0
+ else:
+ # Hard pass: can't draw and can't play — real stalemate signal
+ self.consecutive_passes += 1
self._record_event(player, 3) # passed
if self.consecutive_passes >= self.num_players:
# Stalemate: all players passed in a row
@@ -368,8 +380,9 @@ class BlazingEightsEnv:
hand.remove(card)
self.discard.append(card)
- # Clear active suit (unless new card is 8)
- self.active_suit = None
+ # Clear active suit (unless new card is 8 or SWAP — SWAP inherits)
+ if not is_swap(card):
+ self.active_suit = None
# Clear swap knowledge for this player (cards change over time)
# We keep it until they play; after playing, knowledge decays
diff --git a/train.py b/train.py
index 7f85267..d247cbb 100644
--- a/train.py
+++ b/train.py
@@ -254,6 +254,57 @@ def ppo_update(model: PolicyValueNet, optimizer: torch.optim.Optimizer,
# ---------------------------------------------------------------------------
+# Greedy Warmup (Behavioral Cloning)
+# ---------------------------------------------------------------------------
+def greedy_warmup(model: PolicyValueNet, optimizer: torch.optim.Optimizer,
+ num_players: int, num_games: int = 2000, epochs: int = 5,
+ batch_size: int = 256, device: str = "cpu"):
+ """Pre-train the model to imitate greedy play (play if possible, else draw)."""
+ print(f"Greedy warmup: {num_games} games, {epochs} epochs...")
+ obs_list, action_list, mask_list = [], [], []
+
+ for _ in tqdm(range(num_games), desc="Collecting greedy data", unit="game"):
+ env = BlazingEightsEnv(num_players=num_players)
+ obs = env.reset()
+ for _ in range(500):
+ legal = env.legal_actions()
+ if not legal:
+ break
+ action = greedy_random_action(legal)
+ legal_mask = np.zeros(TOTAL_ACTIONS, dtype=np.float32)
+ for a in legal:
+ legal_mask[a] = 1.0
+ obs_list.append(obs.copy())
+ action_list.append(action)
+ mask_list.append(legal_mask)
+ obs, _, done, _ = env.step(action)
+ if done:
+ break
+
+ obs_t = torch.tensor(np.array(obs_list), dtype=torch.float32, device=device)
+ act_t = torch.tensor(np.array(action_list), dtype=torch.long, device=device)
+ mask_t = torch.tensor(np.array(mask_list), dtype=torch.float32, device=device)
+ print(f" Collected {len(obs_list)} transitions")
+
+ for epoch in range(epochs):
+ indices = np.arange(len(obs_list))
+ np.random.shuffle(indices)
+ total_loss = 0
+ n_batches = 0
+ for start in range(0, len(indices), batch_size):
+ idx = indices[start:start + batch_size]
+ logits, _ = model(obs_t[idx], mask_t[idx])
+ loss = F.cross_entropy(logits, act_t[idx])
+ optimizer.zero_grad()
+ loss.backward()
+ nn.utils.clip_grad_norm_(model.parameters(), 0.5)
+ optimizer.step()
+ total_loss += loss.item()
+ n_batches += 1
+ print(f" Epoch {epoch+1}/{epochs}: loss={total_loss/n_batches:.4f}")
+
+
+# ---------------------------------------------------------------------------
# Training Loop
# ---------------------------------------------------------------------------
def train(args):
@@ -264,6 +315,11 @@ def train(args):
model = PolicyValueNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+ # Greedy warmup: imitate greedy play before self-play
+ if args.greedy_warmup > 0:
+ greedy_warmup(model, optimizer, args.num_players,
+ num_games=args.greedy_warmup, device=device)
+
# Stats
win_counts = defaultdict(int)
game_lengths = []
@@ -421,6 +477,8 @@ if __name__ == "__main__":
parser.add_argument("--log_every", type=int, default=1000)
parser.add_argument("--save_every", type=int, default=10000)
parser.add_argument("--save_path", type=str, default="blazing_ppo")
+ parser.add_argument("--greedy_warmup", type=int, default=2000,
+ help="Number of greedy games for behavioral cloning warmup (0 to skip)")
args = parser.parse_args()
model = train(args)