summaryrefslogtreecommitdiff
path: root/hrm/dataset/common.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /hrm/dataset/common.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'hrm/dataset/common.py')
-rw-r--r--hrm/dataset/common.py51
1 files changed, 51 insertions, 0 deletions
diff --git a/hrm/dataset/common.py b/hrm/dataset/common.py
new file mode 100644
index 0000000..7bc51c6
--- /dev/null
+++ b/hrm/dataset/common.py
@@ -0,0 +1,51 @@
+from typing import List, Optional
+
+import pydantic
+import numpy as np
+
+
+# Global list mapping each dihedral transform id to its inverse.
+# Index corresponds to the original tid, and the value is its inverse.
+DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]
+
+
+class PuzzleDatasetMetadata(pydantic.BaseModel):
+ pad_id: int
+ ignore_label_id: Optional[int]
+ blank_identifier_id: int
+
+ vocab_size: int
+ seq_len: int
+ num_puzzle_identifiers: int
+
+ total_groups: int
+ mean_puzzle_examples: float
+
+ sets: List[str]
+
+
+def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
+ """8 dihedral symmetries by rotate, flip and mirror"""
+
+ if tid == 0:
+ return arr # identity
+ elif tid == 1:
+ return np.rot90(arr, k=1)
+ elif tid == 2:
+ return np.rot90(arr, k=2)
+ elif tid == 3:
+ return np.rot90(arr, k=3)
+ elif tid == 4:
+ return np.fliplr(arr) # horizontal flip
+ elif tid == 5:
+ return np.flipud(arr) # vertical flip
+ elif tid == 6:
+ return arr.T # transpose (reflection along main diagonal)
+ elif tid == 7:
+ return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection
+ else:
+ return arr
+
+
+def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
+ return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])