summaryrefslogtreecommitdiff
path: root/hag/datatypes.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
commitc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch)
tree43edac8013fec4e65a0b9cddec5314489b4aafc2 /hag/datatypes.py
Initial implementation of HAG (Hopfield-Augmented Generation)HEADmaster
Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag/datatypes.py')
-rw-r--r--hag/datatypes.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/hag/datatypes.py b/hag/datatypes.py
new file mode 100644
index 0000000..0f4254d
--- /dev/null
+++ b/hag/datatypes.py
@@ -0,0 +1,37 @@
+"""Data types used across HAG modules."""
+
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+import torch
+
+
+@dataclass
+class HopfieldResult:
+ """Result from Hopfield iterative retrieval."""
+
+ attention_weights: torch.Tensor # (batch, N) or (N,)
+ converged_query: torch.Tensor # (batch, d) or (d,)
+ num_steps: int
+ trajectory: Optional[List[torch.Tensor]] = None # list of q_t
+ energy_curve: Optional[List[torch.Tensor]] = None # list of E(q_t)
+
+
+@dataclass
+class RetrievalResult:
+ """Result from a retriever (FAISS or Hopfield)."""
+
+ passages: List[str]
+ scores: torch.Tensor # top-k scores
+ indices: torch.Tensor # top-k indices
+ hopfield_result: Optional[HopfieldResult] = None
+
+
+@dataclass
+class PipelineResult:
+ """Result from the full RAG/HAG pipeline."""
+
+ question: str
+ answer: str
+ retrieved_passages: List[str]
+ retrieval_result: RetrievalResult