diff options
Diffstat (limited to 'hag/generator.py')
| -rw-r--r-- | hag/generator.py | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/hag/generator.py b/hag/generator.py index 2142e0c..d0de468 100644 --- a/hag/generator.py +++ b/hag/generator.py @@ -3,11 +3,13 @@ import logging from typing import List +import torch + from hag.config import GeneratorConfig logger = logging.getLogger(__name__) -PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. +PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. Give ONLY the answer itself in a few words, with no explanation. Context: {context} @@ -24,20 +26,22 @@ class Generator: For testing, use FakeGenerator instead. """ - def __init__(self, config: GeneratorConfig) -> None: + def __init__(self, config: GeneratorConfig, device: str = "cpu") -> None: self.config = config + self.device = torch.device(device) self._tokenizer = None self._model = None def _load_model(self) -> None: - """Lazy-load the model and tokenizer.""" + """Lazy-load the model and tokenizer, placing model on device.""" from transformers import AutoModelForCausalLM, AutoTokenizer - logger.info("Loading generator model: %s", self.config.model_name) + logger.info("Loading generator model: %s (device=%s)", self.config.model_name, self.device) self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) self._model = AutoModelForCausalLM.from_pretrained( self.config.model_name, torch_dtype="auto", + device_map=self.device, ) self._model.eval() @@ -60,15 +64,23 @@ class Generator: prompt = PROMPT_TEMPLATE.format(context=context, question=question) inputs = self._tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self._model.generate( **inputs, max_new_tokens=self.config.max_new_tokens, temperature=self.config.temperature if self.config.temperature > 0 else None, do_sample=self.config.temperature > 0, + repetition_penalty=1.2, ) # Decode only the generated tokens (skip the prompt) generated = outputs[0][inputs["input_ids"].shape[1]:] - return self._tokenizer.decode(generated, skip_special_tokens=True).strip() + answer = self._tokenizer.decode(generated, skip_special_tokens=True).strip() + # Take only the first sentence/line as the answer + for sep in ["\n", ". ", ".\n"]: + if sep in answer: + answer = answer.split(sep)[0].strip() + break + return answer class FakeGenerator: |
