summaryrefslogtreecommitdiff
path: root/notebooks/mcq_cypher_rag_eval.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks/mcq_cypher_rag_eval.ipynb')
-rw-r--r--notebooks/mcq_cypher_rag_eval.ipynb216
1 files changed, 216 insertions, 0 deletions
diff --git a/notebooks/mcq_cypher_rag_eval.ipynb b/notebooks/mcq_cypher_rag_eval.ipynb
new file mode 100644
index 0000000..63e2da8
--- /dev/null
+++ b/notebooks/mcq_cypher_rag_eval.ipynb
@@ -0,0 +1,216 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "a5caaba0-2c5d-4555-90eb-a7f8db555c2d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import openai\n",
+ "from dotenv import load_dotenv, find_dotenv\n",
+ "import os\n",
+ "import json\n",
+ "from tqdm import tqdm\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "c17b66f4-c35b-4751-8dbd-44b827a81a39",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cypher_rag = pd.read_csv('../data/results/cypher_rag_mcq_output.csv')\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "4e8372a7-d61b-4b27-8161-3ab084c091e8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "llm_for_evaluation = 'gpt-4'\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "3930c00b-4647-4444-a2b6-f6a1030bd388",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "load_dotenv(os.path.join(os.path.expanduser('~'), '.gpt_config.env'))\n",
+ "api_key = os.environ.get('API_KEY')\n",
+ "api_version = os.environ.get('API_VERSION')\n",
+ "resource_endpoint = os.environ.get('RESOURCE_ENDPOINT')\n",
+ "openai.api_type = 'azure'\n",
+ "openai.api_key = api_key\n",
+ "if resource_endpoint:\n",
+ " openai.api_base = resource_endpoint\n",
+ "if api_version:\n",
+ " openai.api_version = api_version"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "77dbdd34-f3a5-40b4-b0c7-1b6d2d6c5835",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def fetch_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature=0):\n",
+ " # print('Calling OpenAI...')\n",
+ " response = openai.ChatCompletion.create(\n",
+ " temperature=temperature,\n",
+ " deployment_id=chat_deployment_id,\n",
+ " model=chat_model_id,\n",
+ " messages=[\n",
+ " {\"role\": \"system\", \"content\": system_prompt},\n",
+ " {\"role\": \"user\", \"content\": instruction}\n",
+ " ]\n",
+ " )\n",
+ " if 'choices' in response \\\n",
+ " and isinstance(response['choices'], list) \\\n",
+ " and len(response) >= 0 \\\n",
+ " and 'message' in response['choices'][0] \\\n",
+ " and 'content' in response['choices'][0]['message']:\n",
+ " return response['choices'][0]['message']['content']\n",
+ " else:\n",
+ " return 'Unexpected response'\n",
+ "\n",
+ "def get_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature=0):\n",
+ " return fetch_GPT_response(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature)\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "8285a1c7-4654-4728-b0d0-2a1300af26df",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "system_prompt = '''\n",
+ "You are an expert evaluator of mcq question. You will be provided with an MCQ question which has 5 options out of which one option is correct.\n",
+ "You will also be provided with the correct label.\n",
+ "You will then provided with an answer.\n",
+ "You need to verify if that answer is correct or not, by checking it with the question and the label.\n",
+ "If the answer is correct, you return True, if the answer is not correct, you return False.\n",
+ "Return your answer in JSON format as follows:\n",
+ "{\"response\": <\"True\"/\"False\">}\n",
+ "\n",
+ "Example 1:\n",
+ "Question: Out of the given list, which Gene is associated with psoriasis and Takayasu's arteritis. Given list is: SHTN1, HLA-B, SLC14A2, BTBD9, DTNB\n",
+ "Label: HLA-B\n",
+ "Answer: The gene HLA-B is associated with both psoriasis and Takayasu's arteritis.\n",
+ "Your response: {\"response\": \"True\", \"reason\": \"Answer is in agreement with the label.\"}\n",
+ "\n",
+ "Example 2:\n",
+ "Question: Out of the given list, which Gene is associated with psoriasis and myelodysplastic syndrome. Given list is: NOD2, CHEK2, HLA-B, GCKR, PKNOX2\n",
+ "Label: HLA-B\n",
+ "Answer: The genes associated with psoriasis and myelodysplastic syndrome from the given list are HLA-B and CHEK2.\n",
+ "Your response: {\"response\": \"False\", \"reason\":\"Answer returns two options HLA-B and CHEK2, but the label says the correct answer is only HLA-B\"}\n",
+ "'''"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "76dfe699-2678-442c-b005-ec76a5989d69",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "306it [11:29, 2.25s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 4.33 s, sys: 473 ms, total: 4.81 s\n",
+ "Wall time: 11min 29s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "\n",
+ "cypher_rag_eval = []\n",
+ "for index, row in tqdm(cypher_rag.iterrows()):\n",
+ " question = row.text\n",
+ " label = row.label\n",
+ " answer = row.cypher_rag_answer\n",
+ " prompt = f'''\n",
+ " Question: {question}\n",
+ " Label: {label}\n",
+ " Answer: {answer}\n",
+ " '''\n",
+ " out = json.loads(get_GPT_response(prompt, system_prompt, llm_for_evaluation, llm_for_evaluation))\n",
+ " cypher_rag_eval.append(out[\"response\"])\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "c788084a-42df-48ab-9b3f-8feeef320a5a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cypher_rag['cypher_rag_eval'] = cypher_rag_eval"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "afbdf643-59b7-4af1-bf95-fe032f0dc4d9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cypher_rag.to_csv('../data/results/cypher_rag_mcq_output_with_eval.csv', index=False)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "86bcafaf-3945-4d63-a0a9-227736d7c672",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python (cypher_rag)",
+ "language": "python",
+ "name": "cypher_rag"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}