diff options
Diffstat (limited to 'train.py')
| -rw-r--r-- | train.py | 12 |
1 files changed, 6 insertions, 6 deletions
@@ -20,6 +20,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributions import Categorical from collections import defaultdict +from tqdm import tqdm from blazing_env import BlazingEightsEnv, TOTAL_ACTIONS, NUM_CARDS, DRAW_ACTION, PASS_ACTION @@ -260,7 +261,8 @@ def train(args): game_lengths = [] batch_transitions = [] - for ep in range(1, args.episodes + 1): + pbar = tqdm(range(1, args.episodes + 1), desc="Training", unit="ep") + for ep in pbar: env = BlazingEightsEnv(num_players=args.num_players) trajectories = collect_game(env, model, device) @@ -344,10 +346,8 @@ def train(args): if ep % args.log_every == 0: avg_len = np.mean(game_lengths[-args.log_every:]) if game_lengths else 0 total_games = sum(win_counts.values()) - win_rates = {p: win_counts[p] / max(total_games, 1) for p in range(args.num_players)} - print(f"Episode {ep:>7d} | Avg game length: {avg_len:.1f} | " - f"Win rates: {win_rates} | " - f"Total games: {total_games}") + wr0 = win_counts[0] / max(total_games, 1) + pbar.set_postfix(avg_len=f"{avg_len:.1f}", wr0=f"{wr0:.1%}", games=total_games) # Save checkpoint if ep % args.save_every == 0: @@ -358,7 +358,7 @@ def train(args): "episode": ep, "num_players": args.num_players, }, path) - print(f" Saved checkpoint: {path}") + tqdm.write(f" Saved checkpoint: {path}") # Final save torch.save({ |
