summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorhaoyuren <13851610112@163.com>2026-02-22 11:28:45 -0600
committerhaoyuren <13851610112@163.com>2026-02-22 11:28:45 -0600
commit3887054e02e622ca2cb7878bc0dec63d28c7f223 (patch)
tree1a341f7562abb41cfc25badde73879a4e914b1ee /train.py
parent1cb5eb34ead9b4efc1032ec74c6ccc439f007c18 (diff)
Fix SWAP inheritance, stalemate logic, add greedy warmup
- SWAP now inherits previous card's suit/rank for matching - Observation encodes effective top card when SWAP is on top - Fix stalemate: only hard passes (can't draw) count, draw+pass resets - Add behavioral cloning warmup: pre-train on greedy policy before PPO - 2p win rate vs greedy random: 60.5% Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'train.py')
-rw-r--r--train.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/train.py b/train.py
index 7f85267..d247cbb 100644
--- a/train.py
+++ b/train.py
@@ -254,6 +254,57 @@ def ppo_update(model: PolicyValueNet, optimizer: torch.optim.Optimizer,
# ---------------------------------------------------------------------------
+# Greedy Warmup (Behavioral Cloning)
+# ---------------------------------------------------------------------------
+def greedy_warmup(model: PolicyValueNet, optimizer: torch.optim.Optimizer,
+ num_players: int, num_games: int = 2000, epochs: int = 5,
+ batch_size: int = 256, device: str = "cpu"):
+ """Pre-train the model to imitate greedy play (play if possible, else draw)."""
+ print(f"Greedy warmup: {num_games} games, {epochs} epochs...")
+ obs_list, action_list, mask_list = [], [], []
+
+ for _ in tqdm(range(num_games), desc="Collecting greedy data", unit="game"):
+ env = BlazingEightsEnv(num_players=num_players)
+ obs = env.reset()
+ for _ in range(500):
+ legal = env.legal_actions()
+ if not legal:
+ break
+ action = greedy_random_action(legal)
+ legal_mask = np.zeros(TOTAL_ACTIONS, dtype=np.float32)
+ for a in legal:
+ legal_mask[a] = 1.0
+ obs_list.append(obs.copy())
+ action_list.append(action)
+ mask_list.append(legal_mask)
+ obs, _, done, _ = env.step(action)
+ if done:
+ break
+
+ obs_t = torch.tensor(np.array(obs_list), dtype=torch.float32, device=device)
+ act_t = torch.tensor(np.array(action_list), dtype=torch.long, device=device)
+ mask_t = torch.tensor(np.array(mask_list), dtype=torch.float32, device=device)
+ print(f" Collected {len(obs_list)} transitions")
+
+ for epoch in range(epochs):
+ indices = np.arange(len(obs_list))
+ np.random.shuffle(indices)
+ total_loss = 0
+ n_batches = 0
+ for start in range(0, len(indices), batch_size):
+ idx = indices[start:start + batch_size]
+ logits, _ = model(obs_t[idx], mask_t[idx])
+ loss = F.cross_entropy(logits, act_t[idx])
+ optimizer.zero_grad()
+ loss.backward()
+ nn.utils.clip_grad_norm_(model.parameters(), 0.5)
+ optimizer.step()
+ total_loss += loss.item()
+ n_batches += 1
+ print(f" Epoch {epoch+1}/{epochs}: loss={total_loss/n_batches:.4f}")
+
+
+# ---------------------------------------------------------------------------
# Training Loop
# ---------------------------------------------------------------------------
def train(args):
@@ -264,6 +315,11 @@ def train(args):
model = PolicyValueNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+ # Greedy warmup: imitate greedy play before self-play
+ if args.greedy_warmup > 0:
+ greedy_warmup(model, optimizer, args.num_players,
+ num_games=args.greedy_warmup, device=device)
+
# Stats
win_counts = defaultdict(int)
game_lengths = []
@@ -421,6 +477,8 @@ if __name__ == "__main__":
parser.add_argument("--log_every", type=int, default=1000)
parser.add_argument("--save_every", type=int, default=10000)
parser.add_argument("--save_path", type=str, default="blazing_ppo")
+ parser.add_argument("--greedy_warmup", type=int, default=2000,
+ help="Number of greedy games for behavioral cloning warmup (0 to skip)")
args = parser.parse_args()
model = train(args)