diff options
| author | maszhongming <mingz5@illinois.edu> | 2025-09-16 15:15:29 -0500 |
|---|---|---|
| committer | maszhongming <mingz5@illinois.edu> | 2025-09-16 15:15:29 -0500 |
| commit | 73c194f304f827b55081b15524479f82a1b7d94c (patch) | |
| tree | 5e8660e421915420892c5eca18f1ad680f80a861 /notebooks/mcq_using_cypher_rag.ipynb | |
Initial commit
Diffstat (limited to 'notebooks/mcq_using_cypher_rag.ipynb')
| -rw-r--r-- | notebooks/mcq_using_cypher_rag.ipynb | 763 |
1 files changed, 763 insertions, 0 deletions
diff --git a/notebooks/mcq_using_cypher_rag.ipynb b/notebooks/mcq_using_cypher_rag.ipynb new file mode 100644 index 0000000..51a393f --- /dev/null +++ b/notebooks/mcq_using_cypher_rag.ipynb @@ -0,0 +1,763 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ad38bed4-0077-4097-81dc-8db1d23255ad", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chat_models import AzureChatOpenAI\n", + "from dotenv import load_dotenv\n", + "import os\n", + "import openai\n", + "from langchain.chains import GraphCypherQAChain\n", + "from langchain.graphs import Neo4jGraph\n", + "from langchain.callbacks import get_openai_callback\n", + "import pandas as pd\n", + "from tqdm import tqdm\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d035a52f-442b-43f2-a7ba-a04006937f0a", + "metadata": {}, + "outputs": [], + "source": [ + "curated_data = pd.read_csv('../data/benchmark_data/mcq_questions.csv')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ddb263ec-9bd5-4caf-8950-13b400a75b98", + "metadata": {}, + "outputs": [], + "source": [ + "LLM_MODEL = 'gpt-4'\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ff7f9255-e5fd-4eec-aade-2f7cacb8b8cd", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/karthiksoman/anaconda3/envs/cypher_rag/lib/python3.10/site-packages/langchain_core/_api/deprecation.py:119: LangChainDeprecationWarning: The class `AzureChatOpenAI` was deprecated in LangChain 0.0.10 and will be removed in 0.3.0. An updated version of the class exists in the langchain-openai package and should be used instead. To use it run `pip install -U langchain-openai` and import as `from langchain_openai import AzureChatOpenAI`.\n", + " warn_deprecated(\n" + ] + } + ], + "source": [ + "\n", + "load_dotenv(os.path.join(os.path.expanduser('~'), '.gpt_config.env'))\n", + "API_KEY = os.environ.get('API_KEY')\n", + "API_VERSION = os.environ.get('API_VERSION')\n", + "RESOURCE_ENDPOINT = os.environ.get('RESOURCE_ENDPOINT')\n", + "openai.api_type = \"azure\"\n", + "openai.api_key = API_KEY\n", + "openai.api_base = RESOURCE_ENDPOINT\n", + "openai.api_version = API_VERSION\n", + "chat_deployment_id = LLM_MODEL\n", + "chat_model_id = chat_deployment_id\n", + "temperature = 0\n", + "chat_model = AzureChatOpenAI(openai_api_key=API_KEY, openai_api_version=API_VERSION, azure_deployment=chat_deployment_id, azure_endpoint=RESOURCE_ENDPOINT, temperature=temperature)\n", + "load_dotenv(os.path.join(os.path.expanduser('~'), '.spoke_neo4j_config.env'))\n", + "username = os.environ.get('NEO4J_USER')\n", + "password = os.environ.get('NEO4J_PSW')\n", + "url = os.environ.get('NEO4J_URL')\n", + "database = os.environ.get('NEO4J_DB')\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "9f0c1ded-fd30-49cb-be0d-369bc5a906bd", + "metadata": {}, + "outputs": [], + "source": [ + "graph = Neo4jGraph(url=url, username=username, password=password, database = database)\n", + "chain = GraphCypherQAChain.from_llm(chat_model, graph=graph, verbose=True, validate_cypher=True,return_intermediate_steps=True)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "922ed8a6-500f-4937-92d2-2df6271798e6", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "0it [00:00, ?it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3m\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "282it [00:09, 29.53it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3m\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "282it [00:21, 29.53it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs7426056', 'rs2736100', 'rs2187668', 'rs2107595', 'rs7405776'] AND c.disease IN ['central nervous system cancer', 'lung adenocarcinoma'] RETURN g.variant, c.disease\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "284it [00:28, 7.63it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (d:Disease)-[:ASSOCIATES_DaG]->(g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound)\n", + "WHERE d.name IN ['thoracic aortic aneurysm', 'abdominal aortic aneurysm'] AND c.variant IN ['rs1642764', 'rs595244', 'rs139606545', 'rs12077210', 'rs12917707']\n", + "RETURN c.variant AS Variant, d.name AS Disease\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "285it [00:41, 4.53it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs8077245', 'rs11712066', 'rs11209026', 'rs10830962', 'rs6010620'] AND c.name IN ['Crohn\\'s disease', 'ankylosing spondylitis'] RETURN g.variant, c.name\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "286it [00:49, 3.41it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1047891', 'rs9268905', 'rs3197999', 'rs1025128', 'rs4624820'] AND c.disease IN ['Crohn\\'s disease', 'sclerosing cholangitis'] RETURN g.variant, c.disease\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "286it [01:01, 3.41it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (d:Disease)-[:ASSOCIATES_DaG]->(g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) \n", + "WHERE d.name IN ['lung adenocarcinoma', 'central nervous system cancer'] AND c.variant IN ['rs10490924', 'rs10830962', 'rs2736100', 'rs2391769', 'rs9272143'] \n", + "RETURN c.variant, d.name\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "288it [01:10, 1.70it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs11759064', 'rs975730', 'rs1150757', 'rs2294008', 'rs7453920'] AND c.disease IN ['gastric fundus cancer', 'duodenal ulcer'] RETURN g.variant, c.disease\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "289it [01:18, 1.37it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Received notification from DBMS server: {severity: WARNING} {code: Neo.ClientNotification.Statement.FeatureDeprecationWarning} {category: DEPRECATION} {title: This feature is deprecated and will be removed in future versions.} {description: The semantics of using colon in the separation of alternative relationship types will change in a future version. (Please use ':RESPONSE_TO_mGrC|ADVRESPONSE_TO_mGarC' instead)} {position: line: 1, column: 34, offset: 33} for query: \"MATCH (g:Gene)-[:RESPONSE_TO_mGrC|:ADVRESPONSE_TO_mGarC]->(c:Compound) \\nWHERE g.variant IN ['rs2294008', 'rs2072499', 'rs3197999', 'rs1537377', 'rs988958'] \\nAND c.disease IN ['gastric fundus cancer', 'atrophic gastritis'] \\nRETURN g.variant, c.disease\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC|:ADVRESPONSE_TO_mGarC]->(c:Compound) \n", + "WHERE g.variant IN ['rs2294008', 'rs2072499', 'rs3197999', 'rs1537377', 'rs988958'] \n", + "AND c.disease IN ['gastric fundus cancer', 'atrophic gastritis'] \n", + "RETURN g.variant, c.disease\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "290it [01:25, 1.13it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs254020', 'rs4625', 'rs6059655', 'rs11738191', 'rs2963222'] AND c.name IN ['keratinocyte carcinoma', 'skin sensitivity to sun'] RETURN g.variant, c.name;\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "291it [01:35, 1.22s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESISTANT_TO_mGrC]->(c:Compound) \n", + "WHERE g.variant IN ['rs6679677', 'rs12187903', 'rs1333047', 'rs11585651', 'rs55730499'] \n", + "AND c.disease IN ['keratinocyte carcinoma', 'anti-neutrophil antibody associated vasculitis'] \n", + "RETURN g.variant, c.disease\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "292it [01:40, 1.45s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs7936312', 'rs325485', 'rs13191786', 'rs72928038', 'rs7523907'] AND c.name IN ['keratinocyte carcinoma', 'autoimmune disease'] RETURN g.variant, c.name;\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "293it [01:46, 1.75s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1950829', 'rs13263709', 'rs1126809', 'rs34871267', 'rs2431108'] AND c.disease IN ['keratinocyte carcinoma', 'age-related hearing impairment'] RETURN g.variant, c.disease\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "294it [01:57, 2.59s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs34243448', 'rs1893592', 'rs1765871', 'rs229541', 'rs10455872'] AND c.disease IN ['aortic stenosis', 'large artery stroke'] RETURN g.variant, c.disease\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "295it [02:03, 2.94s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1042704', 'rs6059655', 'rs34396849', 'rs10052804', 'rs11747125'] AND c.name IN ['skin sensitivity to sun', 'keratinocyte carcinoma'] RETURN g.variant, c.name;\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "296it [02:08, 3.23s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs61815704', 'rs4149909', 'rs36001488', 'rs1333047', 'rs1126809'] AND c.name IN ['skin sensitivity to sun', 'age-related hearing impairment'] RETURN g.variant, c.name\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "297it [02:18, 4.31s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3m\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "298it [02:28, 5.38s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3m\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "299it [02:35, 5.76s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs259919', 'rs10455872', 'rs11958220', 'rs72928038', 'rs141343442'] AND c.disease IN ['autoimmune disease', 'keratinocyte carcinoma'] RETURN g.variant, c.disease;\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "300it [02:41, 5.73s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs2310752', 'rs7528604', 'rs6679677', 'rs34691223', 'rs2963222'] AND c.disease IN ['autoimmune disease', 'anti-neutrophil antibody associated vasculitis'] RETURN g.variant, c.disease\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "301it [02:50, 6.52s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs62324212', 'rs10059133', 'rs149943', 'rs12931267', 'rs17156671'] AND c.disease IN ['autoimmune disease', 'atopic asthma'] RETURN g.variant, c.disease;\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "302it [02:55, 6.23s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs10455872', 'rs35781323', 'rs4615152', 'rs761934676', 'rs11965538'] AND c.disease IN ['large artery stroke', 'aortic stenosis'] RETURN g.variant, c.disease\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "303it [03:00, 5.95s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs259919', 'rs2503199', 'rs325485', 'rs1126809', 'rs229541'] AND c.disease IN ['age-related hearing impairment', 'keratinocyte carcinoma'] RETURN g.variant, c.disease\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "304it [03:08, 6.37s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1126809', 'rs416223', 'rs12722502', 'rs9419958', 'rs1333049'] AND c.name IN ['age-related hearing impairment', 'skin sensitivity to sun'] RETURN g.variant, c.name;\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "305it [03:18, 7.47s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs2447827', 'rs1937455', 'rs62324212', 'rs12205199', 'rs4482879'] AND c.disease IN ['atopic asthma', 'autoimmune disease'] RETURN g.variant, c.disease\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "306it [03:26, 1.48it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "CPU times: user 1.83 s, sys: 194 ms, total: 2.03 s\n", + "Wall time: 3min 26s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "cypher_rag_out = []\n", + "for index, row in tqdm(curated_data.iterrows()):\n", + " question = row['text']\n", + " with get_openai_callback() as cb:\n", + " try:\n", + " cypher_rag_answer = chain.run(query=question, return_final_only=True, verbose=False)\n", + " except ValueError as e:\n", + " cypher_rag_answer = None\n", + " \n", + " cypher_rag_out.append((row['text'], row['correct_node'], cypher_rag_answer, cb.total_tokens)) \n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "b99f74b5-f7c0-4cd1-9825-f832d6676542", + "metadata": {}, + "outputs": [], + "source": [ + "cypher_rag_out_df = pd.DataFrame(cypher_rag_out, columns=['text', 'label', 'cypher_rag_answer', 'total_tokens'])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "77e75f27-fdc8-460c-8bd4-af90c8712a34", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "save_path = '../data/results'\n", + "os.makedirs(save_path, exist_ok=True)\n", + "cypher_rag_out_df.to_csv(os.path.join(save_path, 'cypher_rag_mcq_output.csv'), index=False)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2392ef3e-af62-4cb0-abb2-dbe7f0da8108", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (cypher_rag)", + "language": "python", + "name": "cypher_rag" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |
