From ccff87c15263d1d63235643d54322b991366952e Mon Sep 17 00:00:00 2001 From: LuyaoZhuang Date: Sun, 26 Oct 2025 05:09:57 -0400 Subject: commit --- src/embedding_store.py | 80 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 src/embedding_store.py (limited to 'src/embedding_store.py') diff --git a/src/embedding_store.py b/src/embedding_store.py new file mode 100644 index 0000000..2aff57e --- /dev/null +++ b/src/embedding_store.py @@ -0,0 +1,80 @@ +from copy import deepcopy +from src.utils import compute_mdhash_id +import numpy as np +import pandas as pd +import os + +class EmbeddingStore: + def __init__(self, embedding_model, db_filename, batch_size, namespace): + self.embedding_model = embedding_model + self.db_filename = db_filename + self.batch_size = batch_size + self.namespace = namespace + + self.hash_ids = [] + self.texts = [] + self.embeddings = [] + self.hash_id_to_text = {} + self.hash_id_to_idx = {} + self.text_to_hash_id = {} + + self._load_data() + + def _load_data(self): + if os.path.exists(self.db_filename): + df = pd.read_parquet(self.db_filename) + self.hash_ids = df["hash_id"].values.tolist() + self.texts = df["text"].values.tolist() + self.embeddings = df["embedding"].values.tolist() + + self.hash_id_to_idx = {h: idx for idx, h in enumerate(self.hash_ids)} + self.hash_id_to_text = {h: t for h, t in zip(self.hash_ids, self.texts)} + self.text_to_hash_id = {t: h for t, h in zip(self.texts, self.hash_ids)} + print(f"[{self.namespace}] Loaded {len(self.hash_ids)} records from {self.db_filename}") + + def insert_text(self, text_list): + nodes_dict = {} + for text in text_list: + nodes_dict[compute_mdhash_id(text, prefix=self.namespace + "-")] = {'content': text} + + all_hash_ids = list(nodes_dict.keys()) + + existing = set(self.hash_ids) + missing_ids = [h for h in all_hash_ids if h not in existing] + texts_to_encode = [nodes_dict[hash_id]["content"] for hash_id in missing_ids] + all_embeddings = self.embedding_model.encode(texts_to_encode,normalize_embeddings=True, show_progress_bar=False,batch_size=self.batch_size) + + self._upsert(missing_ids, texts_to_encode, all_embeddings) + + def _upsert(self, hash_ids, texts, embeddings): + self.hash_ids.extend(hash_ids) + self.texts.extend(texts) + self.embeddings.extend(embeddings) + + self.hash_id_to_idx = {h: idx for idx, h in enumerate(self.hash_ids)} + self.hash_id_to_text = {h: t for h, t in zip(self.hash_ids, self.texts)} + self.text_to_hash_id = {t: h for t, h in zip(self.texts, self.hash_ids)} + + self._save_data() + + def _save_data(self): + data_to_save = pd.DataFrame({ + "hash_id": self.hash_ids, + "text": self.texts, + "embedding": self.embeddings + }) + os.makedirs(os.path.dirname(self.db_filename), exist_ok=True) + data_to_save.to_parquet(self.db_filename, index=False) + + def get_hash_id_to_text(self): + return deepcopy(self.hash_id_to_text) + + def encode_texts(self, texts): + return self.embedding_model.encode(texts, normalize_embeddings=True, show_progress_bar=False, batch_size=self.batch_size) + + def get_embeddings(self, hash_ids): + if not hash_ids: + return np.array([]) + indices = np.array([self.hash_id_to_idx[h] for h in hash_ids], dtype=np.intp) + embeddings = np.array(self.embeddings)[indices] + return embeddings \ No newline at end of file -- cgit v1.2.3