blob: 793e3a6bdc69d1d64afa346a7054bcdc5fbb4ca4 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
|
"""All hyperparameters and configuration dataclasses for HAG."""
from dataclasses import dataclass, field
@dataclass
class HopfieldConfig:
"""Configuration for the Hopfield retrieval module."""
beta: float = 1.0 # Inverse temperature. Higher = sharper retrieval
max_iter: int = 5 # Maximum Hopfield iteration steps
conv_threshold: float = 1e-4 # Stop if ||q_{t+1} - q_t|| < threshold
top_k: int = 5 # Number of passages to retrieve from final attention weights
@dataclass
class MemoryBankConfig:
"""Configuration for the memory bank."""
embedding_dim: int = 768 # Must match encoder output dim
normalize: bool = True # L2-normalize embeddings in memory bank
@dataclass
class EncoderConfig:
"""Configuration for the query/passage encoder."""
model_name: str = "facebook/contriever-msmarco"
max_length: int = 512
batch_size: int = 64
@dataclass
class GeneratorConfig:
"""Configuration for the LLM generator."""
model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
max_new_tokens: int = 128
temperature: float = 0.0 # Greedy decoding for reproducibility
@dataclass
class PipelineConfig:
"""Top-level pipeline configuration."""
hopfield: HopfieldConfig = field(default_factory=HopfieldConfig)
memory: MemoryBankConfig = field(default_factory=MemoryBankConfig)
encoder: EncoderConfig = field(default_factory=EncoderConfig)
generator: GeneratorConfig = field(default_factory=GeneratorConfig)
retriever_type: str = "hopfield" # "hopfield" or "faiss"
|