blob: 0f4254d96d4ea01339e66d6d20575baa12bb8322 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
|
"""Data types used across HAG modules."""
from dataclasses import dataclass, field
from typing import List, Optional
import torch
@dataclass
class HopfieldResult:
"""Result from Hopfield iterative retrieval."""
attention_weights: torch.Tensor # (batch, N) or (N,)
converged_query: torch.Tensor # (batch, d) or (d,)
num_steps: int
trajectory: Optional[List[torch.Tensor]] = None # list of q_t
energy_curve: Optional[List[torch.Tensor]] = None # list of E(q_t)
@dataclass
class RetrievalResult:
"""Result from a retriever (FAISS or Hopfield)."""
passages: List[str]
scores: torch.Tensor # top-k scores
indices: torch.Tensor # top-k indices
hopfield_result: Optional[HopfieldResult] = None
@dataclass
class PipelineResult:
"""Result from the full RAG/HAG pipeline."""
question: str
answer: str
retrieved_passages: List[str]
retrieval_result: RetrievalResult
|