diff options
Diffstat (limited to 'train.py')
| -rw-r--r-- | train.py | 10 |
1 files changed, 9 insertions, 1 deletions
@@ -100,6 +100,14 @@ class Transition: self.legal_mask = legal_mask +def greedy_random_action(legal: list[int]) -> int: + """Pick a random playable card; only draw/pass if no card to play.""" + play_actions = [a for a in legal if a < NUM_CARDS or (56 <= a <= 59)] + if play_actions: + return int(np.random.choice(play_actions)) + return int(np.random.choice(legal)) + + def collect_game(env: BlazingEightsEnv, model: PolicyValueNet, device="cpu"): """ Play one full game, return per-player trajectories. @@ -389,7 +397,7 @@ def evaluate_vs_random(model: PolicyValueNet, num_players=2, num_games=1000, dev if player == 0: action, _, _ = model.get_action(obs, legal, device) else: - action = np.random.choice(legal) + action = greedy_random_action(legal) obs, rewards, done, info = env.step(action) if done: if env.winner == 0: |
