diff options
Diffstat (limited to 'arc_eval.ipynb')
| -rw-r--r-- | arc_eval.ipynb | 252 |
1 files changed, 252 insertions, 0 deletions
diff --git a/arc_eval.ipynb b/arc_eval.ipynb new file mode 100644 index 0000000..b2786b8 --- /dev/null +++ b/arc_eval.ipynb @@ -0,0 +1,252 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "from glob import glob\n", + "import hashlib\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.colors as mcolors\n", + "\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "from numba import njit\n", + "\n", + "from dataset.common import inverse_dihedral_transform\n", + "\n", + "\n", + "DATASET_PATH = \"data/arc-aug-1000\" # ARC-1\n", + "# DATASET_PATH = \"data/arc-2-aug-1000\" # ARC-2\n", + "\n", + "CHECKPOINT_PATH = \"checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 amphibian-turaco/step_414456\"\n", + "\n", + "\n", + "PAD_PUZZLE_IDENTIFIER = 0\n", + "\n", + "# Visualization\n", + "ARC_COLOR_MAP = mcolors.ListedColormap([\n", + " \"#000000\", # symbol_0: black\n", + " \"#0074D9\", # symbol_1: blue\n", + " \"#FF4136\", # symbol_2: red\n", + " \"#2ECC40\", # symbol_3: green\n", + " \"#FFDC00\", # symbol_4: yellow\n", + " \"#AAAAAA\", # symbol_5: grey\n", + " \"#F012BE\", # symbol_6: fuschia\n", + " \"#FF851B\", # symbol_7: orange\n", + " \"#7FDBFF\", # symbol_8: teal\n", + " \"#870C25\" # symbol_9: brown\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def load_identifiers_and_preds(dataset_path: str, checkpoint_path: str):\n", + " # Load puzzle identifiers\n", + " with open(os.path.join(dataset_path, \"identifiers.json\"), \"r\") as f:\n", + " identifier_map = json.load(f)\n", + " \n", + " # Load preds\n", + " all_preds = {}\n", + " for filename in glob(f\"{checkpoint_path}_all_preds.*\"):\n", + " preds = torch.load(filename)\n", + " for k, v in preds.items():\n", + " all_preds.setdefault(k, [])\n", + " all_preds[k].append(v)\n", + " \n", + " del preds\n", + "\n", + " all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}\n", + " \n", + " # Remove paddings\n", + " mask = all_preds[\"puzzle_identifiers\"] != PAD_PUZZLE_IDENTIFIER\n", + " all_preds = {k: v[mask] for k, v in all_preds.items()}\n", + "\n", + " return identifier_map, all_preds\n", + "\n", + "\n", + "def inverse_aug(name: str, grid: np.ndarray):\n", + " if \"_\" not in name:\n", + " return grid\n", + "\n", + " trans_id, perm = name.split(\"_\")[-2:]\n", + " trans_id = int(trans_id[1:]) # Remove \"t\" letter\n", + " inv_perm = np.argsort(list(perm))\n", + " \n", + " return inv_perm[inverse_dihedral_transform(grid, trans_id)]\n", + "\n", + "\n", + "def grid_hash(grid: np.ndarray):\n", + " return hash((grid.tobytes(), grid.shape))\n", + "\n", + "\n", + "@njit\n", + "def crop(grid: np.ndarray):\n", + " # Find maximum-sized rectangle without any EOS token inside.\n", + " grid = grid.reshape(30, 30)\n", + "\n", + " max_area = 0\n", + " max_size = (0, 0)\n", + " nr, nc = grid.shape\n", + " \n", + " num_c = nc\n", + " for num_r in range(1, nr + 1):\n", + " # Scan for maximum c\n", + " for c in range(1, num_c + 1):\n", + " x = grid[num_r - 1, c - 1]\n", + " if (x < 2) | (x > 11):\n", + " num_c = c - 1\n", + " break\n", + " \n", + " area = num_r * num_c\n", + " if area > max_area:\n", + " max_area = area\n", + " max_size = (num_r, num_c)\n", + "\n", + " return grid[:max_size[0], :max_size[1]] - 2\n", + "\n", + "\n", + "def test(visualize, Ks=[1, 2, 10, 100, 1000]):\n", + " identifier_map, all_preds = load_identifiers_and_preds(DATASET_PATH, CHECKPOINT_PATH)\n", + " \n", + " global_hmap = {}\n", + " \n", + " # Get puzzles and corresponding answers\n", + " puzzle_labels = {}\n", + " for identifier, input, label in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], all_preds[\"labels\"]):\n", + " name = identifier_map[identifier]\n", + " if \"_\" not in name: # Not-augmented\n", + " puzzle_labels.setdefault(name, {})\n", + " \n", + " input = crop(input.numpy())\n", + " label = crop(label.numpy())\n", + "\n", + " input_hash = grid_hash(input)\n", + " label_hash = grid_hash(label)\n", + "\n", + " global_hmap[input_hash] = input\n", + " global_hmap[label_hash] = label\n", + "\n", + " assert input_hash not in puzzle_labels[name]\n", + " puzzle_labels[name][input_hash] = label_hash\n", + " \n", + " print (\"Number of puzzles\", len(puzzle_labels))\n", + " \n", + " # Argmax prediction\n", + " preds = all_preds[\"logits\"].argmax(-1)\n", + "\n", + " # Collate\n", + " pred_answers = {}\n", + " for identifier, input, pred, q in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], preds, all_preds[\"q_halt_logits\"].sigmoid()):\n", + " name = identifier_map[identifier]\n", + " orig_name = name.split(\"_\")[0]\n", + " \n", + " input = input.numpy()\n", + " input_hash = grid_hash(inverse_aug(name, crop(input)))\n", + " assert input_hash in puzzle_labels[orig_name]\n", + " \n", + " pred = inverse_aug(name, crop(pred.numpy()))\n", + " pred_hash = grid_hash(pred)\n", + " global_hmap[pred_hash] = pred\n", + " \n", + " pred_answers.setdefault(orig_name, {})\n", + " pred_answers[orig_name].setdefault(input_hash, [])\n", + " pred_answers[orig_name][input_hash].append((pred_hash, q.item()))\n", + "\n", + " # test-1\n", + " if visualize:\n", + " num_figs = sum(len(tests) for name, tests in puzzle_labels.items())\n", + " fig, axes = plt.subplots(num_figs, 4, figsize=(8, num_figs * 4))\n", + " \n", + " fig_id = 0\n", + " \n", + " correct = [0 for _ in range(len(Ks))]\n", + " for name, tests in puzzle_labels.items():\n", + " num_test_correct = [0 for _ in range(len(Ks))]\n", + " for input_hash, label_hash in tests.items():\n", + " p = pred_answers[name][input_hash]\n", + " p_map = {}\n", + " \n", + " for h, q in p:\n", + " p_map.setdefault(h, [0, 0])\n", + " p_map[h][0] += 1\n", + " p_map[h][1] += q\n", + " \n", + " for h, stats in p_map.items():\n", + " stats[1] /= stats[0]\n", + " \n", + " p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)\n", + "\n", + " # 2-vote\n", + " for i, k in enumerate(Ks):\n", + " ok = False\n", + " for h, stats in p_map[:k]:\n", + " ok |= h == label_hash\n", + " \n", + " num_test_correct[i] += ok\n", + "\n", + " if visualize:\n", + " # Show input and ground truth\n", + " axes[fig_id, 0].imshow(global_hmap[input_hash], cmap=ARC_COLOR_MAP)\n", + " axes[fig_id, 0].set_title(f\"{name}\\nInput\")\n", + " axes[fig_id, 0].axis('off')\n", + " \n", + " axes[fig_id, 1].imshow(global_hmap[label_hash], cmap=ARC_COLOR_MAP)\n", + " axes[fig_id, 1].set_title(f\"{name}\\nAnswer\")\n", + " axes[fig_id, 1].axis('off')\n", + " \n", + " trial_id = 2\n", + " for h, stats in p_map[:2]:\n", + " ans = global_hmap[h]\n", + " \n", + " axes[fig_id, trial_id].imshow(ans, cmap=ARC_COLOR_MAP)\n", + " axes[fig_id, trial_id].set_title(f\"{name}\\nTrial {trial_id}\")\n", + " axes[fig_id, trial_id].axis('off')\n", + " \n", + " trial_id += 1\n", + " \n", + " fig_id += 1\n", + " \n", + " # Total correctness\n", + " for i in range(len(Ks)):\n", + " correct[i] += num_test_correct[i] == len(tests)\n", + "\n", + " for i, k in enumerate(Ks):\n", + " print (f\"{k}-shot: {correct[i] / len(puzzle_labels) * 100:.2f}%\")\n", + "\n", + "\n", + "test(visualize=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} |
