summaryrefslogtreecommitdiff
path: root/hag/datatypes.py
diff options
context:
space:
mode:
Diffstat (limited to 'hag/datatypes.py')
-rw-r--r--hag/datatypes.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/hag/datatypes.py b/hag/datatypes.py
new file mode 100644
index 0000000..0f4254d
--- /dev/null
+++ b/hag/datatypes.py
@@ -0,0 +1,37 @@
+"""Data types used across HAG modules."""
+
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+import torch
+
+
+@dataclass
+class HopfieldResult:
+ """Result from Hopfield iterative retrieval."""
+
+ attention_weights: torch.Tensor # (batch, N) or (N,)
+ converged_query: torch.Tensor # (batch, d) or (d,)
+ num_steps: int
+ trajectory: Optional[List[torch.Tensor]] = None # list of q_t
+ energy_curve: Optional[List[torch.Tensor]] = None # list of E(q_t)
+
+
+@dataclass
+class RetrievalResult:
+ """Result from a retriever (FAISS or Hopfield)."""
+
+ passages: List[str]
+ scores: torch.Tensor # top-k scores
+ indices: torch.Tensor # top-k indices
+ hopfield_result: Optional[HopfieldResult] = None
+
+
+@dataclass
+class PipelineResult:
+ """Result from the full RAG/HAG pipeline."""
+
+ question: str
+ answer: str
+ retrieved_passages: List[str]
+ retrieval_result: RetrievalResult