diff options
| author | haoyuren <13851610112@163.com> | 2026-02-22 01:57:04 -0600 |
|---|---|---|
| committer | haoyuren <13851610112@163.com> | 2026-02-22 01:57:04 -0600 |
| commit | 480913b234ecf6147666bce641cecbaaeadd408a (patch) | |
| tree | 2181d625cc756cccea89399993dfe436cc70cc72 | |
| parent | 60e5072dcc654322e050f54ae4b789550e6aa40a (diff) | |
Add tqdm progress bar, fix Colab username
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| -rw-r--r-- | train.py | 12 | ||||
| -rw-r--r-- | train_colab.ipynb | 13 |
2 files changed, 8 insertions, 17 deletions
@@ -20,6 +20,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributions import Categorical from collections import defaultdict +from tqdm import tqdm from blazing_env import BlazingEightsEnv, TOTAL_ACTIONS, NUM_CARDS, DRAW_ACTION, PASS_ACTION @@ -260,7 +261,8 @@ def train(args): game_lengths = [] batch_transitions = [] - for ep in range(1, args.episodes + 1): + pbar = tqdm(range(1, args.episodes + 1), desc="Training", unit="ep") + for ep in pbar: env = BlazingEightsEnv(num_players=args.num_players) trajectories = collect_game(env, model, device) @@ -344,10 +346,8 @@ def train(args): if ep % args.log_every == 0: avg_len = np.mean(game_lengths[-args.log_every:]) if game_lengths else 0 total_games = sum(win_counts.values()) - win_rates = {p: win_counts[p] / max(total_games, 1) for p in range(args.num_players)} - print(f"Episode {ep:>7d} | Avg game length: {avg_len:.1f} | " - f"Win rates: {win_rates} | " - f"Total games: {total_games}") + wr0 = win_counts[0] / max(total_games, 1) + pbar.set_postfix(avg_len=f"{avg_len:.1f}", wr0=f"{wr0:.1%}", games=total_games) # Save checkpoint if ep % args.save_every == 0: @@ -358,7 +358,7 @@ def train(args): "episode": ep, "num_players": args.num_players, }, path) - print(f" Saved checkpoint: {path}") + tqdm.write(f" Saved checkpoint: {path}") # Final save torch.save({ diff --git a/train_colab.ipynb b/train_colab.ipynb index 4856b26..24ae0dc 100644 --- a/train_colab.ipynb +++ b/train_colab.ipynb @@ -31,16 +31,7 @@ { "cell_type": "code", "metadata": {}, - "source": [ - "# ====== CONFIG ======\n", - "GITHUB_USERNAME = \"haoyuren\" # <-- your GitHub username\n", - "REPO_NAME = \"blazing8\"\n", - "# ====================\n", - "\n", - "!git clone https://github.com/{GITHUB_USERNAME}/{REPO_NAME}.git\n", - "%cd {REPO_NAME}\n", - "!pip install -q torch numpy" - ], + "source": "# ====== CONFIG ======\nGITHUB_USERNAME = \"YurenHao0426\"\nREPO_NAME = \"blazing8\"\n# ====================\n\n!git clone https://github.com/{GITHUB_USERNAME}/{REPO_NAME}.git\n%cd {REPO_NAME}\n!pip install -q torch numpy tqdm", "execution_count": null, "outputs": [] }, @@ -201,4 +192,4 @@ "outputs": [] } ] -} +}
\ No newline at end of file |
