From c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sun, 15 Feb 2026 18:19:50 +0000 Subject: Initial implementation of HAG (Hopfield-Augmented Generation) Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 --- hag/config.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 hag/config.py (limited to 'hag/config.py') diff --git a/hag/config.py b/hag/config.py new file mode 100644 index 0000000..793e3a6 --- /dev/null +++ b/hag/config.py @@ -0,0 +1,50 @@ +"""All hyperparameters and configuration dataclasses for HAG.""" + +from dataclasses import dataclass, field + + +@dataclass +class HopfieldConfig: + """Configuration for the Hopfield retrieval module.""" + + beta: float = 1.0 # Inverse temperature. Higher = sharper retrieval + max_iter: int = 5 # Maximum Hopfield iteration steps + conv_threshold: float = 1e-4 # Stop if ||q_{t+1} - q_t|| < threshold + top_k: int = 5 # Number of passages to retrieve from final attention weights + + +@dataclass +class MemoryBankConfig: + """Configuration for the memory bank.""" + + embedding_dim: int = 768 # Must match encoder output dim + normalize: bool = True # L2-normalize embeddings in memory bank + + +@dataclass +class EncoderConfig: + """Configuration for the query/passage encoder.""" + + model_name: str = "facebook/contriever-msmarco" + max_length: int = 512 + batch_size: int = 64 + + +@dataclass +class GeneratorConfig: + """Configuration for the LLM generator.""" + + model_name: str = "meta-llama/Llama-3.1-8B-Instruct" + max_new_tokens: int = 128 + temperature: float = 0.0 # Greedy decoding for reproducibility + + +@dataclass +class PipelineConfig: + """Top-level pipeline configuration.""" + + hopfield: HopfieldConfig = field(default_factory=HopfieldConfig) + memory: MemoryBankConfig = field(default_factory=MemoryBankConfig) + encoder: EncoderConfig = field(default_factory=EncoderConfig) + generator: GeneratorConfig = field(default_factory=GeneratorConfig) + retriever_type: str = "hopfield" # "hopfield" or "faiss" -- cgit v1.2.3