summaryrefslogtreecommitdiff
path: root/notebooks/rag_comparison_questions.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks/rag_comparison_questions.ipynb')
-rw-r--r--notebooks/rag_comparison_questions.ipynb272
1 files changed, 272 insertions, 0 deletions
diff --git a/notebooks/rag_comparison_questions.ipynb b/notebooks/rag_comparison_questions.ipynb
new file mode 100644
index 0000000..2eadac4
--- /dev/null
+++ b/notebooks/rag_comparison_questions.ipynb
@@ -0,0 +1,272 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 171,
+ "id": "403d179c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from neo4j import GraphDatabase, basic_auth\n",
+ "from dotenv import load_dotenv\n",
+ "import os\n",
+ "import pickle\n",
+ "import random\n",
+ "import pandas as pd\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "80ee9a49",
+ "metadata": {},
+ "source": [
+ "## Set number of questions to generate"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 172,
+ "id": "fa80e37b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "N_QUESTIONS = 100\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ac046718",
+ "metadata": {},
+ "source": [
+ "## Load KG credentials"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 173,
+ "id": "8d41be45",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "load_dotenv(os.path.join(os.path.expanduser('~'), '.spoke_neo4j_config.env'))\n",
+ "username = os.environ.get('NEO4J_USER')\n",
+ "password = os.environ.get('NEO4J_PSW')\n",
+ "url = os.environ.get('NEO4J_URI')\n",
+ "database = os.environ.get('NEO4J_DB')\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cf3354e7",
+ "metadata": {},
+ "source": [
+ "## Load disease names stored in vectorDB"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 174,
+ "id": "2ec9d667",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open('../data/disease_with_relation_to_genes.pickle', 'rb') as f:\n",
+ " disease_names = pickle.load(f)\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "654a9a58",
+ "metadata": {},
+ "source": [
+ "## Extract GWAS Disease-Gene relation from the KG"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 175,
+ "id": "c280e781",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 158 ms, sys: 19.6 ms, total: 178 ms\n",
+ "Wall time: 550 ms\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "\n",
+ "auth = basic_auth(username, password)\n",
+ "sdb = GraphDatabase.driver(url, auth=auth)\n",
+ "\n",
+ "gwas_query = '''\n",
+ " MATCH (d:Disease)-[r:ASSOCIATES_DaG]->(g:Gene)\n",
+ " WHERE r.sources = ['GWAS']\n",
+ " WITH d, g, r.gwas_pvalue AS pvalue\n",
+ " ORDER BY pvalue\n",
+ " WITH d, COLLECT(g)[0] AS gene_with_lowest_pvalue, pvalue\n",
+ " RETURN d.name AS disease_name, gene_with_lowest_pvalue.name AS gene_name, pvalue\n",
+ "'''\n",
+ "\n",
+ "with sdb.session() as session:\n",
+ " with session.begin_transaction() as tx:\n",
+ " result = tx.run(gwas_query)\n",
+ " out_list = []\n",
+ " for row in result:\n",
+ " out_list.append((row['disease_name'], row['gene_name'], row['pvalue']))\n",
+ "\n",
+ "gwas_disease_names = pd.DataFrame(out_list, columns=['disease_name', 'gene_name', 'gwas_pvalue']).drop_duplicates()\n",
+ "sdb.close()\n",
+ "\n",
+ "gwas_disease_names = gwas_disease_names[gwas_disease_names.disease_name.isin(disease_names)]\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0db2757f",
+ "metadata": {},
+ "source": [
+ "## Create test questions from the extracted relationships"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 176,
+ "id": "9fe85753",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 97.3 ms, sys: 1.08 ms, total: 98.4 ms\n",
+ "Wall time: 97.7 ms\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "\n",
+ "template_questions = [\n",
+ " 'Is {} associated with {}?',\n",
+ " 'What is the GWAS p-value for the association between {} and {}?'\n",
+ "]\n",
+ "\n",
+ "test_questions = []\n",
+ "random.seed(42)\n",
+ "for index,row in gwas_disease_names.iterrows():\n",
+ " selected_question = random.choice(template_questions)\n",
+ " if random.random() < 0.5:\n",
+ " test_questions.append(selected_question.format(row['disease_name'], row['gene_name']))\n",
+ " else:\n",
+ " test_questions.append(selected_question.format(row['gene_name'], row['disease_name']))\n",
+ "\n",
+ "gwas_disease_names.loc[:,'question'] = test_questions\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2f1800f5",
+ "metadata": {},
+ "source": [
+ "## Create perturbed test questions (lower case names) from the extracted relationships"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 177,
+ "id": "c788c8d2",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 96 ms, sys: 962 µs, total: 97 ms\n",
+ "Wall time: 96.3 ms\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "\n",
+ "template_questions = [\n",
+ " 'Is {} associated with {}?',\n",
+ " 'What is the GWAS p-value for the association between {} and {}?'\n",
+ "]\n",
+ "\n",
+ "test_questions_perturbed = []\n",
+ "random.seed(42)\n",
+ "for index,row in gwas_disease_names.iterrows():\n",
+ " selected_question = random.choice(template_questions)\n",
+ " if random.random() < 0.5:\n",
+ " test_questions_perturbed.append(selected_question.format(row['disease_name'].lower(), row['gene_name'].lower()))\n",
+ " else:\n",
+ " test_questions_perturbed.append(selected_question.format(row['gene_name'].lower(), row['disease_name'].lower()))\n",
+ "\n",
+ "gwas_disease_names.loc[:,'question_perturbed'] = test_questions_perturbed\n",
+ "\n",
+ "gwas_disease_names_selected = gwas_disease_names.sample(N_QUESTIONS, random_state=42)\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "06eed996",
+ "metadata": {},
+ "source": [
+ "## Save the test data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 178,
+ "id": "7f02bb5b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "gwas_disease_names_selected = gwas_disease_names.sample(N_QUESTIONS, random_state=42)\n",
+ "\n",
+ "gwas_disease_names_selected.to_csv('../data/rag_comparison_data.csv', index=False)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ea680eb0",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}