summaryrefslogtreecommitdiff
path: root/kg_rag/utility.py
diff options
context:
space:
mode:
authormaszhongming <mingz5@illinois.edu>2025-09-16 15:15:29 -0500
committermaszhongming <mingz5@illinois.edu>2025-09-16 15:15:29 -0500
commit73c194f304f827b55081b15524479f82a1b7d94c (patch)
tree5e8660e421915420892c5eca18f1ad680f80a861 /kg_rag/utility.py
Initial commit
Diffstat (limited to 'kg_rag/utility.py')
-rw-r--r--kg_rag/utility.py443
1 files changed, 443 insertions, 0 deletions
diff --git a/kg_rag/utility.py b/kg_rag/utility.py
new file mode 100644
index 0000000..975367b
--- /dev/null
+++ b/kg_rag/utility.py
@@ -0,0 +1,443 @@
+import pandas as pd
+import numpy as np
+from sklearn.metrics.pairwise import cosine_similarity
+from joblib import Memory
+import json
+import openai
+import os
+import sys
+from tenacity import retry, stop_after_attempt, wait_random_exponential
+import time
+from dotenv import load_dotenv, find_dotenv
+import torch
+from langchain import HuggingFacePipeline
+from langchain.vectorstores import Chroma
+from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer, GPTQConfig
+from kg_rag.config_loader import *
+import ast
+import requests
+import google.generativeai as genai
+
+
+memory = Memory("cachegpt", verbose=0)
+
+# Config openai library
+config_file = config_data['GPT_CONFIG_FILE']
+load_dotenv(config_file)
+api_key = os.environ.get('API_KEY')
+api_version = os.environ.get('API_VERSION')
+resource_endpoint = os.environ.get('RESOURCE_ENDPOINT')
+openai.api_type = config_data['GPT_API_TYPE']
+openai.api_key = api_key
+if resource_endpoint:
+ openai.api_base = resource_endpoint
+if api_version:
+ openai.api_version = api_version
+
+genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
+
+
+torch.cuda.empty_cache()
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
+
+
+def get_spoke_api_resp(base_uri, end_point, params=None):
+ uri = base_uri + end_point
+ if params:
+ return requests.get(uri, params=params)
+ else:
+ return requests.get(uri)
+
+
+@retry(wait=wait_random_exponential(min=10, max=30), stop=stop_after_attempt(5))
+def get_context_using_spoke_api(node_value):
+ type_end_point = "/api/v1/types"
+ result = get_spoke_api_resp(config_data['BASE_URI'], type_end_point)
+ data_spoke_types = result.json()
+ node_types = list(data_spoke_types["nodes"].keys())
+ edge_types = list(data_spoke_types["edges"].keys())
+ node_types_to_remove = ["DatabaseTimestamp", "Version"]
+ filtered_node_types = [node_type for node_type in node_types if node_type not in node_types_to_remove]
+ api_params = {
+ 'node_filters' : filtered_node_types,
+ 'edge_filters': edge_types,
+ 'cutoff_Compound_max_phase': config_data['cutoff_Compound_max_phase'],
+ 'cutoff_Protein_source': config_data['cutoff_Protein_source'],
+ 'cutoff_DaG_diseases_sources': config_data['cutoff_DaG_diseases_sources'],
+ 'cutoff_DaG_textmining': config_data['cutoff_DaG_textmining'],
+ 'cutoff_CtD_phase': config_data['cutoff_CtD_phase'],
+ 'cutoff_PiP_confidence': config_data['cutoff_PiP_confidence'],
+ 'cutoff_ACTeG_level': config_data['cutoff_ACTeG_level'],
+ 'cutoff_DpL_average_prevalence': config_data['cutoff_DpL_average_prevalence'],
+ 'depth' : config_data['depth']
+ }
+ node_type = "Disease"
+ attribute = "name"
+ nbr_end_point = "/api/v1/neighborhood/{}/{}/{}".format(node_type, attribute, node_value)
+ result = get_spoke_api_resp(config_data['BASE_URI'], nbr_end_point, params=api_params)
+ node_context = result.json()
+ nbr_nodes = []
+ nbr_edges = []
+ for item in node_context:
+ if "_" not in item["data"]["neo4j_type"]:
+ try:
+ if item["data"]["neo4j_type"] == "Protein":
+ nbr_nodes.append((item["data"]["neo4j_type"], item["data"]["id"], item["data"]["properties"]["description"]))
+ else:
+ nbr_nodes.append((item["data"]["neo4j_type"], item["data"]["id"], item["data"]["properties"]["name"]))
+ except:
+ nbr_nodes.append((item["data"]["neo4j_type"], item["data"]["id"], item["data"]["properties"]["identifier"]))
+ elif "_" in item["data"]["neo4j_type"]:
+ try:
+ provenance = ", ".join(item["data"]["properties"]["sources"])
+ except:
+ try:
+ provenance = item["data"]["properties"]["source"]
+ if isinstance(provenance, list):
+ provenance = ", ".join(provenance)
+ except:
+ try:
+ preprint_list = ast.literal_eval(item["data"]["properties"]["preprint_list"])
+ if len(preprint_list) > 0:
+ provenance = ", ".join(preprint_list)
+ else:
+ pmid_list = ast.literal_eval(item["data"]["properties"]["pmid_list"])
+ pmid_list = map(lambda x:"pubmedId:"+x, pmid_list)
+ if len(pmid_list) > 0:
+ provenance = ", ".join(pmid_list)
+ else:
+ provenance = "Based on data from Institute For Systems Biology (ISB)"
+ except:
+ provenance = "SPOKE-KG"
+ try:
+ evidence = item["data"]["properties"]
+ except:
+ evidence = None
+ nbr_edges.append((item["data"]["source"], item["data"]["neo4j_type"], item["data"]["target"], provenance, evidence))
+ nbr_nodes_df = pd.DataFrame(nbr_nodes, columns=["node_type", "node_id", "node_name"])
+ nbr_edges_df = pd.DataFrame(nbr_edges, columns=["source", "edge_type", "target", "provenance", "evidence"])
+ merge_1 = pd.merge(nbr_edges_df, nbr_nodes_df, left_on="source", right_on="node_id").drop("node_id", axis=1)
+ merge_1.loc[:,"node_name"] = merge_1.node_type + " " + merge_1.node_name
+ merge_1.drop(["source", "node_type"], axis=1, inplace=True)
+ merge_1 = merge_1.rename(columns={"node_name":"source"})
+ merge_2 = pd.merge(merge_1, nbr_nodes_df, left_on="target", right_on="node_id").drop("node_id", axis=1)
+ merge_2.loc[:,"node_name"] = merge_2.node_type + " " + merge_2.node_name
+ merge_2.drop(["target", "node_type"], axis=1, inplace=True)
+ merge_2 = merge_2.rename(columns={"node_name":"target"})
+ merge_2 = merge_2[["source", "edge_type", "target", "provenance", "evidence"]]
+ merge_2.loc[:, "predicate"] = merge_2.edge_type.apply(lambda x:x.split("_")[0])
+ merge_2.loc[:, "context"] = merge_2.source + " " + merge_2.predicate.str.lower() + " " + merge_2.target + " and Provenance of this association is " + merge_2.provenance + "."
+ context = merge_2.context.str.cat(sep=' ')
+ context += node_value + " has a " + node_context[0]["data"]["properties"]["source"] + " identifier of " + node_context[0]["data"]["properties"]["identifier"] + " and Provenance of this is from " + node_context[0]["data"]["properties"]["source"] + "."
+ return context, merge_2
+
+# if edge_evidence:
+# merge_2.loc[:, "context"] = merge_2.source + " " + merge_2.predicate.str.lower() + " " + merge_2.target + " and Provenance of this association is " + merge_2.provenance + " and attributes associated with this association is in the following JSON format:\n " + merge_2.evidence.astype('str') + "\n\n"
+# else:
+# merge_2.loc[:, "context"] = merge_2.source + " " + merge_2.predicate.str.lower() + " " + merge_2.target + " and Provenance of this association is " + merge_2.provenance + ". "
+# context = merge_2.context.str.cat(sep=' ')
+# context += node_value + " has a " + node_context[0]["data"]["properties"]["source"] + " identifier of " + node_context[0]["data"]["properties"]["identifier"] + " and Provenance of this is from " + node_context[0]["data"]["properties"]["source"] + "."
+# return context
+
+
+def get_prompt(instruction, new_system_prompt):
+ system_prompt = B_SYS + new_system_prompt + E_SYS
+ prompt_template = B_INST + system_prompt + instruction + E_INST
+ return prompt_template
+
+
+def llama_model(model_name, branch_name, cache_dir, temperature=0, top_p=1, max_new_tokens=512, stream=False, method='method-1'):
+ if method == 'method-1':
+ tokenizer = AutoTokenizer.from_pretrained(model_name,
+ revision=branch_name,
+ cache_dir=cache_dir)
+ model = AutoModelForCausalLM.from_pretrained(model_name,
+ device_map='auto',
+ torch_dtype=torch.float16,
+ revision=branch_name,
+ cache_dir=cache_dir)
+ elif method == 'method-2':
+ import transformers
+ tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name,
+ revision=branch_name,
+ cache_dir=cache_dir,
+ legacy=False,
+ token="hf_WbtWB...")
+ model = transformers.LlamaForCausalLM.from_pretrained(model_name,
+ device_map='auto',
+ torch_dtype=torch.float16,
+ revision=branch_name,
+ cache_dir=cache_dir,
+ token="hf_WbtWB...")
+ if not stream:
+ pipe = pipeline("text-generation",
+ model = model,
+ tokenizer = tokenizer,
+ torch_dtype = torch.bfloat16,
+ device_map = "auto",
+ max_new_tokens = max_new_tokens,
+ do_sample = True
+ )
+ else:
+ streamer = TextStreamer(tokenizer)
+ pipe = pipeline("text-generation",
+ model = model,
+ tokenizer = tokenizer,
+ torch_dtype = torch.bfloat16,
+ device_map = "auto",
+ max_new_tokens = max_new_tokens,
+ do_sample = True,
+ streamer=streamer
+ )
+ llm = HuggingFacePipeline(pipeline = pipe,
+ model_kwargs = {"temperature":temperature, "top_p":top_p})
+ return llm
+
+
+@retry(wait=wait_random_exponential(min=10, max=30), stop=stop_after_attempt(5))
+def fetch_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature=0):
+
+ response = openai.ChatCompletion.create(
+ temperature=temperature,
+ # deployment_id=chat_deployment_id,
+ model=chat_model_id,
+ messages=[
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": instruction}
+ ]
+ )
+
+ if 'choices' in response \
+ and isinstance(response['choices'], list) \
+ and len(response) >= 0 \
+ and 'message' in response['choices'][0] \
+ and 'content' in response['choices'][0]['message']:
+ return response['choices'][0]['message']['content']
+ else:
+ return 'Unexpected response'
+
+@memory.cache
+def get_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature=0):
+ res = fetch_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature)
+ return res
+
+
+@retry(wait=wait_random_exponential(min=10, max=30), stop=stop_after_attempt(5))
+def fetch_Gemini_response(instruction, system_prompt, temperature=0.0):
+ model = genai.GenerativeModel(
+ model_name="gemini-2.0-flash",
+ system_instruction=system_prompt,
+ )
+ response = model.generate_content(instruction)
+ return response.text
+
+
+@memory.cache
+def get_Gemini_response(instruction, system_prompt, temperature=0.0):
+ res = fetch_Gemini_response(instruction, system_prompt, temperature)
+ return res
+
+
+def stream_out(output):
+ CHUNK_SIZE = int(round(len(output)/50))
+ SLEEP_TIME = 0.1
+ for i in range(0, len(output), CHUNK_SIZE):
+ print(output[i:i+CHUNK_SIZE], end='')
+ sys.stdout.flush()
+ time.sleep(SLEEP_TIME)
+ print("\n")
+
+
+def get_gpt35():
+ chat_model_id = 'gpt-35-turbo'
+ chat_deployment_id = chat_model_id if openai.api_type == 'azure' else None
+ return chat_model_id, chat_deployment_id
+
+
+def get_gpt4o_mini():
+ chat_model_id = 'gpt-4o-mini'
+ chat_deployment_id = chat_model_id if openai.api_type == 'azure' else None
+ return chat_model_id, chat_deployment_id
+
+
+def get_gemini():
+ chat_model_id = 'gemini-2.0-flash'
+ chat_deployment_id = chat_model_id if openai.api_type == 'azure' else None
+ return chat_model_id, chat_deployment_id
+
+
+def disease_entity_extractor(text):
+ chat_model_id, chat_deployment_id = get_gpt35()
+ resp = get_GPT_response(text, system_prompts["DISEASE_ENTITY_EXTRACTION"], chat_model_id, chat_deployment_id, temperature=0)
+ try:
+ entity_dict = json.loads(resp)
+ return entity_dict["Diseases"]
+ except:
+ return None
+
+
+def disease_entity_extractor_v2(text, model_id):
+ assert model_id in ("gemini-2.0-flash")
+ prompt_updated = system_prompts["DISEASE_ENTITY_EXTRACTION"] + "\n" + "Sentence : " + text
+ resp = get_Gemini_response(prompt_updated, system_prompts["DISEASE_ENTITY_EXTRACTION"], temperature=0.0)
+ if resp.startswith("```json\n"):
+ resp = resp.replace("```json\n", "", 1)
+ if resp.endswith("\n```"):
+ resp = resp.replace("\n```", "", -1)
+ try:
+ entity_dict = json.loads(resp)
+ return entity_dict["Diseases"]
+ except:
+ return None
+
+
+def load_sentence_transformer(sentence_embedding_model):
+ return SentenceTransformerEmbeddings(model_name=sentence_embedding_model)
+
+
+def load_chroma(vector_db_path, sentence_embedding_model):
+ embedding_function = load_sentence_transformer(sentence_embedding_model)
+ return Chroma(persist_directory=vector_db_path, embedding_function=embedding_function)
+
+
+def retrieve_context(question, vectorstore, embedding_function, node_context_df, context_volume, context_sim_threshold, context_sim_min_threshold, edge_evidence,model_id="gpt-3.5-turbo", api=False):
+ print("question:", question)
+ entities = disease_entity_extractor_v2(question, model_id)
+ print("entities:", entities)
+ node_hits = []
+ if entities:
+ max_number_of_high_similarity_context_per_node = int(context_volume/len(entities))
+ for entity in entities:
+ node_search_result = vectorstore.similarity_search_with_score(entity, k=1)
+ node_hits.append(node_search_result[0][0].page_content)
+ question_embedding = embedding_function.embed_query(question)
+ node_context_extracted = ""
+ for node_name in node_hits:
+ if not api:
+ node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0]
+ else:
+ node_context,context_table = get_context_using_spoke_api(node_name)
+ node_context_list = node_context.split(". ")
+ node_context_embeddings = embedding_function.embed_documents(node_context_list)
+ similarities = [cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(node_context_embedding).reshape(1, -1)) for node_context_embedding in node_context_embeddings]
+ similarities = sorted([(e, i) for i, e in enumerate(similarities)], reverse=True)
+ percentile_threshold = np.percentile([s[0] for s in similarities], context_sim_threshold)
+ high_similarity_indices = [s[1] for s in similarities if s[0] > percentile_threshold and s[0] > context_sim_min_threshold]
+ if len(high_similarity_indices) > max_number_of_high_similarity_context_per_node:
+ high_similarity_indices = high_similarity_indices[:max_number_of_high_similarity_context_per_node]
+ high_similarity_context = [node_context_list[index] for index in high_similarity_indices]
+ if edge_evidence:
+ high_similarity_context = list(map(lambda x:x+'.', high_similarity_context))
+ context_table = context_table[context_table.context.isin(high_similarity_context)]
+ context_table.loc[:, "context"] = context_table.source + " " + context_table.predicate.str.lower() + " " + context_table.target + " and Provenance of this association is " + context_table.provenance + " and attributes associated with this association is in the following JSON format:\n " + context_table.evidence.astype('str') + "\n\n"
+ node_context_extracted += context_table.context.str.cat(sep=' ')
+ else:
+ node_context_extracted += ". ".join(high_similarity_context)
+ node_context_extracted += ". "
+ return node_context_extracted
+ else:
+ node_hits = vectorstore.similarity_search_with_score(question, k=5)
+ max_number_of_high_similarity_context_per_node = int(context_volume/5)
+ question_embedding = embedding_function.embed_query(question)
+ node_context_extracted = ""
+ for node in node_hits:
+ node_name = node[0].page_content
+ if not api:
+ node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0]
+ else:
+ node_context, context_table = get_context_using_spoke_api(node_name)
+ node_context_list = node_context.split(". ")
+ node_context_embeddings = embedding_function.embed_documents(node_context_list)
+ similarities = [cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(node_context_embedding).reshape(1, -1)) for node_context_embedding in node_context_embeddings]
+ similarities = sorted([(e, i) for i, e in enumerate(similarities)], reverse=True)
+ percentile_threshold = np.percentile([s[0] for s in similarities], context_sim_threshold)
+ high_similarity_indices = [s[1] for s in similarities if s[0] > percentile_threshold and s[0] > context_sim_min_threshold]
+ if len(high_similarity_indices) > max_number_of_high_similarity_context_per_node:
+ high_similarity_indices = high_similarity_indices[:max_number_of_high_similarity_context_per_node]
+ high_similarity_context = [node_context_list[index] for index in high_similarity_indices]
+ if edge_evidence:
+ high_similarity_context = list(map(lambda x:x+'.', high_similarity_context))
+ context_table = context_table[context_table.context.isin(high_similarity_context)]
+ context_table.loc[:, "context"] = context_table.source + " " + context_table.predicate.str.lower() + " " + context_table.target + " and Provenance of this association is " + context_table.provenance + " and attributes associated with this association is in the following JSON format:\n " + context_table.evidence.astype('str') + "\n\n"
+ node_context_extracted += context_table.context.str.cat(sep=' ')
+ else:
+ node_context_extracted += ". ".join(high_similarity_context)
+ node_context_extracted += ". "
+ return node_context_extracted
+
+
+def interactive(question, vectorstore, node_context_df, embedding_function_for_context_retrieval, llm_type, edge_evidence, system_prompt, api=True, llama_method="method-1"):
+ print(" ")
+ input("Press enter for Step 1 - Disease entity extraction using GPT-3.5-Turbo")
+ print("Processing ...")
+ entities = disease_entity_extractor_v2(question, "gpt-4o-mini")
+ max_number_of_high_similarity_context_per_node = int(config_data["CONTEXT_VOLUME"]/len(entities))
+ print("Extracted entity from the prompt = '{}'".format(", ".join(entities)))
+ print(" ")
+
+ input("Press enter for Step 2 - Match extracted Disease entity to SPOKE nodes")
+ print("Finding vector similarity ...")
+ node_hits = []
+ for entity in entities:
+ node_search_result = vectorstore.similarity_search_with_score(entity, k=1)
+ node_hits.append(node_search_result[0][0].page_content)
+ print("Matched entities from SPOKE = '{}'".format(", ".join(node_hits)))
+ print(" ")
+
+ input("Press enter for Step 3 - Context extraction from SPOKE")
+ node_context = []
+ for node_name in node_hits:
+ if not api:
+ node_context.append(node_context_df[node_context_df.node_name == node_name].node_context.values[0])
+ else:
+ context, context_table = get_context_using_spoke_api(node_name)
+ node_context.append(context)
+ print("Extracted Context is : ")
+ print(". ".join(node_context))
+ print(" ")
+
+ input("Press enter for Step 4 - Context pruning")
+ question_embedding = embedding_function_for_context_retrieval.embed_query(question)
+ node_context_extracted = ""
+ for node_name in node_hits:
+ if not api:
+ node_context = node_context_df[node_context_df.node_name == node_name].node_context.values[0]
+ else:
+ node_context, context_table = get_context_using_spoke_api(node_name)
+ node_context_list = node_context.split(". ")
+ node_context_embeddings = embedding_function_for_context_retrieval.embed_documents(node_context_list)
+ similarities = [cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(node_context_embedding).reshape(1, -1)) for node_context_embedding in node_context_embeddings]
+ similarities = sorted([(e, i) for i, e in enumerate(similarities)], reverse=True)
+ percentile_threshold = np.percentile([s[0] for s in similarities], config_data["QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD"])
+ high_similarity_indices = [s[1] for s in similarities if s[0] > percentile_threshold and s[0] > config_data["QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY"]]
+ if len(high_similarity_indices) > max_number_of_high_similarity_context_per_node:
+ high_similarity_indices = high_similarity_indices[:max_number_of_high_similarity_context_per_node]
+ high_similarity_context = [node_context_list[index] for index in high_similarity_indices]
+ if edge_evidence:
+ high_similarity_context = list(map(lambda x:x+'.', high_similarity_context))
+ context_table = context_table[context_table.context.isin(high_similarity_context)]
+ context_table.loc[:, "context"] = context_table.source + " " + context_table.predicate.str.lower() + " " + context_table.target + " and Provenance of this association is " + context_table.provenance + " and attributes associated with this association is in the following JSON format:\n " + context_table.evidence.astype('str') + "\n\n"
+ node_context_extracted += context_table.context.str.cat(sep=' ')
+ else:
+ node_context_extracted += ". ".join(high_similarity_context)
+ node_context_extracted += ". "
+ print("Pruned Context is : ")
+ print(node_context_extracted)
+ print(" ")
+
+ input("Press enter for Step 5 - LLM prompting")
+ print("Prompting ", llm_type)
+ if llm_type == "llama":
+ from langchain import PromptTemplate, LLMChain
+ template = get_prompt("Context:\n\n{context} \n\nQuestion: {question}", system_prompt)
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
+ llm = llama_model(config_data["LLAMA_MODEL_NAME"], config_data["LLAMA_MODEL_BRANCH"], config_data["LLM_CACHE_DIR"], stream=True, method=llama_method)
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
+ output = llm_chain.run(context=node_context_extracted, question=question)
+ elif "gpt" in llm_type:
+ enriched_prompt = "Context: "+ node_context_extracted + "\n" + "Question: " + question
+ output = get_GPT_response(enriched_prompt, system_prompt, llm_type, llm_type, temperature=config_data["LLM_TEMPERATURE"])
+ stream_out(output)