{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "f4eca598", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import json\n", "import ast\n", "from tqdm import tqdm\n", "import re\n", "import os\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "from scipy.stats import sem\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "8321f533", "metadata": {}, "outputs": [], "source": [ "def jaccard_similarity(list1, list2):\n", " if list1 is not None and list2 is not None and len(list1) > 0 and len(list2) > 0:\n", " list1 = [item.lower() for item in list1]\n", " list2 = [item.lower() for item in list2]\n", "\n", " set1 = set(list1)\n", " set2 = set(list2)\n", " intersection = len(set1.intersection(set2))\n", " union = len(set1) + len(set2) - intersection\n", " if union == 0:\n", " return 0.0\n", " else:\n", " jaccard_similarity = intersection / union\n", " return jaccard_similarity\n", " else:\n", " return 0.0\n", " \n", "def extract_answer(text):\n", " pattern = r'{[^{}]*}'\n", " match = re.search(pattern, text)\n", " if match:\n", " return match.group()\n", " else:\n", " return None\n", " \n", "def extract_by_splitting(text):\n", " compound_list = text.split(':')[1].split(\"Diseases\")[0].split(\"], \")[0]+\"]\"\n", " disease_list = text.split(':')[-1].split(\"}\")[0]\n", " resp = {}\n", " resp[\"Compounds\"] = ast.literal_eval(compound_list)\n", " resp[\"Diseases\"] = ast.literal_eval(disease_list)\n", " return resp\n", " \n", " \n", "def get_hyperparam_perf(files):\n", " llm_performance_list = []\n", " for file in tqdm(files):\n", " df = pd.read_csv(os.path.join(PARENT_PATH, file))\n", " df.dropna(subset=[\"llm_answer\"], inplace=True)\n", " llm_performance_list_across_questions = []\n", " for index, row in df.iterrows():\n", " cmp_gt = ast.literal_eval(row[\"compound_groundTruth\"])\n", " disease_gt = ast.literal_eval(row[\"disease_groundTruth\"])\n", " try:\n", " llm_answer = json.loads(extract_answer(row[\"llm_answer\"]))\n", " except:\n", " llm_answer = extract_by_splitting(row[\"llm_answer\"])\n", " cmp_llm = llm_answer[\"Compounds\"]\n", " disease_llm = llm_answer[\"Diseases\"]\n", " cmp_similarity = jaccard_similarity(cmp_gt, cmp_llm)\n", " disease_similarity = jaccard_similarity(disease_gt, disease_llm)\n", " llm_performance = np.mean([cmp_similarity, disease_similarity])\n", " llm_performance_list_across_questions.append(llm_performance)\n", " llm_performance_list.append((np.mean(llm_performance_list_across_questions), np.std(llm_performance_list_across_questions), sem(llm_performance_list_across_questions), row[\"context_volume\"]))\n", " hyperparam_perf = pd.DataFrame(llm_performance_list, columns=[\"performance_mean\", \"performance_std\", \"performance_sem\", \"context_volume\"]) \n", " return hyperparam_perf" ] }, { "cell_type": "code", "execution_count": 4, "id": "9feae00e", "metadata": {}, "outputs": [], "source": [ "'''\n", "Following files can be obtained by running the run_single_disease_entity_hyperparameter_tuning.py script.\n", "Make sure to change the parent path and filenames based on where and how you save the files\n", "'''\n", "\n", "PARENT_PATH = \"../data/results\"\n", "\n", "\n", "FILES_1 = [\n", " \"minilm_based_single_disease_hyperparam_tuning_round_1_gpt_4.csv\",\n", " \"minilm_based_single_disease_hyperparam_tuning_round_2_gpt_4.csv\",\n", " \"minilm_based_single_disease_hyperparam_tuning_round_3_gpt_4.csv\",\n", " \"minilm_based_single_disease_hyperparam_tuning_round_4_gpt_4.csv\",\n", " \"minilm_based_single_disease_hyperparam_tuning_round_5_gpt_4.csv\"\n", "]\n", "\n", "FILES_2 = [\n", " \"pubmert_based_single_disease_hyperparam_tuning_round_1_gpt_4.csv\",\n", " \"pubmert_based_single_disease_hyperparam_tuning_round_2_gpt_4.csv\",\n", " \"pubmert_based_single_disease_hyperparam_tuning_round_3_gpt_4.csv\",\n", " \"pubmert_based_single_disease_hyperparam_tuning_round_4_gpt_4.csv\",\n", " \"pubmert_based_single_disease_hyperparam_tuning_round_5_gpt_4.csv\"\n", "]" ] }, { "cell_type": "code", "execution_count": 5, "id": "7e85661a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 58.18it/s]\n", "100%|████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 114.49it/s]\n" ] } ], "source": [ "mini_lm_perf = get_hyperparam_perf(FILES_1)\n", "pubmedBert_perf = get_hyperparam_perf(FILES_2)" ] }, { "cell_type": "code", "execution_count": 7, "id": "322a19e4", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "params = mini_lm_perf.context_volume.values\n", "mini_lm_mean_performance = mini_lm_perf.performance_mean.values\n", "mini_lm_std_deviation = mini_lm_perf.performance_std.values\n", "mini_lm_sem_value = mini_lm_perf.performance_sem.values\n", "\n", "pubmedBert_mean_performance = pubmedBert_perf.performance_mean.values\n", "pubmedBert_std_deviation = pubmedBert_perf.performance_std.values\n", "pubmedBert_sem_value = pubmedBert_perf.performance_sem.values\n", "\n", "fig = plt.figure(figsize=(3, 4))\n", "plt.errorbar(params, mini_lm_mean_performance, yerr=None, fmt='o-', capsize=5, label='all-MiniLM-L6-v2')\n", "plt.errorbar(params, pubmedBert_mean_performance, yerr=None, fmt='o-', capsize=5, label='PubMedBERT', color=\"darkred\")\n", "plt.xlabel('Context volume')\n", "plt.ylabel('Mean Performance')\n", "plt.grid(True)\n", "plt.ylim(0.35,0.8)\n", "plt.legend(bbox_to_anchor=(0.2, 0.35))\n", "plt.show()\n", "\n", "fig_filename = \"../data/results/figures/context_volume_single_disease_prompt_miniLM_vs_PubMedBert.svg\"\n", "fig.savefig(fig_filename, format='svg', bbox_inches='tight')\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "b1de37a4", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
performance_meanperformance_stdperformance_semcontext_volume
00.4197610.2174930.02528310
10.6229520.1900950.02209850
20.6614600.1779320.020684100
30.6670210.1791100.020821150
40.6686740.1769350.020568200
\n", "
" ], "text/plain": [ " performance_mean performance_std performance_sem context_volume\n", "0 0.419761 0.217493 0.025283 10\n", "1 0.622952 0.190095 0.022098 50\n", "2 0.661460 0.177932 0.020684 100\n", "3 0.667021 0.179110 0.020821 150\n", "4 0.668674 0.176935 0.020568 200" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mini_lm_perf\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "eecab597", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.61" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "round(mini_lm_perf.performance_mean.mean(), 2)\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "2adc10b9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.67" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "round(pubmedBert_perf.performance_mean.mean(), 2)\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "a0db4b8d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "9.836065573770501" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "100*(round(pubmedBert_perf.performance_mean.mean(), 2) - round(mini_lm_perf.performance_mean.mean(), 2))/round(mini_lm_perf.performance_mean.mean(), 2)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "33437dd3", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }