diff options
Diffstat (limited to 'scripts/recompute_embeddings.py')
| -rw-r--r-- | scripts/recompute_embeddings.py | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/scripts/recompute_embeddings.py b/scripts/recompute_embeddings.py new file mode 100644 index 0000000..884cc7b --- /dev/null +++ b/scripts/recompute_embeddings.py @@ -0,0 +1,65 @@ +import json +import os +import sys +import numpy as np +import torch +from tqdm import tqdm + +# Add src to sys.path +sys.path.append(os.path.join(os.path.dirname(__file__), "../src")) + +from personalization.config.settings import load_local_models_config +from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B +from personalization.retrieval.preference_store.schemas import MemoryCard + +CARDS_FILE = "data/corpora/memory_cards.jsonl" +EMBEDDINGS_FILE = "data/corpora/memory_embeddings.npy" + +def recompute_embeddings(): + if not os.path.exists(CARDS_FILE): + print(f"Error: {CARDS_FILE} not found.") + return + + print("Loading configuration and model...") + cfg = load_local_models_config() + embed_model = Qwen3Embedding8B.from_config(cfg) + + print(f"Reading memory cards from {CARDS_FILE}...") + cards = [] + texts = [] + with open(CARDS_FILE, "r", encoding="utf-8") as f: + for line in f: + if not line.strip(): continue + card = MemoryCard.model_validate_json(line) + cards.append(card) + # Embedding source: note_text (preference) or raw_query? + # Usually we embed the note_text for retrieval. + texts.append(card.note_text) + + print(f"Total cards: {len(cards)}") + + if not cards: + print("No cards found.") + return + + print("Computing embeddings...") + # Batch processing + batch_size = 32 + all_embs = [] + + for i in tqdm(range(0, len(texts), batch_size)): + batch_texts = texts[i : i + batch_size] + # Qwen3Embedding8B.encode returns list of lists (if return_tensor=False) + embs = embed_model.encode(batch_texts, return_tensor=False) + all_embs.extend(embs) + + emb_array = np.array(all_embs, dtype=np.float32) + print(f"Embeddings shape: {emb_array.shape}") + + print(f"Saving to {EMBEDDINGS_FILE}...") + np.save(EMBEDDINGS_FILE, emb_array) + print("Done!") + +if __name__ == "__main__": + recompute_embeddings() + |
