In [1]:
from langchain.chat_models import AzureChatOpenAI
from dotenv import load_dotenv
import os
import openai
from langchain.chains import GraphCypherQAChain
from langchain.graphs import Neo4jGraph
from langchain.callbacks import get_openai_callback
import pandas as pd
from tqdm import tqdm


In [2]:
curated_data = pd.read_csv('../data/benchmark_data/mcq_questions.csv')


In [3]:
LLM_MODEL = 'gpt-4'


In [4]:

load_dotenv(os.path.join(os.path.expanduser('~'), '.gpt_config.env'))
API_KEY = os.environ.get('API_KEY')
API_VERSION = os.environ.get('API_VERSION')
RESOURCE_ENDPOINT = os.environ.get('RESOURCE_ENDPOINT')
openai.api_type = "azure"
openai.api_key = API_KEY
openai.api_base = RESOURCE_ENDPOINT
openai.api_version = API_VERSION
chat_deployment_id = LLM_MODEL
chat_model_id = chat_deployment_id
temperature = 0
chat_model = AzureChatOpenAI(openai_api_key=API_KEY, openai_api_version=API_VERSION, azure_deployment=chat_deployment_id, azure_endpoint=RESOURCE_ENDPOINT, temperature=temperature)
load_dotenv(os.path.join(os.path.expanduser('~'), '.spoke_neo4j_config.env'))
username = os.environ.get('NEO4J_USER')
password = os.environ.get('NEO4J_PSW')
url = os.environ.get('NEO4J_URL')
database = os.environ.get('NEO4J_DB')



  warn_deprecated(


In [5]:
graph = Neo4jGraph(url=url, username=username, password=password, database = database)
chain = GraphCypherQAChain.from_llm(chat_model, graph=graph, verbose=True, validate_cypher=True,return_intermediate_steps=True)



In [16]:
%%time

cypher_rag_out = []
for index, row in tqdm(curated_data.iterrows()):
    question = row['text']
    with get_openai_callback() as cb:
        try:
            cypher_rag_answer = chain.run(query=question, return_final_only=True, verbose=False)
        except ValueError as e:
            cypher_rag_answer = None
                
    cypher_rag_out.append((row['text'], row['correct_node'], cypher_rag_answer, cb.total_tokens))    


0it [00:00, ?it/s]



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3m[0m
Full Context:
[32;1m[1;3m[][0m


282it [00:09, 29.53it/s]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3m[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m


282it [00:21, 29.53it/s]

Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs7426056', 'rs2736100', 'rs2187668', 'rs2107595', 'rs7405776'] AND c.disease IN ['central nervous system cancer', 'lung adenocarcinoma'] RETURN g.variant, c.disease[0m
Full Context:
[32;1m[1;3m[][0m


284it [00:28,  7.63it/s]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease)-[:ASSOCIATES_DaG]->(g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound)
WHERE d.name IN ['thoracic aortic aneurysm', 'abdominal aortic aneurysm'] AND c.variant IN ['rs1642764', 'rs595244', 'rs139606545', 'rs12077210', 'rs12917707']
RETURN c.variant AS Variant, d.name AS Disease[0m
Full Context:
[32;1m[1;3m[][0m


285it [00:41,  4.53it/s]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs8077245', 'rs11712066', 'rs11209026', 'rs10830962', 'rs6010620'] AND c.name IN ['Crohn\'s disease', 'ankylosing spondylitis'] RETURN g.variant, c.name[0m
Full Context:
[32;1m[1;3m[][0m


286it [00:49,  3.41it/s]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1047891', 'rs9268905', 'rs3197999', 'rs1025128', 'rs4624820'] AND c.disease IN ['Crohn\'s disease', 'sclerosing cholangitis'] RETURN g.variant, c.disease[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m


286it [01:01,  3.41it/s]

Generated Cypher:
[32;1m[1;3mMATCH (d:Disease)-[:ASSOCIATES_DaG]->(g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) 
WHERE d.name IN ['lung adenocarcinoma', 'central nervous system cancer'] AND c.variant IN ['rs10490924', 'rs10830962', 'rs2736100', 'rs2391769', 'rs9272143'] 
RETURN c.variant, d.name[0m
Full Context:
[32;1m[1;3m[][0m


288it [01:10,  1.70it/s]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs11759064', 'rs975730', 'rs1150757', 'rs2294008', 'rs7453920'] AND c.disease IN ['gastric fundus cancer', 'duodenal ulcer'] RETURN g.variant, c.disease[0m
Full Context:
[32;1m[1;3m[][0m


289it [01:18,  1.37it/s]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m




Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC|:ADVRESPONSE_TO_mGarC]->(c:Compound) 
WHERE g.variant IN ['rs2294008', 'rs2072499', 'rs3197999', 'rs1537377', 'rs988958'] 
AND c.disease IN ['gastric fundus cancer', 'atrophic gastritis'] 
RETURN g.variant, c.disease[0m
Full Context:
[32;1m[1;3m[][0m


290it [01:25,  1.13it/s]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs254020', 'rs4625', 'rs6059655', 'rs11738191', 'rs2963222'] AND c.name IN ['keratinocyte carcinoma', 'skin sensitivity to sun'] RETURN g.variant, c.name;[0m
Full Context:
[32;1m[1;3m[][0m


291it [01:35,  1.22s/it]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESISTANT_TO_mGrC]->(c:Compound) 
WHERE g.variant IN ['rs6679677', 'rs12187903', 'rs1333047', 'rs11585651', 'rs55730499'] 
AND c.disease IN ['keratinocyte carcinoma', 'anti-neutrophil antibody associated vasculitis'] 
RETURN g.variant, c.disease[0m
Full Context:
[32;1m[1;3m[][0m


292it [01:40,  1.45s/it]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs7936312', 'rs325485', 'rs13191786', 'rs72928038', 'rs7523907'] AND c.name IN ['keratinocyte carcinoma', 'autoimmune disease'] RETURN g.variant, c.name;[0m
Full Context:
[32;1m[1;3m[][0m


293it [01:46,  1.75s/it]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1950829', 'rs13263709', 'rs1126809', 'rs34871267', 'rs2431108'] AND c.disease IN ['keratinocyte carcinoma', 'age-related hearing impairment'] RETURN g.variant, c.disease[0m
Full Context:
[32;1m[1;3m[][0m


294it [01:57,  2.59s/it]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs34243448', 'rs1893592', 'rs1765871', 'rs229541', 'rs10455872'] AND c.disease IN ['aortic stenosis', 'large artery stroke'] RETURN g.variant, c.disease[0m
Full Context:
[32;1m[1;3m[][0m


295it [02:03,  2.94s/it]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1042704', 'rs6059655', 'rs34396849', 'rs10052804', 'rs11747125'] AND c.name IN ['skin sensitivity to sun', 'keratinocyte carcinoma'] RETURN g.variant, c.name;[0m
Full Context:
[32;1m[1;3m[][0m


296it [02:08,  3.23s/it]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs61815704', 'rs4149909', 'rs36001488', 'rs1333047', 'rs1126809'] AND c.name IN ['skin sensitivity to sun', 'age-related hearing impairment'] RETURN g.variant, c.name[0m
Full Context:
[32;1m[1;3m[][0m


297it [02:18,  4.31s/it]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3m[0m
Full Context:
[32;1m[1;3m[][0m


298it [02:28,  5.38s/it]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3m[0m
Full Context:
[32;1m[1;3m[][0m


299it [02:35,  5.76s/it]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs259919', 'rs10455872', 'rs11958220', 'rs72928038', 'rs141343442'] AND c.disease IN ['autoimmune disease', 'keratinocyte carcinoma'] RETURN g.variant, c.disease;[0m
Full Context:
[32;1m[1;3m[][0m


300it [02:41,  5.73s/it]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs2310752', 'rs7528604', 'rs6679677', 'rs34691223', 'rs2963222'] AND c.disease IN ['autoimmune disease', 'anti-neutrophil antibody associated vasculitis'] RETURN g.variant, c.disease[0m
Full Context:
[32;1m[1;3m[][0m


301it [02:50,  6.52s/it]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs62324212', 'rs10059133', 'rs149943', 'rs12931267', 'rs17156671'] AND c.disease IN ['autoimmune disease', 'atopic asthma'] RETURN g.variant, c.disease;[0m
Full Context:
[32;1m[1;3m[][0m


302it [02:55,  6.23s/it]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs10455872', 'rs35781323', 'rs4615152', 'rs761934676', 'rs11965538'] AND c.disease IN ['large artery stroke', 'aortic stenosis'] RETURN g.variant, c.disease[0m
Full Context:
[32;1m[1;3m[][0m


303it [03:00,  5.95s/it]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs259919', 'rs2503199', 'rs325485', 'rs1126809', 'rs229541'] AND c.disease IN ['age-related hearing impairment', 'keratinocyte carcinoma'] RETURN g.variant, c.disease[0m
Full Context:
[32;1m[1;3m[][0m


304it [03:08,  6.37s/it]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1126809', 'rs416223', 'rs12722502', 'rs9419958', 'rs1333049'] AND c.name IN ['age-related hearing impairment', 'skin sensitivity to sun'] RETURN g.variant, c.name;[0m
Full Context:
[32;1m[1;3m[][0m


305it [03:18,  7.47s/it]


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs2447827', 'rs1937455', 'rs62324212', 'rs12205199', 'rs4482879'] AND c.disease IN ['atopic asthma', 'autoimmune disease'] RETURN g.variant, c.disease[0m
Full Context:
[32;1m[1;3m[][0m


306it [03:26,  1.48it/s]


[1m> Finished chain.[0m
CPU times: user 1.83 s, sys: 194 ms, total: 2.03 s
Wall time: 3min 26s





In [18]:
cypher_rag_out_df = pd.DataFrame(cypher_rag_out, columns=['text', 'label', 'cypher_rag_answer', 'total_tokens'])


In [19]:

save_path = '../data/results'
os.makedirs(save_path, exist_ok=True)
cypher_rag_out_df.to_csv(os.path.join(save_path, 'cypher_rag_mcq_output.csv'), index=False)

