summaryrefslogtreecommitdiff
path: root/train_colab.ipynb
diff options
context:
space:
mode:
authorhaoyuren <13851610112@163.com>2026-02-22 11:56:48 -0600
committerhaoyuren <13851610112@163.com>2026-02-22 11:56:48 -0600
commit8392bdfc10f92e61303e39bb356522ee491ce97c (patch)
tree91359c42773bf9d5fb8f1d76d393743ff4a55387 /train_colab.ipynb
parent6f7034fabbfbe27197765f335bdcc64ec8c8c85f (diff)
Batched game collection for ~7x training speedup
- collect_games_batch(): run N games in parallel with single batched forward pass per step - evaluate_vs_greedy_batch(): batched evaluation replacing sequential eval - Add --collect_batch CLI arg for configurable parallel game count - Use torch.inference_mode() for faster collection - Update Colab notebook: GPU info, --collect_batch, log download cell Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'train_colab.ipynb')
-rw-r--r--train_colab.ipynb54
1 files changed, 19 insertions, 35 deletions
diff --git a/train_colab.ipynb b/train_colab.ipynb
index afae94d..ac64e50 100644
--- a/train_colab.ipynb
+++ b/train_colab.ipynb
@@ -15,7 +15,7 @@
{
"cell_type": "markdown",
"metadata": {},
- "source": "# Blazing Eights - Colab GPU Training\n\nClone repo → Train PPO agent (CPU collect, GPU update) → Push trained model back to GitHub\n\n**Game**: UNO variant with custom special cards (8=Wild, K=All draw, J=Skip, Swap=Swap hands)."
+ "source": "# Blazing Eights - Colab GPU Training\n\nClone repo → Train PPO agent (batched collection on CPU, PPO on GPU) → Download model & logs\n\n**Game**: UNO variant with custom special cards (8=Wild, K=All draw, J=Skip, Swap=Swap hands)."
},
{
"cell_type": "markdown",
@@ -34,61 +34,49 @@
{
"cell_type": "code",
"metadata": {},
- "source": [
- "import torch\n",
- "print(f\"PyTorch: {torch.__version__}\")\n",
- "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
- "if torch.cuda.is_available():\n",
- " print(f\"GPU: {torch.cuda.get_device_name(0)}\")"
- ],
+ "source": "import torch\nprint(f\"PyTorch: {torch.__version__}\")\nprint(f\"CUDA available: {torch.cuda.is_available()}\")\nif torch.cuda.is_available():\n print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n print(f\"Memory: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB\")\nelse:\n print(\"WARNING: No GPU detected. Go to Runtime → Change runtime type → GPU\")",
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {},
- "source": [
- "## 2. Train"
- ]
+ "source": "## 2. Train\n\nBatched collection: runs many games in parallel with a single forward pass per step.\n- `--collect_batch`: number of parallel games (higher = faster, more VRAM). Default = 64.\n- Game simulation on CPU, PPO gradient updates on GPU (auto-detected)."
},
{
"cell_type": "code",
"metadata": {},
- "source": "# 2-player training with greedy warmup + CSV logging\n# Game simulation on CPU, PPO updates on GPU automatically\n!python train.py --num_players 2 --episodes 200000 --save_path blazing_ppo_2p\n\n# Show training log\nimport pandas as pd\ndf = pd.read_csv(\"blazing_ppo_2p_log.csv\")\nprint(df.to_string(index=False))",
+ "source": "# 2-player training (GPU PPO + batched collection)\n!python train.py \\\n --num_players 2 \\\n --episodes 200000 \\\n --collect_batch 128 \\\n --save_path blazing_ppo_2p",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {},
- "source": "# (Optional) 3-player training\n# !python train.py --num_players 3 --episodes 300000 --save_path blazing_ppo_3p\n\n# (Optional) Skip greedy warmup\n# !python train.py --num_players 2 --episodes 200000 --greedy_warmup 0 --save_path blazing_ppo_2p_no_warmup",
+ "source": "# (Optional) 3-player training\n# !python train.py --num_players 3 --episodes 300000 --collect_batch 128 --save_path blazing_ppo_3p\n\n# (Optional) Larger batch for faster throughput\n# !python train.py --num_players 2 --episodes 200000 --collect_batch 256 --save_path blazing_ppo_2p\n\n# (Optional) Skip greedy warmup\n# !python train.py --num_players 2 --episodes 200000 --greedy_warmup 0 --save_path blazing_ppo_2p_no_warmup",
"execution_count": null,
"outputs": []
},
{
- "cell_type": "markdown",
+ "cell_type": "code",
+ "source": "# Show training log\nimport pandas as pd\ndf = pd.read_csv(\"blazing_ppo_2p_log.csv\")\nprint(df.to_string(index=False))",
"metadata": {},
- "source": [
- "## 3. Download model locally (Option A)\n",
- "Download .pt files directly from Colab to your machine."
- ]
+ "execution_count": null,
+ "outputs": []
},
{
- "cell_type": "code",
+ "cell_type": "markdown",
+ "source": "## 3. Download model & logs",
"metadata": {},
- "source": [
- "from google.colab import files\n",
- "import glob\n",
- "\n",
- "# Download the final model\n",
- "for f in glob.glob(\"*_final.pt\"):\n",
- " print(f\"Downloading {f}...\")\n",
- " files.download(f)"
- ],
"execution_count": null,
"outputs": []
},
{
+ "cell_type": "code",
+ "metadata": {},
+ "source": "from google.colab import files\nimport glob\n\n# Download final model(s)\nfor f in glob.glob(\"*_final.pt\"):\n print(f\"Downloading {f}...\")\n files.download(f)\n\n# Download training log(s)\nfor f in glob.glob(\"*_log.csv\"):\n print(f\"Downloading {f}...\")\n files.download(f)"
+ },
+ {
"cell_type": "markdown",
"metadata": {},
"source": [
@@ -135,17 +123,13 @@
},
{
"cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 5. Quick evaluation"
- ]
+ "source": "## 5. Quick evaluation",
+ "metadata": {}
},
{
"cell_type": "code",
"metadata": {},
- "source": "import sys\nsys.path.insert(0, \".\")\nfrom train import PolicyValueNet, evaluate_vs_random\n\ndevice = \"cpu\" # eval on CPU (single-sample inference)\nmodel = PolicyValueNet().to(device)\n\nimport glob\nfinal_models = glob.glob(\"*_final.pt\") + glob.glob(\"models/*_final.pt\")\nif final_models:\n ckpt = torch.load(final_models[0], map_location=device, weights_only=True)\n model.load_state_dict(ckpt[\"model\"])\n model.eval()\n print(f\"Loaded: {final_models[0]}\")\n print(f\"Trained for {ckpt.get('episode', '?')} episodes\")\n print()\n\n for n in [2, 3, 4]:\n wr = evaluate_vs_random(model, num_players=n, num_games=2000, device=device)\n print(f\" {n} players: win rate = {wr:.1%} (random baseline: {1/n:.1%})\")\nelse:\n print(\"No model found. Train first!\")",
- "execution_count": null,
- "outputs": []
+ "source": "import sys\nsys.path.insert(0, \".\")\nfrom train import PolicyValueNet, evaluate_vs_greedy_batch\n\ndevice = \"cpu\"\nmodel = PolicyValueNet().to(device)\n\nimport glob\nfinal_models = glob.glob(\"*_final.pt\") + glob.glob(\"models/*_final.pt\")\nif final_models:\n ckpt = torch.load(final_models[0], map_location=device, weights_only=True)\n model.load_state_dict(ckpt[\"model\"])\n model.eval()\n print(f\"Loaded: {final_models[0]}\")\n print(f\"Trained for {ckpt.get('episode', '?')} episodes\")\n print()\n\n for n in [2, 3, 4]:\n wr = evaluate_vs_greedy_batch(model, num_players=n, num_games=2000, device=device)\n print(f\" {n} players: win rate = {wr:.1%} (random baseline: {1/n:.1%})\")\nelse:\n print(\"No model found. Train first!\")"
}
]
} \ No newline at end of file