diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
| commit | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch) | |
| tree | 43edac8013fec4e65a0b9cddec5314489b4aafc2 /hag/config.py | |
Core Hopfield retrieval module with energy-based convergence guarantees,
memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end
pipeline. All 45 tests passing on CPU with synthetic data.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag/config.py')
| -rw-r--r-- | hag/config.py | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/hag/config.py b/hag/config.py new file mode 100644 index 0000000..793e3a6 --- /dev/null +++ b/hag/config.py @@ -0,0 +1,50 @@ +"""All hyperparameters and configuration dataclasses for HAG.""" + +from dataclasses import dataclass, field + + +@dataclass +class HopfieldConfig: + """Configuration for the Hopfield retrieval module.""" + + beta: float = 1.0 # Inverse temperature. Higher = sharper retrieval + max_iter: int = 5 # Maximum Hopfield iteration steps + conv_threshold: float = 1e-4 # Stop if ||q_{t+1} - q_t|| < threshold + top_k: int = 5 # Number of passages to retrieve from final attention weights + + +@dataclass +class MemoryBankConfig: + """Configuration for the memory bank.""" + + embedding_dim: int = 768 # Must match encoder output dim + normalize: bool = True # L2-normalize embeddings in memory bank + + +@dataclass +class EncoderConfig: + """Configuration for the query/passage encoder.""" + + model_name: str = "facebook/contriever-msmarco" + max_length: int = 512 + batch_size: int = 64 + + +@dataclass +class GeneratorConfig: + """Configuration for the LLM generator.""" + + model_name: str = "meta-llama/Llama-3.1-8B-Instruct" + max_new_tokens: int = 128 + temperature: float = 0.0 # Greedy decoding for reproducibility + + +@dataclass +class PipelineConfig: + """Top-level pipeline configuration.""" + + hopfield: HopfieldConfig = field(default_factory=HopfieldConfig) + memory: MemoryBankConfig = field(default_factory=MemoryBankConfig) + encoder: EncoderConfig = field(default_factory=EncoderConfig) + generator: GeneratorConfig = field(default_factory=GeneratorConfig) + retriever_type: str = "hopfield" # "hopfield" or "faiss" |
