summaryrefslogtreecommitdiff
path: root/hag/config.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
commitc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch)
tree43edac8013fec4e65a0b9cddec5314489b4aafc2 /hag/config.py
Initial implementation of HAG (Hopfield-Augmented Generation)HEADmaster
Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag/config.py')
-rw-r--r--hag/config.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/hag/config.py b/hag/config.py
new file mode 100644
index 0000000..793e3a6
--- /dev/null
+++ b/hag/config.py
@@ -0,0 +1,50 @@
+"""All hyperparameters and configuration dataclasses for HAG."""
+
+from dataclasses import dataclass, field
+
+
+@dataclass
+class HopfieldConfig:
+ """Configuration for the Hopfield retrieval module."""
+
+ beta: float = 1.0 # Inverse temperature. Higher = sharper retrieval
+ max_iter: int = 5 # Maximum Hopfield iteration steps
+ conv_threshold: float = 1e-4 # Stop if ||q_{t+1} - q_t|| < threshold
+ top_k: int = 5 # Number of passages to retrieve from final attention weights
+
+
+@dataclass
+class MemoryBankConfig:
+ """Configuration for the memory bank."""
+
+ embedding_dim: int = 768 # Must match encoder output dim
+ normalize: bool = True # L2-normalize embeddings in memory bank
+
+
+@dataclass
+class EncoderConfig:
+ """Configuration for the query/passage encoder."""
+
+ model_name: str = "facebook/contriever-msmarco"
+ max_length: int = 512
+ batch_size: int = 64
+
+
+@dataclass
+class GeneratorConfig:
+ """Configuration for the LLM generator."""
+
+ model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
+ max_new_tokens: int = 128
+ temperature: float = 0.0 # Greedy decoding for reproducibility
+
+
+@dataclass
+class PipelineConfig:
+ """Top-level pipeline configuration."""
+
+ hopfield: HopfieldConfig = field(default_factory=HopfieldConfig)
+ memory: MemoryBankConfig = field(default_factory=MemoryBankConfig)
+ encoder: EncoderConfig = field(default_factory=EncoderConfig)
+ generator: GeneratorConfig = field(default_factory=GeneratorConfig)
+ retriever_type: str = "hopfield" # "hopfield" or "faiss"