diff options
Diffstat (limited to 'notebooks/rag_comparison_questions.ipynb')
| -rw-r--r-- | notebooks/rag_comparison_questions.ipynb | 272 |
1 files changed, 272 insertions, 0 deletions
diff --git a/notebooks/rag_comparison_questions.ipynb b/notebooks/rag_comparison_questions.ipynb new file mode 100644 index 0000000..2eadac4 --- /dev/null +++ b/notebooks/rag_comparison_questions.ipynb @@ -0,0 +1,272 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 171, + "id": "403d179c", + "metadata": {}, + "outputs": [], + "source": [ + "from neo4j import GraphDatabase, basic_auth\n", + "from dotenv import load_dotenv\n", + "import os\n", + "import pickle\n", + "import random\n", + "import pandas as pd\n" + ] + }, + { + "cell_type": "markdown", + "id": "80ee9a49", + "metadata": {}, + "source": [ + "## Set number of questions to generate" + ] + }, + { + "cell_type": "code", + "execution_count": 172, + "id": "fa80e37b", + "metadata": {}, + "outputs": [], + "source": [ + "N_QUESTIONS = 100\n" + ] + }, + { + "cell_type": "markdown", + "id": "ac046718", + "metadata": {}, + "source": [ + "## Load KG credentials" + ] + }, + { + "cell_type": "code", + "execution_count": 173, + "id": "8d41be45", + "metadata": {}, + "outputs": [], + "source": [ + "load_dotenv(os.path.join(os.path.expanduser('~'), '.spoke_neo4j_config.env'))\n", + "username = os.environ.get('NEO4J_USER')\n", + "password = os.environ.get('NEO4J_PSW')\n", + "url = os.environ.get('NEO4J_URI')\n", + "database = os.environ.get('NEO4J_DB')\n" + ] + }, + { + "cell_type": "markdown", + "id": "cf3354e7", + "metadata": {}, + "source": [ + "## Load disease names stored in vectorDB" + ] + }, + { + "cell_type": "code", + "execution_count": 174, + "id": "2ec9d667", + "metadata": {}, + "outputs": [], + "source": [ + "with open('../data/disease_with_relation_to_genes.pickle', 'rb') as f:\n", + " disease_names = pickle.load(f)\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "654a9a58", + "metadata": {}, + "source": [ + "## Extract GWAS Disease-Gene relation from the KG" + ] + }, + { + "cell_type": "code", + "execution_count": 175, + "id": "c280e781", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 158 ms, sys: 19.6 ms, total: 178 ms\n", + "Wall time: 550 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "auth = basic_auth(username, password)\n", + "sdb = GraphDatabase.driver(url, auth=auth)\n", + "\n", + "gwas_query = '''\n", + " MATCH (d:Disease)-[r:ASSOCIATES_DaG]->(g:Gene)\n", + " WHERE r.sources = ['GWAS']\n", + " WITH d, g, r.gwas_pvalue AS pvalue\n", + " ORDER BY pvalue\n", + " WITH d, COLLECT(g)[0] AS gene_with_lowest_pvalue, pvalue\n", + " RETURN d.name AS disease_name, gene_with_lowest_pvalue.name AS gene_name, pvalue\n", + "'''\n", + "\n", + "with sdb.session() as session:\n", + " with session.begin_transaction() as tx:\n", + " result = tx.run(gwas_query)\n", + " out_list = []\n", + " for row in result:\n", + " out_list.append((row['disease_name'], row['gene_name'], row['pvalue']))\n", + "\n", + "gwas_disease_names = pd.DataFrame(out_list, columns=['disease_name', 'gene_name', 'gwas_pvalue']).drop_duplicates()\n", + "sdb.close()\n", + "\n", + "gwas_disease_names = gwas_disease_names[gwas_disease_names.disease_name.isin(disease_names)]\n" + ] + }, + { + "cell_type": "markdown", + "id": "0db2757f", + "metadata": {}, + "source": [ + "## Create test questions from the extracted relationships" + ] + }, + { + "cell_type": "code", + "execution_count": 176, + "id": "9fe85753", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 97.3 ms, sys: 1.08 ms, total: 98.4 ms\n", + "Wall time: 97.7 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "template_questions = [\n", + " 'Is {} associated with {}?',\n", + " 'What is the GWAS p-value for the association between {} and {}?'\n", + "]\n", + "\n", + "test_questions = []\n", + "random.seed(42)\n", + "for index,row in gwas_disease_names.iterrows():\n", + " selected_question = random.choice(template_questions)\n", + " if random.random() < 0.5:\n", + " test_questions.append(selected_question.format(row['disease_name'], row['gene_name']))\n", + " else:\n", + " test_questions.append(selected_question.format(row['gene_name'], row['disease_name']))\n", + "\n", + "gwas_disease_names.loc[:,'question'] = test_questions\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "2f1800f5", + "metadata": {}, + "source": [ + "## Create perturbed test questions (lower case names) from the extracted relationships" + ] + }, + { + "cell_type": "code", + "execution_count": 177, + "id": "c788c8d2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 96 ms, sys: 962 µs, total: 97 ms\n", + "Wall time: 96.3 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "template_questions = [\n", + " 'Is {} associated with {}?',\n", + " 'What is the GWAS p-value for the association between {} and {}?'\n", + "]\n", + "\n", + "test_questions_perturbed = []\n", + "random.seed(42)\n", + "for index,row in gwas_disease_names.iterrows():\n", + " selected_question = random.choice(template_questions)\n", + " if random.random() < 0.5:\n", + " test_questions_perturbed.append(selected_question.format(row['disease_name'].lower(), row['gene_name'].lower()))\n", + " else:\n", + " test_questions_perturbed.append(selected_question.format(row['gene_name'].lower(), row['disease_name'].lower()))\n", + "\n", + "gwas_disease_names.loc[:,'question_perturbed'] = test_questions_perturbed\n", + "\n", + "gwas_disease_names_selected = gwas_disease_names.sample(N_QUESTIONS, random_state=42)\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "06eed996", + "metadata": {}, + "source": [ + "## Save the test data" + ] + }, + { + "cell_type": "code", + "execution_count": 178, + "id": "7f02bb5b", + "metadata": {}, + "outputs": [], + "source": [ + "gwas_disease_names_selected = gwas_disease_names.sample(N_QUESTIONS, random_state=42)\n", + "\n", + "gwas_disease_names_selected.to_csv('../data/rag_comparison_data.csv', index=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea680eb0", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |
