diff options
Diffstat (limited to 'hag/encoder.py')
| -rw-r--r-- | hag/encoder.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/hag/encoder.py b/hag/encoder.py index 7e103f3..c380ad1 100644 --- a/hag/encoder.py +++ b/hag/encoder.py @@ -17,18 +17,20 @@ class Encoder: For testing, use FakeEncoder instead. """ - def __init__(self, config: EncoderConfig) -> None: + def __init__(self, config: EncoderConfig, device: str = "cpu") -> None: self.config = config + self.device = torch.device(device) self._tokenizer = None self._model = None def _load_model(self) -> None: - """Lazy-load the model and tokenizer.""" + """Lazy-load the model and tokenizer, placing model on device.""" from transformers import AutoModel, AutoTokenizer - logger.info("Loading encoder model: %s", self.config.model_name) + logger.info("Loading encoder model: %s (device=%s)", self.config.model_name, self.device) self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) self._model = AutoModel.from_pretrained(self.config.model_name) + self._model.to(self.device) self._model.eval() @torch.no_grad() @@ -39,7 +41,7 @@ class Encoder: texts: single string or list of strings Returns: - (1, d) tensor for single input, (N, d) for list input. + (1, d) tensor for single input, (N, d) for list input. On self.device. """ if self._model is None: self._load_model() @@ -54,6 +56,7 @@ class Encoder: truncation=True, return_tensors="pt", ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self._model(**inputs) # Mean pooling over token embeddings embeddings = outputs.last_hidden_state.mean(dim=1) # (N, d) |
