summaryrefslogtreecommitdiff
path: root/scripts/recompute_embeddings.py
blob: 884cc7b485404ca1a1f76bf1fac63aba8b4fa970 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import json
import os
import sys
import numpy as np
import torch
from tqdm import tqdm

# Add src to sys.path
sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))

from personalization.config.settings import load_local_models_config
from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
from personalization.retrieval.preference_store.schemas import MemoryCard

CARDS_FILE = "data/corpora/memory_cards.jsonl"
EMBEDDINGS_FILE = "data/corpora/memory_embeddings.npy"

def recompute_embeddings():
    if not os.path.exists(CARDS_FILE):
        print(f"Error: {CARDS_FILE} not found.")
        return

    print("Loading configuration and model...")
    cfg = load_local_models_config()
    embed_model = Qwen3Embedding8B.from_config(cfg)
    
    print(f"Reading memory cards from {CARDS_FILE}...")
    cards = []
    texts = []
    with open(CARDS_FILE, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip(): continue
            card = MemoryCard.model_validate_json(line)
            cards.append(card)
            # Embedding source: note_text (preference) or raw_query?
            # Usually we embed the note_text for retrieval.
            texts.append(card.note_text)
            
    print(f"Total cards: {len(cards)}")
    
    if not cards:
        print("No cards found.")
        return

    print("Computing embeddings...")
    # Batch processing
    batch_size = 32
    all_embs = []
    
    for i in tqdm(range(0, len(texts), batch_size)):
        batch_texts = texts[i : i + batch_size]
        # Qwen3Embedding8B.encode returns list of lists (if return_tensor=False)
        embs = embed_model.encode(batch_texts, return_tensor=False)
        all_embs.extend(embs)
        
    emb_array = np.array(all_embs, dtype=np.float32)
    print(f"Embeddings shape: {emb_array.shape}")
    
    print(f"Saving to {EMBEDDINGS_FILE}...")
    np.save(EMBEDDINGS_FILE, emb_array)
    print("Done!")

if __name__ == "__main__":
    recompute_embeddings()