In [171]:
from neo4j import GraphDatabase, basic_auth
from dotenv import load_dotenv
import os
import pickle
import random
import pandas as pd


## Set number of questions to generate

In [172]:
N_QUESTIONS = 100


## Load KG credentials

In [173]:
load_dotenv(os.path.join(os.path.expanduser('~'), '.spoke_neo4j_config.env'))
username = os.environ.get('NEO4J_USER')
password = os.environ.get('NEO4J_PSW')
url = os.environ.get('NEO4J_URI')
database = os.environ.get('NEO4J_DB')


## Load disease names stored in vectorDB

In [174]:
with open('../data/disease_with_relation_to_genes.pickle', 'rb') as f:
    disease_names = pickle.load(f)
    

## Extract GWAS Disease-Gene relation from the KG

In [175]:
%%time

auth = basic_auth(username, password)
sdb = GraphDatabase.driver(url, auth=auth)

gwas_query = '''
    MATCH (d:Disease)-[r:ASSOCIATES_DaG]->(g:Gene)
    WHERE r.sources = ['GWAS']
    WITH d, g, r.gwas_pvalue AS pvalue
    ORDER BY pvalue
    WITH d, COLLECT(g)[0] AS gene_with_lowest_pvalue, pvalue
    RETURN d.name AS disease_name, gene_with_lowest_pvalue.name AS gene_name, pvalue
'''

with sdb.session() as session:
    with session.begin_transaction() as tx:
        result = tx.run(gwas_query)
        out_list = []
        for row in result:
            out_list.append((row['disease_name'], row['gene_name'], row['pvalue']))

gwas_disease_names = pd.DataFrame(out_list, columns=['disease_name', 'gene_name', 'gwas_pvalue']).drop_duplicates()
sdb.close()

gwas_disease_names = gwas_disease_names[gwas_disease_names.disease_name.isin(disease_names)]


CPU times: user 158 ms, sys: 19.6 ms, total: 178 ms
Wall time: 550 ms


## Create test questions from the extracted relationships

In [176]:
%%time

template_questions = [
    'Is {} associated with {}?',
    'What is the GWAS p-value for the association between {} and {}?'
]

test_questions = []
random.seed(42)
for index,row in gwas_disease_names.iterrows():
    selected_question = random.choice(template_questions)
    if random.random() < 0.5:
        test_questions.append(selected_question.format(row['disease_name'], row['gene_name']))
    else:
        test_questions.append(selected_question.format(row['gene_name'], row['disease_name']))

gwas_disease_names.loc[:,'question'] = test_questions





CPU times: user 97.3 ms, sys: 1.08 ms, total: 98.4 ms
Wall time: 97.7 ms


## Create perturbed test questions (lower case names) from the extracted relationships

In [177]:
%%time

template_questions = [
    'Is {} associated with {}?',
    'What is the GWAS p-value for the association between {} and {}?'
]

test_questions_perturbed = []
random.seed(42)
for index,row in gwas_disease_names.iterrows():
    selected_question = random.choice(template_questions)
    if random.random() < 0.5:
        test_questions_perturbed.append(selected_question.format(row['disease_name'].lower(), row['gene_name'].lower()))
    else:
        test_questions_perturbed.append(selected_question.format(row['gene_name'].lower(), row['disease_name'].lower()))

gwas_disease_names.loc[:,'question_perturbed'] = test_questions_perturbed

gwas_disease_names_selected = gwas_disease_names.sample(N_QUESTIONS, random_state=42)




CPU times: user 96 ms, sys: 962 Âµs, total: 97 ms
Wall time: 96.3 ms


## Save the test data

In [178]:
gwas_disease_names_selected = gwas_disease_names.sample(N_QUESTIONS, random_state=42)

gwas_disease_names_selected.to_csv('../data/rag_comparison_data.csv', index=False)
