summaryrefslogtreecommitdiff
path: root/hag/encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'hag/encoder.py')
-rw-r--r--hag/encoder.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/hag/encoder.py b/hag/encoder.py
index 7e103f3..c380ad1 100644
--- a/hag/encoder.py
+++ b/hag/encoder.py
@@ -17,18 +17,20 @@ class Encoder:
For testing, use FakeEncoder instead.
"""
- def __init__(self, config: EncoderConfig) -> None:
+ def __init__(self, config: EncoderConfig, device: str = "cpu") -> None:
self.config = config
+ self.device = torch.device(device)
self._tokenizer = None
self._model = None
def _load_model(self) -> None:
- """Lazy-load the model and tokenizer."""
+ """Lazy-load the model and tokenizer, placing model on device."""
from transformers import AutoModel, AutoTokenizer
- logger.info("Loading encoder model: %s", self.config.model_name)
+ logger.info("Loading encoder model: %s (device=%s)", self.config.model_name, self.device)
self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
self._model = AutoModel.from_pretrained(self.config.model_name)
+ self._model.to(self.device)
self._model.eval()
@torch.no_grad()
@@ -39,7 +41,7 @@ class Encoder:
texts: single string or list of strings
Returns:
- (1, d) tensor for single input, (N, d) for list input.
+ (1, d) tensor for single input, (N, d) for list input. On self.device.
"""
if self._model is None:
self._load_model()
@@ -54,6 +56,7 @@ class Encoder:
truncation=True,
return_tensors="pt",
)
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self._model(**inputs)
# Mean pooling over token embeddings
embeddings = outputs.last_hidden_state.mean(dim=1) # (N, d)