{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "860ebc4a-63e5-462d-b6ab-9bae23d10afb", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.chdir('..')" ] }, { "cell_type": "code", "execution_count": 2, "id": "851d771c-15b4-4168-acf5-86bdd15d9610", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/root/anaconda3/envs/kg_rag_test_2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from kg_rag.utility import *\n", "from tqdm import tqdm\n", "import pandas as pd\n", "import spacy\n", "import scispacy\n", "from scispacy.linking import EntityLinker\n", "from transformers import pipeline\n", "from transformers import AutoModelForTokenClassification\n", "from IPython.display import clear_output" ] }, { "cell_type": "markdown", "id": "f242aeb6-99f7-496a-8fd8-1f0d964a2556", "metadata": {}, "source": [ "## List the NER methods to compare" ] }, { "cell_type": "code", "execution_count": 3, "id": "45fdafb8-65cc-44dd-b22d-f17e5e807b49", "metadata": {}, "outputs": [], "source": [ "method_list = ['gpt', 'biomed-ner-all', 'scispacy']\n" ] }, { "cell_type": "markdown", "id": "ddc073a0-2508-410e-8e39-4bd94020bf8a", "metadata": {}, "source": [ "## Load spacy and bert based models" ] }, { "cell_type": "code", "execution_count": 4, "id": "07a1ccc1-826a-4986-b7a7-f7bf26dd1d8c", "metadata": {}, "outputs": [], "source": [ "nlp = spacy.load(\"en_core_sci_sm\") \n", "nlp.add_pipe(\"scispacy_linker\", config={\"resolve_abbreviations\": True, \"linker_name\": \"umls\"})\n", "\n", "\n", "biomed_ner_all_tokenizer = AutoTokenizer.from_pretrained(\"d4data/biomedical-ner-all\",\n", " revision=\"main\",\n", " cache_dir=config_data['LLM_CACHE_DIR'])\n", "biomed_ner_all_model = AutoModelForTokenClassification.from_pretrained(\"d4data/biomedical-ner-all\", \n", " torch_dtype=torch.float16,\n", " revision=\"main\",\n", " cache_dir=config_data['LLM_CACHE_DIR']\n", " )\n", "clear_output()" ] }, { "cell_type": "markdown", "id": "4cf3589a-6dec-41e5-9f43-703b0171e79c", "metadata": {}, "source": [ "## Load evaluation dataset" ] }, { "cell_type": "code", "execution_count": 5, "id": "108796a6-5887-464b-8394-04e04b017d0b", "metadata": {}, "outputs": [], "source": [ "data = pd.read_csv('data/dataset_for_entity_retrieval_accuracy_analysis.csv')\n" ] }, { "cell_type": "markdown", "id": "ed5fca24-2b5d-4696-bcd4-ba911dce6624", "metadata": {}, "source": [ "## Custom functions" ] }, { "cell_type": "code", "execution_count": 6, "id": "6e262f71-eac1-4894-8390-f1f4c2e8f84f", "metadata": {}, "outputs": [], "source": [ "def entity_extraction(text, method):\n", " if method == 'gpt':\n", " start_time = time.time()\n", " entity = disease_entity_extractor_compare_version(text)\n", " run_time = time.time()-start_time\n", " elif method == 'scispacy':\n", " start_time = time.time()\n", " entity = disease_entity_extractor_scispacy(text)\n", " run_time = time.time()-start_time\n", " elif method == 'biomed-ner-all':\n", " start_time = time.time()\n", " entity = disease_entity_extractor_biomed_ner(text)\n", " run_time = time.time()-start_time\n", " return entity, run_time\n", "\n", "def get_GPT_response_compare_version(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature=0):\n", " return fetch_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature)\n", " \n", "def disease_entity_extractor_compare_version(text):\n", " chat_model_id, chat_deployment_id = get_gpt35()\n", " prompt_updated = system_prompts[\"DISEASE_ENTITY_EXTRACTION\"] + \"\\n\" + \"Sentence : \" + text\n", " resp = get_GPT_response_compare_version(prompt_updated, system_prompts[\"DISEASE_ENTITY_EXTRACTION\"], chat_model_id, chat_deployment_id, temperature=0)\n", " try:\n", " entity_dict = json.loads(resp)\n", " return entity_dict[\"Diseases\"]\n", " except:\n", " return None\n", "\n", "def disease_entity_extractor_scispacy(text):\n", " doc = nlp(text)\n", " disease_semantic_types = {\"T047\", \"T191\"} \n", " entity = []\n", " for ent in doc.ents:\n", " if ent._.kb_ents:\n", " umls_cui = ent._.kb_ents[0][0]\n", " umls_entity = nlp.get_pipe(\"scispacy_linker\").kb.cui_to_entity[umls_cui]\n", " if any(t in disease_semantic_types for t in umls_entity.types):\n", " entity.append(ent.text)\n", " return entity\n", "\n", "def disease_entity_extractor_biomed_ner(text):\n", " pipe = pipeline(\"ner\", model=biomed_ner_all_model, tokenizer=biomed_ner_all_tokenizer, aggregation_strategy=\"simple\", device=0)\n", " out = pipe(text)\n", " return list(filter(None, map(lambda x:x['word'] if x['entity_group']=='Disease_disorder' or x['entity_group']=='Sign_symptom' else None, out)))\n", "\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "65c46409-f3dd-45e6-9ea8-da84cd8db212", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Processing method : gpt, 1/3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "322it [03:07, 1.71it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Processing method : biomed-ner-all, 2/3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "322it [00:05, 63.45it/s] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Processing method : scispacy, 3/3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "322it [00:04, 72.89it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 11.3 s, sys: 1.71 s, total: 13 s\n", "Wall time: 3min 17s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "%%time\n", "comparison_out = []\n", "for method_index, method in enumerate(method_list):\n", " print(f'Processing method : {method}, {method_index+1}/{len(method_list)}')\n", " for row_index, row in tqdm(data.iterrows()):\n", " entity, run_time = entity_extraction(row['text'], method)\n", " comparison_out.append((row['text'], row['node_hits'], entity, run_time, method))\n", "\n", "comparison_out_df = pd.DataFrame(comparison_out, columns=['input_text', 'node_hits', 'entity_extracted', 'run_time_per_text', 'ner_method'])\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "3fbfbfed-3dd6-4e86-8fac-e0ee40d2c363", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | input_text | \n", "node_hits | \n", "entity_extracted | \n", "run_time_per_text | \n", "ner_method | \n", "
|---|---|---|---|---|---|
| 0 | \n", "LIRAGLUTIDE TREATS OBESITY | \n", "obesity | \n", "[obesity] | \n", "2.214761 | \n", "gpt | \n", "
| 1 | \n", "disease ontology identifier for central diabet... | \n", "central diabetes insipidus | \n", "[central diabetes insipidus] | \n", "0.549497 | \n", "gpt | \n", "
| 2 | \n", "Xeroderma pigmentosum, group G is not associat... | \n", "xeroderma pigmentosum | \n", "[Xeroderma pigmentosum] | \n", "0.926769 | \n", "gpt | \n", "
| 3 | \n", "cherubism is not a autosomal dominant disease | \n", "cherubism | \n", "[cherubism, autosomal dominant disease] | \n", "0.675068 | \n", "gpt | \n", "
| 4 | \n", "MASA SYNDROME (DISORDER) IS NOT ASSOCIATED WIT... | \n", "MASA syndrome | \n", "[MASA SYNDROME] | \n", "0.465556 | \n", "gpt | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 961 | \n", "antineoplastic agents treats osteosarcoma | \n", "osteosarcoma | \n", "[osteosarcoma] | \n", "0.012946 | \n", "scispacy | \n", "
| 962 | \n", "timothy syndrome associates gene cacna1c | \n", "Timothy syndrome | \n", "[syndrome] | \n", "0.011308 | \n", "scispacy | \n", "
| 963 | \n", "piebaldism is a autosomal dominant disease | \n", "piebaldism | \n", "[autosomal dominant disease] | \n", "0.012271 | \n", "scispacy | \n", "
| 964 | \n", "Disease ontology identifier for Loeys-Dietz sy... | \n", "Loeys-Dietz syndrome | \n", "[Loeys-Dietz syndrome] | \n", "0.012468 | \n", "scispacy | \n", "
| 965 | \n", "NOONAN SYNDROME ASSOCIATES GENE PTPN11 | \n", "Noonan syndrome | \n", "[NOONAN SYNDROME] | \n", "0.010858 | \n", "scispacy | \n", "
966 rows × 5 columns
\n", "