"""All hyperparameters and configuration dataclasses for HAG.""" from dataclasses import dataclass, field @dataclass class HopfieldConfig: """Configuration for the Hopfield retrieval module.""" beta: float = 1.0 # Inverse temperature. Higher = sharper retrieval max_iter: int = 5 # Maximum Hopfield iteration steps conv_threshold: float = 1e-4 # Stop if ||q_{t+1} - q_t|| < threshold top_k: int = 5 # Number of passages to retrieve from final attention weights @dataclass class MemoryBankConfig: """Configuration for the memory bank.""" embedding_dim: int = 768 # Must match encoder output dim normalize: bool = True # L2-normalize embeddings in memory bank @dataclass class EncoderConfig: """Configuration for the query/passage encoder.""" model_name: str = "facebook/contriever-msmarco" max_length: int = 512 batch_size: int = 64 @dataclass class GeneratorConfig: """Configuration for the LLM generator.""" model_name: str = "meta-llama/Llama-3.1-8B-Instruct" max_new_tokens: int = 128 temperature: float = 0.0 # Greedy decoding for reproducibility @dataclass class PipelineConfig: """Top-level pipeline configuration.""" hopfield: HopfieldConfig = field(default_factory=HopfieldConfig) memory: MemoryBankConfig = field(default_factory=MemoryBankConfig) encoder: EncoderConfig = field(default_factory=EncoderConfig) generator: GeneratorConfig = field(default_factory=GeneratorConfig) retriever_type: str = "hopfield" # "hopfield" or "faiss"