summaryrefslogtreecommitdiff
path: root/kg_rag
diff options
context:
space:
mode:
Diffstat (limited to 'kg_rag')
-rw-r--r--kg_rag/.DS_Storebin0 -> 6148 bytes
-rw-r--r--kg_rag/__init__.py0
-rw-r--r--kg_rag/config_loader.py17
-rw-r--r--kg_rag/prompt_based_generation/GPT/run_mcq_qa.py31
-rw-r--r--kg_rag/prompt_based_generation/GPT/run_true_false_generation.py33
-rw-r--r--kg_rag/prompt_based_generation/GPT/text_generation.py33
-rw-r--r--kg_rag/prompt_based_generation/Llama/run_mcq_qa.py38
-rw-r--r--kg_rag/prompt_based_generation/Llama/run_mcq_qa_medgpt.py38
-rw-r--r--kg_rag/prompt_based_generation/Llama/run_true_false_generation.py35
-rw-r--r--kg_rag/prompt_based_generation/Llama/text_generation.py36
-rw-r--r--kg_rag/rag_based_generation/GPT/drug_action.py52
-rw-r--r--kg_rag/rag_based_generation/GPT/drug_repurposing_v2.py68
-rw-r--r--kg_rag/rag_based_generation/GPT/run_drug_repurposing.py57
-rw-r--r--kg_rag/rag_based_generation/GPT/run_mcq_qa.py91
-rw-r--r--kg_rag/rag_based_generation/GPT/run_single_disease_entity_hyperparameter_tuning.py61
-rw-r--r--kg_rag/rag_based_generation/GPT/run_true_false_generation.py52
-rw-r--r--kg_rag/rag_based_generation/GPT/run_two_disease_entity_hyperparameter_tuning.py57
-rw-r--r--kg_rag/rag_based_generation/GPT/text_generation.py61
-rw-r--r--kg_rag/rag_based_generation/Llama/run_drug_repurposing.py60
-rw-r--r--kg_rag/rag_based_generation/Llama/run_mcq_qa.py61
-rw-r--r--kg_rag/rag_based_generation/Llama/run_mcq_qa_medgpt.py61
-rw-r--r--kg_rag/rag_based_generation/Llama/run_true_false_generation.py59
-rw-r--r--kg_rag/rag_based_generation/Llama/text_generation.py60
-rw-r--r--kg_rag/run_setup.py77
-rw-r--r--kg_rag/test/__init__.py0
-rw-r--r--kg_rag/test/test_vectordb.py42
-rw-r--r--kg_rag/utility.py443
-rw-r--r--kg_rag/vectorDB/__init__.py0
-rw-r--r--kg_rag/vectorDB/create_vectordb.py35
29 files changed, 1658 insertions, 0 deletions
diff --git a/kg_rag/.DS_Store b/kg_rag/.DS_Store
new file mode 100644
index 0000000..9f93d7f
--- /dev/null
+++ b/kg_rag/.DS_Store
Binary files differ
diff --git a/kg_rag/__init__.py b/kg_rag/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/kg_rag/__init__.py
diff --git a/kg_rag/config_loader.py b/kg_rag/config_loader.py
new file mode 100644
index 0000000..f3f5bd4
--- /dev/null
+++ b/kg_rag/config_loader.py
@@ -0,0 +1,17 @@
+import yaml
+import os
+
+with open('config.yaml', 'r') as f:
+ config_data = yaml.safe_load(f)
+
+with open('system_prompts.yaml', 'r') as f:
+ system_prompts = yaml.safe_load(f)
+
+#if 'GPT_CONFIG_FILE' in config_data:
+# config_data['GPT_CONFIG_FILE'] = config_data['GPT_CONFIG_FILE'].replace('$HOME', os.environ['HOME'])
+
+
+__all__ = [
+ 'config_data',
+ 'system_prompts'
+]
diff --git a/kg_rag/prompt_based_generation/GPT/run_mcq_qa.py b/kg_rag/prompt_based_generation/GPT/run_mcq_qa.py
new file mode 100644
index 0000000..762242e
--- /dev/null
+++ b/kg_rag/prompt_based_generation/GPT/run_mcq_qa.py
@@ -0,0 +1,31 @@
+from kg_rag.utility import *
+import sys
+from tqdm import tqdm
+
+CHAT_MODEL_ID = sys.argv[1]
+
+QUESTION_PATH = config_data["MCQ_PATH"]
+SYSTEM_PROMPT = system_prompts["MCQ_QUESTION_PROMPT_BASED"]
+SAVE_PATH = config_data["SAVE_RESULTS_PATH"]
+TEMPERATURE = config_data["LLM_TEMPERATURE"]
+
+CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID
+
+save_name = "_".join(CHAT_MODEL_ID.split("-"))+"_prompt_based_response_for_two_hop_mcq_from_monarch_and_robokop.csv"
+
+
+def main():
+ start_time = time.time()
+ question_df = pd.read_csv(QUESTION_PATH)
+ answer_list = []
+ for index, row in tqdm(question_df.head(50).iterrows(), total=50):
+ question = "Question: "+ row["text"]
+ output = get_GPT_response(question, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)
+ answer_list.append((row["text"], row["correct_node"], output))
+ answer_df = pd.DataFrame(answer_list, columns=["question", "correct_answer", "llm_answer"])
+ answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True)
+ print("Completed in {} min".format((time.time()-start_time)/60))
+
+
+if __name__ == "__main__":
+ main() \ No newline at end of file
diff --git a/kg_rag/prompt_based_generation/GPT/run_true_false_generation.py b/kg_rag/prompt_based_generation/GPT/run_true_false_generation.py
new file mode 100644
index 0000000..0d248db
--- /dev/null
+++ b/kg_rag/prompt_based_generation/GPT/run_true_false_generation.py
@@ -0,0 +1,33 @@
+from kg_rag.utility import *
+import sys
+
+
+CHAT_MODEL_ID = sys.argv[1]
+
+QUESTION_PATH = config_data["TRUE_FALSE_PATH"]
+SYSTEM_PROMPT = system_prompts["TRUE_FALSE_QUESTION_PROMPT_BASED"]
+SAVE_PATH = config_data["SAVE_RESULTS_PATH"]
+TEMPERATURE = config_data["LLM_TEMPERATURE"]
+
+CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID
+
+save_name = "_".join(CHAT_MODEL_ID.split("-"))+"_prompt_based_one_hop_true_false_binary_response.csv"
+
+
+def main():
+ start_time = time.time()
+ question_df = pd.read_csv(QUESTION_PATH)
+ answer_list = []
+ for index, row in question_df.iterrows():
+ question = "Question: "+ row["text"]
+ output = get_GPT_response(question, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)
+ answer_list.append((row["text"], row["label"], output))
+ answer_df = pd.DataFrame(answer_list, columns=["question", "label", "llm_answer"])
+ answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True)
+ print("Completed in {} min".format((time.time()-start_time)/60))
+
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/kg_rag/prompt_based_generation/GPT/text_generation.py b/kg_rag/prompt_based_generation/GPT/text_generation.py
new file mode 100644
index 0000000..235ece7
--- /dev/null
+++ b/kg_rag/prompt_based_generation/GPT/text_generation.py
@@ -0,0 +1,33 @@
+from kg_rag.utility import *
+import argparse
+
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-g', type=str, default='gpt-35-turbo', help='GPT model selection')
+args = parser.parse_args()
+
+CHAT_MODEL_ID = args.g
+
+SYSTEM_PROMPT = system_prompts["PROMPT_BASED_TEXT_GENERATION"]
+TEMPERATURE = config_data["LLM_TEMPERATURE"]
+
+CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID
+
+
+def main():
+ print(" ")
+ question = input("Enter your question : ")
+ print("Here is the prompt-based answer:")
+ print("")
+ output = get_GPT_response(question, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)
+ stream_out(output)
+
+
+
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/kg_rag/prompt_based_generation/Llama/run_mcq_qa.py b/kg_rag/prompt_based_generation/Llama/run_mcq_qa.py
new file mode 100644
index 0000000..e3eb1b0
--- /dev/null
+++ b/kg_rag/prompt_based_generation/Llama/run_mcq_qa.py
@@ -0,0 +1,38 @@
+from langchain import PromptTemplate, LLMChain
+from kg_rag.utility import *
+
+
+QUESTION_PATH = config_data["MCQ_PATH"]
+SYSTEM_PROMPT = system_prompts["MCQ_QUESTION_PROMPT_BASED"]
+SAVE_PATH = config_data["SAVE_RESULTS_PATH"]
+MODEL_NAME = config_data["LLAMA_MODEL_NAME"]
+BRANCH_NAME = config_data["LLAMA_MODEL_BRANCH"]
+CACHE_DIR = config_data["LLM_CACHE_DIR"]
+
+
+
+save_name = "_".join(MODEL_NAME.split("/")[-1].split("-"))+"_prompt_based_two_hop_mcq_from_monarch_and_robokop_response.csv"
+
+INSTRUCTION = "Question: {question}"
+
+
+def main():
+ start_time = time.time()
+ llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR)
+ template = get_prompt(INSTRUCTION, SYSTEM_PROMPT)
+ prompt = PromptTemplate(template=template, input_variables=["question"])
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
+ question_df = pd.read_csv(QUESTION_PATH)
+ answer_list = []
+ for index, row in question_df.iterrows():
+ question = row["text"]
+ output = llm_chain.run(question)
+ answer_list.append((row["text"], row["correct_node"], output))
+ answer_df = pd.DataFrame(answer_list, columns=["question", "correct_answer", "llm_answer"])
+ answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True)
+ print("Completed in {} min".format((time.time()-start_time)/60))
+
+
+
+if __name__ == "__main__":
+ main()
diff --git a/kg_rag/prompt_based_generation/Llama/run_mcq_qa_medgpt.py b/kg_rag/prompt_based_generation/Llama/run_mcq_qa_medgpt.py
new file mode 100644
index 0000000..91a23cb
--- /dev/null
+++ b/kg_rag/prompt_based_generation/Llama/run_mcq_qa_medgpt.py
@@ -0,0 +1,38 @@
+from langchain import PromptTemplate, LLMChain
+from kg_rag.utility import *
+
+
+QUESTION_PATH = config_data["MCQ_PATH"]
+SYSTEM_PROMPT = system_prompts["MCQ_QUESTION_PROMPT_BASED"]
+SAVE_PATH = config_data["SAVE_RESULTS_PATH"]
+MODEL_NAME = 'PharMolix/BioMedGPT-LM-7B'
+BRANCH_NAME = 'main'
+CACHE_DIR = config_data["LLM_CACHE_DIR"]
+
+
+
+save_name = "_".join(MODEL_NAME.split("/")[-1].split("-"))+"_prompt_based_mcq_from_monarch_and_robokop_response.csv"
+
+INSTRUCTION = "Question: {question}"
+
+
+def main():
+ start_time = time.time()
+ llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR)
+ template = get_prompt(INSTRUCTION, SYSTEM_PROMPT)
+ prompt = PromptTemplate(template=template, input_variables=["question"])
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
+ question_df = pd.read_csv(QUESTION_PATH)
+ answer_list = []
+ for index, row in question_df.iterrows():
+ question = row["text"]
+ output = llm_chain.run(question)
+ answer_list.append((row["text"], row["correct_node"], output))
+ answer_df = pd.DataFrame(answer_list, columns=["question", "correct_answer", "llm_answer"])
+ answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True)
+ print("Completed in {} min".format((time.time()-start_time)/60))
+
+
+
+if __name__ == "__main__":
+ main()
diff --git a/kg_rag/prompt_based_generation/Llama/run_true_false_generation.py b/kg_rag/prompt_based_generation/Llama/run_true_false_generation.py
new file mode 100644
index 0000000..81a98a1
--- /dev/null
+++ b/kg_rag/prompt_based_generation/Llama/run_true_false_generation.py
@@ -0,0 +1,35 @@
+from langchain import PromptTemplate, LLMChain
+from kg_rag.utility import *
+
+QUESTION_PATH = config_data["TRUE_FALSE_PATH"]
+SYSTEM_PROMPT = system_prompts["TRUE_FALSE_QUESTION_PROMPT_BASED"]
+SAVE_PATH = config_data["SAVE_RESULTS_PATH"]
+MODEL_NAME = config_data["LLAMA_MODEL_NAME"]
+BRANCH_NAME = config_data["LLAMA_MODEL_BRANCH"]
+CACHE_DIR = config_data["LLM_CACHE_DIR"]
+
+
+INSTRUCTION = "Question: {question}"
+
+save_name = "_".join(MODEL_NAME.split("/")[-1].split("-"))+"_prompt_based_one_hop_true_false_binary_response.csv"
+
+def main():
+ start_time = time.time()
+ llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR)
+ template = get_prompt(INSTRUCTION, SYSTEM_PROMPT)
+ prompt = PromptTemplate(template=template, input_variables=["question"])
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
+ question_df = pd.read_csv(QUESTION_PATH)
+ answer_list = []
+ for index, row in question_df.iterrows():
+ question = row["text"]
+ output = llm_chain.run(question)
+ answer_list.append((row["text"], row["label"], output))
+ answer_df = pd.DataFrame(answer_list, columns=["question", "label", "llm_answer"])
+ answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True)
+ print("Completed in {} min".format((time.time()-start_time)/60))
+
+
+
+if __name__ == "__main__":
+ main()
diff --git a/kg_rag/prompt_based_generation/Llama/text_generation.py b/kg_rag/prompt_based_generation/Llama/text_generation.py
new file mode 100644
index 0000000..49bfebb
--- /dev/null
+++ b/kg_rag/prompt_based_generation/Llama/text_generation.py
@@ -0,0 +1,36 @@
+from langchain import PromptTemplate, LLMChain
+from kg_rag.utility import *
+import argparse
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-m', type=str, default='method-1', help='Method to choose for Llama model')
+args = parser.parse_args()
+
+METHOD = args.m
+
+
+SYSTEM_PROMPT = system_prompts["PROMPT_BASED_TEXT_GENERATION"]
+MODEL_NAME = config_data["LLAMA_MODEL_NAME"]
+BRANCH_NAME = config_data["LLAMA_MODEL_BRANCH"]
+CACHE_DIR = config_data["LLM_CACHE_DIR"]
+
+
+
+INSTRUCTION = "Question: {question}"
+
+
+def main():
+ llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR, stream=True, method=METHOD)
+ template = get_prompt(INSTRUCTION, SYSTEM_PROMPT)
+ prompt = PromptTemplate(template=template, input_variables=["question"])
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
+ print(" ")
+ question = input("Enter your question : ")
+ print("Here is the prompt-based answer:")
+ print("")
+ output = llm_chain.run(question)
+
+
+
+if __name__ == "__main__":
+ main()
diff --git a/kg_rag/rag_based_generation/GPT/drug_action.py b/kg_rag/rag_based_generation/GPT/drug_action.py
new file mode 100644
index 0000000..60c0acf
--- /dev/null
+++ b/kg_rag/rag_based_generation/GPT/drug_action.py
@@ -0,0 +1,52 @@
+from kg_rag.utility import *
+import argparse
+
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-g', type=str, default='gpt-35-turbo', help='GPT model selection')
+parser.add_argument('-i', type=bool, default=False, help='Flag for interactive mode')
+parser.add_argument('-e', type=bool, default=False, help='Flag for showing evidence of association from the graph')
+args = parser.parse_args()
+
+CHAT_MODEL_ID = args.g
+INTERACTIVE = args.i
+EDGE_EVIDENCE = bool(args.e)
+
+
+SYSTEM_PROMPT = system_prompts["DRUG_ACTION"]
+CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"])
+QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"])
+QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"])
+VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"]
+NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"]
+SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"]
+SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"]
+TEMPERATURE = config_data["LLM_TEMPERATURE"]
+
+
+CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID
+
+vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)
+embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)
+node_context_df = pd.read_csv(NODE_CONTEXT_PATH)
+
+def main():
+ print(" ")
+ question = input("Enter your question : ")
+ if not INTERACTIVE:
+ print("Retrieving context from SPOKE graph...")
+ context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE)
+ print("Here is the KG-RAG based answer:")
+ print("")
+ enriched_prompt = "Context: "+ context + "\n" + "Question: " + question
+ output = get_GPT_response(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)
+ stream_out(output)
+ else:
+ interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, CHAT_MODEL_ID, EDGE_EVIDENCE, SYSTEM_PROMPT)
+
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/kg_rag/rag_based_generation/GPT/drug_repurposing_v2.py b/kg_rag/rag_based_generation/GPT/drug_repurposing_v2.py
new file mode 100644
index 0000000..d95053b
--- /dev/null
+++ b/kg_rag/rag_based_generation/GPT/drug_repurposing_v2.py
@@ -0,0 +1,68 @@
+from kg_rag.utility import *
+import argparse
+
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-g', type=str, default='gpt-35-turbo', help='GPT model selection')
+parser.add_argument('-i', type=bool, default=False, help='Flag for interactive mode')
+parser.add_argument('-e', type=bool, default=False, help='Flag for showing evidence of association from the graph')
+args = parser.parse_args()
+
+
+CHAT_MODEL_ID = args.g
+INTERACTIVE = args.i
+EDGE_EVIDENCE = bool(args.e)
+
+SYSTEM_PROMPT = system_prompts["DRUG_REPURPOSING_V2"]
+CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"])
+QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"])
+QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"])
+VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"]
+NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"]
+SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"]
+SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"]
+TEMPERATURE = config_data["LLM_TEMPERATURE"]
+
+
+CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID
+
+vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)
+embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)
+node_context_df = pd.read_csv(NODE_CONTEXT_PATH)
+
+print('')
+question = input("Question : ")
+
+question_template = f'''
+To the question asked at the end, answer by referring the context.
+See example below
+Example 1:
+ Question:
+ What drugs can be repurposed for disease X?
+ Context:
+ Compound Alizapride DOWNREGULATES Gene APOE and Provenance of this association is XX. Gene APOE ASSOCIATES Disease X and Provenance of this association is YY. Gene TTR encodes Protein Transthyretin (ATTR) and Provenance of this association is ZZ. Compound Acetylcysteine treats Disease X and Provenance of this association is PP.
+ Answer:
+ Since Alizapride downregulates gene APOE (Provenance XX) and APOE is associated with Disease X (Provenance YY), Alizapride can be repurposed to treat Disease X. p-value for these associations is XXXX and z-score values for these associations is YYYY.
+Question:
+{question}
+'''
+
+def main():
+ if not INTERACTIVE:
+ print("Retrieving context from SPOKE graph...")
+ context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE)
+ print("Here is the KG-RAG based answer:")
+ print("")
+ enriched_prompt = "Context: "+ context + "\n" + "Question: " + question
+ output = get_GPT_response(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)
+ stream_out(output)
+ else:
+ interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, CHAT_MODEL_ID, EDGE_EVIDENCE, SYSTEM_PROMPT)
+
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/kg_rag/rag_based_generation/GPT/run_drug_repurposing.py b/kg_rag/rag_based_generation/GPT/run_drug_repurposing.py
new file mode 100644
index 0000000..8a5726d
--- /dev/null
+++ b/kg_rag/rag_based_generation/GPT/run_drug_repurposing.py
@@ -0,0 +1,57 @@
+'''
+This script takes the drug repurposing style questions from the csv file and save the result as another csv file.
+Before running this script, make sure to configure the filepaths in config.yaml file.
+Command line argument should be either 'gpt-4' or 'gpt-35-turbo'
+'''
+
+from kg_rag.utility import *
+import sys
+
+
+
+CHAT_MODEL_ID = sys.argv[1]
+
+QUESTION_PATH = config_data["DRUG_REPURPOSING_PATH"]
+SYSTEM_PROMPT = system_prompts["DRUG_REPURPOSING"]
+CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"])
+QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"])
+QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"])
+VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"]
+NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"]
+SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"]
+SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"]
+TEMPERATURE = config_data["LLM_TEMPERATURE"]
+SAVE_PATH = config_data["SAVE_RESULTS_PATH"]
+
+
+CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID
+
+save_name = "_".join(CHAT_MODEL_ID.split("-"))+"_drug_repurposing_questions_response.csv"
+
+
+vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)
+embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)
+node_context_df = pd.read_csv(NODE_CONTEXT_PATH)
+
+
+def main():
+ start_time = time.time()
+ question_df = pd.read_csv(QUESTION_PATH)
+ answer_list = []
+ for index, row in question_df.iterrows():
+ question = row["text"]
+ context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY)
+ enriched_prompt = "Context: " + context + "\n" + "Question: " + question
+ output = get_GPT_response(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)
+ answer_list.append((row["disease_in_question"], row["refDisease"], row["compoundGroundTruth"], row["text"], output))
+ answer_df = pd.DataFrame(answer_list, columns=["disease_in_question", "refDisease", "compoundGroundTruth", "text", "llm_answer"])
+ answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True)
+ print("Completed in {} min".format((time.time()-start_time)/60))
+
+
+
+if __name__ == "__main__":
+ main()
+
+
+
diff --git a/kg_rag/rag_based_generation/GPT/run_mcq_qa.py b/kg_rag/rag_based_generation/GPT/run_mcq_qa.py
new file mode 100644
index 0000000..edf0415
--- /dev/null
+++ b/kg_rag/rag_based_generation/GPT/run_mcq_qa.py
@@ -0,0 +1,91 @@
+'''
+This script takes the MCQ style questions from the csv file and save the result as another csv file.
+Before running this script, make sure to configure the filepaths in config.yaml file.
+Command line argument should be either 'gpt-4' or 'gpt-35-turbo'
+'''
+
+from kg_rag.utility import *
+import sys
+
+
+from tqdm import tqdm
+CHAT_MODEL_ID = sys.argv[1]
+
+QUESTION_PATH = config_data["MCQ_PATH"]
+SYSTEM_PROMPT = system_prompts["MCQ_QUESTION"]
+CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"])
+QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"])
+QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"])
+VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"]
+NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"]
+SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"]
+SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"]
+TEMPERATURE = config_data["LLM_TEMPERATURE"]
+SAVE_PATH = config_data["SAVE_RESULTS_PATH"]
+
+
+CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID
+
+save_name = "_".join(CHAT_MODEL_ID.split("-"))+"_kg_rag_based_mcq_{mode}.csv"
+
+
+vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)
+embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)
+node_context_df = pd.read_csv(NODE_CONTEXT_PATH)
+edge_evidence = False
+
+
+MODE = "0"
+### MODE 0: Original KG_RAG ###
+### MODE 1: jsonlize the context from KG search ###
+### MODE 2: Add the prior domain knowledge ###
+### MODE 3: Combine MODE 1 & 2 ###
+
+def main():
+ start_time = time.time()
+ question_df = pd.read_csv(QUESTION_PATH)
+ answer_list = []
+
+ for index, row in tqdm(question_df.iterrows(), total=306):
+ try:
+ question = row["text"]
+ if MODE == "0":
+ ### MODE 0: Original KG_RAG ###
+ context = retrieve_context(row["text"], vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence, model_id=CHAT_MODEL_ID)
+ enriched_prompt = "Context: "+ context + "\n" + "Question: "+ question
+ output = get_Gemini_response(enriched_prompt, SYSTEM_PROMPT, temperature=TEMPERATURE)
+
+ if MODE == "1":
+ ### MODE 1: jsonlize the context from KG search ###
+ ### Please implement the first strategy here ###
+ output = '...'
+
+ if MODE == "2":
+ ### MODE 2: Add the prior domain knowledge ###
+ ### Please implement the second strategy here ###
+ output = '...'
+
+ if MODE == "3":
+ ### MODE 3: Combine MODE 1 & 2 ###
+ ### Please implement the third strategy here ###
+ output = '...'
+
+ answer_list.append((row["text"], row["correct_node"], output))
+ except Exception as e:
+ print("Error in processing question: ", row["text"])
+ print("Error: ", e)
+ answer_list.append((row["text"], row["correct_node"], "Error"))
+
+
+ answer_df = pd.DataFrame(answer_list, columns=["question", "correct_answer", "llm_answer"])
+ output_file = os.path.join(SAVE_PATH, f"{save_name}".format(mode=MODE),)
+ answer_df.to_csv(output_file, index=False, header=True)
+ print("Save the model outputs in ", output_file)
+ print("Completed in {} min".format((time.time()-start_time)/60))
+
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/kg_rag/rag_based_generation/GPT/run_single_disease_entity_hyperparameter_tuning.py b/kg_rag/rag_based_generation/GPT/run_single_disease_entity_hyperparameter_tuning.py
new file mode 100644
index 0000000..aaf8071
--- /dev/null
+++ b/kg_rag/rag_based_generation/GPT/run_single_disease_entity_hyperparameter_tuning.py
@@ -0,0 +1,61 @@
+'''
+This script is used for hyperparameter tuning on one-hop graph traversal questions.
+Hyperparameters are 'CONTEXT_VOLUME_LIST' and 'SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL_LIST'
+
+This will run on one-hop graph traveral questions from the csv file and save the result as another csv file.
+
+Before running this script, make sure to configure the filepaths in config.yaml file.
+Command line argument should be either 'gpt-4' or 'gpt-35-turbo'
+'''
+
+from kg_rag.utility import *
+import sys
+
+CHAT_MODEL_ID = sys.argv[1]
+
+CONTEXT_VOLUME_LIST = [10, 50, 100, 150, 200]
+SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL_LIST = ["pritamdeka/S-PubMedBert-MS-MARCO", "sentence-transformers/all-MiniLM-L6-v2"]
+SAVE_NAME_LIST = ["pubmedBert_based_one_hop_questions_parameter_tuning_round_{}.csv", "miniLM_based_one_hop_questions_parameter_tuning_round_{}.csv"]
+
+QUESTION_PATH = config_data["SINGLE_DISEASE_ENTITY_FILE"]
+SYSTEM_PROMPT = system_prompts["SINGLE_DISEASE_ENTITY_VALIDATION"]
+QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"])
+QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"])
+VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"]
+NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"]
+SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"]
+TEMPERATURE = config_data["LLM_TEMPERATURE"]
+SAVE_PATH = config_data["SAVE_RESULTS_PATH"]
+
+CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID
+
+vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)
+node_context_df = pd.read_csv(NODE_CONTEXT_PATH)
+edge_evidence = False
+
+def main():
+ start_time = time.time()
+ question_df = pd.read_csv(QUESTION_PATH)
+ for tranformer_index, sentence_embedding_model_for_context_retrieval in enumerate(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL_LIST):
+ embedding_function_for_context_retrieval = load_sentence_transformer(sentence_embedding_model_for_context_retrieval)
+ for context_index, context_volume in enumerate(CONTEXT_VOLUME_LIST):
+ answer_list = []
+ for index, row in question_df.iterrows():
+ question = row["text"]
+ context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, context_volume, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence)
+ enriched_prompt = "Context: "+ context + "\n" + "Question: " + question
+ output = get_GPT_response(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)
+ if not output:
+ time.sleep(5)
+ answer_list.append((row["disease_1"], row["Compounds"], row["Diseases"], row["text"], output, context_volume))
+ answer_df = pd.DataFrame(answer_list, columns=["disease", "compound_groundTruth", "disease_groundTruth", "text", "llm_answer", "context_volume"])
+ save_name = "_".join(CHAT_MODEL_ID.split("-"))+SAVE_NAME_LIST[tranformer_index].format(context_index+1)
+ answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True)
+ print("Completed in {} min".format((time.time()-start_time)/60))
+
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/kg_rag/rag_based_generation/GPT/run_true_false_generation.py b/kg_rag/rag_based_generation/GPT/run_true_false_generation.py
new file mode 100644
index 0000000..7b8d0e3
--- /dev/null
+++ b/kg_rag/rag_based_generation/GPT/run_true_false_generation.py
@@ -0,0 +1,52 @@
+'''
+This script takes the True/False style questions from the csv file and save the result as another csv file.
+Before running this script, make sure to configure the filepaths in config.yaml file.
+Command line argument should be either 'gpt-4' or 'gpt-35-turbo'
+'''
+
+from kg_rag.utility import *
+import sys
+
+CHAT_MODEL_ID = sys.argv[1]
+
+QUESTION_PATH = config_data["TRUE_FALSE_PATH"]
+SYSTEM_PROMPT = system_prompts["TRUE_FALSE_QUESTION"]
+QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"])
+QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"])
+VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"]
+NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"]
+SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"]
+SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"]
+TEMPERATURE = config_data["LLM_TEMPERATURE"]
+SAVE_PATH = config_data["SAVE_RESULTS_PATH"]
+CONTEXT_VOLUME = 100
+
+
+CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID
+
+save_name = "_".join(CHAT_MODEL_ID.split("-"))+"_kg_rag_based_true_false_binary_response.csv"
+
+
+vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)
+embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)
+node_context_df = pd.read_csv(NODE_CONTEXT_PATH)
+edge_evidence = False
+
+def main():
+ start_time = time.time()
+ question_df = pd.read_csv(QUESTION_PATH)
+ answer_list = []
+ for index, row in question_df.iterrows():
+ question = row["text"]
+ context = retrieve_context(row["text"], vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence)
+ enriched_prompt = "Context: "+ context + "\n" + "Question: "+ question
+ output = get_GPT_response(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)
+ answer_list.append((row["text"], row["label"], output))
+ answer_df = pd.DataFrame(answer_list, columns=["question", "label", "llm_answer"])
+ answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True)
+ print("Completed in {} min".format((time.time()-start_time)/60))
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/kg_rag/rag_based_generation/GPT/run_two_disease_entity_hyperparameter_tuning.py b/kg_rag/rag_based_generation/GPT/run_two_disease_entity_hyperparameter_tuning.py
new file mode 100644
index 0000000..043f39d
--- /dev/null
+++ b/kg_rag/rag_based_generation/GPT/run_two_disease_entity_hyperparameter_tuning.py
@@ -0,0 +1,57 @@
+'''
+This script is used for hyperparameter tuning on two-hop graph traversal questions.
+Hyperparameters are 'CONTEXT_VOLUME_LIST' and 'SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL_LIST'
+
+This will run on two-hop graph traveral questions from the csv file and save the result as another csv file.
+
+Before running this script, make sure to configure the filepaths in config.yaml file.
+Command line argument should be either 'gpt-4' or 'gpt-35-turbo'
+'''
+
+from kg_rag.utility import *
+import sys
+
+
+CHAT_MODEL_ID = sys.argv[1]
+
+CONTEXT_VOLUME_LIST = [10, 50, 100, 150, 200]
+SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL_LIST = ["pritamdeka/S-PubMedBert-MS-MARCO", "sentence-transformers/all-MiniLM-L6-v2"]
+SAVE_NAME_LIST = ["pubmedBert_based_two_hop_questions_parameter_tuning_round_{}.csv", "miniLM_based_two_hop_questions_parameter_tuning_round_{}.csv"]
+
+QUESTION_PATH = config_data["TWO_DISEASE_ENTITY_FILE"]
+SYSTEM_PROMPT = system_prompts["TWO_DISEASE_ENTITY_VALIDATION"]
+QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"])
+QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"])
+VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"]
+NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"]
+SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"]
+TEMPERATURE = config_data["LLM_TEMPERATURE"]
+SAVE_PATH = config_data["SAVE_RESULTS_PATH"]
+
+vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)
+embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)
+node_context_df = pd.read_csv(NODE_CONTEXT_PATH)
+edge_evidence = False
+
+def main():
+ start_time = time.time()
+ question_df = pd.read_csv(QUESTION_PATH)
+ for tranformer_index, sentence_embedding_model_for_context_retrieval in enumerate(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL_LIST):
+ for context_index, context_volume in enumerate(CONTEXT_VOLUME_LIST):
+ answer_list = []
+ for index, row in question_df.iterrows():
+ question = row["text"]
+ context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, context_volume, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence)
+ enriched_prompt = "Context: "+ context + "\n" + "Question: " + question
+ output = get_GPT_response(enriched_prompt, system_prompt, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=temperature)
+ if not output:
+ time.sleep(5)
+ answer_list.append((row["disease_1"], row["disease_2"], row["central_nodes"], row["text"], output, context_volume))
+ answer_df = pd.DataFrame(answer_list, columns=["disease_1", "disease_2", "central_nodes_groundTruth", "text", "llm_answer", "context_volume"])
+ save_name = "_".join(CHAT_MODEL_ID.split("-"))+SAVE_NAME_LIST[tranformer_index].format(context_index+1)
+ answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True)
+ print("Completed in {} min".format((time.time()-start_time)/60))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/kg_rag/rag_based_generation/GPT/text_generation.py b/kg_rag/rag_based_generation/GPT/text_generation.py
new file mode 100644
index 0000000..f2fcee1
--- /dev/null
+++ b/kg_rag/rag_based_generation/GPT/text_generation.py
@@ -0,0 +1,61 @@
+'''
+This script takes a question from the user in an interactive fashion and returns the KG-RAG based response in real time
+Before running this script, make sure to configure config.yaml file.
+Command line argument should be either 'gpt-4' or 'gpt-35-turbo'
+'''
+
+from kg_rag.utility import *
+import argparse
+
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-g', type=str, default='gpt-35-turbo', help='GPT model selection')
+parser.add_argument('-i', type=bool, default=False, help='Flag for interactive mode')
+parser.add_argument('-e', type=bool, default=False, help='Flag for showing evidence of association from the graph')
+args = parser.parse_args()
+
+CHAT_MODEL_ID = args.g
+INTERACTIVE = args.i
+EDGE_EVIDENCE = bool(args.e)
+
+
+SYSTEM_PROMPT = system_prompts["KG_RAG_BASED_TEXT_GENERATION"]
+CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"])
+QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"])
+QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"])
+VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"]
+NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"]
+SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"]
+SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"]
+TEMPERATURE = config_data["LLM_TEMPERATURE"]
+
+
+CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID if openai.api_type == "azure" else None
+
+
+vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)
+embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)
+node_context_df = pd.read_csv(NODE_CONTEXT_PATH)
+
+def main():
+ print(" ")
+ question = input("Enter your question : ")
+ if not INTERACTIVE:
+ print("Retrieving context from SPOKE graph...")
+ context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE)
+ print("Here is the KG-RAG based answer:")
+ print("")
+ enriched_prompt = "Context: "+ context + "\n" + "Question: " + question
+ output = get_GPT_response(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)
+ stream_out(output)
+ else:
+ interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, CHAT_MODEL_ID, EDGE_EVIDENCE, SYSTEM_PROMPT)
+
+
+
+if __name__ == "__main__":
+ main()
+
+
+
diff --git a/kg_rag/rag_based_generation/Llama/run_drug_repurposing.py b/kg_rag/rag_based_generation/Llama/run_drug_repurposing.py
new file mode 100644
index 0000000..0b8d2f0
--- /dev/null
+++ b/kg_rag/rag_based_generation/Llama/run_drug_repurposing.py
@@ -0,0 +1,60 @@
+'''
+This script takes the drug repurposing style questions from the csv file and save the result as another csv file.
+This script makes use of Llama model.
+Before running this script, make sure to configure the filepaths in config.yaml file.
+'''
+
+from langchain import PromptTemplate, LLMChain
+from kg_rag.utility import *
+import sys
+
+QUESTION_PATH = config_data["DRUG_REPURPOSING_PATH"]
+SYSTEM_PROMPT = system_prompts["DRUG_REPURPOSING"]
+CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"])
+QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"])
+QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"])
+VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"]
+NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"]
+SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"]
+SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"]
+SAVE_PATH = config_data["SAVE_RESULTS_PATH"]
+MODEL_NAME = config_data["LLAMA_MODEL_NAME"]
+BRANCH_NAME = config_data["LLAMA_MODEL_BRANCH"]
+CACHE_DIR = config_data["LLM_CACHE_DIR"]
+
+
+save_name = "_".join(MODEL_NAME.split("/")[-1].split("-"))+"_drug_repurposing_questions_response.csv"
+
+
+INSTRUCTION = "Context:\n\n{context} \n\nQuestion: {question}"
+
+vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)
+embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)
+node_context_df = pd.read_csv(NODE_CONTEXT_PATH)
+
+
+
+def main():
+ start_time = time.time()
+ llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR, max_new_tokens=1024)
+ template = get_prompt(INSTRUCTION, SYSTEM_PROMPT)
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
+ question_df = pd.read_csv(QUESTION_PATH)
+ answer_list = []
+ for index, row in question_df.iterrows():
+ question = row["text"]
+ context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY)
+ output = llm_chain.run(context=context, question=question)
+ answer_list.append((row["disease_in_question"], row["refDisease"], row["compoundGroundTruth"], row["text"], output))
+ answer_df = pd.DataFrame(answer_list, columns=["disease_in_question", "refDisease", "compoundGroundTruth", "text", "llm_answer"])
+ answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True)
+ print("Completed in {} min".format((time.time()-start_time)/60))
+
+
+
+if __name__ == "__main__":
+ main()
+
+
+
diff --git a/kg_rag/rag_based_generation/Llama/run_mcq_qa.py b/kg_rag/rag_based_generation/Llama/run_mcq_qa.py
new file mode 100644
index 0000000..67ae43c
--- /dev/null
+++ b/kg_rag/rag_based_generation/Llama/run_mcq_qa.py
@@ -0,0 +1,61 @@
+'''
+This script takes the MCQ style questions from the csv file and save the result as another csv file.
+This script makes use of Llama model.
+Before running this script, make sure to configure the filepaths in config.yaml file.
+'''
+from tqdm import tqdm
+from langchain import PromptTemplate, LLMChain
+from kg_rag.utility import *
+
+
+QUESTION_PATH = config_data["MCQ_PATH"]
+SYSTEM_PROMPT = system_prompts["MCQ_QUESTION"]
+CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"])
+QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"])
+QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"])
+VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"]
+NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"]
+SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"]
+SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"]
+SAVE_PATH = config_data["SAVE_RESULTS_PATH"]
+MODEL_NAME = config_data["LLAMA_MODEL_NAME"]
+BRANCH_NAME = config_data["LLAMA_MODEL_BRANCH"]
+CACHE_DIR = config_data["LLM_CACHE_DIR"]
+
+save_name = "_".join(MODEL_NAME.split("/")[-1].split("-"))+"_kg_rag_based_mcq_from_monarch_and_robokop_response.csv"
+
+
+INSTRUCTION = "Context:\n\n{context} \n\nQuestion: {question}"
+
+vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)
+embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)
+node_context_df = pd.read_csv(NODE_CONTEXT_PATH)
+edge_evidence = False
+
+
+
+def main():
+ start_time = time.time()
+ llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR)
+ template = get_prompt(INSTRUCTION, SYSTEM_PROMPT)
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
+ question_df = pd.read_csv(QUESTION_PATH)
+ answer_list = []
+ for index, row in tqdm(question_df.iterrows()):
+ question = row["text"]
+ context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence)
+ output = llm_chain.run(context=context, question=question)
+ answer_list.append((row["text"], row["correct_node"], output))
+ answer_df = pd.DataFrame(answer_list, columns=["question", "correct_answer", "llm_answer"])
+ answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True)
+ print("Completed in {} min".format((time.time()-start_time)/60))
+
+
+
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/kg_rag/rag_based_generation/Llama/run_mcq_qa_medgpt.py b/kg_rag/rag_based_generation/Llama/run_mcq_qa_medgpt.py
new file mode 100644
index 0000000..813b601
--- /dev/null
+++ b/kg_rag/rag_based_generation/Llama/run_mcq_qa_medgpt.py
@@ -0,0 +1,61 @@
+'''
+This script takes the MCQ style questions from the csv file and save the result as another csv file.
+This script makes use of Llama model.
+Before running this script, make sure to configure the filepaths in config.yaml file.
+'''
+
+from langchain import PromptTemplate, LLMChain
+from kg_rag.utility import *
+
+
+QUESTION_PATH = config_data["MCQ_PATH"]
+SYSTEM_PROMPT = system_prompts["MCQ_QUESTION"]
+CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"])
+QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"])
+QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"])
+VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"]
+NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"]
+SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"]
+SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"]
+SAVE_PATH = config_data["SAVE_RESULTS_PATH"]
+MODEL_NAME = 'PharMolix/BioMedGPT-LM-7B'
+BRANCH_NAME = 'main'
+CACHE_DIR = config_data["LLM_CACHE_DIR"]
+
+save_name = "_".join(MODEL_NAME.split("/")[-1].split("-"))+"_kg_rag_based_mcq_from_monarch_and_robokop_response.csv"
+
+
+INSTRUCTION = "Context:\n\n{context} \n\nQuestion: {question}"
+
+vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)
+embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)
+node_context_df = pd.read_csv(NODE_CONTEXT_PATH)
+edge_evidence = False
+
+
+def main():
+ start_time = time.time()
+ llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR)
+ template = get_prompt(INSTRUCTION, SYSTEM_PROMPT)
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
+ question_df = pd.read_csv(QUESTION_PATH)
+ question_df = question_df.sample(50, random_state=40)
+ answer_list = []
+ for index, row in question_df.iterrows():
+ question = row["text"]
+ context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence)
+ output = llm_chain.run(context=context, question=question)
+ answer_list.append((row["text"], row["correct_node"], output))
+ answer_df = pd.DataFrame(answer_list, columns=["question", "correct_answer", "llm_answer"])
+ answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True)
+ print("Completed in {} min".format((time.time()-start_time)/60))
+
+
+
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/kg_rag/rag_based_generation/Llama/run_true_false_generation.py b/kg_rag/rag_based_generation/Llama/run_true_false_generation.py
new file mode 100644
index 0000000..fa1a37d
--- /dev/null
+++ b/kg_rag/rag_based_generation/Llama/run_true_false_generation.py
@@ -0,0 +1,59 @@
+'''
+This script takes the True/False style questions from the csv file and save the result as another csv file.
+This script makes use of Llama model.
+Before running this script, make sure to configure the filepaths in config.yaml file.
+'''
+
+from langchain import PromptTemplate, LLMChain
+from kg_rag.utility import *
+import sys
+
+
+QUESTION_PATH = config_data["TRUE_FALSE_PATH"]
+SYSTEM_PROMPT = system_prompts["TRUE_FALSE_QUESTION"]
+QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"])
+QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"])
+VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"]
+NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"]
+SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"]
+SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"]
+SAVE_PATH = config_data["SAVE_RESULTS_PATH"]
+MODEL_NAME = config_data["LLAMA_MODEL_NAME"]
+BRANCH_NAME = config_data["LLAMA_MODEL_BRANCH"]
+CACHE_DIR = config_data["LLM_CACHE_DIR"]
+CONTEXT_VOLUME = 100
+edge_evidence = False
+
+save_name = "_".join(MODEL_NAME.split("/")[-1].split("-"))+"_kg_rag_based_true_false_binary_response.csv"
+
+
+INSTRUCTION = "Context:\n\n{context} \n\nQuestion: {question}"
+
+vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)
+embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)
+node_context_df = pd.read_csv(NODE_CONTEXT_PATH)
+
+
+def main():
+ start_time = time.time()
+ llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR)
+ template = get_prompt(INSTRUCTION, SYSTEM_PROMPT)
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
+ question_df = pd.read_csv(QUESTION_PATH)
+ answer_list = []
+ for index, row in question_df.iterrows():
+ question = row["text"]
+ context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence)
+ output = llm_chain.run(context=context, question=question)
+ answer_list.append((row["text"], row["label"], output))
+ answer_df = pd.DataFrame(answer_list, columns=["question", "label", "llm_answer"])
+ answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True)
+ print("Completed in {} min".format((time.time()-start_time)/60))
+
+
+
+if __name__ == "__main__":
+ main()
+
+ \ No newline at end of file
diff --git a/kg_rag/rag_based_generation/Llama/text_generation.py b/kg_rag/rag_based_generation/Llama/text_generation.py
new file mode 100644
index 0000000..2824135
--- /dev/null
+++ b/kg_rag/rag_based_generation/Llama/text_generation.py
@@ -0,0 +1,60 @@
+from langchain import PromptTemplate, LLMChain
+from kg_rag.utility import *
+import argparse
+
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-i', type=bool, default=False, help='Flag for interactive mode')
+parser.add_argument('-m', type=str, default='method-1', help='Method to choose for Llama model')
+parser.add_argument('-e', type=bool, default=False, help='Flag for showing evidence of association from the graph')
+args = parser.parse_args()
+
+INTERACTIVE = args.i
+METHOD = args.m
+EDGE_EVIDENCE = bool(args.e)
+
+
+SYSTEM_PROMPT = system_prompts["KG_RAG_BASED_TEXT_GENERATION"]
+CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"])
+QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"])
+QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"])
+VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"]
+NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"]
+SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"]
+SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"]
+MODEL_NAME = config_data["LLAMA_MODEL_NAME"]
+BRANCH_NAME = config_data["LLAMA_MODEL_BRANCH"]
+CACHE_DIR = config_data["LLM_CACHE_DIR"]
+
+
+INSTRUCTION = "Context:\n\n{context} \n\nQuestion: {question}"
+
+vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)
+embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)
+node_context_df = pd.read_csv(NODE_CONTEXT_PATH)
+
+def main():
+ print(" ")
+ question = input("Enter your question : ")
+ if not INTERACTIVE:
+ template = get_prompt(INSTRUCTION, SYSTEM_PROMPT)
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
+ llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR, stream=True, method=METHOD)
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
+ print("Retrieving context from SPOKE graph...")
+ context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE)
+ print("Here is the KG-RAG based answer using Llama:")
+ print("")
+ output = llm_chain.run(context=context, question=question)
+ else:
+ interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, "llama", EDGE_EVIDENCE, SYSTEM_PROMPT, llama_method=METHOD)
+
+
+
+
+
+
+
+if __name__ == "__main__":
+ main()
diff --git a/kg_rag/run_setup.py b/kg_rag/run_setup.py
new file mode 100644
index 0000000..04c856c
--- /dev/null
+++ b/kg_rag/run_setup.py
@@ -0,0 +1,77 @@
+import os
+from kg_rag.utility import config_data
+
+def download_llama(method):
+ from kg_rag.utility import llama_model
+ try:
+ llama_model(config_data["LLAMA_MODEL_NAME"], config_data["LLAMA_MODEL_BRANCH"], config_data["LLM_CACHE_DIR"], method=method)
+ print("Model is successfully downloaded to the provided cache directory!")
+ except:
+ print("Model is not downloaded! Make sure the above mentioned conditions are satisfied")
+
+
+print("")
+print("Starting to set up KG-RAG ...")
+print("")
+
+# user_input = input("Did you update the config.yaml file with all necessary configurations (such as GPT .env path, vectorDB file paths, other file paths)? Enter Y or N: ")
+# print("")
+# if user_input == "Y":
+if True:
+ print("Checking disease vectorDB ...")
+ print("The current VECTOR_DB_PATH is ", config_data["VECTOR_DB_PATH"])
+ try:
+ if os.path.exists(config_data["VECTOR_DB_PATH"]):
+ print("vectorDB already exists!")
+ else:
+ print("Creating vectorDB ...")
+ from kg_rag.vectorDB.create_vectordb import create_vectordb
+ create_vectordb()
+ print("Congratulations! The disease database is completed.")
+ except:
+ print("Double check the path that was given in VECTOR_DB_PATH of config.yaml file.")
+ '''
+ print("")
+ user_input_1 = input("Do you want to install Llama model? Enter Y or N: ")
+ if user_input_1 == "Y":
+ user_input_2 = input("Did you update the config.yaml file with proper configuration for downloading Llama model? Enter Y or N: ")
+ if user_input_2 == "Y":
+ user_input_3 = input("Are you using official Llama model from Meta? Enter Y or N: ")
+ if user_input_3 == "Y":
+ user_input_4 = input("Did you get access to use the model? Enter Y or N: ")
+ if user_input_4 == "Y":
+ download_llama()
+ print("Congratulations! Setup is completed.")
+ else:
+ print("Aborting!")
+ else:
+ download_llama(method='method-1')
+ user_input_5 = input("Did you get a message like 'Model is not downloaded!'? Enter Y or N: ")
+ if user_input_5 == "N":
+ print("Congratulations! Setup is completed.")
+ else:
+ download_llama(method='method-2')
+ user_input_6 = input("Did you get a message like 'Model is not downloaded!'? Enter Y or N: ")
+ if user_input_6 == "N":
+ print("""
+ IMPORTANT :
+ Llama model was downloaded using 'LlamaTokenizer' instead of 'AutoTokenizer' method.
+ So, when you run text generation script, please provide an extra command line argument '-m method-2'.
+ For example:
+ python -m kg_rag.rag_based_generation.Llama.text_generation -m method-2
+ """)
+ print("Congratulations! Setup is completed.")
+ else:
+ print("We have now tried two methods to download Llama. If they both do not work, then please check the Llama configuration requirement in the huggingface model card page. Aborting!")
+ else:
+ print("Aborting!")
+ else:
+ print("No problem. Llama will get installed on-the-fly when you run the model for the first time.")
+ print("Congratulations! Setup is completed.")
+ '''
+else:
+ print("As the first step, update config.yaml file and then run this python script again.")
+
+
+
+
diff --git a/kg_rag/test/__init__.py b/kg_rag/test/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/kg_rag/test/__init__.py
diff --git a/kg_rag/test/test_vectordb.py b/kg_rag/test/test_vectordb.py
new file mode 100644
index 0000000..9365971
--- /dev/null
+++ b/kg_rag/test/test_vectordb.py
@@ -0,0 +1,42 @@
+from kg_rag.utility import *
+import sys
+
+VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"]
+SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"]
+
+print("Testing vectorDB loading ...")
+print("")
+try:
+ vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)
+ print("vectorDB is loaded succesfully!")
+except:
+ print("vectorDB is not loaded. Check the path given in 'VECTOR_DB_PATH' of config.yaml")
+ print("")
+ sys.exit(1)
+try:
+ print("")
+ print("Testing entity extraction ...")
+ print("")
+ entity = "psoriasis"
+ print("Inputting '{}' as the entity to test ...".format(entity))
+ print("")
+ node_search_result = vectorstore.similarity_search_with_score(entity, k=1)
+ extracted_entity = node_search_result[0][0].page_content
+ print("Extracted entity is '{}'".format(extracted_entity))
+ print("")
+ if extracted_entity == "psoriasis":
+ print("Entity extraction is successful!")
+ print("")
+ print("vectorDB is correctly populated and is good to go!")
+ else:
+ print("Entity extraction is not successful. Make sure vectorDB is populated correctly. Refer 'How to run KG-RAG' Step 5")
+ print("")
+ sys.exit(1)
+except:
+ print("Entity extraction is not successful. Make sure vectorDB is populated correctly. Refer 'How to run KG-RAG' Step 5")
+ print("")
+ sys.exit(1)
+
+
+
+
diff --git a/kg_rag/utility.py b/kg_rag/utility.py
new file mode 100644
index 0000000..975367b
--- /dev/null
+++ b/kg_rag/utility.py
@@ -0,0 +1,443 @@
+import pandas as pd
+import numpy as np
+from sklearn.metrics.pairwise import cosine_similarity
+from joblib import Memory
+import json
+import openai
+import os
+import sys
+from tenacity import retry, stop_after_attempt, wait_random_exponential
+import time
+from dotenv import load_dotenv, find_dotenv
+import torch
+from langchain import HuggingFacePipeline
+from langchain.vectorstores import Chroma
+from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer, GPTQConfig
+from kg_rag.config_loader import *
+import ast
+import requests
+import google.generativeai as genai
+
+
+memory = Memory("cachegpt", verbose=0)
+
+# Config openai library
+config_file = config_data['GPT_CONFIG_FILE']
+load_dotenv(config_file)
+api_key = os.environ.get('API_KEY')
+api_version = os.environ.get('API_VERSION')
+resource_endpoint = os.environ.get('RESOURCE_ENDPOINT')
+openai.api_type = config_data['GPT_API_TYPE']
+openai.api_key = api_key
+if resource_endpoint:
+ openai.api_base = resource_endpoint
+if api_version:
+ openai.api_version = api_version
+
+genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
+
+
+torch.cuda.empty_cache()
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
+
+
+def get_spoke_api_resp(base_uri, end_point, params=None):
+ uri = base_uri + end_point
+ if params:
+ return requests.get(uri, params=params)
+ else:
+ return requests.get(uri)
+
+
+@retry(wait=wait_random_exponential(min=10, max=30), stop=stop_after_attempt(5))
+def get_context_using_spoke_api(node_value):
+ type_end_point = "/api/v1/types"
+ result = get_spoke_api_resp(config_data['BASE_URI'], type_end_point)
+ data_spoke_types = result.json()
+ node_types = list(data_spoke_types["nodes"].keys())
+ edge_types = list(data_spoke_types["edges"].keys())
+ node_types_to_remove = ["DatabaseTimestamp", "Version"]
+ filtered_node_types = [node_type for node_type in node_types if node_type not in node_types_to_remove]
+ api_params = {
+ 'node_filters' : filtered_node_types,
+ 'edge_filters': edge_types,
+ 'cutoff_Compound_max_phase': config_data['cutoff_Compound_max_phase'],
+ 'cutoff_Protein_source': config_data['cutoff_Protein_source'],
+ 'cutoff_DaG_diseases_sources': config_data['cutoff_DaG_diseases_sources'],
+ 'cutoff_DaG_textmining': config_data['cutoff_DaG_textmining'],
+ 'cutoff_CtD_phase': config_data['cutoff_CtD_phase'],
+ 'cutoff_PiP_confidence': config_data['cutoff_PiP_confidence'],
+ 'cutoff_ACTeG_level': config_data['cutoff_ACTeG_level'],
+ 'cutoff_DpL_average_prevalence': config_data['cutoff_DpL_average_prevalence'],
+ 'depth' : config_data['depth']
+ }
+ node_type = "Disease"
+ attribute = "name"
+ nbr_end_point = "/api/v1/neighborhood/{}/{}/{}".format(node_type, attribute, node_value)
+ result = get_spoke_api_resp(config_data['BASE_URI'], nbr_end_point, params=api_params)
+ node_context = result.json()
+ nbr_nodes = []
+ nbr_edges = []
+ for item in node_context:
+ if "_" not in item["data"]["neo4j_type"]:
+ try:
+ if item["data"]["neo4j_type"] == "Protein":
+ nbr_nodes.append((item["data"]["neo4j_type"], item["data"]["id"], item["data"]["properties"]["description"]))
+ else:
+ nbr_nodes.append((item["data"]["neo4j_type"], item["data"]["id"], item["data"]["properties"]["name"]))
+ except:
+ nbr_nodes.append((item["data"]["neo4j_type"], item["data"]["id"], item["data"]["properties"]["identifier"]))
+ elif "_" in item["data"]["neo4j_type"]:
+ try:
+ provenance = ", ".join(item["data"]["properties"]["sources"])
+ except:
+ try:
+ provenance = item["data"]["properties"]["source"]
+ if isinstance(provenance, list):
+ provenance = ", ".join(provenance)
+ except:
+ try:
+ preprint_list = ast.literal_eval(item["data"]["properties"]["preprint_list"])
+ if len(preprint_list) > 0:
+ provenance = ", ".join(preprint_list)
+ else:
+ pmid_list = ast.literal_eval(item["data"]["properties"]["pmid_list"])
+ pmid_list = map(lambda x:"pubmedId:"+x, pmid_list)
+ if len(pmid_list) > 0:
+ provenance = ", ".join(pmid_list)
+ else:
+ provenance = "Based on data from Institute For Systems Biology (ISB)"
+ except:
+ provenance = "SPOKE-KG"
+ try:
+ evidence = item["data"]["properties"]
+ except:
+ evidence = None
+ nbr_edges.append((item["data"]["source"], item["data"]["neo4j_type"], item["data"]["target"], provenance, evidence))
+ nbr_nodes_df = pd.DataFrame(nbr_nodes, columns=["node_type", "node_id", "node_name"])
+ nbr_edges_df = pd.DataFrame(nbr_edges, columns=["source", "edge_type", "target", "provenance", "evidence"])
+ merge_1 = pd.merge(nbr_edges_df, nbr_nodes_df, left_on="source", right_on="node_id").drop("node_id", axis=1)
+ merge_1.loc[:,"node_name"] = merge_1.node_type + " " + merge_1.node_name
+ merge_1.drop(["source", "node_type"], axis=1, inplace=True)
+ merge_1 = merge_1.rename(columns={"node_name":"source"})
+ merge_2 = pd.merge(merge_1, nbr_nodes_df, left_on="target", right_on="node_id").drop("node_id", axis=1)
+ merge_2.loc[:,"node_name"] = merge_2.node_type + " " + merge_2.node_name
+ merge_2.drop(["target", "node_type"], axis=1, inplace=True)
+ merge_2 = merge_2.rename(columns={"node_name":"target"})
+ merge_2 = merge_2[["source", "edge_type", "target", "provenance", "evidence"]]
+ merge_2.loc[:, "predicate"] = merge_2.edge_type.apply(lambda x:x.split("_")[0])
+ merge_2.loc[:, "context"] = merge_2.source + " " + merge_2.predicate.str.lower() + " " + merge_2.target + " and Provenance of this association is " + merge_2.provenance + "."
+ context = merge_2.context.str.cat(sep=' ')
+ context += node_value + " has a " + node_context[0]["data"]["properties"]["source"] + " identifier of " + node_context[0]["data"]["properties"]["identifier"] + " and Provenance of this is from " + node_context[0]["data"]["properties"]["source"] + "."
+ return context, merge_2
+
+# if edge_evidence:
+# merge_2.loc[:, "context"] = merge_2.source + " " + merge_2.predicate.str.lower() + " " + merge_2.target + " and Provenance of this association is " + merge_2.provenance + " and attributes associated with this association is in the following JSON format:\n " + merge_2.evidence.astype('str') + "\n\n"
+# else:
+# merge_2.loc[:, "context"] = merge_2.source + " " + merge_2.predicate.str.lower() + " " + merge_2.target + " and Provenance of this association is " + merge_2.provenance + ". "
+# context = merge_2.context.str.cat(sep=' ')
+# context += node_value + " has a " + node_context[0]["data"]["properties"]["source"] + " identifier of " + node_context[0]["data"]["properties"]["identifier"] + " and Provenance of this is from " + node_context[0]["data"]["properties"]["source"] + "."
+# return context
+
+
+def get_prompt(instruction, new_system_prompt):
+ system_prompt = B_SYS + new_system_prompt + E_SYS
+ prompt_template = B_INST + system_prompt + instruction + E_INST
+ return prompt_template
+
+
+def llama_model(model_name, branch_name, cache_dir, temperature=0, top_p=1, max_new_tokens=512, stream=False, method='method-1'):
+ if method == 'method-1':
+ tokenizer = AutoTokenizer.from_pretrained(model_name,
+ revision=branch_name,
+ cache_dir=cache_dir)
+ model = AutoModelForCausalLM.from_pretrained(model_name,
+ device_map='auto',
+ torch_dtype=torch.float16,
+ revision=branch_name,
+ cache_dir=cache_dir)
+ elif method == 'method-2':
+ import transformers
+ tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name,
+ revision=branch_name,
+ cache_dir=cache_dir,
+ legacy=False,
+ token="hf_WbtWB...")
+ model = transformers.LlamaForCausalLM.from_pretrained(model_name,
+ device_map='auto',
+ torch_dtype=torch.float16,
+ revision=branch_name,
+ cache_dir=cache_dir,
+ token="hf_WbtWB...")
+ if not stream:
+ pipe = pipeline("text-generation",
+ model = model,
+ tokenizer = tokenizer,
+ torch_dtype = torch.bfloat16,
+ device_map = "auto",
+ max_new_tokens = max_new_tokens,
+ do_sample = True
+ )
+ else:
+ streamer = TextStreamer(tokenizer)
+ pipe = pipeline("text-generation",
+ model = model,
+ tokenizer = tokenizer,
+ torch_dtype = torch.bfloat16,
+ device_map = "auto",
+ max_new_tokens = max_new_tokens,
+ do_sample = True,
+ streamer=streamer
+ )
+ llm = HuggingFacePipeline(pipeline = pipe,
+ model_kwargs = {"temperature":temperature, "top_p":top_p})
+ return llm
+
+
+@retry(wait=wait_random_exponential(min=10, max=30), stop=stop_after_attempt(5))
+def fetch_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature=0):
+
+ response = openai.ChatCompletion.create(
+ temperature=temperature,
+ # deployment_id=chat_deployment_id,
+ model=chat_model_id,
+ messages=[
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": instruction}
+ ]
+ )
+
+ if 'choices' in response \
+ and isinstance(response['choices'], list) \
+ and len(response) >= 0 \
+ and 'message' in response['choices'][0] \
+ and 'content' in response['choices'][0]['message']:
+ return response['choices'][0]['message']['content']
+ else:
+ return 'Unexpected response'
+
+@memory.cache
+def get_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature=0):
+ res = fetch_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature)
+ return res
+
+
+@retry(wait=wait_random_exponential(min=10, max=30), stop=stop_after_attempt(5))
+def fetch_Gemini_response(instruction, system_prompt, temperature=0.0):
+ model = genai.GenerativeModel(
+ model_name="gemini-2.0-flash",
+ system_instruction=system_prompt,
+ )
+ response = model.generate_content(instruction)
+ return response.text
+
+
+@memory.cache
+def get_Gemini_response(instruction, system_prompt, temperature=0.0):
+ res = fetch_Gemini_response(instruction, system_prompt, temperature)
+ return res
+
+
+def stream_out(output):
+ CHUNK_SIZE = int(round(len(output)/50))
+ SLEEP_TIME = 0.1
+ for i in range(0, len(output), CHUNK_SIZE):
+ print(output[i:i+CHUNK_SIZE], end='')
+ sys.stdout.flush()
+ time.sleep(SLEEP_TIME)
+ print("\n")
+
+
+def get_gpt35():
+ chat_model_id = 'gpt-35-turbo'
+ chat_deployment_id = chat_model_id if openai.api_type == 'azure' else None
+ return chat_model_id, chat_deployment_id
+
+
+def get_gpt4o_mini():
+ chat_model_id = 'gpt-4o-mini'
+ chat_deployment_id = chat_model_id if openai.api_type == 'azure' else None
+ return chat_model_id, chat_deployment_id
+
+
+def get_gemini():
+ chat_model_id = 'gemini-2.0-flash'
+ chat_deployment_id = chat_model_id if openai.api_type == 'azure' else None
+ return chat_model_id, chat_deployment_id
+
+
+def disease_entity_extractor(text):
+ chat_model_id, chat_deployment_id = get_gpt35()
+ resp = get_GPT_response(text, system_prompts["DISEASE_ENTITY_EXTRACTION"], chat_model_id, chat_deployment_id, temperature=0)
+ try:
+ entity_dict = json.loads(resp)
+ return entity_dict["Diseases"]
+ except:
+ return None
+
+
+def disease_entity_extractor_v2(text, model_id):
+ assert model_id in ("gemini-2.0-flash")
+ prompt_updated = system_prompts["DISEASE_ENTITY_EXTRACTION"] + "\n" + "Sentence : " + text
+ resp = get_Gemini_response(prompt_updated, system_prompts["DISEASE_ENTITY_EXTRACTION"], temperature=0.0)
+ if resp.startswith("```json\n"):
+ resp = resp.replace("```json\n", "", 1)
+ if resp.endswith("\n```"):
+ resp = resp.replace("\n```", "", -1)
+ try:
+ entity_dict = json.loads(resp)
+ return entity_dict["Diseases"]
+ except:
+ return None
+
+
+def load_sentence_transformer(sentence_embedding_model):
+ return SentenceTransformerEmbeddings(model_name=sentence_embedding_model)
+
+
+def load_chroma(vector_db_path, sentence_embedding_model):
+ embedding_function = load_sentence_transformer(sentence_embedding_model)
+ return Chroma(persist_directory=vector_db_path, embedding_function=embedding_function)
+
+
+def retrieve_context(question, vectorstore, embedding_function, node_context_df, context_volume, context_sim_threshold, context_sim_min_threshold, edge_evidence,model_id="gpt-3.5-turbo", api=False):
+ print("question:", question)
+ entities = disease_entity_extractor_v2(question, model_id)
+ print("entities:", entities)
+ node_hits = []
+ if entities:
+ max_number_of_high_similarity_context_per_node = int(context_volume/len(entities))
+ for entity in entities:
+ node_search_result = vectorstore.similarity_search_with_score(entity, k=1)
+ node_hits.append(node_search_result[0][0].page_content)
+ question_embedding = embedding_function.embed_query(question)
+ node_context_extracted = ""
+ for node_name in node_hits:
+ if not api:
+ node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0]
+ else:
+ node_context,context_table = get_context_using_spoke_api(node_name)
+ node_context_list = node_context.split(". ")
+ node_context_embeddings = embedding_function.embed_documents(node_context_list)
+ similarities = [cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(node_context_embedding).reshape(1, -1)) for node_context_embedding in node_context_embeddings]
+ similarities = sorted([(e, i) for i, e in enumerate(similarities)], reverse=True)
+ percentile_threshold = np.percentile([s[0] for s in similarities], context_sim_threshold)
+ high_similarity_indices = [s[1] for s in similarities if s[0] > percentile_threshold and s[0] > context_sim_min_threshold]
+ if len(high_similarity_indices) > max_number_of_high_similarity_context_per_node:
+ high_similarity_indices = high_similarity_indices[:max_number_of_high_similarity_context_per_node]
+ high_similarity_context = [node_context_list[index] for index in high_similarity_indices]
+ if edge_evidence:
+ high_similarity_context = list(map(lambda x:x+'.', high_similarity_context))
+ context_table = context_table[context_table.context.isin(high_similarity_context)]
+ context_table.loc[:, "context"] = context_table.source + " " + context_table.predicate.str.lower() + " " + context_table.target + " and Provenance of this association is " + context_table.provenance + " and attributes associated with this association is in the following JSON format:\n " + context_table.evidence.astype('str') + "\n\n"
+ node_context_extracted += context_table.context.str.cat(sep=' ')
+ else:
+ node_context_extracted += ". ".join(high_similarity_context)
+ node_context_extracted += ". "
+ return node_context_extracted
+ else:
+ node_hits = vectorstore.similarity_search_with_score(question, k=5)
+ max_number_of_high_similarity_context_per_node = int(context_volume/5)
+ question_embedding = embedding_function.embed_query(question)
+ node_context_extracted = ""
+ for node in node_hits:
+ node_name = node[0].page_content
+ if not api:
+ node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0]
+ else:
+ node_context, context_table = get_context_using_spoke_api(node_name)
+ node_context_list = node_context.split(". ")
+ node_context_embeddings = embedding_function.embed_documents(node_context_list)
+ similarities = [cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(node_context_embedding).reshape(1, -1)) for node_context_embedding in node_context_embeddings]
+ similarities = sorted([(e, i) for i, e in enumerate(similarities)], reverse=True)
+ percentile_threshold = np.percentile([s[0] for s in similarities], context_sim_threshold)
+ high_similarity_indices = [s[1] for s in similarities if s[0] > percentile_threshold and s[0] > context_sim_min_threshold]
+ if len(high_similarity_indices) > max_number_of_high_similarity_context_per_node:
+ high_similarity_indices = high_similarity_indices[:max_number_of_high_similarity_context_per_node]
+ high_similarity_context = [node_context_list[index] for index in high_similarity_indices]
+ if edge_evidence:
+ high_similarity_context = list(map(lambda x:x+'.', high_similarity_context))
+ context_table = context_table[context_table.context.isin(high_similarity_context)]
+ context_table.loc[:, "context"] = context_table.source + " " + context_table.predicate.str.lower() + " " + context_table.target + " and Provenance of this association is " + context_table.provenance + " and attributes associated with this association is in the following JSON format:\n " + context_table.evidence.astype('str') + "\n\n"
+ node_context_extracted += context_table.context.str.cat(sep=' ')
+ else:
+ node_context_extracted += ". ".join(high_similarity_context)
+ node_context_extracted += ". "
+ return node_context_extracted
+
+
+def interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, llm_type, edge_evidence, system_prompt, api=True, llama_method="method-1"):
+ print(" ")
+ input("Press enter for Step 1 - Disease entity extraction using GPT-3.5-Turbo")
+ print("Processing ...")
+ entities = disease_entity_extractor_v2(question, "gpt-4o-mini")
+ max_number_of_high_similarity_context_per_node = int(config_data["CONTEXT_VOLUME"]/len(entities))
+ print("Extracted entity from the prompt = '{}'".format(", ".join(entities)))
+ print(" ")
+
+ input("Press enter for Step 2 - Match extracted Disease entity to SPOKE nodes")
+ print("Finding vector similarity ...")
+ node_hits = []
+ for entity in entities:
+ node_search_result = vectorstore.similarity_search_with_score(entity, k=1)
+ node_hits.append(node_search_result[0][0].page_content)
+ print("Matched entities from SPOKE = '{}'".format(", ".join(node_hits)))
+ print(" ")
+
+ input("Press enter for Step 3 - Context extraction from SPOKE")
+ node_context = []
+ for node_name in node_hits:
+ if not api:
+ node_context.append(node_context_df[node_context_df.node_name == node_name].node_context.values[0])
+ else:
+ context, context_table = get_context_using_spoke_api(node_name)
+ node_context.append(context)
+ print("Extracted Context is : ")
+ print(". ".join(node_context))
+ print(" ")
+
+ input("Press enter for Step 4 - Context pruning")
+ question_embedding = embedding_function_for_context_retrieval.embed_query(question)
+ node_context_extracted = ""
+ for node_name in node_hits:
+ if not api:
+ node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0]
+ else:
+ node_context, context_table = get_context_using_spoke_api(node_name)
+ node_context_list = node_context.split(". ")
+ node_context_embeddings = embedding_function_for_context_retrieval.embed_documents(node_context_list)
+ similarities = [cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(node_context_embedding).reshape(1, -1)) for node_context_embedding in node_context_embeddings]
+ similarities = sorted([(e, i) for i, e in enumerate(similarities)], reverse=True)
+ percentile_threshold = np.percentile([s[0] for s in similarities], config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"])
+ high_similarity_indices = [s[1] for s in similarities if s[0] > percentile_threshold and s[0] > config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]]
+ if len(high_similarity_indices) > max_number_of_high_similarity_context_per_node:
+ high_similarity_indices = high_similarity_indices[:max_number_of_high_similarity_context_per_node]
+ high_similarity_context = [node_context_list[index] for index in high_similarity_indices]
+ if edge_evidence:
+ high_similarity_context = list(map(lambda x:x+'.', high_similarity_context))
+ context_table = context_table[context_table.context.isin(high_similarity_context)]
+ context_table.loc[:, "context"] = context_table.source + " " + context_table.predicate.str.lower() + " " + context_table.target + " and Provenance of this association is " + context_table.provenance + " and attributes associated with this association is in the following JSON format:\n " + context_table.evidence.astype('str') + "\n\n"
+ node_context_extracted += context_table.context.str.cat(sep=' ')
+ else:
+ node_context_extracted += ". ".join(high_similarity_context)
+ node_context_extracted += ". "
+ print("Pruned Context is : ")
+ print(node_context_extracted)
+ print(" ")
+
+ input("Press enter for Step 5 - LLM prompting")
+ print("Prompting ", llm_type)
+ if llm_type == "llama":
+ from langchain import PromptTemplate, LLMChain
+ template = get_prompt("Context:\n\n{context} \n\nQuestion: {question}", system_prompt)
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
+ llm = llama_model(config_data["LLAMA_MODEL_NAME"], config_data["LLAMA_MODEL_BRANCH"], config_data["LLM_CACHE_DIR"], stream=True, method=llama_method)
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
+ output = llm_chain.run(context=node_context_extracted, question=question)
+ elif "gpt" in llm_type:
+ enriched_prompt = "Context: "+ node_context_extracted + "\n" + "Question: " + question
+ output = get_GPT_response(enriched_prompt, system_prompt, llm_type, llm_type, temperature=config_data["LLM_TEMPERATURE"])
+ stream_out(output)
diff --git a/kg_rag/vectorDB/__init__.py b/kg_rag/vectorDB/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/kg_rag/vectorDB/__init__.py
diff --git a/kg_rag/vectorDB/create_vectordb.py b/kg_rag/vectorDB/create_vectordb.py
new file mode 100644
index 0000000..b02c9ff
--- /dev/null
+++ b/kg_rag/vectorDB/create_vectordb.py
@@ -0,0 +1,35 @@
+import pickle
+from kg_rag.utility import RecursiveCharacterTextSplitter, Chroma, SentenceTransformerEmbeddings, config_data, time
+
+
+DATA_PATH = config_data["VECTOR_DB_DISEASE_ENTITY_PATH"]
+VECTOR_DB_NAME = config_data["VECTOR_DB_PATH"]
+CHUNK_SIZE = int(config_data["VECTOR_DB_CHUNK_SIZE"])
+CHUNK_OVERLAP = int(config_data["VECTOR_DB_CHUNK_OVERLAP"])
+BATCH_SIZE = int(config_data["VECTOR_DB_BATCH_SIZE"])
+SENTENCE_EMBEDDING_MODEL = config_data["VECTOR_DB_SENTENCE_EMBEDDING_MODEL"]
+
+
+def load_data():
+ with open(DATA_PATH, "rb") as f:
+ data = pickle.load(f)
+ metadata_list = list(map(lambda x:{"source": x + " from SPOKE knowledge graph"}, data))
+ return data, metadata_list
+
+def create_vectordb():
+ start_time = time.time()
+ data, metadata_list = load_data()
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
+ docs = text_splitter.create_documents(data, metadatas=metadata_list)
+ batches = [docs[i:i + BATCH_SIZE] for i in range(0, len(docs), BATCH_SIZE)]
+ vectorstore = Chroma(embedding_function=SentenceTransformerEmbeddings(model_name=SENTENCE_EMBEDDING_MODEL),
+ persist_directory=VECTOR_DB_NAME)
+ for batch in batches:
+ vectorstore.add_documents(documents=batch)
+ end_time = round((time.time() - start_time)/(60), 2)
+ print("VectorDB is created in {} mins".format(end_time))
+
+
+if __name__ == "__main__":
+ create_vectordb()
+