From 09d50e47860da0035e178a442dc936028808a0b3 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 16 Feb 2026 14:44:42 -0600 Subject: Add memory centering, grid search experiments, and energy visualizations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add centering support to MemoryBank (center_query, apply_centering, mean persistence in save/load) to remove centroid attractor in Hopfield dynamics - Add center flag to MemoryBankConfig, device field to PipelineConfig - Grid search scripts: initial (β≤8), residual, high-β, and centered grids with dedup-based LLM caching (89-91% call savings) - Energy landscape visualization: 2D contour, 1D profile, UMAP, PCA heatmap comparing centered vs uncentered dynamics - Experiment log (note.md) documenting 4 rounds of results and root cause analysis of centroid attractor problem - Key finding: β_critical ≈ 37.6 for centered memory; best configs beat FAISS baseline by +3-4% F1 Co-Authored-By: Claude Opus 4.6 --- hag/config.py | 4 +++- hag/encoder.py | 11 ++++++---- hag/generator.py | 22 +++++++++++++++----- hag/memory_bank.py | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++--- hag/pipeline.py | 3 ++- 5 files changed, 87 insertions(+), 14 deletions(-) (limited to 'hag') diff --git a/hag/config.py b/hag/config.py index 793e3a6..10d0aff 100644 --- a/hag/config.py +++ b/hag/config.py @@ -19,6 +19,7 @@ class MemoryBankConfig: embedding_dim: int = 768 # Must match encoder output dim normalize: bool = True # L2-normalize embeddings in memory bank + center: bool = False # Mean-center embeddings to remove centroid attractor @dataclass @@ -35,7 +36,7 @@ class GeneratorConfig: """Configuration for the LLM generator.""" model_name: str = "meta-llama/Llama-3.1-8B-Instruct" - max_new_tokens: int = 128 + max_new_tokens: int = 32 temperature: float = 0.0 # Greedy decoding for reproducibility @@ -48,3 +49,4 @@ class PipelineConfig: encoder: EncoderConfig = field(default_factory=EncoderConfig) generator: GeneratorConfig = field(default_factory=GeneratorConfig) retriever_type: str = "hopfield" # "hopfield" or "faiss" + device: str = "cpu" # "cpu", "cuda", "cuda:0", etc. diff --git a/hag/encoder.py b/hag/encoder.py index 7e103f3..c380ad1 100644 --- a/hag/encoder.py +++ b/hag/encoder.py @@ -17,18 +17,20 @@ class Encoder: For testing, use FakeEncoder instead. """ - def __init__(self, config: EncoderConfig) -> None: + def __init__(self, config: EncoderConfig, device: str = "cpu") -> None: self.config = config + self.device = torch.device(device) self._tokenizer = None self._model = None def _load_model(self) -> None: - """Lazy-load the model and tokenizer.""" + """Lazy-load the model and tokenizer, placing model on device.""" from transformers import AutoModel, AutoTokenizer - logger.info("Loading encoder model: %s", self.config.model_name) + logger.info("Loading encoder model: %s (device=%s)", self.config.model_name, self.device) self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) self._model = AutoModel.from_pretrained(self.config.model_name) + self._model.to(self.device) self._model.eval() @torch.no_grad() @@ -39,7 +41,7 @@ class Encoder: texts: single string or list of strings Returns: - (1, d) tensor for single input, (N, d) for list input. + (1, d) tensor for single input, (N, d) for list input. On self.device. """ if self._model is None: self._load_model() @@ -54,6 +56,7 @@ class Encoder: truncation=True, return_tensors="pt", ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self._model(**inputs) # Mean pooling over token embeddings embeddings = outputs.last_hidden_state.mean(dim=1) # (N, d) diff --git a/hag/generator.py b/hag/generator.py index 2142e0c..d0de468 100644 --- a/hag/generator.py +++ b/hag/generator.py @@ -3,11 +3,13 @@ import logging from typing import List +import torch + from hag.config import GeneratorConfig logger = logging.getLogger(__name__) -PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. +PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. Give ONLY the answer itself in a few words, with no explanation. Context: {context} @@ -24,20 +26,22 @@ class Generator: For testing, use FakeGenerator instead. """ - def __init__(self, config: GeneratorConfig) -> None: + def __init__(self, config: GeneratorConfig, device: str = "cpu") -> None: self.config = config + self.device = torch.device(device) self._tokenizer = None self._model = None def _load_model(self) -> None: - """Lazy-load the model and tokenizer.""" + """Lazy-load the model and tokenizer, placing model on device.""" from transformers import AutoModelForCausalLM, AutoTokenizer - logger.info("Loading generator model: %s", self.config.model_name) + logger.info("Loading generator model: %s (device=%s)", self.config.model_name, self.device) self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) self._model = AutoModelForCausalLM.from_pretrained( self.config.model_name, torch_dtype="auto", + device_map=self.device, ) self._model.eval() @@ -60,15 +64,23 @@ class Generator: prompt = PROMPT_TEMPLATE.format(context=context, question=question) inputs = self._tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self._model.generate( **inputs, max_new_tokens=self.config.max_new_tokens, temperature=self.config.temperature if self.config.temperature > 0 else None, do_sample=self.config.temperature > 0, + repetition_penalty=1.2, ) # Decode only the generated tokens (skip the prompt) generated = outputs[0][inputs["input_ids"].shape[1]:] - return self._tokenizer.decode(generated, skip_special_tokens=True).strip() + answer = self._tokenizer.decode(generated, skip_special_tokens=True).strip() + # Take only the first sentence/line as the answer + for sep in ["\n", ". ", ".\n"]: + if sep in answer: + answer = answer.split(sep)[0].strip() + break + return answer class FakeGenerator: diff --git a/hag/memory_bank.py b/hag/memory_bank.py index 42dcc73..0a0a87c 100644 --- a/hag/memory_bank.py +++ b/hag/memory_bank.py @@ -16,12 +16,17 @@ class MemoryBank: The memory bank is M in R^{d x N} where each column is a passage embedding. Also maintains a mapping from column index to passage text for final retrieval. + + When config.center=True, embeddings are mean-centered to remove the centroid + attractor in Hopfield dynamics. The mean is saved so queries can be centered + with the same offset via center_query(). """ def __init__(self, config: MemoryBankConfig) -> None: self.config = config self.embeddings: Optional[torch.Tensor] = None # (d, N) self.passages: List[str] = [] + self.mean: Optional[torch.Tensor] = None # (d,) — saved for query centering def build_from_embeddings( self, embeddings: torch.Tensor, passages: List[str] @@ -38,10 +43,42 @@ class MemoryBank: ) if self.config.normalize: embeddings = F.normalize(embeddings, dim=-1) + if self.config.center: + self.mean = embeddings.mean(dim=0) # (d,) + embeddings = embeddings - self.mean.unsqueeze(0) # (N, d) + logger.info("Centered memory bank (removed mean)") self.embeddings = embeddings.T # Store as (d, N) for efficient matmul self.passages = list(passages) logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim) + def center_query(self, query: torch.Tensor) -> torch.Tensor: + """Center a query embedding using the saved memory mean. + + Must be called before Hopfield retrieval when config.center=True. + + Args: + query: (d,) or (batch, d) — query embedding(s) + + Returns: + Centered query tensor, same shape as input. + """ + if self.mean is None: + return query + return query - self.mean.to(query.device) + + def apply_centering(self) -> None: + """Center an already-loaded (uncentered) memory bank in-place. + + Useful when loading a memory bank that was saved without centering. + Computes and stores the mean, then subtracts it from embeddings. + """ + if self.embeddings is None: + return + # embeddings is (d, N), mean over columns + self.mean = self.embeddings.mean(dim=1) # (d,) + self.embeddings = self.embeddings - self.mean.unsqueeze(1) # (d, N) + logger.info("Applied centering to loaded memory bank") + def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]: """Given top-k indices, return corresponding passage texts. @@ -67,20 +104,38 @@ class MemoryBank: "embedding_dim": self.config.embedding_dim, "normalize": self.config.normalize, }, + "mean": self.mean, } torch.save(data, path) logger.info("Saved memory bank to %s", path) - def load(self, path: str) -> None: + def load(self, path: str, device: str = "cpu") -> None: """Load memory bank from disk. Args: path: file path to load from + device: device to load tensors onto ("cpu", "cuda", "cuda:0", etc.) """ - data = torch.load(path, weights_only=False) + data = torch.load(path, weights_only=False, map_location=device) self.embeddings = data["embeddings"] self.passages = data["passages"] - logger.info("Loaded memory bank from %s (%d passages)", path, self.size) + self.mean = data.get("mean", None) + logger.info("Loaded memory bank from %s (%d passages, device=%s)", path, self.size, device) + + def to(self, device: str) -> "MemoryBank": + """Move memory bank embeddings to the specified device. + + Args: + device: target device ("cpu", "cuda", "cuda:0", etc.) + + Returns: + self (for chaining). + """ + if self.embeddings is not None: + self.embeddings = self.embeddings.to(device) + if self.mean is not None: + self.mean = self.mean.to(device) + return self @property def size(self) -> int: diff --git a/hag/pipeline.py b/hag/pipeline.py index 1fefb84..086b3be 100644 --- a/hag/pipeline.py +++ b/hag/pipeline.py @@ -82,7 +82,8 @@ class RAGPipeline: if self.retriever_type == "hopfield": retrieval_result = self.hopfield_retriever.retrieve(query_emb) else: - query_np = query_emb.detach().numpy().astype(np.float32) + # FAISS requires CPU numpy arrays + query_np = query_emb.detach().cpu().numpy().astype(np.float32) retrieval_result = self.faiss_retriever.retrieve(query_np) # Generate -- cgit v1.2.3