diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-02-16 14:44:42 -0600 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-02-16 14:44:42 -0600 |
| commit | 09d50e47860da0035e178a442dc936028808a0b3 (patch) | |
| tree | 9d651b0c7d289a9a0405953f2da989a3c431f147 /hag/generator.py | |
| parent | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (diff) | |
- Add centering support to MemoryBank (center_query, apply_centering, mean
persistence in save/load) to remove centroid attractor in Hopfield dynamics
- Add center flag to MemoryBankConfig, device field to PipelineConfig
- Grid search scripts: initial (β≤8), residual, high-β, and centered grids
with dedup-based LLM caching (89-91% call savings)
- Energy landscape visualization: 2D contour, 1D profile, UMAP, PCA heatmap
comparing centered vs uncentered dynamics
- Experiment log (note.md) documenting 4 rounds of results and root cause
analysis of centroid attractor problem
- Key finding: β_critical ≈ 37.6 for centered memory; best configs beat
FAISS baseline by +3-4% F1
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag/generator.py')
| -rw-r--r-- | hag/generator.py | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/hag/generator.py b/hag/generator.py index 2142e0c..d0de468 100644 --- a/hag/generator.py +++ b/hag/generator.py @@ -3,11 +3,13 @@ import logging from typing import List +import torch + from hag.config import GeneratorConfig logger = logging.getLogger(__name__) -PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. +PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. Give ONLY the answer itself in a few words, with no explanation. Context: {context} @@ -24,20 +26,22 @@ class Generator: For testing, use FakeGenerator instead. """ - def __init__(self, config: GeneratorConfig) -> None: + def __init__(self, config: GeneratorConfig, device: str = "cpu") -> None: self.config = config + self.device = torch.device(device) self._tokenizer = None self._model = None def _load_model(self) -> None: - """Lazy-load the model and tokenizer.""" + """Lazy-load the model and tokenizer, placing model on device.""" from transformers import AutoModelForCausalLM, AutoTokenizer - logger.info("Loading generator model: %s", self.config.model_name) + logger.info("Loading generator model: %s (device=%s)", self.config.model_name, self.device) self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) self._model = AutoModelForCausalLM.from_pretrained( self.config.model_name, torch_dtype="auto", + device_map=self.device, ) self._model.eval() @@ -60,15 +64,23 @@ class Generator: prompt = PROMPT_TEMPLATE.format(context=context, question=question) inputs = self._tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self._model.generate( **inputs, max_new_tokens=self.config.max_new_tokens, temperature=self.config.temperature if self.config.temperature > 0 else None, do_sample=self.config.temperature > 0, + repetition_penalty=1.2, ) # Decode only the generated tokens (skip the prompt) generated = outputs[0][inputs["input_ids"].shape[1]:] - return self._tokenizer.decode(generated, skip_special_tokens=True).strip() + answer = self._tokenizer.decode(generated, skip_special_tokens=True).strip() + # Take only the first sentence/line as the answer + for sep in ["\n", ". ", ".\n"]: + if sep in answer: + answer = answer.split(sep)[0].strip() + break + return answer class FakeGenerator: |
