diff options
| author | haoyuren <13851610112@163.com> | 2026-02-22 11:59:01 -0600 |
|---|---|---|
| committer | haoyuren <13851610112@163.com> | 2026-02-22 11:59:01 -0600 |
| commit | 7e15218730fe86b88ac0a53cc84bf929416a0687 (patch) | |
| tree | 4eb5ec4a620d639a1880a24e8c826a1e3e491921 /train_colab.ipynb | |
| parent | 8392bdfc10f92e61303e39bb356522ee491ce97c (diff) | |
Fix invalid notebook cell schema (markdown with execution_count)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'train_colab.ipynb')
| -rw-r--r-- | train_colab.ipynb | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/train_colab.ipynb b/train_colab.ipynb index ac64e50..a591beb 100644 --- a/train_colab.ipynb +++ b/train_colab.ipynb @@ -67,14 +67,14 @@ { "cell_type": "markdown", "source": "## 3. Download model & logs", - "metadata": {}, - "execution_count": null, - "outputs": [] + "metadata": {} }, { "cell_type": "code", "metadata": {}, - "source": "from google.colab import files\nimport glob\n\n# Download final model(s)\nfor f in glob.glob(\"*_final.pt\"):\n print(f\"Downloading {f}...\")\n files.download(f)\n\n# Download training log(s)\nfor f in glob.glob(\"*_log.csv\"):\n print(f\"Downloading {f}...\")\n files.download(f)" + "source": "from google.colab import files\nimport glob\n\n# Download final model(s)\nfor f in glob.glob(\"*_final.pt\"):\n print(f\"Downloading {f}...\")\n files.download(f)\n\n# Download training log(s)\nfor f in glob.glob(\"*_log.csv\"):\n print(f\"Downloading {f}...\")\n files.download(f)", + "execution_count": null, + "outputs": [] }, { "cell_type": "markdown", @@ -129,7 +129,9 @@ { "cell_type": "code", "metadata": {}, - "source": "import sys\nsys.path.insert(0, \".\")\nfrom train import PolicyValueNet, evaluate_vs_greedy_batch\n\ndevice = \"cpu\"\nmodel = PolicyValueNet().to(device)\n\nimport glob\nfinal_models = glob.glob(\"*_final.pt\") + glob.glob(\"models/*_final.pt\")\nif final_models:\n ckpt = torch.load(final_models[0], map_location=device, weights_only=True)\n model.load_state_dict(ckpt[\"model\"])\n model.eval()\n print(f\"Loaded: {final_models[0]}\")\n print(f\"Trained for {ckpt.get('episode', '?')} episodes\")\n print()\n\n for n in [2, 3, 4]:\n wr = evaluate_vs_greedy_batch(model, num_players=n, num_games=2000, device=device)\n print(f\" {n} players: win rate = {wr:.1%} (random baseline: {1/n:.1%})\")\nelse:\n print(\"No model found. Train first!\")" + "source": "import sys\nsys.path.insert(0, \".\")\nfrom train import PolicyValueNet, evaluate_vs_greedy_batch\n\ndevice = \"cpu\"\nmodel = PolicyValueNet().to(device)\n\nimport glob\nfinal_models = glob.glob(\"*_final.pt\") + glob.glob(\"models/*_final.pt\")\nif final_models:\n ckpt = torch.load(final_models[0], map_location=device, weights_only=True)\n model.load_state_dict(ckpt[\"model\"])\n model.eval()\n print(f\"Loaded: {final_models[0]}\")\n print(f\"Trained for {ckpt.get('episode', '?')} episodes\")\n print()\n\n for n in [2, 3, 4]:\n wr = evaluate_vs_greedy_batch(model, num_players=n, num_games=2000, device=device)\n print(f\" {n} players: win rate = {wr:.1%} (random baseline: {1/n:.1%})\")\nelse:\n print(\"No model found. Train first!\")", + "execution_count": null, + "outputs": [] } ] }
\ No newline at end of file |
