summaryrefslogtreecommitdiff
path: root/hag
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-02-16 14:44:42 -0600
committerYurenHao0426 <Blackhao0426@gmail.com>2026-02-16 14:44:42 -0600
commit09d50e47860da0035e178a442dc936028808a0b3 (patch)
tree9d651b0c7d289a9a0405953f2da989a3c431f147 /hag
parentc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (diff)
Add memory centering, grid search experiments, and energy visualizationsHEADmaster
- Add centering support to MemoryBank (center_query, apply_centering, mean persistence in save/load) to remove centroid attractor in Hopfield dynamics - Add center flag to MemoryBankConfig, device field to PipelineConfig - Grid search scripts: initial (β≤8), residual, high-β, and centered grids with dedup-based LLM caching (89-91% call savings) - Energy landscape visualization: 2D contour, 1D profile, UMAP, PCA heatmap comparing centered vs uncentered dynamics - Experiment log (note.md) documenting 4 rounds of results and root cause analysis of centroid attractor problem - Key finding: β_critical ≈ 37.6 for centered memory; best configs beat FAISS baseline by +3-4% F1 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag')
-rw-r--r--hag/config.py4
-rw-r--r--hag/encoder.py11
-rw-r--r--hag/generator.py22
-rw-r--r--hag/memory_bank.py61
-rw-r--r--hag/pipeline.py3
5 files changed, 87 insertions, 14 deletions
diff --git a/hag/config.py b/hag/config.py
index 793e3a6..10d0aff 100644
--- a/hag/config.py
+++ b/hag/config.py
@@ -19,6 +19,7 @@ class MemoryBankConfig:
embedding_dim: int = 768 # Must match encoder output dim
normalize: bool = True # L2-normalize embeddings in memory bank
+ center: bool = False # Mean-center embeddings to remove centroid attractor
@dataclass
@@ -35,7 +36,7 @@ class GeneratorConfig:
"""Configuration for the LLM generator."""
model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
- max_new_tokens: int = 128
+ max_new_tokens: int = 32
temperature: float = 0.0 # Greedy decoding for reproducibility
@@ -48,3 +49,4 @@ class PipelineConfig:
encoder: EncoderConfig = field(default_factory=EncoderConfig)
generator: GeneratorConfig = field(default_factory=GeneratorConfig)
retriever_type: str = "hopfield" # "hopfield" or "faiss"
+ device: str = "cpu" # "cpu", "cuda", "cuda:0", etc.
diff --git a/hag/encoder.py b/hag/encoder.py
index 7e103f3..c380ad1 100644
--- a/hag/encoder.py
+++ b/hag/encoder.py
@@ -17,18 +17,20 @@ class Encoder:
For testing, use FakeEncoder instead.
"""
- def __init__(self, config: EncoderConfig) -> None:
+ def __init__(self, config: EncoderConfig, device: str = "cpu") -> None:
self.config = config
+ self.device = torch.device(device)
self._tokenizer = None
self._model = None
def _load_model(self) -> None:
- """Lazy-load the model and tokenizer."""
+ """Lazy-load the model and tokenizer, placing model on device."""
from transformers import AutoModel, AutoTokenizer
- logger.info("Loading encoder model: %s", self.config.model_name)
+ logger.info("Loading encoder model: %s (device=%s)", self.config.model_name, self.device)
self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
self._model = AutoModel.from_pretrained(self.config.model_name)
+ self._model.to(self.device)
self._model.eval()
@torch.no_grad()
@@ -39,7 +41,7 @@ class Encoder:
texts: single string or list of strings
Returns:
- (1, d) tensor for single input, (N, d) for list input.
+ (1, d) tensor for single input, (N, d) for list input. On self.device.
"""
if self._model is None:
self._load_model()
@@ -54,6 +56,7 @@ class Encoder:
truncation=True,
return_tensors="pt",
)
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self._model(**inputs)
# Mean pooling over token embeddings
embeddings = outputs.last_hidden_state.mean(dim=1) # (N, d)
diff --git a/hag/generator.py b/hag/generator.py
index 2142e0c..d0de468 100644
--- a/hag/generator.py
+++ b/hag/generator.py
@@ -3,11 +3,13 @@
import logging
from typing import List
+import torch
+
from hag.config import GeneratorConfig
logger = logging.getLogger(__name__)
-PROMPT_TEMPLATE = """Answer the following question based on the provided context passages.
+PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. Give ONLY the answer itself in a few words, with no explanation.
Context:
{context}
@@ -24,20 +26,22 @@ class Generator:
For testing, use FakeGenerator instead.
"""
- def __init__(self, config: GeneratorConfig) -> None:
+ def __init__(self, config: GeneratorConfig, device: str = "cpu") -> None:
self.config = config
+ self.device = torch.device(device)
self._tokenizer = None
self._model = None
def _load_model(self) -> None:
- """Lazy-load the model and tokenizer."""
+ """Lazy-load the model and tokenizer, placing model on device."""
from transformers import AutoModelForCausalLM, AutoTokenizer
- logger.info("Loading generator model: %s", self.config.model_name)
+ logger.info("Loading generator model: %s (device=%s)", self.config.model_name, self.device)
self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
self._model = AutoModelForCausalLM.from_pretrained(
self.config.model_name,
torch_dtype="auto",
+ device_map=self.device,
)
self._model.eval()
@@ -60,15 +64,23 @@ class Generator:
prompt = PROMPT_TEMPLATE.format(context=context, question=question)
inputs = self._tokenizer(prompt, return_tensors="pt")
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self._model.generate(
**inputs,
max_new_tokens=self.config.max_new_tokens,
temperature=self.config.temperature if self.config.temperature > 0 else None,
do_sample=self.config.temperature > 0,
+ repetition_penalty=1.2,
)
# Decode only the generated tokens (skip the prompt)
generated = outputs[0][inputs["input_ids"].shape[1]:]
- return self._tokenizer.decode(generated, skip_special_tokens=True).strip()
+ answer = self._tokenizer.decode(generated, skip_special_tokens=True).strip()
+ # Take only the first sentence/line as the answer
+ for sep in ["\n", ". ", ".\n"]:
+ if sep in answer:
+ answer = answer.split(sep)[0].strip()
+ break
+ return answer
class FakeGenerator:
diff --git a/hag/memory_bank.py b/hag/memory_bank.py
index 42dcc73..0a0a87c 100644
--- a/hag/memory_bank.py
+++ b/hag/memory_bank.py
@@ -16,12 +16,17 @@ class MemoryBank:
The memory bank is M in R^{d x N} where each column is a passage embedding.
Also maintains a mapping from column index to passage text for final retrieval.
+
+ When config.center=True, embeddings are mean-centered to remove the centroid
+ attractor in Hopfield dynamics. The mean is saved so queries can be centered
+ with the same offset via center_query().
"""
def __init__(self, config: MemoryBankConfig) -> None:
self.config = config
self.embeddings: Optional[torch.Tensor] = None # (d, N)
self.passages: List[str] = []
+ self.mean: Optional[torch.Tensor] = None # (d,) — saved for query centering
def build_from_embeddings(
self, embeddings: torch.Tensor, passages: List[str]
@@ -38,10 +43,42 @@ class MemoryBank:
)
if self.config.normalize:
embeddings = F.normalize(embeddings, dim=-1)
+ if self.config.center:
+ self.mean = embeddings.mean(dim=0) # (d,)
+ embeddings = embeddings - self.mean.unsqueeze(0) # (N, d)
+ logger.info("Centered memory bank (removed mean)")
self.embeddings = embeddings.T # Store as (d, N) for efficient matmul
self.passages = list(passages)
logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim)
+ def center_query(self, query: torch.Tensor) -> torch.Tensor:
+ """Center a query embedding using the saved memory mean.
+
+ Must be called before Hopfield retrieval when config.center=True.
+
+ Args:
+ query: (d,) or (batch, d) — query embedding(s)
+
+ Returns:
+ Centered query tensor, same shape as input.
+ """
+ if self.mean is None:
+ return query
+ return query - self.mean.to(query.device)
+
+ def apply_centering(self) -> None:
+ """Center an already-loaded (uncentered) memory bank in-place.
+
+ Useful when loading a memory bank that was saved without centering.
+ Computes and stores the mean, then subtracts it from embeddings.
+ """
+ if self.embeddings is None:
+ return
+ # embeddings is (d, N), mean over columns
+ self.mean = self.embeddings.mean(dim=1) # (d,)
+ self.embeddings = self.embeddings - self.mean.unsqueeze(1) # (d, N)
+ logger.info("Applied centering to loaded memory bank")
+
def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]:
"""Given top-k indices, return corresponding passage texts.
@@ -67,20 +104,38 @@ class MemoryBank:
"embedding_dim": self.config.embedding_dim,
"normalize": self.config.normalize,
},
+ "mean": self.mean,
}
torch.save(data, path)
logger.info("Saved memory bank to %s", path)
- def load(self, path: str) -> None:
+ def load(self, path: str, device: str = "cpu") -> None:
"""Load memory bank from disk.
Args:
path: file path to load from
+ device: device to load tensors onto ("cpu", "cuda", "cuda:0", etc.)
"""
- data = torch.load(path, weights_only=False)
+ data = torch.load(path, weights_only=False, map_location=device)
self.embeddings = data["embeddings"]
self.passages = data["passages"]
- logger.info("Loaded memory bank from %s (%d passages)", path, self.size)
+ self.mean = data.get("mean", None)
+ logger.info("Loaded memory bank from %s (%d passages, device=%s)", path, self.size, device)
+
+ def to(self, device: str) -> "MemoryBank":
+ """Move memory bank embeddings to the specified device.
+
+ Args:
+ device: target device ("cpu", "cuda", "cuda:0", etc.)
+
+ Returns:
+ self (for chaining).
+ """
+ if self.embeddings is not None:
+ self.embeddings = self.embeddings.to(device)
+ if self.mean is not None:
+ self.mean = self.mean.to(device)
+ return self
@property
def size(self) -> int:
diff --git a/hag/pipeline.py b/hag/pipeline.py
index 1fefb84..086b3be 100644
--- a/hag/pipeline.py
+++ b/hag/pipeline.py
@@ -82,7 +82,8 @@ class RAGPipeline:
if self.retriever_type == "hopfield":
retrieval_result = self.hopfield_retriever.retrieve(query_emb)
else:
- query_np = query_emb.detach().numpy().astype(np.float32)
+ # FAISS requires CPU numpy arrays
+ query_np = query_emb.detach().cpu().numpy().astype(np.float32)
retrieval_result = self.faiss_retriever.retrieve(query_np)
# Generate