summaryrefslogtreecommitdiff
path: root/notebooks/mcq_using_cypher_rag.ipynb
diff options
context:
space:
mode:
authormaszhongming <mingz5@illinois.edu>2025-09-16 15:15:29 -0500
committermaszhongming <mingz5@illinois.edu>2025-09-16 15:15:29 -0500
commit73c194f304f827b55081b15524479f82a1b7d94c (patch)
tree5e8660e421915420892c5eca18f1ad680f80a861 /notebooks/mcq_using_cypher_rag.ipynb
Initial commit
Diffstat (limited to 'notebooks/mcq_using_cypher_rag.ipynb')
-rw-r--r--notebooks/mcq_using_cypher_rag.ipynb763
1 files changed, 763 insertions, 0 deletions
diff --git a/notebooks/mcq_using_cypher_rag.ipynb b/notebooks/mcq_using_cypher_rag.ipynb
new file mode 100644
index 0000000..51a393f
--- /dev/null
+++ b/notebooks/mcq_using_cypher_rag.ipynb
@@ -0,0 +1,763 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "ad38bed4-0077-4097-81dc-8db1d23255ad",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from langchain.chat_models import AzureChatOpenAI\n",
+ "from dotenv import load_dotenv\n",
+ "import os\n",
+ "import openai\n",
+ "from langchain.chains import GraphCypherQAChain\n",
+ "from langchain.graphs import Neo4jGraph\n",
+ "from langchain.callbacks import get_openai_callback\n",
+ "import pandas as pd\n",
+ "from tqdm import tqdm\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "d035a52f-442b-43f2-a7ba-a04006937f0a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "curated_data = pd.read_csv('../data/benchmark_data/mcq_questions.csv')\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "ddb263ec-9bd5-4caf-8950-13b400a75b98",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "LLM_MODEL = 'gpt-4'\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "ff7f9255-e5fd-4eec-aade-2f7cacb8b8cd",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/karthiksoman/anaconda3/envs/cypher_rag/lib/python3.10/site-packages/langchain_core/_api/deprecation.py:119: LangChainDeprecationWarning: The class `AzureChatOpenAI` was deprecated in LangChain 0.0.10 and will be removed in 0.3.0. An updated version of the class exists in the langchain-openai package and should be used instead. To use it run `pip install -U langchain-openai` and import as `from langchain_openai import AzureChatOpenAI`.\n",
+ " warn_deprecated(\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "load_dotenv(os.path.join(os.path.expanduser('~'), '.gpt_config.env'))\n",
+ "API_KEY = os.environ.get('API_KEY')\n",
+ "API_VERSION = os.environ.get('API_VERSION')\n",
+ "RESOURCE_ENDPOINT = os.environ.get('RESOURCE_ENDPOINT')\n",
+ "openai.api_type = \"azure\"\n",
+ "openai.api_key = API_KEY\n",
+ "openai.api_base = RESOURCE_ENDPOINT\n",
+ "openai.api_version = API_VERSION\n",
+ "chat_deployment_id = LLM_MODEL\n",
+ "chat_model_id = chat_deployment_id\n",
+ "temperature = 0\n",
+ "chat_model = AzureChatOpenAI(openai_api_key=API_KEY, openai_api_version=API_VERSION, azure_deployment=chat_deployment_id, azure_endpoint=RESOURCE_ENDPOINT, temperature=temperature)\n",
+ "load_dotenv(os.path.join(os.path.expanduser('~'), '.spoke_neo4j_config.env'))\n",
+ "username = os.environ.get('NEO4J_USER')\n",
+ "password = os.environ.get('NEO4J_PSW')\n",
+ "url = os.environ.get('NEO4J_URL')\n",
+ "database = os.environ.get('NEO4J_DB')\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "9f0c1ded-fd30-49cb-be0d-369bc5a906bd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "graph = Neo4jGraph(url=url, username=username, password=password, database = database)\n",
+ "chain = GraphCypherQAChain.from_llm(chat_model, graph=graph, verbose=True, validate_cypher=True,return_intermediate_steps=True)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "922ed8a6-500f-4937-92d2-2df6271798e6",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "0it [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3m\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "282it [00:09, 29.53it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3m\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n",
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "282it [00:21, 29.53it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs7426056', 'rs2736100', 'rs2187668', 'rs2107595', 'rs7405776'] AND c.disease IN ['central nervous system cancer', 'lung adenocarcinoma'] RETURN g.variant, c.disease\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "284it [00:28, 7.63it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (d:Disease)-[:ASSOCIATES_DaG]->(g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound)\n",
+ "WHERE d.name IN ['thoracic aortic aneurysm', 'abdominal aortic aneurysm'] AND c.variant IN ['rs1642764', 'rs595244', 'rs139606545', 'rs12077210', 'rs12917707']\n",
+ "RETURN c.variant AS Variant, d.name AS Disease\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "285it [00:41, 4.53it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs8077245', 'rs11712066', 'rs11209026', 'rs10830962', 'rs6010620'] AND c.name IN ['Crohn\\'s disease', 'ankylosing spondylitis'] RETURN g.variant, c.name\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "286it [00:49, 3.41it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1047891', 'rs9268905', 'rs3197999', 'rs1025128', 'rs4624820'] AND c.disease IN ['Crohn\\'s disease', 'sclerosing cholangitis'] RETURN g.variant, c.disease\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n",
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "286it [01:01, 3.41it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (d:Disease)-[:ASSOCIATES_DaG]->(g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) \n",
+ "WHERE d.name IN ['lung adenocarcinoma', 'central nervous system cancer'] AND c.variant IN ['rs10490924', 'rs10830962', 'rs2736100', 'rs2391769', 'rs9272143'] \n",
+ "RETURN c.variant, d.name\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "288it [01:10, 1.70it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs11759064', 'rs975730', 'rs1150757', 'rs2294008', 'rs7453920'] AND c.disease IN ['gastric fundus cancer', 'duodenal ulcer'] RETURN g.variant, c.disease\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "289it [01:18, 1.37it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Received notification from DBMS server: {severity: WARNING} {code: Neo.ClientNotification.Statement.FeatureDeprecationWarning} {category: DEPRECATION} {title: This feature is deprecated and will be removed in future versions.} {description: The semantics of using colon in the separation of alternative relationship types will change in a future version. (Please use ':RESPONSE_TO_mGrC|ADVRESPONSE_TO_mGarC' instead)} {position: line: 1, column: 34, offset: 33} for query: \"MATCH (g:Gene)-[:RESPONSE_TO_mGrC|:ADVRESPONSE_TO_mGarC]->(c:Compound) \\nWHERE g.variant IN ['rs2294008', 'rs2072499', 'rs3197999', 'rs1537377', 'rs988958'] \\nAND c.disease IN ['gastric fundus cancer', 'atrophic gastritis'] \\nRETURN g.variant, c.disease\"\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC|:ADVRESPONSE_TO_mGarC]->(c:Compound) \n",
+ "WHERE g.variant IN ['rs2294008', 'rs2072499', 'rs3197999', 'rs1537377', 'rs988958'] \n",
+ "AND c.disease IN ['gastric fundus cancer', 'atrophic gastritis'] \n",
+ "RETURN g.variant, c.disease\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "290it [01:25, 1.13it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs254020', 'rs4625', 'rs6059655', 'rs11738191', 'rs2963222'] AND c.name IN ['keratinocyte carcinoma', 'skin sensitivity to sun'] RETURN g.variant, c.name;\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "291it [01:35, 1.22s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESISTANT_TO_mGrC]->(c:Compound) \n",
+ "WHERE g.variant IN ['rs6679677', 'rs12187903', 'rs1333047', 'rs11585651', 'rs55730499'] \n",
+ "AND c.disease IN ['keratinocyte carcinoma', 'anti-neutrophil antibody associated vasculitis'] \n",
+ "RETURN g.variant, c.disease\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "292it [01:40, 1.45s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs7936312', 'rs325485', 'rs13191786', 'rs72928038', 'rs7523907'] AND c.name IN ['keratinocyte carcinoma', 'autoimmune disease'] RETURN g.variant, c.name;\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "293it [01:46, 1.75s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1950829', 'rs13263709', 'rs1126809', 'rs34871267', 'rs2431108'] AND c.disease IN ['keratinocyte carcinoma', 'age-related hearing impairment'] RETURN g.variant, c.disease\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "294it [01:57, 2.59s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs34243448', 'rs1893592', 'rs1765871', 'rs229541', 'rs10455872'] AND c.disease IN ['aortic stenosis', 'large artery stroke'] RETURN g.variant, c.disease\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "295it [02:03, 2.94s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1042704', 'rs6059655', 'rs34396849', 'rs10052804', 'rs11747125'] AND c.name IN ['skin sensitivity to sun', 'keratinocyte carcinoma'] RETURN g.variant, c.name;\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "296it [02:08, 3.23s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs61815704', 'rs4149909', 'rs36001488', 'rs1333047', 'rs1126809'] AND c.name IN ['skin sensitivity to sun', 'age-related hearing impairment'] RETURN g.variant, c.name\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "297it [02:18, 4.31s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3m\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "298it [02:28, 5.38s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3m\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "299it [02:35, 5.76s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs259919', 'rs10455872', 'rs11958220', 'rs72928038', 'rs141343442'] AND c.disease IN ['autoimmune disease', 'keratinocyte carcinoma'] RETURN g.variant, c.disease;\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "300it [02:41, 5.73s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs2310752', 'rs7528604', 'rs6679677', 'rs34691223', 'rs2963222'] AND c.disease IN ['autoimmune disease', 'anti-neutrophil antibody associated vasculitis'] RETURN g.variant, c.disease\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "301it [02:50, 6.52s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs62324212', 'rs10059133', 'rs149943', 'rs12931267', 'rs17156671'] AND c.disease IN ['autoimmune disease', 'atopic asthma'] RETURN g.variant, c.disease;\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "302it [02:55, 6.23s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs10455872', 'rs35781323', 'rs4615152', 'rs761934676', 'rs11965538'] AND c.disease IN ['large artery stroke', 'aortic stenosis'] RETURN g.variant, c.disease\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "303it [03:00, 5.95s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs259919', 'rs2503199', 'rs325485', 'rs1126809', 'rs229541'] AND c.disease IN ['age-related hearing impairment', 'keratinocyte carcinoma'] RETURN g.variant, c.disease\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "304it [03:08, 6.37s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1126809', 'rs416223', 'rs12722502', 'rs9419958', 'rs1333049'] AND c.name IN ['age-related hearing impairment', 'skin sensitivity to sun'] RETURN g.variant, c.name;\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "305it [03:18, 7.47s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "\n",
+ "\n",
+ "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n",
+ "Generated Cypher:\n",
+ "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs2447827', 'rs1937455', 'rs62324212', 'rs12205199', 'rs4482879'] AND c.disease IN ['atopic asthma', 'autoimmune disease'] RETURN g.variant, c.disease\u001b[0m\n",
+ "Full Context:\n",
+ "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "306it [03:26, 1.48it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\u001b[1m> Finished chain.\u001b[0m\n",
+ "CPU times: user 1.83 s, sys: 194 ms, total: 2.03 s\n",
+ "Wall time: 3min 26s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "\n",
+ "cypher_rag_out = []\n",
+ "for index, row in tqdm(curated_data.iterrows()):\n",
+ " question = row['text']\n",
+ " with get_openai_callback() as cb:\n",
+ " try:\n",
+ " cypher_rag_answer = chain.run(query=question, return_final_only=True, verbose=False)\n",
+ " except ValueError as e:\n",
+ " cypher_rag_answer = None\n",
+ " \n",
+ " cypher_rag_out.append((row['text'], row['correct_node'], cypher_rag_answer, cb.total_tokens)) \n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "b99f74b5-f7c0-4cd1-9825-f832d6676542",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cypher_rag_out_df = pd.DataFrame(cypher_rag_out, columns=['text', 'label', 'cypher_rag_answer', 'total_tokens'])\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "77e75f27-fdc8-460c-8bd4-af90c8712a34",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "save_path = '../data/results'\n",
+ "os.makedirs(save_path, exist_ok=True)\n",
+ "cypher_rag_out_df.to_csv(os.path.join(save_path, 'cypher_rag_mcq_output.csv'), index=False)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2392ef3e-af62-4cb0-abb2-dbe7f0da8108",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python (cypher_rag)",
+ "language": "python",
+ "name": "cypher_rag"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}