"""Data types used across HAG modules.""" from dataclasses import dataclass, field from typing import List, Optional import torch @dataclass class HopfieldResult: """Result from Hopfield iterative retrieval.""" attention_weights: torch.Tensor # (batch, N) or (N,) converged_query: torch.Tensor # (batch, d) or (d,) num_steps: int trajectory: Optional[List[torch.Tensor]] = None # list of q_t energy_curve: Optional[List[torch.Tensor]] = None # list of E(q_t) @dataclass class RetrievalResult: """Result from a retriever (FAISS or Hopfield).""" passages: List[str] scores: torch.Tensor # top-k scores indices: torch.Tensor # top-k indices hopfield_result: Optional[HopfieldResult] = None @dataclass class PipelineResult: """Result from the full RAG/HAG pipeline.""" question: str answer: str retrieved_passages: List[str] retrieval_result: RetrievalResult