from copy import deepcopy from src.utils import compute_mdhash_id import numpy as np import pandas as pd import os class EmbeddingStore: def __init__(self, embedding_model, db_filename, batch_size, namespace): self.embedding_model = embedding_model self.db_filename = db_filename self.batch_size = batch_size self.namespace = namespace self.hash_ids = [] self.texts = [] self.embeddings = [] self.hash_id_to_text = {} self.hash_id_to_idx = {} self.text_to_hash_id = {} self._load_data() def _load_data(self): if os.path.exists(self.db_filename): df = pd.read_parquet(self.db_filename) self.hash_ids = df["hash_id"].values.tolist() self.texts = df["text"].values.tolist() self.embeddings = df["embedding"].values.tolist() self.hash_id_to_idx = {h: idx for idx, h in enumerate(self.hash_ids)} self.hash_id_to_text = {h: t for h, t in zip(self.hash_ids, self.texts)} self.text_to_hash_id = {t: h for t, h in zip(self.texts, self.hash_ids)} print(f"[{self.namespace}] Loaded {len(self.hash_ids)} records from {self.db_filename}") def insert_text(self, text_list): nodes_dict = {} for text in text_list: nodes_dict[compute_mdhash_id(text, prefix=self.namespace + "-")] = {'content': text} all_hash_ids = list(nodes_dict.keys()) existing = set(self.hash_ids) missing_ids = [h for h in all_hash_ids if h not in existing] texts_to_encode = [nodes_dict[hash_id]["content"] for hash_id in missing_ids] all_embeddings = self.embedding_model.encode(texts_to_encode,normalize_embeddings=True, show_progress_bar=False,batch_size=self.batch_size) self._upsert(missing_ids, texts_to_encode, all_embeddings) def _upsert(self, hash_ids, texts, embeddings): self.hash_ids.extend(hash_ids) self.texts.extend(texts) self.embeddings.extend(embeddings) self.hash_id_to_idx = {h: idx for idx, h in enumerate(self.hash_ids)} self.hash_id_to_text = {h: t for h, t in zip(self.hash_ids, self.texts)} self.text_to_hash_id = {t: h for t, h in zip(self.texts, self.hash_ids)} self._save_data() def _save_data(self): data_to_save = pd.DataFrame({ "hash_id": self.hash_ids, "text": self.texts, "embedding": self.embeddings }) os.makedirs(os.path.dirname(self.db_filename), exist_ok=True) data_to_save.to_parquet(self.db_filename, index=False) def get_hash_id_to_text(self): return deepcopy(self.hash_id_to_text) def encode_texts(self, texts): return self.embedding_model.encode(texts, normalize_embeddings=True, show_progress_bar=False, batch_size=self.batch_size) def get_embeddings(self, hash_ids): if not hash_ids: return np.array([]) indices = np.array([self.hash_id_to_idx[h] for h in hash_ids], dtype=np.intp) embeddings = np.array(self.embeddings)[indices] return embeddings