"""LLM generation wrapper for producing answers from retrieved context.""" import logging from typing import List import torch from hag.config import GeneratorConfig logger = logging.getLogger(__name__) PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. Give ONLY the answer itself in a few words, with no explanation. Context: {context} Question: {question} Answer:""" class Generator: """LLM-based answer generator. Uses a HuggingFace causal LM (e.g., Llama-3.1-8B-Instruct). For testing, use FakeGenerator instead. """ def __init__(self, config: GeneratorConfig, device: str = "cpu") -> None: self.config = config self.device = torch.device(device) self._tokenizer = None self._model = None def _load_model(self) -> None: """Lazy-load the model and tokenizer, placing model on device.""" from transformers import AutoModelForCausalLM, AutoTokenizer logger.info("Loading generator model: %s (device=%s)", self.config.model_name, self.device) self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) self._model = AutoModelForCausalLM.from_pretrained( self.config.model_name, torch_dtype="auto", device_map=self.device, ) self._model.eval() def generate(self, question: str, passages: List[str]) -> str: """Generate an answer given a question and retrieved passages. Args: question: the user question passages: list of retrieved passage texts Returns: Generated answer string. """ if self._model is None: self._load_model() context = "\n\n".join( f"[{i+1}] {p}" for i, p in enumerate(passages) ) prompt = PROMPT_TEMPLATE.format(context=context, question=question) inputs = self._tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self._model.generate( **inputs, max_new_tokens=self.config.max_new_tokens, temperature=self.config.temperature if self.config.temperature > 0 else None, do_sample=self.config.temperature > 0, repetition_penalty=1.2, ) # Decode only the generated tokens (skip the prompt) generated = outputs[0][inputs["input_ids"].shape[1]:] answer = self._tokenizer.decode(generated, skip_special_tokens=True).strip() # Take only the first sentence/line as the answer for sep in ["\n", ". ", ".\n"]: if sep in answer: answer = answer.split(sep)[0].strip() break return answer class FakeGenerator: """Deterministic mock generator for testing. No model download needed.""" def generate(self, question: str, passages: List[str]) -> str: """Return a mock answer. Args: question: the user question passages: list of retrieved passages Returns: Mock answer string. """ return "mock answer"