"""LLM generation wrapper for producing answers from retrieved context.""" import logging from typing import List from hag.config import GeneratorConfig logger = logging.getLogger(__name__) PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. Context: {context} Question: {question} Answer:""" class Generator: """LLM-based answer generator. Uses a HuggingFace causal LM (e.g., Llama-3.1-8B-Instruct). For testing, use FakeGenerator instead. """ def __init__(self, config: GeneratorConfig) -> None: self.config = config self._tokenizer = None self._model = None def _load_model(self) -> None: """Lazy-load the model and tokenizer.""" from transformers import AutoModelForCausalLM, AutoTokenizer logger.info("Loading generator model: %s", self.config.model_name) self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) self._model = AutoModelForCausalLM.from_pretrained( self.config.model_name, torch_dtype="auto", ) self._model.eval() def generate(self, question: str, passages: List[str]) -> str: """Generate an answer given a question and retrieved passages. Args: question: the user question passages: list of retrieved passage texts Returns: Generated answer string. """ if self._model is None: self._load_model() context = "\n\n".join( f"[{i+1}] {p}" for i, p in enumerate(passages) ) prompt = PROMPT_TEMPLATE.format(context=context, question=question) inputs = self._tokenizer(prompt, return_tensors="pt") outputs = self._model.generate( **inputs, max_new_tokens=self.config.max_new_tokens, temperature=self.config.temperature if self.config.temperature > 0 else None, do_sample=self.config.temperature > 0, ) # Decode only the generated tokens (skip the prompt) generated = outputs[0][inputs["input_ids"].shape[1]:] return self._tokenizer.decode(generated, skip_special_tokens=True).strip() class FakeGenerator: """Deterministic mock generator for testing. No model download needed.""" def generate(self, question: str, passages: List[str]) -> str: """Return a mock answer. Args: question: the user question passages: list of retrieved passages Returns: Mock answer string. """ return "mock answer"