From 73c194f304f827b55081b15524479f82a1b7d94c Mon Sep 17 00:00:00 2001 From: maszhongming Date: Tue, 16 Sep 2025 15:15:29 -0500 Subject: Initial commit --- kg_rag/utility.py | 443 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 443 insertions(+) create mode 100644 kg_rag/utility.py (limited to 'kg_rag/utility.py') diff --git a/kg_rag/utility.py b/kg_rag/utility.py new file mode 100644 index 0000000..975367b --- /dev/null +++ b/kg_rag/utility.py @@ -0,0 +1,443 @@ +import pandas as pd +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity +from joblib import Memory +import json +import openai +import os +import sys +from tenacity import retry, stop_after_attempt, wait_random_exponential +import time +from dotenv import load_dotenv, find_dotenv +import torch +from langchain import HuggingFacePipeline +from langchain.vectorstores import Chroma +from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings +from langchain.text_splitter import RecursiveCharacterTextSplitter +from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer, GPTQConfig +from kg_rag.config_loader import * +import ast +import requests +import google.generativeai as genai + + +memory = Memory("cachegpt", verbose=0) + +# Config openai library +config_file = config_data['GPT_CONFIG_FILE'] +load_dotenv(config_file) +api_key = os.environ.get('API_KEY') +api_version = os.environ.get('API_VERSION') +resource_endpoint = os.environ.get('RESOURCE_ENDPOINT') +openai.api_type = config_data['GPT_API_TYPE'] +openai.api_key = api_key +if resource_endpoint: + openai.api_base = resource_endpoint +if api_version: + openai.api_version = api_version + +genai.configure(api_key=os.environ["GOOGLE_API_KEY"]) + + +torch.cuda.empty_cache() +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + + +def get_spoke_api_resp(base_uri, end_point, params=None): + uri = base_uri + end_point + if params: + return requests.get(uri, params=params) + else: + return requests.get(uri) + + +@retry(wait=wait_random_exponential(min=10, max=30), stop=stop_after_attempt(5)) +def get_context_using_spoke_api(node_value): + type_end_point = "/api/v1/types" + result = get_spoke_api_resp(config_data['BASE_URI'], type_end_point) + data_spoke_types = result.json() + node_types = list(data_spoke_types["nodes"].keys()) + edge_types = list(data_spoke_types["edges"].keys()) + node_types_to_remove = ["DatabaseTimestamp", "Version"] + filtered_node_types = [node_type for node_type in node_types if node_type not in node_types_to_remove] + api_params = { + 'node_filters' : filtered_node_types, + 'edge_filters': edge_types, + 'cutoff_Compound_max_phase': config_data['cutoff_Compound_max_phase'], + 'cutoff_Protein_source': config_data['cutoff_Protein_source'], + 'cutoff_DaG_diseases_sources': config_data['cutoff_DaG_diseases_sources'], + 'cutoff_DaG_textmining': config_data['cutoff_DaG_textmining'], + 'cutoff_CtD_phase': config_data['cutoff_CtD_phase'], + 'cutoff_PiP_confidence': config_data['cutoff_PiP_confidence'], + 'cutoff_ACTeG_level': config_data['cutoff_ACTeG_level'], + 'cutoff_DpL_average_prevalence': config_data['cutoff_DpL_average_prevalence'], + 'depth' : config_data['depth'] + } + node_type = "Disease" + attribute = "name" + nbr_end_point = "/api/v1/neighborhood/{}/{}/{}".format(node_type, attribute, node_value) + result = get_spoke_api_resp(config_data['BASE_URI'], nbr_end_point, params=api_params) + node_context = result.json() + nbr_nodes = [] + nbr_edges = [] + for item in node_context: + if "_" not in item["data"]["neo4j_type"]: + try: + if item["data"]["neo4j_type"] == "Protein": + nbr_nodes.append((item["data"]["neo4j_type"], item["data"]["id"], item["data"]["properties"]["description"])) + else: + nbr_nodes.append((item["data"]["neo4j_type"], item["data"]["id"], item["data"]["properties"]["name"])) + except: + nbr_nodes.append((item["data"]["neo4j_type"], item["data"]["id"], item["data"]["properties"]["identifier"])) + elif "_" in item["data"]["neo4j_type"]: + try: + provenance = ", ".join(item["data"]["properties"]["sources"]) + except: + try: + provenance = item["data"]["properties"]["source"] + if isinstance(provenance, list): + provenance = ", ".join(provenance) + except: + try: + preprint_list = ast.literal_eval(item["data"]["properties"]["preprint_list"]) + if len(preprint_list) > 0: + provenance = ", ".join(preprint_list) + else: + pmid_list = ast.literal_eval(item["data"]["properties"]["pmid_list"]) + pmid_list = map(lambda x:"pubmedId:"+x, pmid_list) + if len(pmid_list) > 0: + provenance = ", ".join(pmid_list) + else: + provenance = "Based on data from Institute For Systems Biology (ISB)" + except: + provenance = "SPOKE-KG" + try: + evidence = item["data"]["properties"] + except: + evidence = None + nbr_edges.append((item["data"]["source"], item["data"]["neo4j_type"], item["data"]["target"], provenance, evidence)) + nbr_nodes_df = pd.DataFrame(nbr_nodes, columns=["node_type", "node_id", "node_name"]) + nbr_edges_df = pd.DataFrame(nbr_edges, columns=["source", "edge_type", "target", "provenance", "evidence"]) + merge_1 = pd.merge(nbr_edges_df, nbr_nodes_df, left_on="source", right_on="node_id").drop("node_id", axis=1) + merge_1.loc[:,"node_name"] = merge_1.node_type + " " + merge_1.node_name + merge_1.drop(["source", "node_type"], axis=1, inplace=True) + merge_1 = merge_1.rename(columns={"node_name":"source"}) + merge_2 = pd.merge(merge_1, nbr_nodes_df, left_on="target", right_on="node_id").drop("node_id", axis=1) + merge_2.loc[:,"node_name"] = merge_2.node_type + " " + merge_2.node_name + merge_2.drop(["target", "node_type"], axis=1, inplace=True) + merge_2 = merge_2.rename(columns={"node_name":"target"}) + merge_2 = merge_2[["source", "edge_type", "target", "provenance", "evidence"]] + merge_2.loc[:, "predicate"] = merge_2.edge_type.apply(lambda x:x.split("_")[0]) + merge_2.loc[:, "context"] = merge_2.source + " " + merge_2.predicate.str.lower() + " " + merge_2.target + " and Provenance of this association is " + merge_2.provenance + "." + context = merge_2.context.str.cat(sep=' ') + context += node_value + " has a " + node_context[0]["data"]["properties"]["source"] + " identifier of " + node_context[0]["data"]["properties"]["identifier"] + " and Provenance of this is from " + node_context[0]["data"]["properties"]["source"] + "." + return context, merge_2 + +# if edge_evidence: +# merge_2.loc[:, "context"] = merge_2.source + " " + merge_2.predicate.str.lower() + " " + merge_2.target + " and Provenance of this association is " + merge_2.provenance + " and attributes associated with this association is in the following JSON format:\n " + merge_2.evidence.astype('str') + "\n\n" +# else: +# merge_2.loc[:, "context"] = merge_2.source + " " + merge_2.predicate.str.lower() + " " + merge_2.target + " and Provenance of this association is " + merge_2.provenance + ". " +# context = merge_2.context.str.cat(sep=' ') +# context += node_value + " has a " + node_context[0]["data"]["properties"]["source"] + " identifier of " + node_context[0]["data"]["properties"]["identifier"] + " and Provenance of this is from " + node_context[0]["data"]["properties"]["source"] + "." +# return context + + +def get_prompt(instruction, new_system_prompt): + system_prompt = B_SYS + new_system_prompt + E_SYS + prompt_template = B_INST + system_prompt + instruction + E_INST + return prompt_template + + +def llama_model(model_name, branch_name, cache_dir, temperature=0, top_p=1, max_new_tokens=512, stream=False, method='method-1'): + if method == 'method-1': + tokenizer = AutoTokenizer.from_pretrained(model_name, + revision=branch_name, + cache_dir=cache_dir) + model = AutoModelForCausalLM.from_pretrained(model_name, + device_map='auto', + torch_dtype=torch.float16, + revision=branch_name, + cache_dir=cache_dir) + elif method == 'method-2': + import transformers + tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name, + revision=branch_name, + cache_dir=cache_dir, + legacy=False, + token="hf_WbtWB...") + model = transformers.LlamaForCausalLM.from_pretrained(model_name, + device_map='auto', + torch_dtype=torch.float16, + revision=branch_name, + cache_dir=cache_dir, + token="hf_WbtWB...") + if not stream: + pipe = pipeline("text-generation", + model = model, + tokenizer = tokenizer, + torch_dtype = torch.bfloat16, + device_map = "auto", + max_new_tokens = max_new_tokens, + do_sample = True + ) + else: + streamer = TextStreamer(tokenizer) + pipe = pipeline("text-generation", + model = model, + tokenizer = tokenizer, + torch_dtype = torch.bfloat16, + device_map = "auto", + max_new_tokens = max_new_tokens, + do_sample = True, + streamer=streamer + ) + llm = HuggingFacePipeline(pipeline = pipe, + model_kwargs = {"temperature":temperature, "top_p":top_p}) + return llm + + +@retry(wait=wait_random_exponential(min=10, max=30), stop=stop_after_attempt(5)) +def fetch_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature=0): + + response = openai.ChatCompletion.create( + temperature=temperature, + # deployment_id=chat_deployment_id, + model=chat_model_id, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": instruction} + ] + ) + + if 'choices' in response \ + and isinstance(response['choices'], list) \ + and len(response) >= 0 \ + and 'message' in response['choices'][0] \ + and 'content' in response['choices'][0]['message']: + return response['choices'][0]['message']['content'] + else: + return 'Unexpected response' + +@memory.cache +def get_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature=0): + res = fetch_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature) + return res + + +@retry(wait=wait_random_exponential(min=10, max=30), stop=stop_after_attempt(5)) +def fetch_Gemini_response(instruction, system_prompt, temperature=0.0): + model = genai.GenerativeModel( + model_name="gemini-2.0-flash", + system_instruction=system_prompt, + ) + response = model.generate_content(instruction) + return response.text + + +@memory.cache +def get_Gemini_response(instruction, system_prompt, temperature=0.0): + res = fetch_Gemini_response(instruction, system_prompt, temperature) + return res + + +def stream_out(output): + CHUNK_SIZE = int(round(len(output)/50)) + SLEEP_TIME = 0.1 + for i in range(0, len(output), CHUNK_SIZE): + print(output[i:i+CHUNK_SIZE], end='') + sys.stdout.flush() + time.sleep(SLEEP_TIME) + print("\n") + + +def get_gpt35(): + chat_model_id = 'gpt-35-turbo' + chat_deployment_id = chat_model_id if openai.api_type == 'azure' else None + return chat_model_id, chat_deployment_id + + +def get_gpt4o_mini(): + chat_model_id = 'gpt-4o-mini' + chat_deployment_id = chat_model_id if openai.api_type == 'azure' else None + return chat_model_id, chat_deployment_id + + +def get_gemini(): + chat_model_id = 'gemini-2.0-flash' + chat_deployment_id = chat_model_id if openai.api_type == 'azure' else None + return chat_model_id, chat_deployment_id + + +def disease_entity_extractor(text): + chat_model_id, chat_deployment_id = get_gpt35() + resp = get_GPT_response(text, system_prompts["DISEASE_ENTITY_EXTRACTION"], chat_model_id, chat_deployment_id, temperature=0) + try: + entity_dict = json.loads(resp) + return entity_dict["Diseases"] + except: + return None + + +def disease_entity_extractor_v2(text, model_id): + assert model_id in ("gemini-2.0-flash") + prompt_updated = system_prompts["DISEASE_ENTITY_EXTRACTION"] + "\n" + "Sentence : " + text + resp = get_Gemini_response(prompt_updated, system_prompts["DISEASE_ENTITY_EXTRACTION"], temperature=0.0) + if resp.startswith("```json\n"): + resp = resp.replace("```json\n", "", 1) + if resp.endswith("\n```"): + resp = resp.replace("\n```", "", -1) + try: + entity_dict = json.loads(resp) + return entity_dict["Diseases"] + except: + return None + + +def load_sentence_transformer(sentence_embedding_model): + return SentenceTransformerEmbeddings(model_name=sentence_embedding_model) + + +def load_chroma(vector_db_path, sentence_embedding_model): + embedding_function = load_sentence_transformer(sentence_embedding_model) + return Chroma(persist_directory=vector_db_path, embedding_function=embedding_function) + + +def retrieve_context(question, vectorstore, embedding_function, node_context_df, context_volume, context_sim_threshold, context_sim_min_threshold, edge_evidence,model_id="gpt-3.5-turbo", api=False): + print("question:", question) + entities = disease_entity_extractor_v2(question, model_id) + print("entities:", entities) + node_hits = [] + if entities: + max_number_of_high_similarity_context_per_node = int(context_volume/len(entities)) + for entity in entities: + node_search_result = vectorstore.similarity_search_with_score(entity, k=1) + node_hits.append(node_search_result[0][0].page_content) + question_embedding = embedding_function.embed_query(question) + node_context_extracted = "" + for node_name in node_hits: + if not api: + node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0] + else: + node_context,context_table = get_context_using_spoke_api(node_name) + node_context_list = node_context.split(". ") + node_context_embeddings = embedding_function.embed_documents(node_context_list) + similarities = [cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(node_context_embedding).reshape(1, -1)) for node_context_embedding in node_context_embeddings] + similarities = sorted([(e, i) for i, e in enumerate(similarities)], reverse=True) + percentile_threshold = np.percentile([s[0] for s in similarities], context_sim_threshold) + high_similarity_indices = [s[1] for s in similarities if s[0] > percentile_threshold and s[0] > context_sim_min_threshold] + if len(high_similarity_indices) > max_number_of_high_similarity_context_per_node: + high_similarity_indices = high_similarity_indices[:max_number_of_high_similarity_context_per_node] + high_similarity_context = [node_context_list[index] for index in high_similarity_indices] + if edge_evidence: + high_similarity_context = list(map(lambda x:x+'.', high_similarity_context)) + context_table = context_table[context_table.context.isin(high_similarity_context)] + context_table.loc[:, "context"] = context_table.source + " " + context_table.predicate.str.lower() + " " + context_table.target + " and Provenance of this association is " + context_table.provenance + " and attributes associated with this association is in the following JSON format:\n " + context_table.evidence.astype('str') + "\n\n" + node_context_extracted += context_table.context.str.cat(sep=' ') + else: + node_context_extracted += ". ".join(high_similarity_context) + node_context_extracted += ". " + return node_context_extracted + else: + node_hits = vectorstore.similarity_search_with_score(question, k=5) + max_number_of_high_similarity_context_per_node = int(context_volume/5) + question_embedding = embedding_function.embed_query(question) + node_context_extracted = "" + for node in node_hits: + node_name = node[0].page_content + if not api: + node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0] + else: + node_context, context_table = get_context_using_spoke_api(node_name) + node_context_list = node_context.split(". ") + node_context_embeddings = embedding_function.embed_documents(node_context_list) + similarities = [cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(node_context_embedding).reshape(1, -1)) for node_context_embedding in node_context_embeddings] + similarities = sorted([(e, i) for i, e in enumerate(similarities)], reverse=True) + percentile_threshold = np.percentile([s[0] for s in similarities], context_sim_threshold) + high_similarity_indices = [s[1] for s in similarities if s[0] > percentile_threshold and s[0] > context_sim_min_threshold] + if len(high_similarity_indices) > max_number_of_high_similarity_context_per_node: + high_similarity_indices = high_similarity_indices[:max_number_of_high_similarity_context_per_node] + high_similarity_context = [node_context_list[index] for index in high_similarity_indices] + if edge_evidence: + high_similarity_context = list(map(lambda x:x+'.', high_similarity_context)) + context_table = context_table[context_table.context.isin(high_similarity_context)] + context_table.loc[:, "context"] = context_table.source + " " + context_table.predicate.str.lower() + " " + context_table.target + " and Provenance of this association is " + context_table.provenance + " and attributes associated with this association is in the following JSON format:\n " + context_table.evidence.astype('str') + "\n\n" + node_context_extracted += context_table.context.str.cat(sep=' ') + else: + node_context_extracted += ". ".join(high_similarity_context) + node_context_extracted += ". " + return node_context_extracted + + +def interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, llm_type, edge_evidence, system_prompt, api=True, llama_method="method-1"): + print(" ") + input("Press enter for Step 1 - Disease entity extraction using GPT-3.5-Turbo") + print("Processing ...") + entities = disease_entity_extractor_v2(question, "gpt-4o-mini") + max_number_of_high_similarity_context_per_node = int(config_data["CONTEXT_VOLUME"]/len(entities)) + print("Extracted entity from the prompt = '{}'".format(", ".join(entities))) + print(" ") + + input("Press enter for Step 2 - Match extracted Disease entity to SPOKE nodes") + print("Finding vector similarity ...") + node_hits = [] + for entity in entities: + node_search_result = vectorstore.similarity_search_with_score(entity, k=1) + node_hits.append(node_search_result[0][0].page_content) + print("Matched entities from SPOKE = '{}'".format(", ".join(node_hits))) + print(" ") + + input("Press enter for Step 3 - Context extraction from SPOKE") + node_context = [] + for node_name in node_hits: + if not api: + node_context.append(node_context_df[node_context_df.node_name == node_name].node_context.values[0]) + else: + context, context_table = get_context_using_spoke_api(node_name) + node_context.append(context) + print("Extracted Context is : ") + print(". ".join(node_context)) + print(" ") + + input("Press enter for Step 4 - Context pruning") + question_embedding = embedding_function_for_context_retrieval.embed_query(question) + node_context_extracted = "" + for node_name in node_hits: + if not api: + node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0] + else: + node_context, context_table = get_context_using_spoke_api(node_name) + node_context_list = node_context.split(". ") + node_context_embeddings = embedding_function_for_context_retrieval.embed_documents(node_context_list) + similarities = [cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(node_context_embedding).reshape(1, -1)) for node_context_embedding in node_context_embeddings] + similarities = sorted([(e, i) for i, e in enumerate(similarities)], reverse=True) + percentile_threshold = np.percentile([s[0] for s in similarities], config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"]) + high_similarity_indices = [s[1] for s in similarities if s[0] > percentile_threshold and s[0] > config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]] + if len(high_similarity_indices) > max_number_of_high_similarity_context_per_node: + high_similarity_indices = high_similarity_indices[:max_number_of_high_similarity_context_per_node] + high_similarity_context = [node_context_list[index] for index in high_similarity_indices] + if edge_evidence: + high_similarity_context = list(map(lambda x:x+'.', high_similarity_context)) + context_table = context_table[context_table.context.isin(high_similarity_context)] + context_table.loc[:, "context"] = context_table.source + " " + context_table.predicate.str.lower() + " " + context_table.target + " and Provenance of this association is " + context_table.provenance + " and attributes associated with this association is in the following JSON format:\n " + context_table.evidence.astype('str') + "\n\n" + node_context_extracted += context_table.context.str.cat(sep=' ') + else: + node_context_extracted += ". ".join(high_similarity_context) + node_context_extracted += ". " + print("Pruned Context is : ") + print(node_context_extracted) + print(" ") + + input("Press enter for Step 5 - LLM prompting") + print("Prompting ", llm_type) + if llm_type == "llama": + from langchain import PromptTemplate, LLMChain + template = get_prompt("Context:\n\n{context} \n\nQuestion: {question}", system_prompt) + prompt = PromptTemplate(template=template, input_variables=["context", "question"]) + llm = llama_model(config_data["LLAMA_MODEL_NAME"], config_data["LLAMA_MODEL_BRANCH"], config_data["LLM_CACHE_DIR"], stream=True, method=llama_method) + llm_chain = LLMChain(prompt=prompt, llm=llm) + output = llm_chain.run(context=node_context_extracted, question=question) + elif "gpt" in llm_type: + enriched_prompt = "Context: "+ node_context_extracted + "\n" + "Question: " + question + output = get_GPT_response(enriched_prompt, system_prompt, llm_type, llm_type, temperature=config_data["LLM_TEMPERATURE"]) + stream_out(output) -- cgit v1.2.3