summaryrefslogtreecommitdiff
path: root/scripts/recompute_embeddings.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/recompute_embeddings.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'scripts/recompute_embeddings.py')
-rw-r--r--scripts/recompute_embeddings.py65
1 files changed, 65 insertions, 0 deletions
diff --git a/scripts/recompute_embeddings.py b/scripts/recompute_embeddings.py
new file mode 100644
index 0000000..884cc7b
--- /dev/null
+++ b/scripts/recompute_embeddings.py
@@ -0,0 +1,65 @@
+import json
+import os
+import sys
+import numpy as np
+import torch
+from tqdm import tqdm
+
+# Add src to sys.path
+sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
+
+from personalization.config.settings import load_local_models_config
+from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
+from personalization.retrieval.preference_store.schemas import MemoryCard
+
+CARDS_FILE = "data/corpora/memory_cards.jsonl"
+EMBEDDINGS_FILE = "data/corpora/memory_embeddings.npy"
+
+def recompute_embeddings():
+ if not os.path.exists(CARDS_FILE):
+ print(f"Error: {CARDS_FILE} not found.")
+ return
+
+ print("Loading configuration and model...")
+ cfg = load_local_models_config()
+ embed_model = Qwen3Embedding8B.from_config(cfg)
+
+ print(f"Reading memory cards from {CARDS_FILE}...")
+ cards = []
+ texts = []
+ with open(CARDS_FILE, "r", encoding="utf-8") as f:
+ for line in f:
+ if not line.strip(): continue
+ card = MemoryCard.model_validate_json(line)
+ cards.append(card)
+ # Embedding source: note_text (preference) or raw_query?
+ # Usually we embed the note_text for retrieval.
+ texts.append(card.note_text)
+
+ print(f"Total cards: {len(cards)}")
+
+ if not cards:
+ print("No cards found.")
+ return
+
+ print("Computing embeddings...")
+ # Batch processing
+ batch_size = 32
+ all_embs = []
+
+ for i in tqdm(range(0, len(texts), batch_size)):
+ batch_texts = texts[i : i + batch_size]
+ # Qwen3Embedding8B.encode returns list of lists (if return_tensor=False)
+ embs = embed_model.encode(batch_texts, return_tensor=False)
+ all_embs.extend(embs)
+
+ emb_array = np.array(all_embs, dtype=np.float32)
+ print(f"Embeddings shape: {emb_array.shape}")
+
+ print(f"Saving to {EMBEDDINGS_FILE}...")
+ np.save(EMBEDDINGS_FILE, emb_array)
+ print("Done!")
+
+if __name__ == "__main__":
+ recompute_embeddings()
+