summaryrefslogtreecommitdiff
path: root/train_colab.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'train_colab.ipynb')
-rw-r--r--train_colab.ipynb12
1 files changed, 7 insertions, 5 deletions
diff --git a/train_colab.ipynb b/train_colab.ipynb
index ac64e50..a591beb 100644
--- a/train_colab.ipynb
+++ b/train_colab.ipynb
@@ -67,14 +67,14 @@
{
"cell_type": "markdown",
"source": "## 3. Download model & logs",
- "metadata": {},
- "execution_count": null,
- "outputs": []
+ "metadata": {}
},
{
"cell_type": "code",
"metadata": {},
- "source": "from google.colab import files\nimport glob\n\n# Download final model(s)\nfor f in glob.glob(\"*_final.pt\"):\n print(f\"Downloading {f}...\")\n files.download(f)\n\n# Download training log(s)\nfor f in glob.glob(\"*_log.csv\"):\n print(f\"Downloading {f}...\")\n files.download(f)"
+ "source": "from google.colab import files\nimport glob\n\n# Download final model(s)\nfor f in glob.glob(\"*_final.pt\"):\n print(f\"Downloading {f}...\")\n files.download(f)\n\n# Download training log(s)\nfor f in glob.glob(\"*_log.csv\"):\n print(f\"Downloading {f}...\")\n files.download(f)",
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "markdown",
@@ -129,7 +129,9 @@
{
"cell_type": "code",
"metadata": {},
- "source": "import sys\nsys.path.insert(0, \".\")\nfrom train import PolicyValueNet, evaluate_vs_greedy_batch\n\ndevice = \"cpu\"\nmodel = PolicyValueNet().to(device)\n\nimport glob\nfinal_models = glob.glob(\"*_final.pt\") + glob.glob(\"models/*_final.pt\")\nif final_models:\n ckpt = torch.load(final_models[0], map_location=device, weights_only=True)\n model.load_state_dict(ckpt[\"model\"])\n model.eval()\n print(f\"Loaded: {final_models[0]}\")\n print(f\"Trained for {ckpt.get('episode', '?')} episodes\")\n print()\n\n for n in [2, 3, 4]:\n wr = evaluate_vs_greedy_batch(model, num_players=n, num_games=2000, device=device)\n print(f\" {n} players: win rate = {wr:.1%} (random baseline: {1/n:.1%})\")\nelse:\n print(\"No model found. Train first!\")"
+ "source": "import sys\nsys.path.insert(0, \".\")\nfrom train import PolicyValueNet, evaluate_vs_greedy_batch\n\ndevice = \"cpu\"\nmodel = PolicyValueNet().to(device)\n\nimport glob\nfinal_models = glob.glob(\"*_final.pt\") + glob.glob(\"models/*_final.pt\")\nif final_models:\n ckpt = torch.load(final_models[0], map_location=device, weights_only=True)\n model.load_state_dict(ckpt[\"model\"])\n model.eval()\n print(f\"Loaded: {final_models[0]}\")\n print(f\"Trained for {ckpt.get('episode', '?')} episodes\")\n print()\n\n for n in [2, 3, 4]:\n wr = evaluate_vs_greedy_batch(model, num_players=n, num_games=2000, device=device)\n print(f\" {n} players: win rate = {wr:.1%} (random baseline: {1/n:.1%})\")\nelse:\n print(\"No model found. Train first!\")",
+ "execution_count": null,
+ "outputs": []
}
]
} \ No newline at end of file