diff options
| author | haoyuren <13851610112@163.com> | 2026-02-22 11:28:45 -0600 |
|---|---|---|
| committer | haoyuren <13851610112@163.com> | 2026-02-22 11:28:45 -0600 |
| commit | 3887054e02e622ca2cb7878bc0dec63d28c7f223 (patch) | |
| tree | 1a341f7562abb41cfc25badde73879a4e914b1ee /train.py | |
| parent | 1cb5eb34ead9b4efc1032ec74c6ccc439f007c18 (diff) | |
Fix SWAP inheritance, stalemate logic, add greedy warmup
- SWAP now inherits previous card's suit/rank for matching
- Observation encodes effective top card when SWAP is on top
- Fix stalemate: only hard passes (can't draw) count, draw+pass resets
- Add behavioral cloning warmup: pre-train on greedy policy before PPO
- 2p win rate vs greedy random: 60.5%
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'train.py')
| -rw-r--r-- | train.py | 58 |
1 files changed, 58 insertions, 0 deletions
@@ -254,6 +254,57 @@ def ppo_update(model: PolicyValueNet, optimizer: torch.optim.Optimizer, # --------------------------------------------------------------------------- +# Greedy Warmup (Behavioral Cloning) +# --------------------------------------------------------------------------- +def greedy_warmup(model: PolicyValueNet, optimizer: torch.optim.Optimizer, + num_players: int, num_games: int = 2000, epochs: int = 5, + batch_size: int = 256, device: str = "cpu"): + """Pre-train the model to imitate greedy play (play if possible, else draw).""" + print(f"Greedy warmup: {num_games} games, {epochs} epochs...") + obs_list, action_list, mask_list = [], [], [] + + for _ in tqdm(range(num_games), desc="Collecting greedy data", unit="game"): + env = BlazingEightsEnv(num_players=num_players) + obs = env.reset() + for _ in range(500): + legal = env.legal_actions() + if not legal: + break + action = greedy_random_action(legal) + legal_mask = np.zeros(TOTAL_ACTIONS, dtype=np.float32) + for a in legal: + legal_mask[a] = 1.0 + obs_list.append(obs.copy()) + action_list.append(action) + mask_list.append(legal_mask) + obs, _, done, _ = env.step(action) + if done: + break + + obs_t = torch.tensor(np.array(obs_list), dtype=torch.float32, device=device) + act_t = torch.tensor(np.array(action_list), dtype=torch.long, device=device) + mask_t = torch.tensor(np.array(mask_list), dtype=torch.float32, device=device) + print(f" Collected {len(obs_list)} transitions") + + for epoch in range(epochs): + indices = np.arange(len(obs_list)) + np.random.shuffle(indices) + total_loss = 0 + n_batches = 0 + for start in range(0, len(indices), batch_size): + idx = indices[start:start + batch_size] + logits, _ = model(obs_t[idx], mask_t[idx]) + loss = F.cross_entropy(logits, act_t[idx]) + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optimizer.step() + total_loss += loss.item() + n_batches += 1 + print(f" Epoch {epoch+1}/{epochs}: loss={total_loss/n_batches:.4f}") + + +# --------------------------------------------------------------------------- # Training Loop # --------------------------------------------------------------------------- def train(args): @@ -264,6 +315,11 @@ def train(args): model = PolicyValueNet().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + # Greedy warmup: imitate greedy play before self-play + if args.greedy_warmup > 0: + greedy_warmup(model, optimizer, args.num_players, + num_games=args.greedy_warmup, device=device) + # Stats win_counts = defaultdict(int) game_lengths = [] @@ -421,6 +477,8 @@ if __name__ == "__main__": parser.add_argument("--log_every", type=int, default=1000) parser.add_argument("--save_every", type=int, default=10000) parser.add_argument("--save_path", type=str, default="blazing_ppo") + parser.add_argument("--greedy_warmup", type=int, default=2000, + help="Number of greedy games for behavioral cloning warmup (0 to skip)") args = parser.parse_args() model = train(args) |
