summaryrefslogtreecommitdiff
path: root/hag/datatypes.py
blob: 0f4254d96d4ea01339e66d6d20575baa12bb8322 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""Data types used across HAG modules."""

from dataclasses import dataclass, field
from typing import List, Optional

import torch


@dataclass
class HopfieldResult:
    """Result from Hopfield iterative retrieval."""

    attention_weights: torch.Tensor  # (batch, N) or (N,)
    converged_query: torch.Tensor  # (batch, d) or (d,)
    num_steps: int
    trajectory: Optional[List[torch.Tensor]] = None  # list of q_t
    energy_curve: Optional[List[torch.Tensor]] = None  # list of E(q_t)


@dataclass
class RetrievalResult:
    """Result from a retriever (FAISS or Hopfield)."""

    passages: List[str]
    scores: torch.Tensor  # top-k scores
    indices: torch.Tensor  # top-k indices
    hopfield_result: Optional[HopfieldResult] = None


@dataclass
class PipelineResult:
    """Result from the full RAG/HAG pipeline."""

    question: str
    answer: str
    retrieved_passages: List[str]
    retrieval_result: RetrievalResult