From c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sun, 15 Feb 2026 18:19:50 +0000 Subject: Initial implementation of HAG (Hopfield-Augmented Generation) Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 --- hag/generator.py | 87 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 hag/generator.py (limited to 'hag/generator.py') diff --git a/hag/generator.py b/hag/generator.py new file mode 100644 index 0000000..2142e0c --- /dev/null +++ b/hag/generator.py @@ -0,0 +1,87 @@ +"""LLM generation wrapper for producing answers from retrieved context.""" + +import logging +from typing import List + +from hag.config import GeneratorConfig + +logger = logging.getLogger(__name__) + +PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. + +Context: +{context} + +Question: {question} + +Answer:""" + + +class Generator: + """LLM-based answer generator. + + Uses a HuggingFace causal LM (e.g., Llama-3.1-8B-Instruct). + For testing, use FakeGenerator instead. + """ + + def __init__(self, config: GeneratorConfig) -> None: + self.config = config + self._tokenizer = None + self._model = None + + def _load_model(self) -> None: + """Lazy-load the model and tokenizer.""" + from transformers import AutoModelForCausalLM, AutoTokenizer + + logger.info("Loading generator model: %s", self.config.model_name) + self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) + self._model = AutoModelForCausalLM.from_pretrained( + self.config.model_name, + torch_dtype="auto", + ) + self._model.eval() + + def generate(self, question: str, passages: List[str]) -> str: + """Generate an answer given a question and retrieved passages. + + Args: + question: the user question + passages: list of retrieved passage texts + + Returns: + Generated answer string. + """ + if self._model is None: + self._load_model() + + context = "\n\n".join( + f"[{i+1}] {p}" for i, p in enumerate(passages) + ) + prompt = PROMPT_TEMPLATE.format(context=context, question=question) + + inputs = self._tokenizer(prompt, return_tensors="pt") + outputs = self._model.generate( + **inputs, + max_new_tokens=self.config.max_new_tokens, + temperature=self.config.temperature if self.config.temperature > 0 else None, + do_sample=self.config.temperature > 0, + ) + # Decode only the generated tokens (skip the prompt) + generated = outputs[0][inputs["input_ids"].shape[1]:] + return self._tokenizer.decode(generated, skip_special_tokens=True).strip() + + +class FakeGenerator: + """Deterministic mock generator for testing. No model download needed.""" + + def generate(self, question: str, passages: List[str]) -> str: + """Return a mock answer. + + Args: + question: the user question + passages: list of retrieved passages + + Returns: + Mock answer string. + """ + return "mock answer" -- cgit v1.2.3