summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhaoyuren <13851610112@163.com>2026-02-22 01:57:04 -0600
committerhaoyuren <13851610112@163.com>2026-02-22 01:57:04 -0600
commit480913b234ecf6147666bce641cecbaaeadd408a (patch)
tree2181d625cc756cccea89399993dfe436cc70cc72
parent60e5072dcc654322e050f54ae4b789550e6aa40a (diff)
Add tqdm progress bar, fix Colab username
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
-rw-r--r--train.py12
-rw-r--r--train_colab.ipynb13
2 files changed, 8 insertions, 17 deletions
diff --git a/train.py b/train.py
index 96affb2..e955c09 100644
--- a/train.py
+++ b/train.py
@@ -20,6 +20,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from collections import defaultdict
+from tqdm import tqdm
from blazing_env import BlazingEightsEnv, TOTAL_ACTIONS, NUM_CARDS, DRAW_ACTION, PASS_ACTION
@@ -260,7 +261,8 @@ def train(args):
game_lengths = []
batch_transitions = []
- for ep in range(1, args.episodes + 1):
+ pbar = tqdm(range(1, args.episodes + 1), desc="Training", unit="ep")
+ for ep in pbar:
env = BlazingEightsEnv(num_players=args.num_players)
trajectories = collect_game(env, model, device)
@@ -344,10 +346,8 @@ def train(args):
if ep % args.log_every == 0:
avg_len = np.mean(game_lengths[-args.log_every:]) if game_lengths else 0
total_games = sum(win_counts.values())
- win_rates = {p: win_counts[p] / max(total_games, 1) for p in range(args.num_players)}
- print(f"Episode {ep:>7d} | Avg game length: {avg_len:.1f} | "
- f"Win rates: {win_rates} | "
- f"Total games: {total_games}")
+ wr0 = win_counts[0] / max(total_games, 1)
+ pbar.set_postfix(avg_len=f"{avg_len:.1f}", wr0=f"{wr0:.1%}", games=total_games)
# Save checkpoint
if ep % args.save_every == 0:
@@ -358,7 +358,7 @@ def train(args):
"episode": ep,
"num_players": args.num_players,
}, path)
- print(f" Saved checkpoint: {path}")
+ tqdm.write(f" Saved checkpoint: {path}")
# Final save
torch.save({
diff --git a/train_colab.ipynb b/train_colab.ipynb
index 4856b26..24ae0dc 100644
--- a/train_colab.ipynb
+++ b/train_colab.ipynb
@@ -31,16 +31,7 @@
{
"cell_type": "code",
"metadata": {},
- "source": [
- "# ====== CONFIG ======\n",
- "GITHUB_USERNAME = \"haoyuren\" # <-- your GitHub username\n",
- "REPO_NAME = \"blazing8\"\n",
- "# ====================\n",
- "\n",
- "!git clone https://github.com/{GITHUB_USERNAME}/{REPO_NAME}.git\n",
- "%cd {REPO_NAME}\n",
- "!pip install -q torch numpy"
- ],
+ "source": "# ====== CONFIG ======\nGITHUB_USERNAME = \"YurenHao0426\"\nREPO_NAME = \"blazing8\"\n# ====================\n\n!git clone https://github.com/{GITHUB_USERNAME}/{REPO_NAME}.git\n%cd {REPO_NAME}\n!pip install -q torch numpy tqdm",
"execution_count": null,
"outputs": []
},
@@ -201,4 +192,4 @@
"outputs": []
}
]
-}
+} \ No newline at end of file