From 09d50e47860da0035e178a442dc936028808a0b3 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 16 Feb 2026 14:44:42 -0600 Subject: Add memory centering, grid search experiments, and energy visualizations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add centering support to MemoryBank (center_query, apply_centering, mean persistence in save/load) to remove centroid attractor in Hopfield dynamics - Add center flag to MemoryBankConfig, device field to PipelineConfig - Grid search scripts: initial (β≤8), residual, high-β, and centered grids with dedup-based LLM caching (89-91% call savings) - Energy landscape visualization: 2D contour, 1D profile, UMAP, PCA heatmap comparing centered vs uncentered dynamics - Experiment log (note.md) documenting 4 rounds of results and root cause analysis of centroid attractor problem - Key finding: β_critical ≈ 37.6 for centered memory; best configs beat FAISS baseline by +3-4% F1 Co-Authored-By: Claude Opus 4.6 --- hag/generator.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) (limited to 'hag/generator.py') diff --git a/hag/generator.py b/hag/generator.py index 2142e0c..d0de468 100644 --- a/hag/generator.py +++ b/hag/generator.py @@ -3,11 +3,13 @@ import logging from typing import List +import torch + from hag.config import GeneratorConfig logger = logging.getLogger(__name__) -PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. +PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. Give ONLY the answer itself in a few words, with no explanation. Context: {context} @@ -24,20 +26,22 @@ class Generator: For testing, use FakeGenerator instead. """ - def __init__(self, config: GeneratorConfig) -> None: + def __init__(self, config: GeneratorConfig, device: str = "cpu") -> None: self.config = config + self.device = torch.device(device) self._tokenizer = None self._model = None def _load_model(self) -> None: - """Lazy-load the model and tokenizer.""" + """Lazy-load the model and tokenizer, placing model on device.""" from transformers import AutoModelForCausalLM, AutoTokenizer - logger.info("Loading generator model: %s", self.config.model_name) + logger.info("Loading generator model: %s (device=%s)", self.config.model_name, self.device) self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) self._model = AutoModelForCausalLM.from_pretrained( self.config.model_name, torch_dtype="auto", + device_map=self.device, ) self._model.eval() @@ -60,15 +64,23 @@ class Generator: prompt = PROMPT_TEMPLATE.format(context=context, question=question) inputs = self._tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self._model.generate( **inputs, max_new_tokens=self.config.max_new_tokens, temperature=self.config.temperature if self.config.temperature > 0 else None, do_sample=self.config.temperature > 0, + repetition_penalty=1.2, ) # Decode only the generated tokens (skip the prompt) generated = outputs[0][inputs["input_ids"].shape[1]:] - return self._tokenizer.decode(generated, skip_special_tokens=True).strip() + answer = self._tokenizer.decode(generated, skip_special_tokens=True).strip() + # Take only the first sentence/line as the answer + for sep in ["\n", ". ", ".\n"]: + if sep in answer: + answer = answer.split(sep)[0].strip() + break + return answer class FakeGenerator: -- cgit v1.2.3