From 480913b234ecf6147666bce641cecbaaeadd408a Mon Sep 17 00:00:00 2001 From: haoyuren <13851610112@163.com> Date: Sun, 22 Feb 2026 01:57:04 -0600 Subject: Add tqdm progress bar, fix Colab username Co-Authored-By: Claude Opus 4.6 --- train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'train.py') diff --git a/train.py b/train.py index 96affb2..e955c09 100644 --- a/train.py +++ b/train.py @@ -20,6 +20,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributions import Categorical from collections import defaultdict +from tqdm import tqdm from blazing_env import BlazingEightsEnv, TOTAL_ACTIONS, NUM_CARDS, DRAW_ACTION, PASS_ACTION @@ -260,7 +261,8 @@ def train(args): game_lengths = [] batch_transitions = [] - for ep in range(1, args.episodes + 1): + pbar = tqdm(range(1, args.episodes + 1), desc="Training", unit="ep") + for ep in pbar: env = BlazingEightsEnv(num_players=args.num_players) trajectories = collect_game(env, model, device) @@ -344,10 +346,8 @@ def train(args): if ep % args.log_every == 0: avg_len = np.mean(game_lengths[-args.log_every:]) if game_lengths else 0 total_games = sum(win_counts.values()) - win_rates = {p: win_counts[p] / max(total_games, 1) for p in range(args.num_players)} - print(f"Episode {ep:>7d} | Avg game length: {avg_len:.1f} | " - f"Win rates: {win_rates} | " - f"Total games: {total_games}") + wr0 = win_counts[0] / max(total_games, 1) + pbar.set_postfix(avg_len=f"{avg_len:.1f}", wr0=f"{wr0:.1%}", games=total_games) # Save checkpoint if ep % args.save_every == 0: @@ -358,7 +358,7 @@ def train(args): "episode": ep, "num_players": args.num_players, }, path) - print(f" Saved checkpoint: {path}") + tqdm.write(f" Saved checkpoint: {path}") # Final save torch.save({ -- cgit v1.2.3