diff options
Diffstat (limited to 'notebooks/full_text_index_vs_kg_rag.ipynb')
| -rw-r--r-- | notebooks/full_text_index_vs_kg_rag.ipynb | 893 |
1 files changed, 893 insertions, 0 deletions
diff --git a/notebooks/full_text_index_vs_kg_rag.ipynb b/notebooks/full_text_index_vs_kg_rag.ipynb new file mode 100644 index 0000000..782e038 --- /dev/null +++ b/notebooks/full_text_index_vs_kg_rag.ipynb @@ -0,0 +1,893 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "7bed2b6f-b8b2-4256-869c-1f3fa8561d1a", + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "import pandas as pd\n", + "import os\n", + "from dotenv import load_dotenv\n", + "from neo4j import GraphDatabase, basic_auth\n", + "from tqdm import tqdm\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import openai\n", + "import numpy as np\n" + ] + }, + { + "cell_type": "markdown", + "id": "b0b0240a-87ba-4e3e-b672-394cb2dd3c3a", + "metadata": {}, + "source": [ + "## Load data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4a162750-64e9-4906-84b3-7858b82a17da", + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.read_csv('../data/rag_comparison_data.csv')\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "ffc91ffd-d76c-427f-9cbc-50b07e62fa98", + "metadata": {}, + "source": [ + "## Custom functions" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "82a7dd0b-01a5-4501-82bc-96224074c9d0", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def connect_to_graph():\n", + " auth = basic_auth(graph_usr, graph_psw)\n", + " return GraphDatabase.driver(graph_uri, auth=auth)\n", + "\n", + "def run_cypher(cypher_query):\n", + " with connect_to_graph() as sdb:\n", + " with sdb.session() as session:\n", + " result = session.run(cypher_query)\n", + " out = []\n", + " for row in result:\n", + " out.append((row['d_name'], row['r_type'], row['g_label'], row['g_name'], row['relationship_properties']))\n", + " return out\n", + "\n", + "\n", + "def lucene_search(query, source):\n", + " source_search_uri = f\"https://spoke.rbvi.ucsf.edu/api/v1/search/{source}/{query}\"\n", + " source_search_resp = requests.get(source_search_uri)\n", + " return source_search_resp\n", + "\n", + "def get_context_using_lucene_search(query):\n", + " source = 'Disease'\n", + " source_resp = lucene_search(query, source) \n", + " if source_resp.status_code == 200:\n", + " source_resp_data = source_resp.json()\n", + " source_name = source_resp_data[0]['name']\n", + " else:\n", + " return ''\n", + " cypher = f'''\n", + " MATCH(d:Disease{{name:\"{source_name}\"}})-[r]-(g) \n", + " RETURN DISTINCT d.name AS d_name, TYPE(r) AS r_type, LABELS(g) AS g_label, g.name AS g_name, PROPERTIES(r) AS relationship_properties \n", + " '''\n", + " graph_out = run_cypher(cypher)\n", + " context = ''\n", + " if len(graph_out) > 0: \n", + " for i in graph_out:\n", + " try:\n", + " prov = ', '.join(i[3][\"sources\"]) \n", + " except:\n", + " try:\n", + " prov = i[3][\"source\"]\n", + " except:\n", + " prov = ''\n", + " context += 'Disease ' + i[0] + ' ' + i[1].split('_')[0].lower() + ' ' + i[2][0] + ' ' + i[3] + f'. Provenance of this association is {prov}. ' + '\\n' + str(i[4]) \n", + " return context\n", + "\n", + "\n", + "\n", + "def chat_completion_with_token_usage(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature):\n", + " response = openai.ChatCompletion.create(\n", + " temperature=temperature,\n", + " deployment_id=chat_deployment_id,\n", + " model=chat_model_id,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": instruction}\n", + " ]\n", + " )\n", + " return response['choices'][0]['message']['content'], response.usage.total_tokens\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "09225048-3d6e-46e0-91c3-a750c9a49c34", + "metadata": {}, + "source": [ + "## OpenAI credentials" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "07387b5e-c158-4e16-bec1-fb6dbcba80f3", + "metadata": {}, + "outputs": [], + "source": [ + "load_dotenv(os.path.join(os.path.expanduser('~'), '.gpt_config.env'))\n", + "\n", + "api_key = os.environ.get('API_KEY')\n", + "api_version = os.environ.get('API_VERSION')\n", + "resource_endpoint = os.environ.get('RESOURCE_ENDPOINT')\n", + "openai.api_type = 'azure'\n", + "openai.api_key = api_key\n", + "if resource_endpoint:\n", + " openai.api_base = resource_endpoint\n", + "if api_version:\n", + " openai.api_version = api_version\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "71f31b21-527e-42cd-9998-9ebf328af4f1", + "metadata": {}, + "source": [ + "## Graph credentials" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "7e45c21d-712a-4957-adcc-2dee2ca12e06", + "metadata": {}, + "outputs": [], + "source": [ + "load_dotenv(os.path.join(os.path.expanduser('~'), '.spoke_neo4j_config.env'))\n", + "graph_usr = os.environ.get('NEO4J_USER')\n", + "graph_psw = os.environ.get('NEO4J_PSW')\n", + "graph_uri = os.environ.get('NEO4J_URL')\n", + "database = os.environ.get('NEO4J_DB')\n" + ] + }, + { + "cell_type": "markdown", + "id": "532c603b-c1d3-4d8a-abea-ba10c186f726", + "metadata": {}, + "source": [ + "## Setting system prompt for LLM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d57b2ec-c5ca-49ad-9d7f-e566009838d8", + "metadata": {}, + "outputs": [], + "source": [ + "system_prompt = '''\n", + " You are an expert biomedical researcher. \n", + " For answering the Question at the end with brevity, you need to first read the Context provided. \n", + " Then give your final answer briefly, by citing the Provenance information from the context. \n", + " You can find Provenance from the Context statement 'Provenance of this association is <Provenance>'. \n", + " Do not forget to cite the Provenance information. \n", + " Note that, if Provenance is 'GWAS' report it as 'GWAS Catalog'. \n", + " If Provenance is 'DISEASES' report it as 'DISEASES database - https://diseases.jensenlab.org'. \n", + " Additionally, when providing drug or medication suggestions, give maximum information available and then advise the user to seek guidance from a healthcare professional as a precautionary measure.\n", + "'''\n", + "\n", + "chat_model = 'gpt-4-32k'\n", + "temperature = 0\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "422b4993-ebfe-4088-af9a-0a6d013649e2", + "metadata": {}, + "source": [ + "## Example query for Lucene based RAG" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "5a412550-35f5-4b33-b15f-36a1743234a1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Yes, Parkinson's disease is associated with the PINK1 gene. Provenance of this association is not provided in the context.\n" + ] + } + ], + "source": [ + "\n", + "query = \"Is Parkinson's disease associated with PINK1 gene?\"\n", + "\n", + "context = get_context_using_lucene_search(query)\n", + "\n", + "prompt = f'''\n", + "Context: {context}\n", + "Question: {query}\n", + "'''\n", + "output, token_usage = chat_completion_with_token_usage(prompt, system_prompt, chat_model, chat_model, temperature)\n", + "\n", + "print(output)\n" + ] + }, + { + "cell_type": "markdown", + "id": "2c0d739f-d3e0-4137-9857-ab98b0dd52ee", + "metadata": {}, + "source": [ + "## Example query for Lucene based RAG (after perturbation)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "4c01aef6-a01b-4c94-a8cb-f02858b9d5d9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Yes, Parkinson's disease is associated with the PINK1 gene. Provenance of this association is not provided in the context.\n" + ] + } + ], + "source": [ + "\n", + "query = \"Is parkinson's disease associated with pink1 gene?\"\n", + "\n", + "context = get_context_using_lucene_search(query)\n", + "\n", + "prompt = f'''\n", + "Context: {context}\n", + "Question: {query}\n", + "'''\n", + "output, token_usage = chat_completion_with_token_usage(prompt, system_prompt, chat_model, chat_model, temperature)\n", + "\n", + "print(output)\n" + ] + }, + { + "cell_type": "markdown", + "id": "9b73767d-8575-47c0-95aa-923258d2af84", + "metadata": {}, + "source": [ + "## Lucene based context extraction" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "83543617-01ca-49da-aef4-8e936fd39218", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100it [00:43, 2.31it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 14.1 s, sys: 1.04 s, total: 15.1 s\n", + "Wall time: 43.3 s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "lucene_based_context_list = []\n", + "for row_index, row in tqdm(data.iterrows()):\n", + " query = row['question']\n", + " lucene_based_context_list.append(get_context_using_lucene_search(query))\n", + "\n", + "data['extracted_context'] = lucene_based_context_list\n", + "# data_non_empty_context = data[data.extracted_context != '']\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "85138849-4f58-45de-9a28-7eecb9612412", + "metadata": {}, + "source": [ + "## Lucene based context extraction - after query perturbation" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "24dd2e6e-c9c5-4341-b14d-e5b37ded88dc", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100it [01:00, 1.66it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 19.5 s, sys: 1.32 s, total: 20.8 s\n", + "Wall time: 1min\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "lucene_based_context_list = []\n", + "for row_index, row in tqdm(data.iterrows()):\n", + " query = row['question_perturbed']\n", + " lucene_based_context_list.append(get_context_using_lucene_search(query))\n", + "\n", + "data['extracted_context_after_perturbation'] = lucene_based_context_list\n", + "\n", + "# data_non_empty_context_after_perturbation = data[data.extracted_context_after_perturbation != '']\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "eaf7d711-1a8f-45d4-ba60-d8519e06c4c8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100it [05:10, 3.11s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 1.94 s, sys: 484 ms, total: 2.42 s\n", + "Wall time: 5min 10s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "full_text_track_output = []\n", + "full_text_track_token_usage = []\n", + "for row_index, row in tqdm(data.iterrows()):\n", + " query = row['question']\n", + " context = row['extracted_context']\n", + " prompt = f'''\n", + " Context: {context}\n", + " Question: {query}\n", + " '''\n", + " try:\n", + " output, token_usage = chat_completion_with_token_usage(prompt, system_prompt, chat_model, chat_model, temperature)\n", + " full_text_track_output.append(output)\n", + " full_text_track_token_usage.append(token_usage)\n", + " except:\n", + " full_text_track_output.append(None)\n", + " full_text_track_token_usage.append(None)\n", + " \n", + "data['full_text_index_answer'] = full_text_track_output\n", + "data['token_usage'] = full_text_track_token_usage\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d596ec29-dae9-4b7b-9cf4-6f515943d946", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100it [04:53, 2.94s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 1.81 s, sys: 503 ms, total: 2.31 s\n", + "Wall time: 4min 53s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "full_text_track_output = []\n", + "full_text_track_token_usage = []\n", + "for row_index, row in tqdm(data.iterrows()):\n", + " query = row['question_perturbed']\n", + " context = row['extracted_context_after_perturbation']\n", + " prompt = f'''\n", + " Context: {context}\n", + " Question: {query}\n", + " '''\n", + " try:\n", + " output, token_usage = chat_completion_with_token_usage(prompt, system_prompt, chat_model, chat_model, temperature)\n", + " full_text_track_output.append(output)\n", + " full_text_track_token_usage.append(token_usage)\n", + " except:\n", + " full_text_track_output.append(None)\n", + " full_text_track_token_usage.append(None)\n", + " \n", + "data['full_text_index_answer_after_perturbation'] = full_text_track_output\n", + "data['token_usage_after_perturbation'] = full_text_track_token_usage\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "06f551db-34a6-4aca-b591-a30b66f095c8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Correct retrieval percentage for Full-text index based retrieval 61.0%\n" + ] + } + ], + "source": [ + "data_no_nan = data.dropna(subset=['full_text_index_answer'])\n", + "data_yes_count_df = data_no_nan[data_no_nan.full_text_index_answer.str.contains('Yes')]\n", + "data_yes_count = data_yes_count_df.shape[0]\n", + "indices_to_remove = data_yes_count_df.index.tolist()\n", + "data_no_nan = data_no_nan.drop(indices_to_remove)\n", + "data_no_nan.loc[:, 'contains_pvalue'] = data_no_nan.apply(lambda row: str(row['gwas_pvalue']) in str(row['full_text_index_answer']), axis=1)\n", + "data_p_value_correct_retrieval_count = data_no_nan[data_no_nan.contains_pvalue==True].shape[0]\n", + "data_total_correct_retrieval = data_yes_count + data_p_value_correct_retrieval_count\n", + "\n", + "full_text_index_based_total_correct_retrieval_percentage = 100*data_total_correct_retrieval/data.shape[0]\n", + "\n", + "print(f'Correct retrieval percentage for Full-text index based retrieval {full_text_index_based_total_correct_retrieval_percentage}%')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "050699a8-ff46-4b4b-a31a-808ef0799e02", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Correct retrieval percentage for Full-text index based retrieval after name perturbation 58.0%\n" + ] + } + ], + "source": [ + "data_no_nan = data.dropna(subset=['full_text_index_answer_after_perturbation'])\n", + "data_yes_count_df = data_no_nan[data_no_nan.full_text_index_answer_after_perturbation.str.contains('Yes')]\n", + "data_yes_count = data_yes_count_df.shape[0]\n", + "indices_to_remove = data_yes_count_df.index.tolist()\n", + "data_no_nan = data_no_nan.drop(indices_to_remove)\n", + "data_no_nan.loc[:, 'contains_pvalue'] = data_no_nan.apply(lambda row: str(row['gwas_pvalue']) in str(row['full_text_index_answer_after_perturbation']), axis=1)\n", + "data_p_value_correct_retrieval_count = data_no_nan[data_no_nan.contains_pvalue==True].shape[0]\n", + "data_total_correct_retrieval = data_yes_count + data_p_value_correct_retrieval_count\n", + "\n", + "full_text_index_based_total_correct_retrieval_perturbed_percentage = 100*data_total_correct_retrieval/data.shape[0]\n", + "\n", + "print(f'Correct retrieval percentage for Full-text index based retrieval after name perturbation {full_text_index_based_total_correct_retrieval_perturbed_percentage}%')\n" + ] + }, + { + "cell_type": "markdown", + "id": "e4a46cb6-60ea-426a-96f6-f047985b4178", + "metadata": {}, + "source": [ + "## KG-RAG" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "db00c5f2-a1b7-4a77-b26f-5fc0dbdb81dc", + "metadata": {}, + "outputs": [], + "source": [ + "kg_rag = pd.read_csv('../data/results/kg_rag_output.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "dbf582a6-079b-4676-a9b7-61bbe9f544c8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Correct retrieval percentage for KG-RAG 97.0%\n" + ] + } + ], + "source": [ + "\n", + "kg_rag_no_nan = kg_rag.dropna(subset=['kg_rag_answer'])\n", + "kg_rag_yes_count_df = kg_rag_no_nan[kg_rag_no_nan.kg_rag_answer.str.contains('Yes')]\n", + "kg_rag_yes_count = kg_rag_yes_count_df.shape[0]\n", + "indices_to_remove = kg_rag_yes_count_df.index.tolist()\n", + "kg_rag_no_nan = kg_rag_no_nan.drop(indices_to_remove)\n", + "kg_rag_no_nan.loc[:, 'contains_pvalue'] = kg_rag_no_nan.apply(lambda row: str(row['gwas_pvalue']) in str(row['kg_rag_answer']), axis=1)\n", + "kg_rag_p_value_correct_retrieval_count = kg_rag_no_nan[kg_rag_no_nan.contains_pvalue==True].shape[0]\n", + "kg_rag_total_correct_retrieval = kg_rag_yes_count + kg_rag_p_value_correct_retrieval_count\n", + "\n", + "\n", + "kg_rag_total_correct_retrieval_percentage = 100*kg_rag_total_correct_retrieval/kg_rag.shape[0]\n", + "\n", + "print(f'Correct retrieval percentage for KG-RAG {kg_rag_total_correct_retrieval_percentage}%')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "d4212209-cf87-4e14-b5fa-342c33080117", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Correct retrieval percentage for KG-RAG after name perturbation 97.0%\n" + ] + } + ], + "source": [ + "\n", + "\n", + "kg_rag_no_nan = kg_rag.dropna(subset=['kg_rag_answer_perturbed'])\n", + "kg_rag_yes_count_df = kg_rag_no_nan[kg_rag_no_nan.kg_rag_answer_perturbed.str.contains('Yes')]\n", + "kg_rag_yes_count = kg_rag_yes_count_df.shape[0]\n", + "indices_to_remove = kg_rag_yes_count_df.index.tolist()\n", + "kg_rag_no_nan = kg_rag_no_nan.drop(indices_to_remove)\n", + "kg_rag_no_nan.loc[:, 'contains_pvalue'] = kg_rag_no_nan.apply(lambda row: str(row['gwas_pvalue']) in str(row['kg_rag_answer_perturbed']), axis=1)\n", + "kg_rag_p_value_correct_retrieval_count = kg_rag_no_nan[kg_rag_no_nan.contains_pvalue==True].shape[0]\n", + "kg_rag_total_correct_retrieval_perturbed = kg_rag_yes_count + kg_rag_p_value_correct_retrieval_count\n", + "\n", + "\n", + "kg_rag_total_correct_retrieval_perturbed_percentage = 100*kg_rag_total_correct_retrieval_perturbed/kg_rag.shape[0]\n", + "\n", + "print(f'Correct retrieval percentage for KG-RAG after name perturbation {kg_rag_total_correct_retrieval_perturbed_percentage}%')\n" + ] + }, + { + "cell_type": "markdown", + "id": "14046f1f-91ea-4a33-b7a0-15c42d335eec", + "metadata": {}, + "source": [ + "## Cypher-RAG" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "9fc3b113-bdd8-46cf-80a7-5611005ee613", + "metadata": {}, + "outputs": [], + "source": [ + "neo4j_rag = pd.read_csv('../data/results/cypher_rag_output.csv')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "acf53878-a5a9-42ff-ba51-3ec71482ca36", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Correct retrieval percentage for Cypher-RAG 75.0%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/p1/h56gxdhs5vgb0ztp7h4z606h0000gn/T/ipykernel_70250/1568521084.py:2: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " neo4j_rag_no_nan.loc[:, 'contains_pvalue'] = neo4j_rag_no_nan.apply(lambda row: str(row['gwas_pvalue']) in str(row['neo4j_rag_answer']), axis=1)\n" + ] + } + ], + "source": [ + "\n", + "neo4j_rag_no_nan = neo4j_rag.dropna(subset=['neo4j_rag_answer'])\n", + "neo4j_rag_no_nan.loc[:, 'contains_pvalue'] = neo4j_rag_no_nan.apply(lambda row: str(row['gwas_pvalue']) in str(row['neo4j_rag_answer']), axis=1)\n", + "neo4j_rag_yes_count_df = neo4j_rag_no_nan[neo4j_rag_no_nan.neo4j_rag_answer.str.contains('Yes')]\n", + "neo4j_rag_yes_count = neo4j_rag_yes_count_df.shape[0]\n", + "indices_to_remove = neo4j_rag_yes_count_df.index.tolist()\n", + "neo4j_rag_no_nan = neo4j_rag_no_nan.drop(indices_to_remove)\n", + "neo4j_rag_p_value_correct_retrieval_count = neo4j_rag_no_nan[neo4j_rag_no_nan.contains_pvalue==True].shape[0]\n", + "neo4j_rag_total_correct_retrieval = neo4j_rag_yes_count + neo4j_rag_p_value_correct_retrieval_count\n", + "\n", + "neo4j_rag_total_correct_retrieval_percentage = 100*neo4j_rag_total_correct_retrieval/neo4j_rag.shape[0]\n", + "\n", + "print(f'Correct retrieval percentage for Cypher-RAG {neo4j_rag_total_correct_retrieval_percentage}%')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "24c274c9-25b6-4db9-85cb-92149b17b685", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Correct retrieval percentage for Cypher-RAG after name perturbation 0.0%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/p1/h56gxdhs5vgb0ztp7h4z606h0000gn/T/ipykernel_70250/1404871373.py:2: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " neo4j_rag_no_nan.loc[:, 'contains_pvalue'] = neo4j_rag_no_nan.apply(lambda row: str(row['gwas_pvalue']) in str(row['neo4j_rag_answer_perturbed']), axis=1)\n" + ] + } + ], + "source": [ + "\n", + "neo4j_rag_no_nan = neo4j_rag.dropna(subset=['neo4j_rag_answer_perturbed'])\n", + "neo4j_rag_no_nan.loc[:, 'contains_pvalue'] = neo4j_rag_no_nan.apply(lambda row: str(row['gwas_pvalue']) in str(row['neo4j_rag_answer_perturbed']), axis=1)\n", + "neo4j_rag_yes_count_df = neo4j_rag_no_nan[neo4j_rag_no_nan.neo4j_rag_answer_perturbed.str.contains('Yes')]\n", + "neo4j_rag_yes_count = neo4j_rag_yes_count_df.shape[0]\n", + "indices_to_remove = neo4j_rag_yes_count_df.index.tolist()\n", + "neo4j_rag_no_nan = neo4j_rag_no_nan.drop(indices_to_remove)\n", + "neo4j_rag_p_value_correct_retrieval_count = neo4j_rag_no_nan[neo4j_rag_no_nan.contains_pvalue==True].shape[0]\n", + "neo4j_rag_total_correct_retrieval_perturbed = neo4j_rag_yes_count + neo4j_rag_p_value_correct_retrieval_count\n", + "\n", + "neo4j_rag_total_correct_retrieval_perturbed_percentage = 100*neo4j_rag_total_correct_retrieval_perturbed/neo4j_rag.shape[0]\n", + "\n", + "print(f'Correct retrieval percentage for Cypher-RAG after name perturbation {neo4j_rag_total_correct_retrieval_perturbed_percentage}%')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "723e8ae3-b3ba-4589-b4eb-70b33e2facc7", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 500x300 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "categories = ['Unperturbed', 'Perturbed']\n", + "full_text_index_rag_percentage = [full_text_index_based_total_correct_retrieval_percentage, full_text_index_based_total_correct_retrieval_perturbed_percentage]\n", + "neo4j_rag_percentage = [neo4j_rag_total_correct_retrieval_percentage, neo4j_rag_total_correct_retrieval_perturbed_percentage]\n", + "kg_rag_percentage = [kg_rag_total_correct_retrieval_percentage, kg_rag_total_correct_retrieval_perturbed_percentage]\n", + "\n", + "full_text_index_color = 'blue'\n", + "neo4j_color = 'red'\n", + "kg_rag_color = 'green'\n", + "\n", + "fig, ax = plt.subplots(figsize=(5, 3))\n", + "bar_width = 0.25\n", + "index = np.arange(len(categories))\n", + "\n", + "ax.bar(index - bar_width, full_text_index_rag_percentage, bar_width, color=full_text_index_color, label='Full-Text Index')\n", + "ax.bar(index, neo4j_rag_percentage, bar_width, color=neo4j_color, label='Cypher-RAG')\n", + "ax.bar(index + bar_width, kg_rag_percentage, bar_width, color=kg_rag_color, label='KG-RAG')\n", + "\n", + "ax.set_ylabel('Retrieval accuracy')\n", + "ax.set_xticks(index)\n", + "ax.set_xticklabels(categories)\n", + "ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))\n", + "\n", + "sns.despine()\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "fig_path = '../data/results/figures'\n", + "os.makedirs(fig_path, exist_ok=True)\n", + "fig.savefig(os.path.join(fig_path, 'retrieval_accuracy_three_way_comparison.svg'), format='svg', bbox_inches='tight')\n" + ] + }, + { + "cell_type": "markdown", + "id": "11a596c4-597f-48c9-87d1-8e74b5522389", + "metadata": {}, + "source": [ + "## Token usage plot for full-text, cypher-rag and kg-rag" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "799e6fd9-7fd7-4884-bda9-350a2f7a050d", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "<Figure size 400x300 with 1 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "There is 53.9% reduction in token usage for KG-RAG compared to Cypher-RAG\n", + "There is 65.1% reduction in token usage for KG-RAG compared compared to Full-Text Index\n" + ] + } + ], + "source": [ + "\n", + "data.loc[:, 'token_usage_combined'] = 0.5*(data.token_usage + data.token_usage_after_perturbation)\n", + "neo4j_rag.loc[:, 'token_usage'] = 0.5*(neo4j_rag.total_tokens_used + neo4j_rag.total_tokens_used_perturbed)\n", + "kg_rag.loc[:, 'token_usage'] = 0.5*(kg_rag.total_tokens_used + kg_rag.total_tokens_used_perturbed)\n", + "\n", + "neo4j_avg = neo4j_rag['token_usage'].mean()\n", + "neo4j_sem = neo4j_rag['token_usage'].sem()\n", + "\n", + "kg_avg = kg_rag['token_usage'].mean()\n", + "kg_sem = kg_rag['token_usage'].sem()\n", + "\n", + "full_text_avg = data['token_usage_combined'].mean()\n", + "full_text_sem = data['token_usage_combined'].sem()\n", + "\n", + "\n", + "fig = plt.figure(figsize=(4, 3))\n", + "plt.bar(0, full_text_avg, yerr=full_text_sem, color='blue', ecolor='black', capsize=5, label='Full-Text Index')\n", + "plt.bar(1, neo4j_avg, yerr=neo4j_sem, color='red', ecolor='black', capsize=5, label='Cypher-RAG')\n", + "plt.bar(2, kg_avg, yerr=kg_sem, color='green', ecolor='black', capsize=5, label='KG-RAG')\n", + "\n", + "plt.ylabel('Average token usage')\n", + "plt.xticks([0, 1, 2], ['Full-Text Index', 'Cypher-RAG', 'KG-RAG'], rotation=45, ha='right')\n", + "plt.tight_layout()\n", + "sns.despine()\n", + "# plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))\n", + "plt.show()\n", + "\n", + "percentage_of_reduction_cypher_to_kg = round(100*(neo4j_avg-kg_avg)/neo4j_avg,1)\n", + "percentage_of_reduction_fulltext_to_kg = round(100*(full_text_avg-kg_avg)/full_text_avg,1)\n", + "\n", + "print(f'There is {percentage_of_reduction_cypher_to_kg}% reduction in token usage for KG-RAG compared to Cypher-RAG')\n", + "print(f'There is {percentage_of_reduction_fulltext_to_kg}% reduction in token usage for KG-RAG compared compared to Full-Text Index')\n", + "\n", + "fig_path = '../data/results/figures'\n", + "os.makedirs(fig_path, exist_ok=True)\n", + "fig.savefig(os.path.join(fig_path, 'token_usage_three_way_comparison.svg'), format='svg', bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "ccb52e21-9bb5-4867-bd60-fea79f11130d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "10590.367088607594" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "full_text_avg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e2627b4-9f6a-431d-af05-a17ccdc37b36", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (cypher_rag)", + "language": "python", + "name": "cypher_rag" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |
