diff options
Diffstat (limited to 'kg_rag')
29 files changed, 1658 insertions, 0 deletions
diff --git a/kg_rag/.DS_Store b/kg_rag/.DS_Store Binary files differnew file mode 100644 index 0000000..9f93d7f --- /dev/null +++ b/kg_rag/.DS_Store diff --git a/kg_rag/__init__.py b/kg_rag/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/kg_rag/__init__.py diff --git a/kg_rag/config_loader.py b/kg_rag/config_loader.py new file mode 100644 index 0000000..f3f5bd4 --- /dev/null +++ b/kg_rag/config_loader.py @@ -0,0 +1,17 @@ +import yaml +import os + +with open('config.yaml', 'r') as f: + config_data = yaml.safe_load(f) + +with open('system_prompts.yaml', 'r') as f: + system_prompts = yaml.safe_load(f) + +#if 'GPT_CONFIG_FILE' in config_data: +# config_data['GPT_CONFIG_FILE'] = config_data['GPT_CONFIG_FILE'].replace('$HOME', os.environ['HOME']) + + +__all__ = [ + 'config_data', + 'system_prompts' +] diff --git a/kg_rag/prompt_based_generation/GPT/run_mcq_qa.py b/kg_rag/prompt_based_generation/GPT/run_mcq_qa.py new file mode 100644 index 0000000..762242e --- /dev/null +++ b/kg_rag/prompt_based_generation/GPT/run_mcq_qa.py @@ -0,0 +1,31 @@ +from kg_rag.utility import * +import sys +from tqdm import tqdm + +CHAT_MODEL_ID = sys.argv[1] + +QUESTION_PATH = config_data["MCQ_PATH"] +SYSTEM_PROMPT = system_prompts["MCQ_QUESTION_PROMPT_BASED"] +SAVE_PATH = config_data["SAVE_RESULTS_PATH"] +TEMPERATURE = config_data["LLM_TEMPERATURE"] + +CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID + +save_name = "_".join(CHAT_MODEL_ID.split("-"))+"_prompt_based_response_for_two_hop_mcq_from_monarch_and_robokop.csv" + + +def main(): + start_time = time.time() + question_df = pd.read_csv(QUESTION_PATH) + answer_list = [] + for index, row in tqdm(question_df.head(50).iterrows(), total=50): + question = "Question: "+ row["text"] + output = get_GPT_response(question, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE) + answer_list.append((row["text"], row["correct_node"], output)) + answer_df = pd.DataFrame(answer_list, columns=["question", "correct_answer", "llm_answer"]) + answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True) + print("Completed in {} min".format((time.time()-start_time)/60)) + + +if __name__ == "__main__": + main()
\ No newline at end of file diff --git a/kg_rag/prompt_based_generation/GPT/run_true_false_generation.py b/kg_rag/prompt_based_generation/GPT/run_true_false_generation.py new file mode 100644 index 0000000..0d248db --- /dev/null +++ b/kg_rag/prompt_based_generation/GPT/run_true_false_generation.py @@ -0,0 +1,33 @@ +from kg_rag.utility import * +import sys + + +CHAT_MODEL_ID = sys.argv[1] + +QUESTION_PATH = config_data["TRUE_FALSE_PATH"] +SYSTEM_PROMPT = system_prompts["TRUE_FALSE_QUESTION_PROMPT_BASED"] +SAVE_PATH = config_data["SAVE_RESULTS_PATH"] +TEMPERATURE = config_data["LLM_TEMPERATURE"] + +CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID + +save_name = "_".join(CHAT_MODEL_ID.split("-"))+"_prompt_based_one_hop_true_false_binary_response.csv" + + +def main(): + start_time = time.time() + question_df = pd.read_csv(QUESTION_PATH) + answer_list = [] + for index, row in question_df.iterrows(): + question = "Question: "+ row["text"] + output = get_GPT_response(question, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE) + answer_list.append((row["text"], row["label"], output)) + answer_df = pd.DataFrame(answer_list, columns=["question", "label", "llm_answer"]) + answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True) + print("Completed in {} min".format((time.time()-start_time)/60)) + + + +if __name__ == "__main__": + main() + diff --git a/kg_rag/prompt_based_generation/GPT/text_generation.py b/kg_rag/prompt_based_generation/GPT/text_generation.py new file mode 100644 index 0000000..235ece7 --- /dev/null +++ b/kg_rag/prompt_based_generation/GPT/text_generation.py @@ -0,0 +1,33 @@ +from kg_rag.utility import * +import argparse + + + +parser = argparse.ArgumentParser() +parser.add_argument('-g', type=str, default='gpt-35-turbo', help='GPT model selection') +args = parser.parse_args() + +CHAT_MODEL_ID = args.g + +SYSTEM_PROMPT = system_prompts["PROMPT_BASED_TEXT_GENERATION"] +TEMPERATURE = config_data["LLM_TEMPERATURE"] + +CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID + + +def main(): + print(" ") + question = input("Enter your question : ") + print("Here is the prompt-based answer:") + print("") + output = get_GPT_response(question, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE) + stream_out(output) + + + + + +if __name__ == "__main__": + main() + + diff --git a/kg_rag/prompt_based_generation/Llama/run_mcq_qa.py b/kg_rag/prompt_based_generation/Llama/run_mcq_qa.py new file mode 100644 index 0000000..e3eb1b0 --- /dev/null +++ b/kg_rag/prompt_based_generation/Llama/run_mcq_qa.py @@ -0,0 +1,38 @@ +from langchain import PromptTemplate, LLMChain +from kg_rag.utility import * + + +QUESTION_PATH = config_data["MCQ_PATH"] +SYSTEM_PROMPT = system_prompts["MCQ_QUESTION_PROMPT_BASED"] +SAVE_PATH = config_data["SAVE_RESULTS_PATH"] +MODEL_NAME = config_data["LLAMA_MODEL_NAME"] +BRANCH_NAME = config_data["LLAMA_MODEL_BRANCH"] +CACHE_DIR = config_data["LLM_CACHE_DIR"] + + + +save_name = "_".join(MODEL_NAME.split("/")[-1].split("-"))+"_prompt_based_two_hop_mcq_from_monarch_and_robokop_response.csv" + +INSTRUCTION = "Question: {question}" + + +def main(): + start_time = time.time() + llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR) + template = get_prompt(INSTRUCTION, SYSTEM_PROMPT) + prompt = PromptTemplate(template=template, input_variables=["question"]) + llm_chain = LLMChain(prompt=prompt, llm=llm) + question_df = pd.read_csv(QUESTION_PATH) + answer_list = [] + for index, row in question_df.iterrows(): + question = row["text"] + output = llm_chain.run(question) + answer_list.append((row["text"], row["correct_node"], output)) + answer_df = pd.DataFrame(answer_list, columns=["question", "correct_answer", "llm_answer"]) + answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True) + print("Completed in {} min".format((time.time()-start_time)/60)) + + + +if __name__ == "__main__": + main() diff --git a/kg_rag/prompt_based_generation/Llama/run_mcq_qa_medgpt.py b/kg_rag/prompt_based_generation/Llama/run_mcq_qa_medgpt.py new file mode 100644 index 0000000..91a23cb --- /dev/null +++ b/kg_rag/prompt_based_generation/Llama/run_mcq_qa_medgpt.py @@ -0,0 +1,38 @@ +from langchain import PromptTemplate, LLMChain +from kg_rag.utility import * + + +QUESTION_PATH = config_data["MCQ_PATH"] +SYSTEM_PROMPT = system_prompts["MCQ_QUESTION_PROMPT_BASED"] +SAVE_PATH = config_data["SAVE_RESULTS_PATH"] +MODEL_NAME = 'PharMolix/BioMedGPT-LM-7B' +BRANCH_NAME = 'main' +CACHE_DIR = config_data["LLM_CACHE_DIR"] + + + +save_name = "_".join(MODEL_NAME.split("/")[-1].split("-"))+"_prompt_based_mcq_from_monarch_and_robokop_response.csv" + +INSTRUCTION = "Question: {question}" + + +def main(): + start_time = time.time() + llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR) + template = get_prompt(INSTRUCTION, SYSTEM_PROMPT) + prompt = PromptTemplate(template=template, input_variables=["question"]) + llm_chain = LLMChain(prompt=prompt, llm=llm) + question_df = pd.read_csv(QUESTION_PATH) + answer_list = [] + for index, row in question_df.iterrows(): + question = row["text"] + output = llm_chain.run(question) + answer_list.append((row["text"], row["correct_node"], output)) + answer_df = pd.DataFrame(answer_list, columns=["question", "correct_answer", "llm_answer"]) + answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True) + print("Completed in {} min".format((time.time()-start_time)/60)) + + + +if __name__ == "__main__": + main() diff --git a/kg_rag/prompt_based_generation/Llama/run_true_false_generation.py b/kg_rag/prompt_based_generation/Llama/run_true_false_generation.py new file mode 100644 index 0000000..81a98a1 --- /dev/null +++ b/kg_rag/prompt_based_generation/Llama/run_true_false_generation.py @@ -0,0 +1,35 @@ +from langchain import PromptTemplate, LLMChain +from kg_rag.utility import * + +QUESTION_PATH = config_data["TRUE_FALSE_PATH"] +SYSTEM_PROMPT = system_prompts["TRUE_FALSE_QUESTION_PROMPT_BASED"] +SAVE_PATH = config_data["SAVE_RESULTS_PATH"] +MODEL_NAME = config_data["LLAMA_MODEL_NAME"] +BRANCH_NAME = config_data["LLAMA_MODEL_BRANCH"] +CACHE_DIR = config_data["LLM_CACHE_DIR"] + + +INSTRUCTION = "Question: {question}" + +save_name = "_".join(MODEL_NAME.split("/")[-1].split("-"))+"_prompt_based_one_hop_true_false_binary_response.csv" + +def main(): + start_time = time.time() + llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR) + template = get_prompt(INSTRUCTION, SYSTEM_PROMPT) + prompt = PromptTemplate(template=template, input_variables=["question"]) + llm_chain = LLMChain(prompt=prompt, llm=llm) + question_df = pd.read_csv(QUESTION_PATH) + answer_list = [] + for index, row in question_df.iterrows(): + question = row["text"] + output = llm_chain.run(question) + answer_list.append((row["text"], row["label"], output)) + answer_df = pd.DataFrame(answer_list, columns=["question", "label", "llm_answer"]) + answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True) + print("Completed in {} min".format((time.time()-start_time)/60)) + + + +if __name__ == "__main__": + main() diff --git a/kg_rag/prompt_based_generation/Llama/text_generation.py b/kg_rag/prompt_based_generation/Llama/text_generation.py new file mode 100644 index 0000000..49bfebb --- /dev/null +++ b/kg_rag/prompt_based_generation/Llama/text_generation.py @@ -0,0 +1,36 @@ +from langchain import PromptTemplate, LLMChain +from kg_rag.utility import * +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('-m', type=str, default='method-1', help='Method to choose for Llama model') +args = parser.parse_args() + +METHOD = args.m + + +SYSTEM_PROMPT = system_prompts["PROMPT_BASED_TEXT_GENERATION"] +MODEL_NAME = config_data["LLAMA_MODEL_NAME"] +BRANCH_NAME = config_data["LLAMA_MODEL_BRANCH"] +CACHE_DIR = config_data["LLM_CACHE_DIR"] + + + +INSTRUCTION = "Question: {question}" + + +def main(): + llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR, stream=True, method=METHOD) + template = get_prompt(INSTRUCTION, SYSTEM_PROMPT) + prompt = PromptTemplate(template=template, input_variables=["question"]) + llm_chain = LLMChain(prompt=prompt, llm=llm) + print(" ") + question = input("Enter your question : ") + print("Here is the prompt-based answer:") + print("") + output = llm_chain.run(question) + + + +if __name__ == "__main__": + main() diff --git a/kg_rag/rag_based_generation/GPT/drug_action.py b/kg_rag/rag_based_generation/GPT/drug_action.py new file mode 100644 index 0000000..60c0acf --- /dev/null +++ b/kg_rag/rag_based_generation/GPT/drug_action.py @@ -0,0 +1,52 @@ +from kg_rag.utility import * +import argparse + + + +parser = argparse.ArgumentParser() +parser.add_argument('-g', type=str, default='gpt-35-turbo', help='GPT model selection') +parser.add_argument('-i', type=bool, default=False, help='Flag for interactive mode') +parser.add_argument('-e', type=bool, default=False, help='Flag for showing evidence of association from the graph') +args = parser.parse_args() + +CHAT_MODEL_ID = args.g +INTERACTIVE = args.i +EDGE_EVIDENCE = bool(args.e) + + +SYSTEM_PROMPT = system_prompts["DRUG_ACTION"] +CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"]) +QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"]) +QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]) +VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"] +NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"] +SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"] +SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"] +TEMPERATURE = config_data["LLM_TEMPERATURE"] + + +CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID + +vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL) +embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL) +node_context_df = pd.read_csv(NODE_CONTEXT_PATH) + +def main(): + print(" ") + question = input("Enter your question : ") + if not INTERACTIVE: + print("Retrieving context from SPOKE graph...") + context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE) + print("Here is the KG-RAG based answer:") + print("") + enriched_prompt = "Context: "+ context + "\n" + "Question: " + question + output = get_GPT_response(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE) + stream_out(output) + else: + interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, CHAT_MODEL_ID, EDGE_EVIDENCE, SYSTEM_PROMPT) + + + +if __name__ == "__main__": + main() + diff --git a/kg_rag/rag_based_generation/GPT/drug_repurposing_v2.py b/kg_rag/rag_based_generation/GPT/drug_repurposing_v2.py new file mode 100644 index 0000000..d95053b --- /dev/null +++ b/kg_rag/rag_based_generation/GPT/drug_repurposing_v2.py @@ -0,0 +1,68 @@ +from kg_rag.utility import * +import argparse + + + +parser = argparse.ArgumentParser() +parser.add_argument('-g', type=str, default='gpt-35-turbo', help='GPT model selection') +parser.add_argument('-i', type=bool, default=False, help='Flag for interactive mode') +parser.add_argument('-e', type=bool, default=False, help='Flag for showing evidence of association from the graph') +args = parser.parse_args() + + +CHAT_MODEL_ID = args.g +INTERACTIVE = args.i +EDGE_EVIDENCE = bool(args.e) + +SYSTEM_PROMPT = system_prompts["DRUG_REPURPOSING_V2"] +CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"]) +QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"]) +QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]) +VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"] +NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"] +SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"] +SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"] +TEMPERATURE = config_data["LLM_TEMPERATURE"] + + +CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID + +vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL) +embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL) +node_context_df = pd.read_csv(NODE_CONTEXT_PATH) + +print('') +question = input("Question : ") + +question_template = f''' +To the question asked at the end, answer by referring the context. +See example below +Example 1: + Question: + What drugs can be repurposed for disease X? + Context: + Compound Alizapride DOWNREGULATES Gene APOE and Provenance of this association is XX. Gene APOE ASSOCIATES Disease X and Provenance of this association is YY. Gene TTR encodes Protein Transthyretin (ATTR) and Provenance of this association is ZZ. Compound Acetylcysteine treats Disease X and Provenance of this association is PP. + Answer: + Since Alizapride downregulates gene APOE (Provenance XX) and APOE is associated with Disease X (Provenance YY), Alizapride can be repurposed to treat Disease X. p-value for these associations is XXXX and z-score values for these associations is YYYY. +Question: +{question} +''' + +def main(): + if not INTERACTIVE: + print("Retrieving context from SPOKE graph...") + context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE) + print("Here is the KG-RAG based answer:") + print("") + enriched_prompt = "Context: "+ context + "\n" + "Question: " + question + output = get_GPT_response(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE) + stream_out(output) + else: + interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, CHAT_MODEL_ID, EDGE_EVIDENCE, SYSTEM_PROMPT) + + + +if __name__ == "__main__": + main() + + diff --git a/kg_rag/rag_based_generation/GPT/run_drug_repurposing.py b/kg_rag/rag_based_generation/GPT/run_drug_repurposing.py new file mode 100644 index 0000000..8a5726d --- /dev/null +++ b/kg_rag/rag_based_generation/GPT/run_drug_repurposing.py @@ -0,0 +1,57 @@ +''' +This script takes the drug repurposing style questions from the csv file and save the result as another csv file. +Before running this script, make sure to configure the filepaths in config.yaml file. +Command line argument should be either 'gpt-4' or 'gpt-35-turbo' +''' + +from kg_rag.utility import * +import sys + + + +CHAT_MODEL_ID = sys.argv[1] + +QUESTION_PATH = config_data["DRUG_REPURPOSING_PATH"] +SYSTEM_PROMPT = system_prompts["DRUG_REPURPOSING"] +CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"]) +QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"]) +QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]) +VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"] +NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"] +SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"] +SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"] +TEMPERATURE = config_data["LLM_TEMPERATURE"] +SAVE_PATH = config_data["SAVE_RESULTS_PATH"] + + +CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID + +save_name = "_".join(CHAT_MODEL_ID.split("-"))+"_drug_repurposing_questions_response.csv" + + +vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL) +embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL) +node_context_df = pd.read_csv(NODE_CONTEXT_PATH) + + +def main(): + start_time = time.time() + question_df = pd.read_csv(QUESTION_PATH) + answer_list = [] + for index, row in question_df.iterrows(): + question = row["text"] + context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY) + enriched_prompt = "Context: " + context + "\n" + "Question: " + question + output = get_GPT_response(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE) + answer_list.append((row["disease_in_question"], row["refDisease"], row["compoundGroundTruth"], row["text"], output)) + answer_df = pd.DataFrame(answer_list, columns=["disease_in_question", "refDisease", "compoundGroundTruth", "text", "llm_answer"]) + answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True) + print("Completed in {} min".format((time.time()-start_time)/60)) + + + +if __name__ == "__main__": + main() + + + diff --git a/kg_rag/rag_based_generation/GPT/run_mcq_qa.py b/kg_rag/rag_based_generation/GPT/run_mcq_qa.py new file mode 100644 index 0000000..edf0415 --- /dev/null +++ b/kg_rag/rag_based_generation/GPT/run_mcq_qa.py @@ -0,0 +1,91 @@ +''' +This script takes the MCQ style questions from the csv file and save the result as another csv file. +Before running this script, make sure to configure the filepaths in config.yaml file. +Command line argument should be either 'gpt-4' or 'gpt-35-turbo' +''' + +from kg_rag.utility import * +import sys + + +from tqdm import tqdm +CHAT_MODEL_ID = sys.argv[1] + +QUESTION_PATH = config_data["MCQ_PATH"] +SYSTEM_PROMPT = system_prompts["MCQ_QUESTION"] +CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"]) +QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"]) +QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]) +VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"] +NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"] +SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"] +SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"] +TEMPERATURE = config_data["LLM_TEMPERATURE"] +SAVE_PATH = config_data["SAVE_RESULTS_PATH"] + + +CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID + +save_name = "_".join(CHAT_MODEL_ID.split("-"))+"_kg_rag_based_mcq_{mode}.csv" + + +vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL) +embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL) +node_context_df = pd.read_csv(NODE_CONTEXT_PATH) +edge_evidence = False + + +MODE = "0" +### MODE 0: Original KG_RAG ### +### MODE 1: jsonlize the context from KG search ### +### MODE 2: Add the prior domain knowledge ### +### MODE 3: Combine MODE 1 & 2 ### + +def main(): + start_time = time.time() + question_df = pd.read_csv(QUESTION_PATH) + answer_list = [] + + for index, row in tqdm(question_df.iterrows(), total=306): + try: + question = row["text"] + if MODE == "0": + ### MODE 0: Original KG_RAG ### + context = retrieve_context(row["text"], vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence, model_id=CHAT_MODEL_ID) + enriched_prompt = "Context: "+ context + "\n" + "Question: "+ question + output = get_Gemini_response(enriched_prompt, SYSTEM_PROMPT, temperature=TEMPERATURE) + + if MODE == "1": + ### MODE 1: jsonlize the context from KG search ### + ### Please implement the first strategy here ### + output = '...' + + if MODE == "2": + ### MODE 2: Add the prior domain knowledge ### + ### Please implement the second strategy here ### + output = '...' + + if MODE == "3": + ### MODE 3: Combine MODE 1 & 2 ### + ### Please implement the third strategy here ### + output = '...' + + answer_list.append((row["text"], row["correct_node"], output)) + except Exception as e: + print("Error in processing question: ", row["text"]) + print("Error: ", e) + answer_list.append((row["text"], row["correct_node"], "Error")) + + + answer_df = pd.DataFrame(answer_list, columns=["question", "correct_answer", "llm_answer"]) + output_file = os.path.join(SAVE_PATH, f"{save_name}".format(mode=MODE),) + answer_df.to_csv(output_file, index=False, header=True) + print("Save the model outputs in ", output_file) + print("Completed in {} min".format((time.time()-start_time)/60)) + + + +if __name__ == "__main__": + main() + + diff --git a/kg_rag/rag_based_generation/GPT/run_single_disease_entity_hyperparameter_tuning.py b/kg_rag/rag_based_generation/GPT/run_single_disease_entity_hyperparameter_tuning.py new file mode 100644 index 0000000..aaf8071 --- /dev/null +++ b/kg_rag/rag_based_generation/GPT/run_single_disease_entity_hyperparameter_tuning.py @@ -0,0 +1,61 @@ +''' +This script is used for hyperparameter tuning on one-hop graph traversal questions. +Hyperparameters are 'CONTEXT_VOLUME_LIST' and 'SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL_LIST' + +This will run on one-hop graph traveral questions from the csv file and save the result as another csv file. + +Before running this script, make sure to configure the filepaths in config.yaml file. +Command line argument should be either 'gpt-4' or 'gpt-35-turbo' +''' + +from kg_rag.utility import * +import sys + +CHAT_MODEL_ID = sys.argv[1] + +CONTEXT_VOLUME_LIST = [10, 50, 100, 150, 200] +SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL_LIST = ["pritamdeka/S-PubMedBert-MS-MARCO", "sentence-transformers/all-MiniLM-L6-v2"] +SAVE_NAME_LIST = ["pubmedBert_based_one_hop_questions_parameter_tuning_round_{}.csv", "miniLM_based_one_hop_questions_parameter_tuning_round_{}.csv"] + +QUESTION_PATH = config_data["SINGLE_DISEASE_ENTITY_FILE"] +SYSTEM_PROMPT = system_prompts["SINGLE_DISEASE_ENTITY_VALIDATION"] +QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"]) +QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]) +VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"] +NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"] +SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"] +TEMPERATURE = config_data["LLM_TEMPERATURE"] +SAVE_PATH = config_data["SAVE_RESULTS_PATH"] + +CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID + +vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL) +node_context_df = pd.read_csv(NODE_CONTEXT_PATH) +edge_evidence = False + +def main(): + start_time = time.time() + question_df = pd.read_csv(QUESTION_PATH) + for tranformer_index, sentence_embedding_model_for_context_retrieval in enumerate(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL_LIST): + embedding_function_for_context_retrieval = load_sentence_transformer(sentence_embedding_model_for_context_retrieval) + for context_index, context_volume in enumerate(CONTEXT_VOLUME_LIST): + answer_list = [] + for index, row in question_df.iterrows(): + question = row["text"] + context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, context_volume, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence) + enriched_prompt = "Context: "+ context + "\n" + "Question: " + question + output = get_GPT_response(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE) + if not output: + time.sleep(5) + answer_list.append((row["disease_1"], row["Compounds"], row["Diseases"], row["text"], output, context_volume)) + answer_df = pd.DataFrame(answer_list, columns=["disease", "compound_groundTruth", "disease_groundTruth", "text", "llm_answer", "context_volume"]) + save_name = "_".join(CHAT_MODEL_ID.split("-"))+SAVE_NAME_LIST[tranformer_index].format(context_index+1) + answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True) + print("Completed in {} min".format((time.time()-start_time)/60)) + + + +if __name__ == "__main__": + main() + + diff --git a/kg_rag/rag_based_generation/GPT/run_true_false_generation.py b/kg_rag/rag_based_generation/GPT/run_true_false_generation.py new file mode 100644 index 0000000..7b8d0e3 --- /dev/null +++ b/kg_rag/rag_based_generation/GPT/run_true_false_generation.py @@ -0,0 +1,52 @@ +''' +This script takes the True/False style questions from the csv file and save the result as another csv file. +Before running this script, make sure to configure the filepaths in config.yaml file. +Command line argument should be either 'gpt-4' or 'gpt-35-turbo' +''' + +from kg_rag.utility import * +import sys + +CHAT_MODEL_ID = sys.argv[1] + +QUESTION_PATH = config_data["TRUE_FALSE_PATH"] +SYSTEM_PROMPT = system_prompts["TRUE_FALSE_QUESTION"] +QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"]) +QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]) +VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"] +NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"] +SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"] +SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"] +TEMPERATURE = config_data["LLM_TEMPERATURE"] +SAVE_PATH = config_data["SAVE_RESULTS_PATH"] +CONTEXT_VOLUME = 100 + + +CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID + +save_name = "_".join(CHAT_MODEL_ID.split("-"))+"_kg_rag_based_true_false_binary_response.csv" + + +vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL) +embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL) +node_context_df = pd.read_csv(NODE_CONTEXT_PATH) +edge_evidence = False + +def main(): + start_time = time.time() + question_df = pd.read_csv(QUESTION_PATH) + answer_list = [] + for index, row in question_df.iterrows(): + question = row["text"] + context = retrieve_context(row["text"], vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence) + enriched_prompt = "Context: "+ context + "\n" + "Question: "+ question + output = get_GPT_response(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE) + answer_list.append((row["text"], row["label"], output)) + answer_df = pd.DataFrame(answer_list, columns=["question", "label", "llm_answer"]) + answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True) + print("Completed in {} min".format((time.time()-start_time)/60)) + + +if __name__ == "__main__": + main() + diff --git a/kg_rag/rag_based_generation/GPT/run_two_disease_entity_hyperparameter_tuning.py b/kg_rag/rag_based_generation/GPT/run_two_disease_entity_hyperparameter_tuning.py new file mode 100644 index 0000000..043f39d --- /dev/null +++ b/kg_rag/rag_based_generation/GPT/run_two_disease_entity_hyperparameter_tuning.py @@ -0,0 +1,57 @@ +''' +This script is used for hyperparameter tuning on two-hop graph traversal questions. +Hyperparameters are 'CONTEXT_VOLUME_LIST' and 'SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL_LIST' + +This will run on two-hop graph traveral questions from the csv file and save the result as another csv file. + +Before running this script, make sure to configure the filepaths in config.yaml file. +Command line argument should be either 'gpt-4' or 'gpt-35-turbo' +''' + +from kg_rag.utility import * +import sys + + +CHAT_MODEL_ID = sys.argv[1] + +CONTEXT_VOLUME_LIST = [10, 50, 100, 150, 200] +SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL_LIST = ["pritamdeka/S-PubMedBert-MS-MARCO", "sentence-transformers/all-MiniLM-L6-v2"] +SAVE_NAME_LIST = ["pubmedBert_based_two_hop_questions_parameter_tuning_round_{}.csv", "miniLM_based_two_hop_questions_parameter_tuning_round_{}.csv"] + +QUESTION_PATH = config_data["TWO_DISEASE_ENTITY_FILE"] +SYSTEM_PROMPT = system_prompts["TWO_DISEASE_ENTITY_VALIDATION"] +QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"]) +QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]) +VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"] +NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"] +SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"] +TEMPERATURE = config_data["LLM_TEMPERATURE"] +SAVE_PATH = config_data["SAVE_RESULTS_PATH"] + +vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL) +embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL) +node_context_df = pd.read_csv(NODE_CONTEXT_PATH) +edge_evidence = False + +def main(): + start_time = time.time() + question_df = pd.read_csv(QUESTION_PATH) + for tranformer_index, sentence_embedding_model_for_context_retrieval in enumerate(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL_LIST): + for context_index, context_volume in enumerate(CONTEXT_VOLUME_LIST): + answer_list = [] + for index, row in question_df.iterrows(): + question = row["text"] + context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, context_volume, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence) + enriched_prompt = "Context: "+ context + "\n" + "Question: " + question + output = get_GPT_response(enriched_prompt, system_prompt, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=temperature) + if not output: + time.sleep(5) + answer_list.append((row["disease_1"], row["disease_2"], row["central_nodes"], row["text"], output, context_volume)) + answer_df = pd.DataFrame(answer_list, columns=["disease_1", "disease_2", "central_nodes_groundTruth", "text", "llm_answer", "context_volume"]) + save_name = "_".join(CHAT_MODEL_ID.split("-"))+SAVE_NAME_LIST[tranformer_index].format(context_index+1) + answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True) + print("Completed in {} min".format((time.time()-start_time)/60)) + + +if __name__ == "__main__": + main() diff --git a/kg_rag/rag_based_generation/GPT/text_generation.py b/kg_rag/rag_based_generation/GPT/text_generation.py new file mode 100644 index 0000000..f2fcee1 --- /dev/null +++ b/kg_rag/rag_based_generation/GPT/text_generation.py @@ -0,0 +1,61 @@ +''' +This script takes a question from the user in an interactive fashion and returns the KG-RAG based response in real time +Before running this script, make sure to configure config.yaml file. +Command line argument should be either 'gpt-4' or 'gpt-35-turbo' +''' + +from kg_rag.utility import * +import argparse + + + +parser = argparse.ArgumentParser() +parser.add_argument('-g', type=str, default='gpt-35-turbo', help='GPT model selection') +parser.add_argument('-i', type=bool, default=False, help='Flag for interactive mode') +parser.add_argument('-e', type=bool, default=False, help='Flag for showing evidence of association from the graph') +args = parser.parse_args() + +CHAT_MODEL_ID = args.g +INTERACTIVE = args.i +EDGE_EVIDENCE = bool(args.e) + + +SYSTEM_PROMPT = system_prompts["KG_RAG_BASED_TEXT_GENERATION"] +CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"]) +QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"]) +QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]) +VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"] +NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"] +SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"] +SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"] +TEMPERATURE = config_data["LLM_TEMPERATURE"] + + +CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID if openai.api_type == "azure" else None + + +vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL) +embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL) +node_context_df = pd.read_csv(NODE_CONTEXT_PATH) + +def main(): + print(" ") + question = input("Enter your question : ") + if not INTERACTIVE: + print("Retrieving context from SPOKE graph...") + context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE) + print("Here is the KG-RAG based answer:") + print("") + enriched_prompt = "Context: "+ context + "\n" + "Question: " + question + output = get_GPT_response(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE) + stream_out(output) + else: + interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, CHAT_MODEL_ID, EDGE_EVIDENCE, SYSTEM_PROMPT) + + + +if __name__ == "__main__": + main() + + + diff --git a/kg_rag/rag_based_generation/Llama/run_drug_repurposing.py b/kg_rag/rag_based_generation/Llama/run_drug_repurposing.py new file mode 100644 index 0000000..0b8d2f0 --- /dev/null +++ b/kg_rag/rag_based_generation/Llama/run_drug_repurposing.py @@ -0,0 +1,60 @@ +''' +This script takes the drug repurposing style questions from the csv file and save the result as another csv file. +This script makes use of Llama model. +Before running this script, make sure to configure the filepaths in config.yaml file. +''' + +from langchain import PromptTemplate, LLMChain +from kg_rag.utility import * +import sys + +QUESTION_PATH = config_data["DRUG_REPURPOSING_PATH"] +SYSTEM_PROMPT = system_prompts["DRUG_REPURPOSING"] +CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"]) +QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"]) +QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]) +VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"] +NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"] +SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"] +SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"] +SAVE_PATH = config_data["SAVE_RESULTS_PATH"] +MODEL_NAME = config_data["LLAMA_MODEL_NAME"] +BRANCH_NAME = config_data["LLAMA_MODEL_BRANCH"] +CACHE_DIR = config_data["LLM_CACHE_DIR"] + + +save_name = "_".join(MODEL_NAME.split("/")[-1].split("-"))+"_drug_repurposing_questions_response.csv" + + +INSTRUCTION = "Context:\n\n{context} \n\nQuestion: {question}" + +vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL) +embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL) +node_context_df = pd.read_csv(NODE_CONTEXT_PATH) + + + +def main(): + start_time = time.time() + llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR, max_new_tokens=1024) + template = get_prompt(INSTRUCTION, SYSTEM_PROMPT) + prompt = PromptTemplate(template=template, input_variables=["context", "question"]) + llm_chain = LLMChain(prompt=prompt, llm=llm) + question_df = pd.read_csv(QUESTION_PATH) + answer_list = [] + for index, row in question_df.iterrows(): + question = row["text"] + context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY) + output = llm_chain.run(context=context, question=question) + answer_list.append((row["disease_in_question"], row["refDisease"], row["compoundGroundTruth"], row["text"], output)) + answer_df = pd.DataFrame(answer_list, columns=["disease_in_question", "refDisease", "compoundGroundTruth", "text", "llm_answer"]) + answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True) + print("Completed in {} min".format((time.time()-start_time)/60)) + + + +if __name__ == "__main__": + main() + + + diff --git a/kg_rag/rag_based_generation/Llama/run_mcq_qa.py b/kg_rag/rag_based_generation/Llama/run_mcq_qa.py new file mode 100644 index 0000000..67ae43c --- /dev/null +++ b/kg_rag/rag_based_generation/Llama/run_mcq_qa.py @@ -0,0 +1,61 @@ +''' +This script takes the MCQ style questions from the csv file and save the result as another csv file. +This script makes use of Llama model. +Before running this script, make sure to configure the filepaths in config.yaml file. +''' +from tqdm import tqdm +from langchain import PromptTemplate, LLMChain +from kg_rag.utility import * + + +QUESTION_PATH = config_data["MCQ_PATH"] +SYSTEM_PROMPT = system_prompts["MCQ_QUESTION"] +CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"]) +QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"]) +QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]) +VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"] +NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"] +SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"] +SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"] +SAVE_PATH = config_data["SAVE_RESULTS_PATH"] +MODEL_NAME = config_data["LLAMA_MODEL_NAME"] +BRANCH_NAME = config_data["LLAMA_MODEL_BRANCH"] +CACHE_DIR = config_data["LLM_CACHE_DIR"] + +save_name = "_".join(MODEL_NAME.split("/")[-1].split("-"))+"_kg_rag_based_mcq_from_monarch_and_robokop_response.csv" + + +INSTRUCTION = "Context:\n\n{context} \n\nQuestion: {question}" + +vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL) +embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL) +node_context_df = pd.read_csv(NODE_CONTEXT_PATH) +edge_evidence = False + + + +def main(): + start_time = time.time() + llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR) + template = get_prompt(INSTRUCTION, SYSTEM_PROMPT) + prompt = PromptTemplate(template=template, input_variables=["context", "question"]) + llm_chain = LLMChain(prompt=prompt, llm=llm) + question_df = pd.read_csv(QUESTION_PATH) + answer_list = [] + for index, row in tqdm(question_df.iterrows()): + question = row["text"] + context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence) + output = llm_chain.run(context=context, question=question) + answer_list.append((row["text"], row["correct_node"], output)) + answer_df = pd.DataFrame(answer_list, columns=["question", "correct_answer", "llm_answer"]) + answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True) + print("Completed in {} min".format((time.time()-start_time)/60)) + + + + + +if __name__ == "__main__": + main() + + diff --git a/kg_rag/rag_based_generation/Llama/run_mcq_qa_medgpt.py b/kg_rag/rag_based_generation/Llama/run_mcq_qa_medgpt.py new file mode 100644 index 0000000..813b601 --- /dev/null +++ b/kg_rag/rag_based_generation/Llama/run_mcq_qa_medgpt.py @@ -0,0 +1,61 @@ +''' +This script takes the MCQ style questions from the csv file and save the result as another csv file. +This script makes use of Llama model. +Before running this script, make sure to configure the filepaths in config.yaml file. +''' + +from langchain import PromptTemplate, LLMChain +from kg_rag.utility import * + + +QUESTION_PATH = config_data["MCQ_PATH"] +SYSTEM_PROMPT = system_prompts["MCQ_QUESTION"] +CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"]) +QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"]) +QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]) +VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"] +NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"] +SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"] +SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"] +SAVE_PATH = config_data["SAVE_RESULTS_PATH"] +MODEL_NAME = 'PharMolix/BioMedGPT-LM-7B' +BRANCH_NAME = 'main' +CACHE_DIR = config_data["LLM_CACHE_DIR"] + +save_name = "_".join(MODEL_NAME.split("/")[-1].split("-"))+"_kg_rag_based_mcq_from_monarch_and_robokop_response.csv" + + +INSTRUCTION = "Context:\n\n{context} \n\nQuestion: {question}" + +vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL) +embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL) +node_context_df = pd.read_csv(NODE_CONTEXT_PATH) +edge_evidence = False + + +def main(): + start_time = time.time() + llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR) + template = get_prompt(INSTRUCTION, SYSTEM_PROMPT) + prompt = PromptTemplate(template=template, input_variables=["context", "question"]) + llm_chain = LLMChain(prompt=prompt, llm=llm) + question_df = pd.read_csv(QUESTION_PATH) + question_df = question_df.sample(50, random_state=40) + answer_list = [] + for index, row in question_df.iterrows(): + question = row["text"] + context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence) + output = llm_chain.run(context=context, question=question) + answer_list.append((row["text"], row["correct_node"], output)) + answer_df = pd.DataFrame(answer_list, columns=["question", "correct_answer", "llm_answer"]) + answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True) + print("Completed in {} min".format((time.time()-start_time)/60)) + + + + + +if __name__ == "__main__": + main() + + diff --git a/kg_rag/rag_based_generation/Llama/run_true_false_generation.py b/kg_rag/rag_based_generation/Llama/run_true_false_generation.py new file mode 100644 index 0000000..fa1a37d --- /dev/null +++ b/kg_rag/rag_based_generation/Llama/run_true_false_generation.py @@ -0,0 +1,59 @@ +''' +This script takes the True/False style questions from the csv file and save the result as another csv file. +This script makes use of Llama model. +Before running this script, make sure to configure the filepaths in config.yaml file. +''' + +from langchain import PromptTemplate, LLMChain +from kg_rag.utility import * +import sys + + +QUESTION_PATH = config_data["TRUE_FALSE_PATH"] +SYSTEM_PROMPT = system_prompts["TRUE_FALSE_QUESTION"] +QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"]) +QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]) +VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"] +NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"] +SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"] +SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"] +SAVE_PATH = config_data["SAVE_RESULTS_PATH"] +MODEL_NAME = config_data["LLAMA_MODEL_NAME"] +BRANCH_NAME = config_data["LLAMA_MODEL_BRANCH"] +CACHE_DIR = config_data["LLM_CACHE_DIR"] +CONTEXT_VOLUME = 100 +edge_evidence = False + +save_name = "_".join(MODEL_NAME.split("/")[-1].split("-"))+"_kg_rag_based_true_false_binary_response.csv" + + +INSTRUCTION = "Context:\n\n{context} \n\nQuestion: {question}" + +vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL) +embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL) +node_context_df = pd.read_csv(NODE_CONTEXT_PATH) + + +def main(): + start_time = time.time() + llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR) + template = get_prompt(INSTRUCTION, SYSTEM_PROMPT) + prompt = PromptTemplate(template=template, input_variables=["context", "question"]) + llm_chain = LLMChain(prompt=prompt, llm=llm) + question_df = pd.read_csv(QUESTION_PATH) + answer_list = [] + for index, row in question_df.iterrows(): + question = row["text"] + context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, edge_evidence) + output = llm_chain.run(context=context, question=question) + answer_list.append((row["text"], row["label"], output)) + answer_df = pd.DataFrame(answer_list, columns=["question", "label", "llm_answer"]) + answer_df.to_csv(os.path.join(SAVE_PATH, save_name), index=False, header=True) + print("Completed in {} min".format((time.time()-start_time)/60)) + + + +if __name__ == "__main__": + main() + +
\ No newline at end of file diff --git a/kg_rag/rag_based_generation/Llama/text_generation.py b/kg_rag/rag_based_generation/Llama/text_generation.py new file mode 100644 index 0000000..2824135 --- /dev/null +++ b/kg_rag/rag_based_generation/Llama/text_generation.py @@ -0,0 +1,60 @@ +from langchain import PromptTemplate, LLMChain +from kg_rag.utility import * +import argparse + + + +parser = argparse.ArgumentParser() +parser.add_argument('-i', type=bool, default=False, help='Flag for interactive mode') +parser.add_argument('-m', type=str, default='method-1', help='Method to choose for Llama model') +parser.add_argument('-e', type=bool, default=False, help='Flag for showing evidence of association from the graph') +args = parser.parse_args() + +INTERACTIVE = args.i +METHOD = args.m +EDGE_EVIDENCE = bool(args.e) + + +SYSTEM_PROMPT = system_prompts["KG_RAG_BASED_TEXT_GENERATION"] +CONTEXT_VOLUME = int(config_data["CONTEXT_VOLUME"]) +QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"]) +QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]) +VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"] +NODE_CONTEXT_PATH = config_data["NODE_CONTEXT_PATH"] +SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"] +SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL"] +MODEL_NAME = config_data["LLAMA_MODEL_NAME"] +BRANCH_NAME = config_data["LLAMA_MODEL_BRANCH"] +CACHE_DIR = config_data["LLM_CACHE_DIR"] + + +INSTRUCTION = "Context:\n\n{context} \n\nQuestion: {question}" + +vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL) +embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL) +node_context_df = pd.read_csv(NODE_CONTEXT_PATH) + +def main(): + print(" ") + question = input("Enter your question : ") + if not INTERACTIVE: + template = get_prompt(INSTRUCTION, SYSTEM_PROMPT) + prompt = PromptTemplate(template=template, input_variables=["context", "question"]) + llm = llama_model(MODEL_NAME, BRANCH_NAME, CACHE_DIR, stream=True, method=METHOD) + llm_chain = LLMChain(prompt=prompt, llm=llm) + print("Retrieving context from SPOKE graph...") + context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE) + print("Here is the KG-RAG based answer using Llama:") + print("") + output = llm_chain.run(context=context, question=question) + else: + interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, "llama", EDGE_EVIDENCE, SYSTEM_PROMPT, llama_method=METHOD) + + + + + + + +if __name__ == "__main__": + main() diff --git a/kg_rag/run_setup.py b/kg_rag/run_setup.py new file mode 100644 index 0000000..04c856c --- /dev/null +++ b/kg_rag/run_setup.py @@ -0,0 +1,77 @@ +import os +from kg_rag.utility import config_data + +def download_llama(method): + from kg_rag.utility import llama_model + try: + llama_model(config_data["LLAMA_MODEL_NAME"], config_data["LLAMA_MODEL_BRANCH"], config_data["LLM_CACHE_DIR"], method=method) + print("Model is successfully downloaded to the provided cache directory!") + except: + print("Model is not downloaded! Make sure the above mentioned conditions are satisfied") + + +print("") +print("Starting to set up KG-RAG ...") +print("") + +# user_input = input("Did you update the config.yaml file with all necessary configurations (such as GPT .env path, vectorDB file paths, other file paths)? Enter Y or N: ") +# print("") +# if user_input == "Y": +if True: + print("Checking disease vectorDB ...") + print("The current VECTOR_DB_PATH is ", config_data["VECTOR_DB_PATH"]) + try: + if os.path.exists(config_data["VECTOR_DB_PATH"]): + print("vectorDB already exists!") + else: + print("Creating vectorDB ...") + from kg_rag.vectorDB.create_vectordb import create_vectordb + create_vectordb() + print("Congratulations! The disease database is completed.") + except: + print("Double check the path that was given in VECTOR_DB_PATH of config.yaml file.") + ''' + print("") + user_input_1 = input("Do you want to install Llama model? Enter Y or N: ") + if user_input_1 == "Y": + user_input_2 = input("Did you update the config.yaml file with proper configuration for downloading Llama model? Enter Y or N: ") + if user_input_2 == "Y": + user_input_3 = input("Are you using official Llama model from Meta? Enter Y or N: ") + if user_input_3 == "Y": + user_input_4 = input("Did you get access to use the model? Enter Y or N: ") + if user_input_4 == "Y": + download_llama() + print("Congratulations! Setup is completed.") + else: + print("Aborting!") + else: + download_llama(method='method-1') + user_input_5 = input("Did you get a message like 'Model is not downloaded!'? Enter Y or N: ") + if user_input_5 == "N": + print("Congratulations! Setup is completed.") + else: + download_llama(method='method-2') + user_input_6 = input("Did you get a message like 'Model is not downloaded!'? Enter Y or N: ") + if user_input_6 == "N": + print(""" + IMPORTANT : + Llama model was downloaded using 'LlamaTokenizer' instead of 'AutoTokenizer' method. + So, when you run text generation script, please provide an extra command line argument '-m method-2'. + For example: + python -m kg_rag.rag_based_generation.Llama.text_generation -m method-2 + """) + print("Congratulations! Setup is completed.") + else: + print("We have now tried two methods to download Llama. If they both do not work, then please check the Llama configuration requirement in the huggingface model card page. Aborting!") + else: + print("Aborting!") + else: + print("No problem. Llama will get installed on-the-fly when you run the model for the first time.") + print("Congratulations! Setup is completed.") + ''' +else: + print("As the first step, update config.yaml file and then run this python script again.") + + + + diff --git a/kg_rag/test/__init__.py b/kg_rag/test/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/kg_rag/test/__init__.py diff --git a/kg_rag/test/test_vectordb.py b/kg_rag/test/test_vectordb.py new file mode 100644 index 0000000..9365971 --- /dev/null +++ b/kg_rag/test/test_vectordb.py @@ -0,0 +1,42 @@ +from kg_rag.utility import * +import sys + +VECTOR_DB_PATH = config_data["VECTOR_DB_PATH"] +SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data["SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL"] + +print("Testing vectorDB loading ...") +print("") +try: + vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL) + print("vectorDB is loaded succesfully!") +except: + print("vectorDB is not loaded. Check the path given in 'VECTOR_DB_PATH' of config.yaml") + print("") + sys.exit(1) +try: + print("") + print("Testing entity extraction ...") + print("") + entity = "psoriasis" + print("Inputting '{}' as the entity to test ...".format(entity)) + print("") + node_search_result = vectorstore.similarity_search_with_score(entity, k=1) + extracted_entity = node_search_result[0][0].page_content + print("Extracted entity is '{}'".format(extracted_entity)) + print("") + if extracted_entity == "psoriasis": + print("Entity extraction is successful!") + print("") + print("vectorDB is correctly populated and is good to go!") + else: + print("Entity extraction is not successful. Make sure vectorDB is populated correctly. Refer 'How to run KG-RAG' Step 5") + print("") + sys.exit(1) +except: + print("Entity extraction is not successful. Make sure vectorDB is populated correctly. Refer 'How to run KG-RAG' Step 5") + print("") + sys.exit(1) + + + + diff --git a/kg_rag/utility.py b/kg_rag/utility.py new file mode 100644 index 0000000..975367b --- /dev/null +++ b/kg_rag/utility.py @@ -0,0 +1,443 @@ +import pandas as pd +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity +from joblib import Memory +import json +import openai +import os +import sys +from tenacity import retry, stop_after_attempt, wait_random_exponential +import time +from dotenv import load_dotenv, find_dotenv +import torch +from langchain import HuggingFacePipeline +from langchain.vectorstores import Chroma +from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings +from langchain.text_splitter import RecursiveCharacterTextSplitter +from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer, GPTQConfig +from kg_rag.config_loader import * +import ast +import requests +import google.generativeai as genai + + +memory = Memory("cachegpt", verbose=0) + +# Config openai library +config_file = config_data['GPT_CONFIG_FILE'] +load_dotenv(config_file) +api_key = os.environ.get('API_KEY') +api_version = os.environ.get('API_VERSION') +resource_endpoint = os.environ.get('RESOURCE_ENDPOINT') +openai.api_type = config_data['GPT_API_TYPE'] +openai.api_key = api_key +if resource_endpoint: + openai.api_base = resource_endpoint +if api_version: + openai.api_version = api_version + +genai.configure(api_key=os.environ["GOOGLE_API_KEY"]) + + +torch.cuda.empty_cache() +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" + + +def get_spoke_api_resp(base_uri, end_point, params=None): + uri = base_uri + end_point + if params: + return requests.get(uri, params=params) + else: + return requests.get(uri) + + +@retry(wait=wait_random_exponential(min=10, max=30), stop=stop_after_attempt(5)) +def get_context_using_spoke_api(node_value): + type_end_point = "/api/v1/types" + result = get_spoke_api_resp(config_data['BASE_URI'], type_end_point) + data_spoke_types = result.json() + node_types = list(data_spoke_types["nodes"].keys()) + edge_types = list(data_spoke_types["edges"].keys()) + node_types_to_remove = ["DatabaseTimestamp", "Version"] + filtered_node_types = [node_type for node_type in node_types if node_type not in node_types_to_remove] + api_params = { + 'node_filters' : filtered_node_types, + 'edge_filters': edge_types, + 'cutoff_Compound_max_phase': config_data['cutoff_Compound_max_phase'], + 'cutoff_Protein_source': config_data['cutoff_Protein_source'], + 'cutoff_DaG_diseases_sources': config_data['cutoff_DaG_diseases_sources'], + 'cutoff_DaG_textmining': config_data['cutoff_DaG_textmining'], + 'cutoff_CtD_phase': config_data['cutoff_CtD_phase'], + 'cutoff_PiP_confidence': config_data['cutoff_PiP_confidence'], + 'cutoff_ACTeG_level': config_data['cutoff_ACTeG_level'], + 'cutoff_DpL_average_prevalence': config_data['cutoff_DpL_average_prevalence'], + 'depth' : config_data['depth'] + } + node_type = "Disease" + attribute = "name" + nbr_end_point = "/api/v1/neighborhood/{}/{}/{}".format(node_type, attribute, node_value) + result = get_spoke_api_resp(config_data['BASE_URI'], nbr_end_point, params=api_params) + node_context = result.json() + nbr_nodes = [] + nbr_edges = [] + for item in node_context: + if "_" not in item["data"]["neo4j_type"]: + try: + if item["data"]["neo4j_type"] == "Protein": + nbr_nodes.append((item["data"]["neo4j_type"], item["data"]["id"], item["data"]["properties"]["description"])) + else: + nbr_nodes.append((item["data"]["neo4j_type"], item["data"]["id"], item["data"]["properties"]["name"])) + except: + nbr_nodes.append((item["data"]["neo4j_type"], item["data"]["id"], item["data"]["properties"]["identifier"])) + elif "_" in item["data"]["neo4j_type"]: + try: + provenance = ", ".join(item["data"]["properties"]["sources"]) + except: + try: + provenance = item["data"]["properties"]["source"] + if isinstance(provenance, list): + provenance = ", ".join(provenance) + except: + try: + preprint_list = ast.literal_eval(item["data"]["properties"]["preprint_list"]) + if len(preprint_list) > 0: + provenance = ", ".join(preprint_list) + else: + pmid_list = ast.literal_eval(item["data"]["properties"]["pmid_list"]) + pmid_list = map(lambda x:"pubmedId:"+x, pmid_list) + if len(pmid_list) > 0: + provenance = ", ".join(pmid_list) + else: + provenance = "Based on data from Institute For Systems Biology (ISB)" + except: + provenance = "SPOKE-KG" + try: + evidence = item["data"]["properties"] + except: + evidence = None + nbr_edges.append((item["data"]["source"], item["data"]["neo4j_type"], item["data"]["target"], provenance, evidence)) + nbr_nodes_df = pd.DataFrame(nbr_nodes, columns=["node_type", "node_id", "node_name"]) + nbr_edges_df = pd.DataFrame(nbr_edges, columns=["source", "edge_type", "target", "provenance", "evidence"]) + merge_1 = pd.merge(nbr_edges_df, nbr_nodes_df, left_on="source", right_on="node_id").drop("node_id", axis=1) + merge_1.loc[:,"node_name"] = merge_1.node_type + " " + merge_1.node_name + merge_1.drop(["source", "node_type"], axis=1, inplace=True) + merge_1 = merge_1.rename(columns={"node_name":"source"}) + merge_2 = pd.merge(merge_1, nbr_nodes_df, left_on="target", right_on="node_id").drop("node_id", axis=1) + merge_2.loc[:,"node_name"] = merge_2.node_type + " " + merge_2.node_name + merge_2.drop(["target", "node_type"], axis=1, inplace=True) + merge_2 = merge_2.rename(columns={"node_name":"target"}) + merge_2 = merge_2[["source", "edge_type", "target", "provenance", "evidence"]] + merge_2.loc[:, "predicate"] = merge_2.edge_type.apply(lambda x:x.split("_")[0]) + merge_2.loc[:, "context"] = merge_2.source + " " + merge_2.predicate.str.lower() + " " + merge_2.target + " and Provenance of this association is " + merge_2.provenance + "." + context = merge_2.context.str.cat(sep=' ') + context += node_value + " has a " + node_context[0]["data"]["properties"]["source"] + " identifier of " + node_context[0]["data"]["properties"]["identifier"] + " and Provenance of this is from " + node_context[0]["data"]["properties"]["source"] + "." + return context, merge_2 + +# if edge_evidence: +# merge_2.loc[:, "context"] = merge_2.source + " " + merge_2.predicate.str.lower() + " " + merge_2.target + " and Provenance of this association is " + merge_2.provenance + " and attributes associated with this association is in the following JSON format:\n " + merge_2.evidence.astype('str') + "\n\n" +# else: +# merge_2.loc[:, "context"] = merge_2.source + " " + merge_2.predicate.str.lower() + " " + merge_2.target + " and Provenance of this association is " + merge_2.provenance + ". " +# context = merge_2.context.str.cat(sep=' ') +# context += node_value + " has a " + node_context[0]["data"]["properties"]["source"] + " identifier of " + node_context[0]["data"]["properties"]["identifier"] + " and Provenance of this is from " + node_context[0]["data"]["properties"]["source"] + "." +# return context + + +def get_prompt(instruction, new_system_prompt): + system_prompt = B_SYS + new_system_prompt + E_SYS + prompt_template = B_INST + system_prompt + instruction + E_INST + return prompt_template + + +def llama_model(model_name, branch_name, cache_dir, temperature=0, top_p=1, max_new_tokens=512, stream=False, method='method-1'): + if method == 'method-1': + tokenizer = AutoTokenizer.from_pretrained(model_name, + revision=branch_name, + cache_dir=cache_dir) + model = AutoModelForCausalLM.from_pretrained(model_name, + device_map='auto', + torch_dtype=torch.float16, + revision=branch_name, + cache_dir=cache_dir) + elif method == 'method-2': + import transformers + tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name, + revision=branch_name, + cache_dir=cache_dir, + legacy=False, + token="hf_WbtWB...") + model = transformers.LlamaForCausalLM.from_pretrained(model_name, + device_map='auto', + torch_dtype=torch.float16, + revision=branch_name, + cache_dir=cache_dir, + token="hf_WbtWB...") + if not stream: + pipe = pipeline("text-generation", + model = model, + tokenizer = tokenizer, + torch_dtype = torch.bfloat16, + device_map = "auto", + max_new_tokens = max_new_tokens, + do_sample = True + ) + else: + streamer = TextStreamer(tokenizer) + pipe = pipeline("text-generation", + model = model, + tokenizer = tokenizer, + torch_dtype = torch.bfloat16, + device_map = "auto", + max_new_tokens = max_new_tokens, + do_sample = True, + streamer=streamer + ) + llm = HuggingFacePipeline(pipeline = pipe, + model_kwargs = {"temperature":temperature, "top_p":top_p}) + return llm + + +@retry(wait=wait_random_exponential(min=10, max=30), stop=stop_after_attempt(5)) +def fetch_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature=0): + + response = openai.ChatCompletion.create( + temperature=temperature, + # deployment_id=chat_deployment_id, + model=chat_model_id, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": instruction} + ] + ) + + if 'choices' in response \ + and isinstance(response['choices'], list) \ + and len(response) >= 0 \ + and 'message' in response['choices'][0] \ + and 'content' in response['choices'][0]['message']: + return response['choices'][0]['message']['content'] + else: + return 'Unexpected response' + +@memory.cache +def get_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature=0): + res = fetch_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature) + return res + + +@retry(wait=wait_random_exponential(min=10, max=30), stop=stop_after_attempt(5)) +def fetch_Gemini_response(instruction, system_prompt, temperature=0.0): + model = genai.GenerativeModel( + model_name="gemini-2.0-flash", + system_instruction=system_prompt, + ) + response = model.generate_content(instruction) + return response.text + + +@memory.cache +def get_Gemini_response(instruction, system_prompt, temperature=0.0): + res = fetch_Gemini_response(instruction, system_prompt, temperature) + return res + + +def stream_out(output): + CHUNK_SIZE = int(round(len(output)/50)) + SLEEP_TIME = 0.1 + for i in range(0, len(output), CHUNK_SIZE): + print(output[i:i+CHUNK_SIZE], end='') + sys.stdout.flush() + time.sleep(SLEEP_TIME) + print("\n") + + +def get_gpt35(): + chat_model_id = 'gpt-35-turbo' + chat_deployment_id = chat_model_id if openai.api_type == 'azure' else None + return chat_model_id, chat_deployment_id + + +def get_gpt4o_mini(): + chat_model_id = 'gpt-4o-mini' + chat_deployment_id = chat_model_id if openai.api_type == 'azure' else None + return chat_model_id, chat_deployment_id + + +def get_gemini(): + chat_model_id = 'gemini-2.0-flash' + chat_deployment_id = chat_model_id if openai.api_type == 'azure' else None + return chat_model_id, chat_deployment_id + + +def disease_entity_extractor(text): + chat_model_id, chat_deployment_id = get_gpt35() + resp = get_GPT_response(text, system_prompts["DISEASE_ENTITY_EXTRACTION"], chat_model_id, chat_deployment_id, temperature=0) + try: + entity_dict = json.loads(resp) + return entity_dict["Diseases"] + except: + return None + + +def disease_entity_extractor_v2(text, model_id): + assert model_id in ("gemini-2.0-flash") + prompt_updated = system_prompts["DISEASE_ENTITY_EXTRACTION"] + "\n" + "Sentence : " + text + resp = get_Gemini_response(prompt_updated, system_prompts["DISEASE_ENTITY_EXTRACTION"], temperature=0.0) + if resp.startswith("```json\n"): + resp = resp.replace("```json\n", "", 1) + if resp.endswith("\n```"): + resp = resp.replace("\n```", "", -1) + try: + entity_dict = json.loads(resp) + return entity_dict["Diseases"] + except: + return None + + +def load_sentence_transformer(sentence_embedding_model): + return SentenceTransformerEmbeddings(model_name=sentence_embedding_model) + + +def load_chroma(vector_db_path, sentence_embedding_model): + embedding_function = load_sentence_transformer(sentence_embedding_model) + return Chroma(persist_directory=vector_db_path, embedding_function=embedding_function) + + +def retrieve_context(question, vectorstore, embedding_function, node_context_df, context_volume, context_sim_threshold, context_sim_min_threshold, edge_evidence,model_id="gpt-3.5-turbo", api=False): + print("question:", question) + entities = disease_entity_extractor_v2(question, model_id) + print("entities:", entities) + node_hits = [] + if entities: + max_number_of_high_similarity_context_per_node = int(context_volume/len(entities)) + for entity in entities: + node_search_result = vectorstore.similarity_search_with_score(entity, k=1) + node_hits.append(node_search_result[0][0].page_content) + question_embedding = embedding_function.embed_query(question) + node_context_extracted = "" + for node_name in node_hits: + if not api: + node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0] + else: + node_context,context_table = get_context_using_spoke_api(node_name) + node_context_list = node_context.split(". ") + node_context_embeddings = embedding_function.embed_documents(node_context_list) + similarities = [cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(node_context_embedding).reshape(1, -1)) for node_context_embedding in node_context_embeddings] + similarities = sorted([(e, i) for i, e in enumerate(similarities)], reverse=True) + percentile_threshold = np.percentile([s[0] for s in similarities], context_sim_threshold) + high_similarity_indices = [s[1] for s in similarities if s[0] > percentile_threshold and s[0] > context_sim_min_threshold] + if len(high_similarity_indices) > max_number_of_high_similarity_context_per_node: + high_similarity_indices = high_similarity_indices[:max_number_of_high_similarity_context_per_node] + high_similarity_context = [node_context_list[index] for index in high_similarity_indices] + if edge_evidence: + high_similarity_context = list(map(lambda x:x+'.', high_similarity_context)) + context_table = context_table[context_table.context.isin(high_similarity_context)] + context_table.loc[:, "context"] = context_table.source + " " + context_table.predicate.str.lower() + " " + context_table.target + " and Provenance of this association is " + context_table.provenance + " and attributes associated with this association is in the following JSON format:\n " + context_table.evidence.astype('str') + "\n\n" + node_context_extracted += context_table.context.str.cat(sep=' ') + else: + node_context_extracted += ". ".join(high_similarity_context) + node_context_extracted += ". " + return node_context_extracted + else: + node_hits = vectorstore.similarity_search_with_score(question, k=5) + max_number_of_high_similarity_context_per_node = int(context_volume/5) + question_embedding = embedding_function.embed_query(question) + node_context_extracted = "" + for node in node_hits: + node_name = node[0].page_content + if not api: + node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0] + else: + node_context, context_table = get_context_using_spoke_api(node_name) + node_context_list = node_context.split(". ") + node_context_embeddings = embedding_function.embed_documents(node_context_list) + similarities = [cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(node_context_embedding).reshape(1, -1)) for node_context_embedding in node_context_embeddings] + similarities = sorted([(e, i) for i, e in enumerate(similarities)], reverse=True) + percentile_threshold = np.percentile([s[0] for s in similarities], context_sim_threshold) + high_similarity_indices = [s[1] for s in similarities if s[0] > percentile_threshold and s[0] > context_sim_min_threshold] + if len(high_similarity_indices) > max_number_of_high_similarity_context_per_node: + high_similarity_indices = high_similarity_indices[:max_number_of_high_similarity_context_per_node] + high_similarity_context = [node_context_list[index] for index in high_similarity_indices] + if edge_evidence: + high_similarity_context = list(map(lambda x:x+'.', high_similarity_context)) + context_table = context_table[context_table.context.isin(high_similarity_context)] + context_table.loc[:, "context"] = context_table.source + " " + context_table.predicate.str.lower() + " " + context_table.target + " and Provenance of this association is " + context_table.provenance + " and attributes associated with this association is in the following JSON format:\n " + context_table.evidence.astype('str') + "\n\n" + node_context_extracted += context_table.context.str.cat(sep=' ') + else: + node_context_extracted += ". ".join(high_similarity_context) + node_context_extracted += ". " + return node_context_extracted + + +def interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, llm_type, edge_evidence, system_prompt, api=True, llama_method="method-1"): + print(" ") + input("Press enter for Step 1 - Disease entity extraction using GPT-3.5-Turbo") + print("Processing ...") + entities = disease_entity_extractor_v2(question, "gpt-4o-mini") + max_number_of_high_similarity_context_per_node = int(config_data["CONTEXT_VOLUME"]/len(entities)) + print("Extracted entity from the prompt = '{}'".format(", ".join(entities))) + print(" ") + + input("Press enter for Step 2 - Match extracted Disease entity to SPOKE nodes") + print("Finding vector similarity ...") + node_hits = [] + for entity in entities: + node_search_result = vectorstore.similarity_search_with_score(entity, k=1) + node_hits.append(node_search_result[0][0].page_content) + print("Matched entities from SPOKE = '{}'".format(", ".join(node_hits))) + print(" ") + + input("Press enter for Step 3 - Context extraction from SPOKE") + node_context = [] + for node_name in node_hits: + if not api: + node_context.append(node_context_df[node_context_df.node_name == node_name].node_context.values[0]) + else: + context, context_table = get_context_using_spoke_api(node_name) + node_context.append(context) + print("Extracted Context is : ") + print(". ".join(node_context)) + print(" ") + + input("Press enter for Step 4 - Context pruning") + question_embedding = embedding_function_for_context_retrieval.embed_query(question) + node_context_extracted = "" + for node_name in node_hits: + if not api: + node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0] + else: + node_context, context_table = get_context_using_spoke_api(node_name) + node_context_list = node_context.split(". ") + node_context_embeddings = embedding_function_for_context_retrieval.embed_documents(node_context_list) + similarities = [cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(node_context_embedding).reshape(1, -1)) for node_context_embedding in node_context_embeddings] + similarities = sorted([(e, i) for i, e in enumerate(similarities)], reverse=True) + percentile_threshold = np.percentile([s[0] for s in similarities], config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"]) + high_similarity_indices = [s[1] for s in similarities if s[0] > percentile_threshold and s[0] > config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]] + if len(high_similarity_indices) > max_number_of_high_similarity_context_per_node: + high_similarity_indices = high_similarity_indices[:max_number_of_high_similarity_context_per_node] + high_similarity_context = [node_context_list[index] for index in high_similarity_indices] + if edge_evidence: + high_similarity_context = list(map(lambda x:x+'.', high_similarity_context)) + context_table = context_table[context_table.context.isin(high_similarity_context)] + context_table.loc[:, "context"] = context_table.source + " " + context_table.predicate.str.lower() + " " + context_table.target + " and Provenance of this association is " + context_table.provenance + " and attributes associated with this association is in the following JSON format:\n " + context_table.evidence.astype('str') + "\n\n" + node_context_extracted += context_table.context.str.cat(sep=' ') + else: + node_context_extracted += ". ".join(high_similarity_context) + node_context_extracted += ". " + print("Pruned Context is : ") + print(node_context_extracted) + print(" ") + + input("Press enter for Step 5 - LLM prompting") + print("Prompting ", llm_type) + if llm_type == "llama": + from langchain import PromptTemplate, LLMChain + template = get_prompt("Context:\n\n{context} \n\nQuestion: {question}", system_prompt) + prompt = PromptTemplate(template=template, input_variables=["context", "question"]) + llm = llama_model(config_data["LLAMA_MODEL_NAME"], config_data["LLAMA_MODEL_BRANCH"], config_data["LLM_CACHE_DIR"], stream=True, method=llama_method) + llm_chain = LLMChain(prompt=prompt, llm=llm) + output = llm_chain.run(context=node_context_extracted, question=question) + elif "gpt" in llm_type: + enriched_prompt = "Context: "+ node_context_extracted + "\n" + "Question: " + question + output = get_GPT_response(enriched_prompt, system_prompt, llm_type, llm_type, temperature=config_data["LLM_TEMPERATURE"]) + stream_out(output) diff --git a/kg_rag/vectorDB/__init__.py b/kg_rag/vectorDB/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/kg_rag/vectorDB/__init__.py diff --git a/kg_rag/vectorDB/create_vectordb.py b/kg_rag/vectorDB/create_vectordb.py new file mode 100644 index 0000000..b02c9ff --- /dev/null +++ b/kg_rag/vectorDB/create_vectordb.py @@ -0,0 +1,35 @@ +import pickle +from kg_rag.utility import RecursiveCharacterTextSplitter, Chroma, SentenceTransformerEmbeddings, config_data, time + + +DATA_PATH = config_data["VECTOR_DB_DISEASE_ENTITY_PATH"] +VECTOR_DB_NAME = config_data["VECTOR_DB_PATH"] +CHUNK_SIZE = int(config_data["VECTOR_DB_CHUNK_SIZE"]) +CHUNK_OVERLAP = int(config_data["VECTOR_DB_CHUNK_OVERLAP"]) +BATCH_SIZE = int(config_data["VECTOR_DB_BATCH_SIZE"]) +SENTENCE_EMBEDDING_MODEL = config_data["VECTOR_DB_SENTENCE_EMBEDDING_MODEL"] + + +def load_data(): + with open(DATA_PATH, "rb") as f: + data = pickle.load(f) + metadata_list = list(map(lambda x:{"source": x + " from SPOKE knowledge graph"}, data)) + return data, metadata_list + +def create_vectordb(): + start_time = time.time() + data, metadata_list = load_data() + text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) + docs = text_splitter.create_documents(data, metadatas=metadata_list) + batches = [docs[i:i + BATCH_SIZE] for i in range(0, len(docs), BATCH_SIZE)] + vectorstore = Chroma(embedding_function=SentenceTransformerEmbeddings(model_name=SENTENCE_EMBEDDING_MODEL), + persist_directory=VECTOR_DB_NAME) + for batch in batches: + vectorstore.add_documents(documents=batch) + end_time = round((time.time() - start_time)/(60), 2) + print("VectorDB is created in {} mins".format(end_time)) + + +if __name__ == "__main__": + create_vectordb() + |
