{ "cells": [ { "cell_type": "code", "execution_count": 171, "id": "403d179c", "metadata": {}, "outputs": [], "source": [ "from neo4j import GraphDatabase, basic_auth\n", "from dotenv import load_dotenv\n", "import os\n", "import pickle\n", "import random\n", "import pandas as pd\n" ] }, { "cell_type": "markdown", "id": "80ee9a49", "metadata": {}, "source": [ "## Set number of questions to generate" ] }, { "cell_type": "code", "execution_count": 172, "id": "fa80e37b", "metadata": {}, "outputs": [], "source": [ "N_QUESTIONS = 100\n" ] }, { "cell_type": "markdown", "id": "ac046718", "metadata": {}, "source": [ "## Load KG credentials" ] }, { "cell_type": "code", "execution_count": 173, "id": "8d41be45", "metadata": {}, "outputs": [], "source": [ "load_dotenv(os.path.join(os.path.expanduser('~'), '.spoke_neo4j_config.env'))\n", "username = os.environ.get('NEO4J_USER')\n", "password = os.environ.get('NEO4J_PSW')\n", "url = os.environ.get('NEO4J_URI')\n", "database = os.environ.get('NEO4J_DB')\n" ] }, { "cell_type": "markdown", "id": "cf3354e7", "metadata": {}, "source": [ "## Load disease names stored in vectorDB" ] }, { "cell_type": "code", "execution_count": 174, "id": "2ec9d667", "metadata": {}, "outputs": [], "source": [ "with open('../data/disease_with_relation_to_genes.pickle', 'rb') as f:\n", " disease_names = pickle.load(f)\n", " " ] }, { "cell_type": "markdown", "id": "654a9a58", "metadata": {}, "source": [ "## Extract GWAS Disease-Gene relation from the KG" ] }, { "cell_type": "code", "execution_count": 175, "id": "c280e781", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 158 ms, sys: 19.6 ms, total: 178 ms\n", "Wall time: 550 ms\n" ] } ], "source": [ "%%time\n", "\n", "auth = basic_auth(username, password)\n", "sdb = GraphDatabase.driver(url, auth=auth)\n", "\n", "gwas_query = '''\n", " MATCH (d:Disease)-[r:ASSOCIATES_DaG]->(g:Gene)\n", " WHERE r.sources = ['GWAS']\n", " WITH d, g, r.gwas_pvalue AS pvalue\n", " ORDER BY pvalue\n", " WITH d, COLLECT(g)[0] AS gene_with_lowest_pvalue, pvalue\n", " RETURN d.name AS disease_name, gene_with_lowest_pvalue.name AS gene_name, pvalue\n", "'''\n", "\n", "with sdb.session() as session:\n", " with session.begin_transaction() as tx:\n", " result = tx.run(gwas_query)\n", " out_list = []\n", " for row in result:\n", " out_list.append((row['disease_name'], row['gene_name'], row['pvalue']))\n", "\n", "gwas_disease_names = pd.DataFrame(out_list, columns=['disease_name', 'gene_name', 'gwas_pvalue']).drop_duplicates()\n", "sdb.close()\n", "\n", "gwas_disease_names = gwas_disease_names[gwas_disease_names.disease_name.isin(disease_names)]\n" ] }, { "cell_type": "markdown", "id": "0db2757f", "metadata": {}, "source": [ "## Create test questions from the extracted relationships" ] }, { "cell_type": "code", "execution_count": 176, "id": "9fe85753", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 97.3 ms, sys: 1.08 ms, total: 98.4 ms\n", "Wall time: 97.7 ms\n" ] } ], "source": [ "%%time\n", "\n", "template_questions = [\n", " 'Is {} associated with {}?',\n", " 'What is the GWAS p-value for the association between {} and {}?'\n", "]\n", "\n", "test_questions = []\n", "random.seed(42)\n", "for index,row in gwas_disease_names.iterrows():\n", " selected_question = random.choice(template_questions)\n", " if random.random() < 0.5:\n", " test_questions.append(selected_question.format(row['disease_name'], row['gene_name']))\n", " else:\n", " test_questions.append(selected_question.format(row['gene_name'], row['disease_name']))\n", "\n", "gwas_disease_names.loc[:,'question'] = test_questions\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "id": "2f1800f5", "metadata": {}, "source": [ "## Create perturbed test questions (lower case names) from the extracted relationships" ] }, { "cell_type": "code", "execution_count": 177, "id": "c788c8d2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 96 ms, sys: 962 µs, total: 97 ms\n", "Wall time: 96.3 ms\n" ] } ], "source": [ "%%time\n", "\n", "template_questions = [\n", " 'Is {} associated with {}?',\n", " 'What is the GWAS p-value for the association between {} and {}?'\n", "]\n", "\n", "test_questions_perturbed = []\n", "random.seed(42)\n", "for index,row in gwas_disease_names.iterrows():\n", " selected_question = random.choice(template_questions)\n", " if random.random() < 0.5:\n", " test_questions_perturbed.append(selected_question.format(row['disease_name'].lower(), row['gene_name'].lower()))\n", " else:\n", " test_questions_perturbed.append(selected_question.format(row['gene_name'].lower(), row['disease_name'].lower()))\n", "\n", "gwas_disease_names.loc[:,'question_perturbed'] = test_questions_perturbed\n", "\n", "gwas_disease_names_selected = gwas_disease_names.sample(N_QUESTIONS, random_state=42)\n", "\n", "\n" ] }, { "cell_type": "markdown", "id": "06eed996", "metadata": {}, "source": [ "## Save the test data" ] }, { "cell_type": "code", "execution_count": 178, "id": "7f02bb5b", "metadata": {}, "outputs": [], "source": [ "gwas_disease_names_selected = gwas_disease_names.sample(N_QUESTIONS, random_state=42)\n", "\n", "gwas_disease_names_selected.to_csv('../data/rag_comparison_data.csv', index=False)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ea680eb0", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }