summaryrefslogtreecommitdiff
path: root/arc_eval.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'arc_eval.ipynb')
-rw-r--r--arc_eval.ipynb252
1 files changed, 252 insertions, 0 deletions
diff --git a/arc_eval.ipynb b/arc_eval.ipynb
new file mode 100644
index 0000000..b2786b8
--- /dev/null
+++ b/arc_eval.ipynb
@@ -0,0 +1,252 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import json\n",
+ "from glob import glob\n",
+ "import hashlib\n",
+ "import matplotlib.pyplot as plt\n",
+ "import matplotlib.colors as mcolors\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "import numpy as np\n",
+ "from numba import njit\n",
+ "\n",
+ "from dataset.common import inverse_dihedral_transform\n",
+ "\n",
+ "\n",
+ "DATASET_PATH = \"data/arc-aug-1000\" # ARC-1\n",
+ "# DATASET_PATH = \"data/arc-2-aug-1000\" # ARC-2\n",
+ "\n",
+ "CHECKPOINT_PATH = \"checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 amphibian-turaco/step_414456\"\n",
+ "\n",
+ "\n",
+ "PAD_PUZZLE_IDENTIFIER = 0\n",
+ "\n",
+ "# Visualization\n",
+ "ARC_COLOR_MAP = mcolors.ListedColormap([\n",
+ " \"#000000\", # symbol_0: black\n",
+ " \"#0074D9\", # symbol_1: blue\n",
+ " \"#FF4136\", # symbol_2: red\n",
+ " \"#2ECC40\", # symbol_3: green\n",
+ " \"#FFDC00\", # symbol_4: yellow\n",
+ " \"#AAAAAA\", # symbol_5: grey\n",
+ " \"#F012BE\", # symbol_6: fuschia\n",
+ " \"#FF851B\", # symbol_7: orange\n",
+ " \"#7FDBFF\", # symbol_8: teal\n",
+ " \"#870C25\" # symbol_9: brown\n",
+ "])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def load_identifiers_and_preds(dataset_path: str, checkpoint_path: str):\n",
+ " # Load puzzle identifiers\n",
+ " with open(os.path.join(dataset_path, \"identifiers.json\"), \"r\") as f:\n",
+ " identifier_map = json.load(f)\n",
+ " \n",
+ " # Load preds\n",
+ " all_preds = {}\n",
+ " for filename in glob(f\"{checkpoint_path}_all_preds.*\"):\n",
+ " preds = torch.load(filename)\n",
+ " for k, v in preds.items():\n",
+ " all_preds.setdefault(k, [])\n",
+ " all_preds[k].append(v)\n",
+ " \n",
+ " del preds\n",
+ "\n",
+ " all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}\n",
+ " \n",
+ " # Remove paddings\n",
+ " mask = all_preds[\"puzzle_identifiers\"] != PAD_PUZZLE_IDENTIFIER\n",
+ " all_preds = {k: v[mask] for k, v in all_preds.items()}\n",
+ "\n",
+ " return identifier_map, all_preds\n",
+ "\n",
+ "\n",
+ "def inverse_aug(name: str, grid: np.ndarray):\n",
+ " if \"_\" not in name:\n",
+ " return grid\n",
+ "\n",
+ " trans_id, perm = name.split(\"_\")[-2:]\n",
+ " trans_id = int(trans_id[1:]) # Remove \"t\" letter\n",
+ " inv_perm = np.argsort(list(perm))\n",
+ " \n",
+ " return inv_perm[inverse_dihedral_transform(grid, trans_id)]\n",
+ "\n",
+ "\n",
+ "def grid_hash(grid: np.ndarray):\n",
+ " return hash((grid.tobytes(), grid.shape))\n",
+ "\n",
+ "\n",
+ "@njit\n",
+ "def crop(grid: np.ndarray):\n",
+ " # Find maximum-sized rectangle without any EOS token inside.\n",
+ " grid = grid.reshape(30, 30)\n",
+ "\n",
+ " max_area = 0\n",
+ " max_size = (0, 0)\n",
+ " nr, nc = grid.shape\n",
+ " \n",
+ " num_c = nc\n",
+ " for num_r in range(1, nr + 1):\n",
+ " # Scan for maximum c\n",
+ " for c in range(1, num_c + 1):\n",
+ " x = grid[num_r - 1, c - 1]\n",
+ " if (x < 2) | (x > 11):\n",
+ " num_c = c - 1\n",
+ " break\n",
+ " \n",
+ " area = num_r * num_c\n",
+ " if area > max_area:\n",
+ " max_area = area\n",
+ " max_size = (num_r, num_c)\n",
+ "\n",
+ " return grid[:max_size[0], :max_size[1]] - 2\n",
+ "\n",
+ "\n",
+ "def test(visualize, Ks=[1, 2, 10, 100, 1000]):\n",
+ " identifier_map, all_preds = load_identifiers_and_preds(DATASET_PATH, CHECKPOINT_PATH)\n",
+ " \n",
+ " global_hmap = {}\n",
+ " \n",
+ " # Get puzzles and corresponding answers\n",
+ " puzzle_labels = {}\n",
+ " for identifier, input, label in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], all_preds[\"labels\"]):\n",
+ " name = identifier_map[identifier]\n",
+ " if \"_\" not in name: # Not-augmented\n",
+ " puzzle_labels.setdefault(name, {})\n",
+ " \n",
+ " input = crop(input.numpy())\n",
+ " label = crop(label.numpy())\n",
+ "\n",
+ " input_hash = grid_hash(input)\n",
+ " label_hash = grid_hash(label)\n",
+ "\n",
+ " global_hmap[input_hash] = input\n",
+ " global_hmap[label_hash] = label\n",
+ "\n",
+ " assert input_hash not in puzzle_labels[name]\n",
+ " puzzle_labels[name][input_hash] = label_hash\n",
+ " \n",
+ " print (\"Number of puzzles\", len(puzzle_labels))\n",
+ " \n",
+ " # Argmax prediction\n",
+ " preds = all_preds[\"logits\"].argmax(-1)\n",
+ "\n",
+ " # Collate\n",
+ " pred_answers = {}\n",
+ " for identifier, input, pred, q in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], preds, all_preds[\"q_halt_logits\"].sigmoid()):\n",
+ " name = identifier_map[identifier]\n",
+ " orig_name = name.split(\"_\")[0]\n",
+ " \n",
+ " input = input.numpy()\n",
+ " input_hash = grid_hash(inverse_aug(name, crop(input)))\n",
+ " assert input_hash in puzzle_labels[orig_name]\n",
+ " \n",
+ " pred = inverse_aug(name, crop(pred.numpy()))\n",
+ " pred_hash = grid_hash(pred)\n",
+ " global_hmap[pred_hash] = pred\n",
+ " \n",
+ " pred_answers.setdefault(orig_name, {})\n",
+ " pred_answers[orig_name].setdefault(input_hash, [])\n",
+ " pred_answers[orig_name][input_hash].append((pred_hash, q.item()))\n",
+ "\n",
+ " # test-1\n",
+ " if visualize:\n",
+ " num_figs = sum(len(tests) for name, tests in puzzle_labels.items())\n",
+ " fig, axes = plt.subplots(num_figs, 4, figsize=(8, num_figs * 4))\n",
+ " \n",
+ " fig_id = 0\n",
+ " \n",
+ " correct = [0 for _ in range(len(Ks))]\n",
+ " for name, tests in puzzle_labels.items():\n",
+ " num_test_correct = [0 for _ in range(len(Ks))]\n",
+ " for input_hash, label_hash in tests.items():\n",
+ " p = pred_answers[name][input_hash]\n",
+ " p_map = {}\n",
+ " \n",
+ " for h, q in p:\n",
+ " p_map.setdefault(h, [0, 0])\n",
+ " p_map[h][0] += 1\n",
+ " p_map[h][1] += q\n",
+ " \n",
+ " for h, stats in p_map.items():\n",
+ " stats[1] /= stats[0]\n",
+ " \n",
+ " p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)\n",
+ "\n",
+ " # 2-vote\n",
+ " for i, k in enumerate(Ks):\n",
+ " ok = False\n",
+ " for h, stats in p_map[:k]:\n",
+ " ok |= h == label_hash\n",
+ " \n",
+ " num_test_correct[i] += ok\n",
+ "\n",
+ " if visualize:\n",
+ " # Show input and ground truth\n",
+ " axes[fig_id, 0].imshow(global_hmap[input_hash], cmap=ARC_COLOR_MAP)\n",
+ " axes[fig_id, 0].set_title(f\"{name}\\nInput\")\n",
+ " axes[fig_id, 0].axis('off')\n",
+ " \n",
+ " axes[fig_id, 1].imshow(global_hmap[label_hash], cmap=ARC_COLOR_MAP)\n",
+ " axes[fig_id, 1].set_title(f\"{name}\\nAnswer\")\n",
+ " axes[fig_id, 1].axis('off')\n",
+ " \n",
+ " trial_id = 2\n",
+ " for h, stats in p_map[:2]:\n",
+ " ans = global_hmap[h]\n",
+ " \n",
+ " axes[fig_id, trial_id].imshow(ans, cmap=ARC_COLOR_MAP)\n",
+ " axes[fig_id, trial_id].set_title(f\"{name}\\nTrial {trial_id}\")\n",
+ " axes[fig_id, trial_id].axis('off')\n",
+ " \n",
+ " trial_id += 1\n",
+ " \n",
+ " fig_id += 1\n",
+ " \n",
+ " # Total correctness\n",
+ " for i in range(len(Ks)):\n",
+ " correct[i] += num_test_correct[i] == len(tests)\n",
+ "\n",
+ " for i, k in enumerate(Ks):\n",
+ " print (f\"{k}-shot: {correct[i] / len(puzzle_labels) * 100:.2f}%\")\n",
+ "\n",
+ "\n",
+ "test(visualize=False)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}