diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
| commit | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch) | |
| tree | 43edac8013fec4e65a0b9cddec5314489b4aafc2 /hag/datatypes.py | |
Core Hopfield retrieval module with energy-based convergence guarantees,
memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end
pipeline. All 45 tests passing on CPU with synthetic data.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag/datatypes.py')
| -rw-r--r-- | hag/datatypes.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/hag/datatypes.py b/hag/datatypes.py new file mode 100644 index 0000000..0f4254d --- /dev/null +++ b/hag/datatypes.py @@ -0,0 +1,37 @@ +"""Data types used across HAG modules.""" + +from dataclasses import dataclass, field +from typing import List, Optional + +import torch + + +@dataclass +class HopfieldResult: + """Result from Hopfield iterative retrieval.""" + + attention_weights: torch.Tensor # (batch, N) or (N,) + converged_query: torch.Tensor # (batch, d) or (d,) + num_steps: int + trajectory: Optional[List[torch.Tensor]] = None # list of q_t + energy_curve: Optional[List[torch.Tensor]] = None # list of E(q_t) + + +@dataclass +class RetrievalResult: + """Result from a retriever (FAISS or Hopfield).""" + + passages: List[str] + scores: torch.Tensor # top-k scores + indices: torch.Tensor # top-k indices + hopfield_result: Optional[HopfieldResult] = None + + +@dataclass +class PipelineResult: + """Result from the full RAG/HAG pipeline.""" + + question: str + answer: str + retrieved_passages: List[str] + retrieval_result: RetrievalResult |
