diff options
| author | maszhongming <mingz5@illinois.edu> | 2025-09-16 15:15:29 -0500 |
|---|---|---|
| committer | maszhongming <mingz5@illinois.edu> | 2025-09-16 15:15:29 -0500 |
| commit | 73c194f304f827b55081b15524479f82a1b7d94c (patch) | |
| tree | 5e8660e421915420892c5eca18f1ad680f80a861 /notebooks/mcq_cypher_rag_eval.ipynb | |
Initial commit
Diffstat (limited to 'notebooks/mcq_cypher_rag_eval.ipynb')
| -rw-r--r-- | notebooks/mcq_cypher_rag_eval.ipynb | 216 |
1 files changed, 216 insertions, 0 deletions
diff --git a/notebooks/mcq_cypher_rag_eval.ipynb b/notebooks/mcq_cypher_rag_eval.ipynb new file mode 100644 index 0000000..63e2da8 --- /dev/null +++ b/notebooks/mcq_cypher_rag_eval.ipynb @@ -0,0 +1,216 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "a5caaba0-2c5d-4555-90eb-a7f8db555c2d", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import openai\n", + "from dotenv import load_dotenv, find_dotenv\n", + "import os\n", + "import json\n", + "from tqdm import tqdm\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c17b66f4-c35b-4751-8dbd-44b827a81a39", + "metadata": {}, + "outputs": [], + "source": [ + "cypher_rag = pd.read_csv('../data/results/cypher_rag_mcq_output.csv')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4e8372a7-d61b-4b27-8161-3ab084c091e8", + "metadata": {}, + "outputs": [], + "source": [ + "llm_for_evaluation = 'gpt-4'\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3930c00b-4647-4444-a2b6-f6a1030bd388", + "metadata": {}, + "outputs": [], + "source": [ + "load_dotenv(os.path.join(os.path.expanduser('~'), '.gpt_config.env'))\n", + "api_key = os.environ.get('API_KEY')\n", + "api_version = os.environ.get('API_VERSION')\n", + "resource_endpoint = os.environ.get('RESOURCE_ENDPOINT')\n", + "openai.api_type = 'azure'\n", + "openai.api_key = api_key\n", + "if resource_endpoint:\n", + " openai.api_base = resource_endpoint\n", + "if api_version:\n", + " openai.api_version = api_version" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "77dbdd34-f3a5-40b4-b0c7-1b6d2d6c5835", + "metadata": {}, + "outputs": [], + "source": [ + "def fetch_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature=0):\n", + " # print('Calling OpenAI...')\n", + " response = openai.ChatCompletion.create(\n", + " temperature=temperature,\n", + " deployment_id=chat_deployment_id,\n", + " model=chat_model_id,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": instruction}\n", + " ]\n", + " )\n", + " if 'choices' in response \\\n", + " and isinstance(response['choices'], list) \\\n", + " and len(response) >= 0 \\\n", + " and 'message' in response['choices'][0] \\\n", + " and 'content' in response['choices'][0]['message']:\n", + " return response['choices'][0]['message']['content']\n", + " else:\n", + " return 'Unexpected response'\n", + "\n", + "def get_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature=0):\n", + " return fetch_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8285a1c7-4654-4728-b0d0-2a1300af26df", + "metadata": {}, + "outputs": [], + "source": [ + "system_prompt = '''\n", + "You are an expert evaluator of mcq question. You will be provided with an MCQ question which has 5 options out of which one option is correct.\n", + "You will also be provided with the correct label.\n", + "You will then provided with an answer.\n", + "You need to verify if that answer is correct or not, by checking it with the question and the label.\n", + "If the answer is correct, you return True, if the answer is not correct, you return False.\n", + "Return your answer in JSON format as follows:\n", + "{\"response\": <\"True\"/\"False\">}\n", + "\n", + "Example 1:\n", + "Question: Out of the given list, which Gene is associated with psoriasis and Takayasu's arteritis. Given list is: SHTN1, HLA-B, SLC14A2, BTBD9, DTNB\n", + "Label: HLA-B\n", + "Answer: The gene HLA-B is associated with both psoriasis and Takayasu's arteritis.\n", + "Your response: {\"response\": \"True\", \"reason\": \"Answer is in agreement with the label.\"}\n", + "\n", + "Example 2:\n", + "Question: Out of the given list, which Gene is associated with psoriasis and myelodysplastic syndrome. Given list is: NOD2, CHEK2, HLA-B, GCKR, PKNOX2\n", + "Label: HLA-B\n", + "Answer: The genes associated with psoriasis and myelodysplastic syndrome from the given list are HLA-B and CHEK2.\n", + "Your response: {\"response\": \"False\", \"reason\":\"Answer returns two options HLA-B and CHEK2, but the label says the correct answer is only HLA-B\"}\n", + "'''" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "76dfe699-2678-442c-b005-ec76a5989d69", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "306it [11:29, 2.25s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 4.33 s, sys: 473 ms, total: 4.81 s\n", + "Wall time: 11min 29s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "cypher_rag_eval = []\n", + "for index, row in tqdm(cypher_rag.iterrows()):\n", + " question = row.text\n", + " label = row.label\n", + " answer = row.cypher_rag_answer\n", + " prompt = f'''\n", + " Question: {question}\n", + " Label: {label}\n", + " Answer: {answer}\n", + " '''\n", + " out = json.loads(get_GPT_response(prompt, system_prompt, llm_for_evaluation, llm_for_evaluation))\n", + " cypher_rag_eval.append(out[\"response\"])\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c788084a-42df-48ab-9b3f-8feeef320a5a", + "metadata": {}, + "outputs": [], + "source": [ + "cypher_rag['cypher_rag_eval'] = cypher_rag_eval" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "afbdf643-59b7-4af1-bf95-fe032f0dc4d9", + "metadata": {}, + "outputs": [], + "source": [ + "cypher_rag.to_csv('../data/results/cypher_rag_mcq_output_with_eval.csv', index=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86bcafaf-3945-4d63-a0a9-227736d7c672", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (cypher_rag)", + "language": "python", + "name": "cypher_rag" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |
