In [None]:
import os
os.chdir('..')

In [None]:
from langchain.chains import GraphCypherQAChain
from langchain.chat_models import ChatOpenAI
from langchain.graphs import Neo4jGraph
from langchain.callbacks import get_openai_callback
from dotenv import load_dotenv
import os
import openai
import pandas as pd
from neo4j.exceptions import CypherSyntaxError
from kg_rag.utility import *
from tqdm import tqdm
import pandas as pd



## Choose the LLM

In [145]:
LLM_MODEL = 'gpt-4-32k'


## Load test data

In [146]:
data = pd.read_csv('../data/rag_comparison_data.csv')



## Custom function for neo4j RAG chain

In [149]:
def get_neo4j_cypher_rag_chain():
 load_dotenv(os.path.join(os.path.expanduser('~'), '.spoke_neo4j_config.env'))
 username = os.environ.get('NEO4J_USER')
 password = os.environ.get('NEO4J_PSW')
 url = os.environ.get('NEO4J_URI')
 database = os.environ.get('NEO4J_DB')

 graph = Neo4jGraph(
 url=url, 
 username=username, 
 password=password,
 database = database
 )

 load_dotenv(os.path.join(os.path.expanduser('~'), '.gpt_config.env'))
 API_KEY = os.environ.get('API_KEY')
 API_VERSION = os.environ.get('API_VERSION')
 RESOURCE_ENDPOINT = os.environ.get('RESOURCE_ENDPOINT')
 openai.api_type = "azure"
 openai.api_key = API_KEY
 openai.api_base = RESOURCE_ENDPOINT
 openai.api_version = API_VERSION
 chat_deployment_id = LLM_MODEL
 chat_model_id = chat_deployment_id
 temperature = 0
 chat_model = ChatOpenAI(openai_api_key=API_KEY, 
 engine=chat_deployment_id, 
 temperature=temperature)
 chain = GraphCypherQAChain.from_llm(
 chat_model, 
 graph=graph, 
 verbose=True, 
 validate_cypher=True,
 return_intermediate_steps=True
 )
 return chain

## Initiate neo4j RAG chain

In [150]:
%%time
neo4j_rag_chain = get_neo4j_cypher_rag_chain()


 engine was transferred to model_kwargs.
 Please confirm that engine is what you intended.


CPU times: user 14.6 ms, sys: 4.67 ms, total: 19.2 ms
Wall time: 22.1 s


## Run on test data

In [118]:
%%time

neo4j_rag_answer = []
total_tokens_used = []

for index, row in data.iterrows():
 question = row['question']
 with get_openai_callback() as cb:
 try:
 neo4j_rag_answer.append(neo4j_rag_chain.run(query=question, return_final_only=True, verbose=False))
 except ValueError as e:
 neo4j_rag_answer.append(None)
 total_tokens_used.append(cb.total_tokens)

data.loc[:,'neo4j_rag_answer'] = neo4j_rag_answer
data.loc[:, 'total_tokens_used'] = total_tokens_used




[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'childhood-onset asthma'})-[r:ASSOCIATES_DaG]->(g:Gene {name: 'RORA'}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 2e-37}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: "SHANK2"})-[r:ASSOCIATES_DaG]-(d:Disease {name: "skin benign neoplasm"}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 5e-08}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'hypertrophic cardiomyopathy'}), (g:Gene {name: 'AMBRA1'})
RETURN EXISTS((d)-[:ASSOCIATES_DaG]->(g)) AS is_associated[0m
Full Context:
[32;1m[1;3m[{'is_associated': True}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: '

Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: 'GMDS'}), (d:Disease {name: 'hemorrhoid'}) 
RETURN EXISTS((g)-[:ASSOCIATES_DaG]->(d)) AS is_associated[0m
Full Context:
[32;1m[1;3m[{'is_associated': False}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'rheumatoid arthritis'})-[r:ASSOCIATES_DaG]->(g:Gene {name: 'DPP4'}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 2e-21}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: "SMAD7"})-[r:ASSOCIATES_DaG]-(d:Disease {name: "colon carcinoma"}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 3e-08}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: 'PKIA'})-[r:ASSOCIATES_DaG]->(d:Disease {name: 'pulmonary hypertension'}) RETURN r


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'lung squamous cell carcinoma'}), (g:Gene {name: 'PDS5B'})
RETURN EXISTS((d)-[:ASSOCIATES_DaG]->(g)) AS is_associated[0m
Full Context:
[32;1m[1;3m[{'is_associated': True}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'common variable immunodeficiency'})-[r:ASSOCIATES_DaG]-(g:Gene {name: 'CLEC16A'}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 2e-09}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: 'TERT'}), (d:Disease {name: 'lung non-small cell carcinoma'}) 
RETURN EXISTS((g)-[:ASSOCIATES_DaG]->(d)) AS association_exists[0m
Full Context:
[32;1m[1;3m[{'association_exists': False}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypher


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'Skin Melanoma'})-[r:ASSOCIATES_DaG]-(g:Gene {name: 'CYP1B1'}) RETURN d, r, g[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'myositis'})-[r:ASSOCIATES_DaG]->(g:Gene {name: 'ATP6V1G2'}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 6e-49}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: "SPG7"})-[r:ASSOCIATES_DaG]-(d:Disease {name: "melanoma"}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[{'r.gwas_pvalue': 9e-26}][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: 'EYA2'})-[:ASSOCIATES_DaG]->(d:Disease {name: 'type 2 diabetes m


[1m> Finished chain.[0m
CPU times: user 3.59 s, sys: 295 ms, total: 3.88 s
Wall time: 9min 27s


## Run on perturbed test data

In [119]:
%%time

neo4j_rag_answer = []
total_tokens_used = []

for index, row in data.iterrows():
 question = row['question_perturbed']
 with get_openai_callback() as cb:
 try:
 neo4j_rag_answer.append(neo4j_rag_chain.run(query=question, return_final_only=True, verbose=False))
 except ValueError as e:
 neo4j_rag_answer.append(None)
 total_tokens_used.append(cb.total_tokens)

data.loc[:,'neo4j_rag_answer_perturbed'] = neo4j_rag_answer
data.loc[:, 'total_tokens_used_perturbed'] = total_tokens_used




[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'childhood-onset asthma'})-[r:ASSOCIATES_DaG]->(g:Gene {name: 'rora'}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'skin benign neoplasm'}), (g:Gene {name: 'shank2'}), (d)-[r:ASSOCIATES_DaG]->(g) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'hypertrophic cardiomyopathy'}), (g:Gene {name: 'ambra1'})
RETURN EXISTS((d)-[:ASSOCIATES_DaG]->(g)) AS is_associated[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'lung adenocarcinoma'}), (g:Gene {name: 'cyp2a6'})
MATCH (d)-


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: 'prr5l'})-[:ASSOCIATES_DaG]->(d:Disease {name: 'asthma'}) RETURN g, d[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'gastric fundus cancer'})-[r:ASSOCIATES_DaG]->(g:Gene {name: 'gon4l'}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'type 2 diabetes mellitus'}), (g:Gene {name: 'dnah1'})
RETURN EXISTS((d)-[:ASSOCIATES_DaG]->(g)) AS isAssociated[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'schizophrenia'})-[r:ASSOCIATES_DaG]->(g:Gene {name: 'slc17a3'}) RETURN r.


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: 'kcnk16'})-[r:ASSOCIATES_DaG]->(d:Disease {name: 'type 2 diabetes mellitus'}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'type 1 diabetes mellitus'}), (g:Gene {name: 'dgkq'})
RETURN EXISTS((d)-[:ASSOCIATES_DaG]->(g)) AS isAssociated[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: "six6"})-[r:ASSOCIATES_DaG]->(d:Disease {name: "refractive error"}) RETURN r.gwas_pvalue[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (g:Gene {name: 'map4k4'})-[r:ASSOCIATES_DaG]-(d:Disease {name: 'parkin


[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'esophageal carcinoma'}), (g:Gene {name: 'casp8'})
RETURN EXISTS((d)-[:ASSOCIATES_DaG]->(g)) AS is_associated[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'Skin Melanoma'})-[r:ASSOCIATES_DaG]-(g:Gene {name: 'GPRC5A'}) RETURN d, r, g[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'lung squamous cell carcinoma'}), (g:Gene {name: 'brca2'}) 
RETURN EXISTS((d)-[:ASSOCIATES_DaG]->(g)) AS is_associated[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'systemic lupus erythematosus'})-[r:ASS

## Save the result

In [120]:
save_path = '../data/results'
os.makedirs(save_path, exist_ok=True)
data.to_csv(os.path.join(save_path, 'cypher_rag_output.csv'), index=False)

