{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "945c420e-bb44-4ffb-b899-e049caf0d918", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.chdir('..')" ] }, { "cell_type": "code", "execution_count": 2, "id": "f2bdefb3-3e59-409a-81b4-2e9ffbdfdb1a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/root/anaconda3/envs/kg_rag_test_2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import pandas as pd\n", "from kg_rag.utility import *\n", "from tqdm import tqdm\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "19fc98b9-64a8-40c0-9e5a-92b4392e6969", "metadata": {}, "outputs": [], "source": [ "data = pd.read_csv('data/dataset_for_entity_retrieval_accuracy_analysis.csv')\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "2851be4c-2a76-4f6d-b5f4-118e8122b155", "metadata": {}, "outputs": [], "source": [ "VECTOR_DB_PATH = config_data[\"VECTOR_DB_PATH\"]\n", "SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL\"]\n", "\n", "vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "7255fbab-d8b4-43a3-b870-9d67ad79d061", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "322it [00:05, 56.20it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 4.74 s, sys: 896 ms, total: 5.64 s\n", "Wall time: 5.73 s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "%%time\n", "\n", "correct_retrieval = 0\n", "\n", "for index, row in tqdm(data.iterrows()):\n", " question = row['text']\n", " entities = disease_entity_extractor_v2(question) \n", " for entity in entities:\n", " node_search_result = vectorstore.similarity_search_with_score(entity, k=1)\n", " if node_search_result[0][0].page_content == row['node_hits']:\n", " correct_retrieval += 1 \n", " break\n", " \n" ] }, { "cell_type": "code", "execution_count": 20, "id": "2f997335-bff7-431c-bbd8-608513eddcc7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Retrieval accuracy is 99.7%\n" ] } ], "source": [ "retrieval_accuracy = 100*correct_retrieval/data.shape[0]\n", "print(f'Retrieval accuracy is {round(retrieval_accuracy,1)}%')\n" ] }, { "cell_type": "code", "execution_count": null, "id": "afe971ab-b8b9-4c88-9657-c588813b412f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }