summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhaoyuren <13851610112@163.com>2026-02-22 11:59:01 -0600
committerhaoyuren <13851610112@163.com>2026-02-22 11:59:01 -0600
commit7e15218730fe86b88ac0a53cc84bf929416a0687 (patch)
tree4eb5ec4a620d639a1880a24e8c826a1e3e491921
parent8392bdfc10f92e61303e39bb356522ee491ce97c (diff)
Fix invalid notebook cell schema (markdown with execution_count)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
-rw-r--r--train_colab.ipynb12
1 files changed, 7 insertions, 5 deletions
diff --git a/train_colab.ipynb b/train_colab.ipynb
index ac64e50..a591beb 100644
--- a/train_colab.ipynb
+++ b/train_colab.ipynb
@@ -67,14 +67,14 @@
{
"cell_type": "markdown",
"source": "## 3. Download model & logs",
- "metadata": {},
- "execution_count": null,
- "outputs": []
+ "metadata": {}
},
{
"cell_type": "code",
"metadata": {},
- "source": "from google.colab import files\nimport glob\n\n# Download final model(s)\nfor f in glob.glob(\"*_final.pt\"):\n print(f\"Downloading {f}...\")\n files.download(f)\n\n# Download training log(s)\nfor f in glob.glob(\"*_log.csv\"):\n print(f\"Downloading {f}...\")\n files.download(f)"
+ "source": "from google.colab import files\nimport glob\n\n# Download final model(s)\nfor f in glob.glob(\"*_final.pt\"):\n print(f\"Downloading {f}...\")\n files.download(f)\n\n# Download training log(s)\nfor f in glob.glob(\"*_log.csv\"):\n print(f\"Downloading {f}...\")\n files.download(f)",
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "markdown",
@@ -129,7 +129,9 @@
{
"cell_type": "code",
"metadata": {},
- "source": "import sys\nsys.path.insert(0, \".\")\nfrom train import PolicyValueNet, evaluate_vs_greedy_batch\n\ndevice = \"cpu\"\nmodel = PolicyValueNet().to(device)\n\nimport glob\nfinal_models = glob.glob(\"*_final.pt\") + glob.glob(\"models/*_final.pt\")\nif final_models:\n ckpt = torch.load(final_models[0], map_location=device, weights_only=True)\n model.load_state_dict(ckpt[\"model\"])\n model.eval()\n print(f\"Loaded: {final_models[0]}\")\n print(f\"Trained for {ckpt.get('episode', '?')} episodes\")\n print()\n\n for n in [2, 3, 4]:\n wr = evaluate_vs_greedy_batch(model, num_players=n, num_games=2000, device=device)\n print(f\" {n} players: win rate = {wr:.1%} (random baseline: {1/n:.1%})\")\nelse:\n print(\"No model found. Train first!\")"
+ "source": "import sys\nsys.path.insert(0, \".\")\nfrom train import PolicyValueNet, evaluate_vs_greedy_batch\n\ndevice = \"cpu\"\nmodel = PolicyValueNet().to(device)\n\nimport glob\nfinal_models = glob.glob(\"*_final.pt\") + glob.glob(\"models/*_final.pt\")\nif final_models:\n ckpt = torch.load(final_models[0], map_location=device, weights_only=True)\n model.load_state_dict(ckpt[\"model\"])\n model.eval()\n print(f\"Loaded: {final_models[0]}\")\n print(f\"Trained for {ckpt.get('episode', '?')} episodes\")\n print()\n\n for n in [2, 3, 4]:\n wr = evaluate_vs_greedy_batch(model, num_players=n, num_games=2000, device=device)\n print(f\" {n} players: win rate = {wr:.1%} (random baseline: {1/n:.1%})\")\nelse:\n print(\"No model found. Train first!\")",
+ "execution_count": null,
+ "outputs": []
}
]
} \ No newline at end of file