summaryrefslogtreecommitdiff
path: root/hag/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'hag/config.py')
-rw-r--r--hag/config.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/hag/config.py b/hag/config.py
index 793e3a6..10d0aff 100644
--- a/hag/config.py
+++ b/hag/config.py
@@ -19,6 +19,7 @@ class MemoryBankConfig:
embedding_dim: int = 768 # Must match encoder output dim
normalize: bool = True # L2-normalize embeddings in memory bank
+ center: bool = False # Mean-center embeddings to remove centroid attractor
@dataclass
@@ -35,7 +36,7 @@ class GeneratorConfig:
"""Configuration for the LLM generator."""
model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
- max_new_tokens: int = 128
+ max_new_tokens: int = 32
temperature: float = 0.0 # Greedy decoding for reproducibility
@@ -48,3 +49,4 @@ class PipelineConfig:
encoder: EncoderConfig = field(default_factory=EncoderConfig)
generator: GeneratorConfig = field(default_factory=GeneratorConfig)
retriever_type: str = "hopfield" # "hopfield" or "faiss"
+ device: str = "cpu" # "cpu", "cuda", "cuda:0", etc.