summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorhaoyuren <13851610112@163.com>2026-02-22 01:57:04 -0600
committerhaoyuren <13851610112@163.com>2026-02-22 01:57:04 -0600
commit480913b234ecf6147666bce641cecbaaeadd408a (patch)
tree2181d625cc756cccea89399993dfe436cc70cc72 /train.py
parent60e5072dcc654322e050f54ae4b789550e6aa40a (diff)
Add tqdm progress bar, fix Colab username
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'train.py')
-rw-r--r--train.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/train.py b/train.py
index 96affb2..e955c09 100644
--- a/train.py
+++ b/train.py
@@ -20,6 +20,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from collections import defaultdict
+from tqdm import tqdm
from blazing_env import BlazingEightsEnv, TOTAL_ACTIONS, NUM_CARDS, DRAW_ACTION, PASS_ACTION
@@ -260,7 +261,8 @@ def train(args):
game_lengths = []
batch_transitions = []
- for ep in range(1, args.episodes + 1):
+ pbar = tqdm(range(1, args.episodes + 1), desc="Training", unit="ep")
+ for ep in pbar:
env = BlazingEightsEnv(num_players=args.num_players)
trajectories = collect_game(env, model, device)
@@ -344,10 +346,8 @@ def train(args):
if ep % args.log_every == 0:
avg_len = np.mean(game_lengths[-args.log_every:]) if game_lengths else 0
total_games = sum(win_counts.values())
- win_rates = {p: win_counts[p] / max(total_games, 1) for p in range(args.num_players)}
- print(f"Episode {ep:>7d} | Avg game length: {avg_len:.1f} | "
- f"Win rates: {win_rates} | "
- f"Total games: {total_games}")
+ wr0 = win_counts[0] / max(total_games, 1)
+ pbar.set_postfix(avg_len=f"{avg_len:.1f}", wr0=f"{wr0:.1%}", games=total_games)
# Save checkpoint
if ep % args.save_every == 0:
@@ -358,7 +358,7 @@ def train(args):
"episode": ep,
"num_players": args.num_players,
}, path)
- print(f" Saved checkpoint: {path}")
+ tqdm.write(f" Saved checkpoint: {path}")
# Final save
torch.save({