{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "ad38bed4-0077-4097-81dc-8db1d23255ad", "metadata": {}, "outputs": [], "source": [ "from langchain.chat_models import AzureChatOpenAI\n", "from dotenv import load_dotenv\n", "import os\n", "import openai\n", "from langchain.chains import GraphCypherQAChain\n", "from langchain.graphs import Neo4jGraph\n", "from langchain.callbacks import get_openai_callback\n", "import pandas as pd\n", "from tqdm import tqdm\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "d035a52f-442b-43f2-a7ba-a04006937f0a", "metadata": {}, "outputs": [], "source": [ "curated_data = pd.read_csv('../data/benchmark_data/mcq_questions.csv')\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "ddb263ec-9bd5-4caf-8950-13b400a75b98", "metadata": {}, "outputs": [], "source": [ "LLM_MODEL = 'gpt-4'\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "ff7f9255-e5fd-4eec-aade-2f7cacb8b8cd", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/karthiksoman/anaconda3/envs/cypher_rag/lib/python3.10/site-packages/langchain_core/_api/deprecation.py:119: LangChainDeprecationWarning: The class `AzureChatOpenAI` was deprecated in LangChain 0.0.10 and will be removed in 0.3.0. An updated version of the class exists in the langchain-openai package and should be used instead. To use it run `pip install -U langchain-openai` and import as `from langchain_openai import AzureChatOpenAI`.\n", " warn_deprecated(\n" ] } ], "source": [ "\n", "load_dotenv(os.path.join(os.path.expanduser('~'), '.gpt_config.env'))\n", "API_KEY = os.environ.get('API_KEY')\n", "API_VERSION = os.environ.get('API_VERSION')\n", "RESOURCE_ENDPOINT = os.environ.get('RESOURCE_ENDPOINT')\n", "openai.api_type = \"azure\"\n", "openai.api_key = API_KEY\n", "openai.api_base = RESOURCE_ENDPOINT\n", "openai.api_version = API_VERSION\n", "chat_deployment_id = LLM_MODEL\n", "chat_model_id = chat_deployment_id\n", "temperature = 0\n", "chat_model = AzureChatOpenAI(openai_api_key=API_KEY, openai_api_version=API_VERSION, azure_deployment=chat_deployment_id, azure_endpoint=RESOURCE_ENDPOINT, temperature=temperature)\n", "load_dotenv(os.path.join(os.path.expanduser('~'), '.spoke_neo4j_config.env'))\n", "username = os.environ.get('NEO4J_USER')\n", "password = os.environ.get('NEO4J_PSW')\n", "url = os.environ.get('NEO4J_URL')\n", "database = os.environ.get('NEO4J_DB')\n", "\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "9f0c1ded-fd30-49cb-be0d-369bc5a906bd", "metadata": {}, "outputs": [], "source": [ "graph = Neo4jGraph(url=url, username=username, password=password, database = database)\n", "chain = GraphCypherQAChain.from_llm(chat_model, graph=graph, verbose=True, validate_cypher=True,return_intermediate_steps=True)\n", "\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "922ed8a6-500f-4937-92d2-2df6271798e6", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "0it [00:00, ?it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3m\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "282it [00:09, 29.53it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3m\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "282it [00:21, 29.53it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs7426056', 'rs2736100', 'rs2187668', 'rs2107595', 'rs7405776'] AND c.disease IN ['central nervous system cancer', 'lung adenocarcinoma'] RETURN g.variant, c.disease\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "284it [00:28, 7.63it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (d:Disease)-[:ASSOCIATES_DaG]->(g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound)\n", "WHERE d.name IN ['thoracic aortic aneurysm', 'abdominal aortic aneurysm'] AND c.variant IN ['rs1642764', 'rs595244', 'rs139606545', 'rs12077210', 'rs12917707']\n", "RETURN c.variant AS Variant, d.name AS Disease\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "285it [00:41, 4.53it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs8077245', 'rs11712066', 'rs11209026', 'rs10830962', 'rs6010620'] AND c.name IN ['Crohn\\'s disease', 'ankylosing spondylitis'] RETURN g.variant, c.name\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "286it [00:49, 3.41it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1047891', 'rs9268905', 'rs3197999', 'rs1025128', 'rs4624820'] AND c.disease IN ['Crohn\\'s disease', 'sclerosing cholangitis'] RETURN g.variant, c.disease\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "286it [01:01, 3.41it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (d:Disease)-[:ASSOCIATES_DaG]->(g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) \n", "WHERE d.name IN ['lung adenocarcinoma', 'central nervous system cancer'] AND c.variant IN ['rs10490924', 'rs10830962', 'rs2736100', 'rs2391769', 'rs9272143'] \n", "RETURN c.variant, d.name\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "288it [01:10, 1.70it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs11759064', 'rs975730', 'rs1150757', 'rs2294008', 'rs7453920'] AND c.disease IN ['gastric fundus cancer', 'duodenal ulcer'] RETURN g.variant, c.disease\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "289it [01:18, 1.37it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Received notification from DBMS server: {severity: WARNING} {code: Neo.ClientNotification.Statement.FeatureDeprecationWarning} {category: DEPRECATION} {title: This feature is deprecated and will be removed in future versions.} {description: The semantics of using colon in the separation of alternative relationship types will change in a future version. (Please use ':RESPONSE_TO_mGrC|ADVRESPONSE_TO_mGarC' instead)} {position: line: 1, column: 34, offset: 33} for query: \"MATCH (g:Gene)-[:RESPONSE_TO_mGrC|:ADVRESPONSE_TO_mGarC]->(c:Compound) \\nWHERE g.variant IN ['rs2294008', 'rs2072499', 'rs3197999', 'rs1537377', 'rs988958'] \\nAND c.disease IN ['gastric fundus cancer', 'atrophic gastritis'] \\nRETURN g.variant, c.disease\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC|:ADVRESPONSE_TO_mGarC]->(c:Compound) \n", "WHERE g.variant IN ['rs2294008', 'rs2072499', 'rs3197999', 'rs1537377', 'rs988958'] \n", "AND c.disease IN ['gastric fundus cancer', 'atrophic gastritis'] \n", "RETURN g.variant, c.disease\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "290it [01:25, 1.13it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs254020', 'rs4625', 'rs6059655', 'rs11738191', 'rs2963222'] AND c.name IN ['keratinocyte carcinoma', 'skin sensitivity to sun'] RETURN g.variant, c.name;\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "291it [01:35, 1.22s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESISTANT_TO_mGrC]->(c:Compound) \n", "WHERE g.variant IN ['rs6679677', 'rs12187903', 'rs1333047', 'rs11585651', 'rs55730499'] \n", "AND c.disease IN ['keratinocyte carcinoma', 'anti-neutrophil antibody associated vasculitis'] \n", "RETURN g.variant, c.disease\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "292it [01:40, 1.45s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs7936312', 'rs325485', 'rs13191786', 'rs72928038', 'rs7523907'] AND c.name IN ['keratinocyte carcinoma', 'autoimmune disease'] RETURN g.variant, c.name;\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "293it [01:46, 1.75s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1950829', 'rs13263709', 'rs1126809', 'rs34871267', 'rs2431108'] AND c.disease IN ['keratinocyte carcinoma', 'age-related hearing impairment'] RETURN g.variant, c.disease\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "294it [01:57, 2.59s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs34243448', 'rs1893592', 'rs1765871', 'rs229541', 'rs10455872'] AND c.disease IN ['aortic stenosis', 'large artery stroke'] RETURN g.variant, c.disease\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "295it [02:03, 2.94s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1042704', 'rs6059655', 'rs34396849', 'rs10052804', 'rs11747125'] AND c.name IN ['skin sensitivity to sun', 'keratinocyte carcinoma'] RETURN g.variant, c.name;\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "296it [02:08, 3.23s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs61815704', 'rs4149909', 'rs36001488', 'rs1333047', 'rs1126809'] AND c.name IN ['skin sensitivity to sun', 'age-related hearing impairment'] RETURN g.variant, c.name\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "297it [02:18, 4.31s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3m\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "298it [02:28, 5.38s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3m\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "299it [02:35, 5.76s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs259919', 'rs10455872', 'rs11958220', 'rs72928038', 'rs141343442'] AND c.disease IN ['autoimmune disease', 'keratinocyte carcinoma'] RETURN g.variant, c.disease;\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "300it [02:41, 5.73s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs2310752', 'rs7528604', 'rs6679677', 'rs34691223', 'rs2963222'] AND c.disease IN ['autoimmune disease', 'anti-neutrophil antibody associated vasculitis'] RETURN g.variant, c.disease\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "301it [02:50, 6.52s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs62324212', 'rs10059133', 'rs149943', 'rs12931267', 'rs17156671'] AND c.disease IN ['autoimmune disease', 'atopic asthma'] RETURN g.variant, c.disease;\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "302it [02:55, 6.23s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs10455872', 'rs35781323', 'rs4615152', 'rs761934676', 'rs11965538'] AND c.disease IN ['large artery stroke', 'aortic stenosis'] RETURN g.variant, c.disease\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "303it [03:00, 5.95s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs259919', 'rs2503199', 'rs325485', 'rs1126809', 'rs229541'] AND c.disease IN ['age-related hearing impairment', 'keratinocyte carcinoma'] RETURN g.variant, c.disease\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "304it [03:08, 6.37s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs1126809', 'rs416223', 'rs12722502', 'rs9419958', 'rs1333049'] AND c.name IN ['age-related hearing impairment', 'skin sensitivity to sun'] RETURN g.variant, c.name;\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "305it [03:18, 7.47s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", "Generated Cypher:\n", "\u001b[32;1m\u001b[1;3mMATCH (g:Gene)-[:RESPONSE_TO_mGrC]->(c:Compound) WHERE g.variant IN ['rs2447827', 'rs1937455', 'rs62324212', 'rs12205199', 'rs4482879'] AND c.disease IN ['atopic asthma', 'autoimmune disease'] RETURN g.variant, c.disease\u001b[0m\n", "Full Context:\n", "\u001b[32;1m\u001b[1;3m[]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "306it [03:26, 1.48it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "CPU times: user 1.83 s, sys: 194 ms, total: 2.03 s\n", "Wall time: 3min 26s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "%%time\n", "\n", "cypher_rag_out = []\n", "for index, row in tqdm(curated_data.iterrows()):\n", " question = row['text']\n", " with get_openai_callback() as cb:\n", " try:\n", " cypher_rag_answer = chain.run(query=question, return_final_only=True, verbose=False)\n", " except ValueError as e:\n", " cypher_rag_answer = None\n", " \n", " cypher_rag_out.append((row['text'], row['correct_node'], cypher_rag_answer, cb.total_tokens)) \n" ] }, { "cell_type": "code", "execution_count": 18, "id": "b99f74b5-f7c0-4cd1-9825-f832d6676542", "metadata": {}, "outputs": [], "source": [ "cypher_rag_out_df = pd.DataFrame(cypher_rag_out, columns=['text', 'label', 'cypher_rag_answer', 'total_tokens'])\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "77e75f27-fdc8-460c-8bd4-af90c8712a34", "metadata": {}, "outputs": [], "source": [ "\n", "save_path = '../data/results'\n", "os.makedirs(save_path, exist_ok=True)\n", "cypher_rag_out_df.to_csv(os.path.join(save_path, 'cypher_rag_mcq_output.csv'), index=False)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2392ef3e-af62-4cb0-abb2-dbe7f0da8108", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python (cypher_rag)", "language": "python", "name": "cypher_rag" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }