summaryrefslogtreecommitdiff
path: root/kg_rag/vectorDB/create_vectordb.py
blob: b02c9ff37cb71f2f221930650d8fc40bbd5104fd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import pickle
from kg_rag.utility import RecursiveCharacterTextSplitter, Chroma, SentenceTransformerEmbeddings, config_data, time


DATA_PATH = config_data["VECTOR_DB_DISEASE_ENTITY_PATH"]
VECTOR_DB_NAME = config_data["VECTOR_DB_PATH"]
CHUNK_SIZE = int(config_data["VECTOR_DB_CHUNK_SIZE"])
CHUNK_OVERLAP = int(config_data["VECTOR_DB_CHUNK_OVERLAP"])
BATCH_SIZE = int(config_data["VECTOR_DB_BATCH_SIZE"])
SENTENCE_EMBEDDING_MODEL = config_data["VECTOR_DB_SENTENCE_EMBEDDING_MODEL"]


def load_data():
    with open(DATA_PATH, "rb") as f:
        data = pickle.load(f)
    metadata_list = list(map(lambda x:{"source": x + " from SPOKE knowledge graph"}, data))
    return data, metadata_list

def create_vectordb():
    start_time = time.time()
    data, metadata_list = load_data()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
    docs = text_splitter.create_documents(data, metadatas=metadata_list)
    batches = [docs[i:i + BATCH_SIZE] for i in range(0, len(docs), BATCH_SIZE)]
    vectorstore = Chroma(embedding_function=SentenceTransformerEmbeddings(model_name=SENTENCE_EMBEDDING_MODEL), 
                         persist_directory=VECTOR_DB_NAME)
    for batch in batches:
        vectorstore.add_documents(documents=batch)
    end_time = round((time.time() - start_time)/(60), 2)
    print("VectorDB is created in {} mins".format(end_time))


if __name__ == "__main__":
    create_vectordb()