summaryrefslogtreecommitdiff
path: root/hag/generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'hag/generator.py')
-rw-r--r--hag/generator.py22
1 files changed, 17 insertions, 5 deletions
diff --git a/hag/generator.py b/hag/generator.py
index 2142e0c..d0de468 100644
--- a/hag/generator.py
+++ b/hag/generator.py
@@ -3,11 +3,13 @@
import logging
from typing import List
+import torch
+
from hag.config import GeneratorConfig
logger = logging.getLogger(__name__)
-PROMPT_TEMPLATE = """Answer the following question based on the provided context passages.
+PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. Give ONLY the answer itself in a few words, with no explanation.
Context:
{context}
@@ -24,20 +26,22 @@ class Generator:
For testing, use FakeGenerator instead.
"""
- def __init__(self, config: GeneratorConfig) -> None:
+ def __init__(self, config: GeneratorConfig, device: str = "cpu") -> None:
self.config = config
+ self.device = torch.device(device)
self._tokenizer = None
self._model = None
def _load_model(self) -> None:
- """Lazy-load the model and tokenizer."""
+ """Lazy-load the model and tokenizer, placing model on device."""
from transformers import AutoModelForCausalLM, AutoTokenizer
- logger.info("Loading generator model: %s", self.config.model_name)
+ logger.info("Loading generator model: %s (device=%s)", self.config.model_name, self.device)
self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
self._model = AutoModelForCausalLM.from_pretrained(
self.config.model_name,
torch_dtype="auto",
+ device_map=self.device,
)
self._model.eval()
@@ -60,15 +64,23 @@ class Generator:
prompt = PROMPT_TEMPLATE.format(context=context, question=question)
inputs = self._tokenizer(prompt, return_tensors="pt")
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self._model.generate(
**inputs,
max_new_tokens=self.config.max_new_tokens,
temperature=self.config.temperature if self.config.temperature > 0 else None,
do_sample=self.config.temperature > 0,
+ repetition_penalty=1.2,
)
# Decode only the generated tokens (skip the prompt)
generated = outputs[0][inputs["input_ids"].shape[1]:]
- return self._tokenizer.decode(generated, skip_special_tokens=True).strip()
+ answer = self._tokenizer.decode(generated, skip_special_tokens=True).strip()
+ # Take only the first sentence/line as the answer
+ for sep in ["\n", ". ", ".\n"]:
+ if sep in answer:
+ answer = answer.split(sep)[0].strip()
+ break
+ return answer
class FakeGenerator: