From 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sat, 13 Jun 2026 12:35:36 -0500 Subject: rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipeline Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 --- trm/evaluators/arc.py | 177 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 trm/evaluators/arc.py (limited to 'trm/evaluators') diff --git a/trm/evaluators/arc.py b/trm/evaluators/arc.py new file mode 100644 index 0000000..3003cbf --- /dev/null +++ b/trm/evaluators/arc.py @@ -0,0 +1,177 @@ +from typing import Dict, Sequence, Optional +import os +import json + +import torch +import numpy as np +from numba import njit +import torch.distributed as dist + +from dataset.build_arc_dataset import inverse_aug, grid_hash, arc_grid_to_np +from dataset.common import PuzzleDatasetMetadata + +@njit +def _crop(grid: np.ndarray): + """Find maximum-sized rectangle without any EOS token inside. """ + grid = grid.reshape(30, 30) + + max_area = 0 + max_size = (0, 0) + nr, nc = grid.shape + + num_c = nc + for num_r in range(1, nr + 1): + # Scan for maximum c + for c in range(1, num_c + 1): + x = grid[num_r - 1, c - 1] + if (x < 2) | (x > 11): + num_c = c - 1 + break + + area = num_r * num_c + if area > max_area: + max_area = area + max_size = (num_r, num_c) + + return (grid[:max_size[0], :max_size[1]] - 2).astype(np.uint8) + + +class ARC: + required_outputs = {"inputs", "puzzle_identifiers", "q_halt_logits", "preds"} + + def __init__(self, data_path: str, + eval_metadata: PuzzleDatasetMetadata, + submission_K: int = 2, + pass_Ks: Sequence[int] = (1, 2, 5, 10, 100, 1000), + aggregated_voting: bool = True): + super().__init__() + self.pass_Ks = pass_Ks + self.submission_K = submission_K + self.aggregated_voting = aggregated_voting + self.blank_identifier_id = eval_metadata.blank_identifier_id + + # Load identifiers and test puzzles + with open(os.path.join(data_path, "identifiers.json"), "r") as f: + self.identifier_map = json.load(f) + with open(os.path.join(data_path, "test_puzzles.json"), "r") as f: + self.test_puzzles = json.load(f) + + # States + self._local_hmap = {} + self._local_preds = {} + + def begin_eval(self): + if not self.aggregated_voting: + # Clear previous predictions + self._local_hmap = {} + self._local_preds = {} + + def update_batch(self, batch: Dict[str, torch.Tensor], preds: Dict[str, torch.Tensor]): + # Collect required outputs to CPU + outputs = {} + q_values = None + + for collection in (batch, preds): + for k, v in collection.items(): + if k in self.required_outputs: + if k == "q_halt_logits": + q_values = v.to(torch.float64).sigmoid().cpu() + else: + outputs[k] = v.cpu() + + assert q_values is not None + + # Remove padding from outputs + mask = outputs["puzzle_identifiers"] != self.blank_identifier_id + outputs = {k: v[mask] for k, v in outputs.items()} + + # Get predictions + for identifier, input, pred, q in zip(outputs["puzzle_identifiers"].numpy(), outputs["inputs"].numpy(), outputs["preds"].numpy(), q_values.numpy()): + name = self.identifier_map[identifier] + orig_name, _inverse_fn = inverse_aug(name) + + input_hash = grid_hash(_inverse_fn(_crop(input))) + + pred = _inverse_fn(_crop(pred)) + assert np.all((pred >= 0) & (pred <= 9)), f"Puzzle {name}'s prediction out of 0-9 range." # Sanity check + + # Store into local state + pred_hash = grid_hash(pred) + + self._local_hmap[pred_hash] = pred + + self._local_preds.setdefault(orig_name, {}) + self._local_preds[orig_name].setdefault(input_hash, []) + self._local_preds[orig_name][input_hash].append((pred_hash, float(q))) + + def result(self, save_path: Optional[str], rank: int, world_size: int, group: Optional[torch.distributed.ProcessGroup] = None) -> Optional[Dict[str, float]]: + # Gather predictions to rank 0 for voting + global_hmap_preds = [None for _ in range(world_size)] if rank == 0 else None + dist.gather_object((self._local_hmap, self._local_preds), global_hmap_preds, dst=0, group=group) + + # Rank 0 logic + if rank != 0: + return + + submission = {} + correct = [0.0 for _ in range(len(self.pass_Ks))] + + for name, puzzle in self.test_puzzles.items(): + # Process test examples in this puzzle + submission[name] = [] + num_test_correct = [0 for _ in range(len(self.pass_Ks))] + for pair in puzzle["test"]: + input_hash = grid_hash(arc_grid_to_np(pair["input"])) + label_hash = grid_hash(arc_grid_to_np(pair["output"])) + + p_map = {} + for hmap, preds in global_hmap_preds: # type: ignore + for h, q in preds.get(name, {}).get(input_hash, {}): + p_map.setdefault(h, [0, 0]) + p_map[h][0] += 1 + p_map[h][1] += q + + if not len(p_map): + print (f"Puzzle {name} has no predictions.") + continue + + for h, stats in p_map.items(): + stats[1] /= stats[0] + + p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True) + + # vote for different Ks + for i, k in enumerate(self.pass_Ks): + ok = False + for h, stats in p_map[:k]: + ok |= h == label_hash + + num_test_correct[i] += ok + + # Query grids + pred_grids = [] + for h, stats in p_map[:self.submission_K]: + for hmap, preds in global_hmap_preds: # type: ignore + if h in hmap: + pred_grids.append(hmap[h]) + break + + # Pad to K + while len(pred_grids) < self.submission_K: + pred_grids.append(pred_grids[0]) + + submission[name].append({f"attempt_{i + 1}": grid.tolist() for i, grid in enumerate(pred_grids)}) + + # Total correctness + for i in range(len(self.pass_Ks)): + correct[i] += num_test_correct[i] / len(puzzle["test"]) + + # Save submission + if save_path is not None: + with open(os.path.join(save_path, "submission.json"), "w") as f: + json.dump(submission, f) + + # Final result + all_results = {f"ARC/pass@{k}": correct[i] / len(self.test_puzzles) for i, k in enumerate(self.pass_Ks)} + + return all_results \ No newline at end of file -- cgit v1.2.3