summaryrefslogtreecommitdiff
path: root/kg_rag/vectorDB
diff options
context:
space:
mode:
authormaszhongming <mingz5@illinois.edu>2025-09-16 15:15:29 -0500
committermaszhongming <mingz5@illinois.edu>2025-09-16 15:15:29 -0500
commit73c194f304f827b55081b15524479f82a1b7d94c (patch)
tree5e8660e421915420892c5eca18f1ad680f80a861 /kg_rag/vectorDB
Initial commit
Diffstat (limited to 'kg_rag/vectorDB')
-rw-r--r--kg_rag/vectorDB/__init__.py0
-rw-r--r--kg_rag/vectorDB/create_vectordb.py35
2 files changed, 35 insertions, 0 deletions
diff --git a/kg_rag/vectorDB/__init__.py b/kg_rag/vectorDB/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/kg_rag/vectorDB/__init__.py
diff --git a/kg_rag/vectorDB/create_vectordb.py b/kg_rag/vectorDB/create_vectordb.py
new file mode 100644
index 0000000..b02c9ff
--- /dev/null
+++ b/kg_rag/vectorDB/create_vectordb.py
@@ -0,0 +1,35 @@
+import pickle
+from kg_rag.utility import RecursiveCharacterTextSplitter, Chroma, SentenceTransformerEmbeddings, config_data, time
+
+
+DATA_PATH = config_data["VECTOR_DB_DISEASE_ENTITY_PATH"]
+VECTOR_DB_NAME = config_data["VECTOR_DB_PATH"]
+CHUNK_SIZE = int(config_data["VECTOR_DB_CHUNK_SIZE"])
+CHUNK_OVERLAP = int(config_data["VECTOR_DB_CHUNK_OVERLAP"])
+BATCH_SIZE = int(config_data["VECTOR_DB_BATCH_SIZE"])
+SENTENCE_EMBEDDING_MODEL = config_data["VECTOR_DB_SENTENCE_EMBEDDING_MODEL"]
+
+
+def load_data():
+ with open(DATA_PATH, "rb") as f:
+ data = pickle.load(f)
+ metadata_list = list(map(lambda x:{"source": x + " from SPOKE knowledge graph"}, data))
+ return data, metadata_list
+
+def create_vectordb():
+ start_time = time.time()
+ data, metadata_list = load_data()
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
+ docs = text_splitter.create_documents(data, metadatas=metadata_list)
+ batches = [docs[i:i + BATCH_SIZE] for i in range(0, len(docs), BATCH_SIZE)]
+ vectorstore = Chroma(embedding_function=SentenceTransformerEmbeddings(model_name=SENTENCE_EMBEDDING_MODEL),
+ persist_directory=VECTOR_DB_NAME)
+ for batch in batches:
+ vectorstore.add_documents(documents=batch)
+ end_time = round((time.time() - start_time)/(60), 2)
+ print("VectorDB is created in {} mins".format(end_time))
+
+
+if __name__ == "__main__":
+ create_vectordb()
+