summaryrefslogtreecommitdiff
path: root/hag/generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'hag/generator.py')
-rw-r--r--hag/generator.py87
1 files changed, 87 insertions, 0 deletions
diff --git a/hag/generator.py b/hag/generator.py
new file mode 100644
index 0000000..2142e0c
--- /dev/null
+++ b/hag/generator.py
@@ -0,0 +1,87 @@
+"""LLM generation wrapper for producing answers from retrieved context."""
+
+import logging
+from typing import List
+
+from hag.config import GeneratorConfig
+
+logger = logging.getLogger(__name__)
+
+PROMPT_TEMPLATE = """Answer the following question based on the provided context passages.
+
+Context:
+{context}
+
+Question: {question}
+
+Answer:"""
+
+
+class Generator:
+ """LLM-based answer generator.
+
+ Uses a HuggingFace causal LM (e.g., Llama-3.1-8B-Instruct).
+ For testing, use FakeGenerator instead.
+ """
+
+ def __init__(self, config: GeneratorConfig) -> None:
+ self.config = config
+ self._tokenizer = None
+ self._model = None
+
+ def _load_model(self) -> None:
+ """Lazy-load the model and tokenizer."""
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+
+ logger.info("Loading generator model: %s", self.config.model_name)
+ self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
+ self._model = AutoModelForCausalLM.from_pretrained(
+ self.config.model_name,
+ torch_dtype="auto",
+ )
+ self._model.eval()
+
+ def generate(self, question: str, passages: List[str]) -> str:
+ """Generate an answer given a question and retrieved passages.
+
+ Args:
+ question: the user question
+ passages: list of retrieved passage texts
+
+ Returns:
+ Generated answer string.
+ """
+ if self._model is None:
+ self._load_model()
+
+ context = "\n\n".join(
+ f"[{i+1}] {p}" for i, p in enumerate(passages)
+ )
+ prompt = PROMPT_TEMPLATE.format(context=context, question=question)
+
+ inputs = self._tokenizer(prompt, return_tensors="pt")
+ outputs = self._model.generate(
+ **inputs,
+ max_new_tokens=self.config.max_new_tokens,
+ temperature=self.config.temperature if self.config.temperature > 0 else None,
+ do_sample=self.config.temperature > 0,
+ )
+ # Decode only the generated tokens (skip the prompt)
+ generated = outputs[0][inputs["input_ids"].shape[1]:]
+ return self._tokenizer.decode(generated, skip_special_tokens=True).strip()
+
+
+class FakeGenerator:
+ """Deterministic mock generator for testing. No model download needed."""
+
+ def generate(self, question: str, passages: List[str]) -> str:
+ """Return a mock answer.
+
+ Args:
+ question: the user question
+ passages: list of retrieved passages
+
+ Returns:
+ Mock answer string.
+ """
+ return "mock answer"