summaryrefslogtreecommitdiff
path: root/hag/encoder.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-02-16 14:44:42 -0600
committerYurenHao0426 <Blackhao0426@gmail.com>2026-02-16 14:44:42 -0600
commit09d50e47860da0035e178a442dc936028808a0b3 (patch)
tree9d651b0c7d289a9a0405953f2da989a3c431f147 /hag/encoder.py
parentc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (diff)
Add memory centering, grid search experiments, and energy visualizationsHEADmaster
- Add centering support to MemoryBank (center_query, apply_centering, mean persistence in save/load) to remove centroid attractor in Hopfield dynamics - Add center flag to MemoryBankConfig, device field to PipelineConfig - Grid search scripts: initial (β≤8), residual, high-β, and centered grids with dedup-based LLM caching (89-91% call savings) - Energy landscape visualization: 2D contour, 1D profile, UMAP, PCA heatmap comparing centered vs uncentered dynamics - Experiment log (note.md) documenting 4 rounds of results and root cause analysis of centroid attractor problem - Key finding: β_critical ≈ 37.6 for centered memory; best configs beat FAISS baseline by +3-4% F1 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag/encoder.py')
-rw-r--r--hag/encoder.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/hag/encoder.py b/hag/encoder.py
index 7e103f3..c380ad1 100644
--- a/hag/encoder.py
+++ b/hag/encoder.py
@@ -17,18 +17,20 @@ class Encoder:
For testing, use FakeEncoder instead.
"""
- def __init__(self, config: EncoderConfig) -> None:
+ def __init__(self, config: EncoderConfig, device: str = "cpu") -> None:
self.config = config
+ self.device = torch.device(device)
self._tokenizer = None
self._model = None
def _load_model(self) -> None:
- """Lazy-load the model and tokenizer."""
+ """Lazy-load the model and tokenizer, placing model on device."""
from transformers import AutoModel, AutoTokenizer
- logger.info("Loading encoder model: %s", self.config.model_name)
+ logger.info("Loading encoder model: %s (device=%s)", self.config.model_name, self.device)
self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
self._model = AutoModel.from_pretrained(self.config.model_name)
+ self._model.to(self.device)
self._model.eval()
@torch.no_grad()
@@ -39,7 +41,7 @@ class Encoder:
texts: single string or list of strings
Returns:
- (1, d) tensor for single input, (N, d) for list input.
+ (1, d) tensor for single input, (N, d) for list input. On self.device.
"""
if self._model is None:
self._load_model()
@@ -54,6 +56,7 @@ class Encoder:
truncation=True,
return_tensors="pt",
)
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self._model(**inputs)
# Mean pooling over token embeddings
embeddings = outputs.last_hidden_state.mean(dim=1) # (N, d)