{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "b86c2320-71ed-4223-9df7-0b9281cb652c", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.chdir('..')\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "8e9dc80f-43a6-4d8d-9d99-343bc6515ff8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/root/anaconda3/envs/kg_rag_test_2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from kg_rag.utility import *\n", "from tqdm import tqdm\n" ] }, { "cell_type": "markdown", "id": "3b991006-9e91-4db1-9c11-62cbf1d9c356", "metadata": {}, "source": [ "## Choose the LLM" ] }, { "cell_type": "code", "execution_count": 3, "id": "5ae38918-24e1-4a28-b4e5-461eda38002c", "metadata": {}, "outputs": [], "source": [ "LLM_MODEL = 'gpt-4-32k'\n" ] }, { "cell_type": "markdown", "id": "db3c5056-15d6-4608-87c8-1e897dc4075e", "metadata": {}, "source": [ "## Configure KG-RAG" ] }, { "cell_type": "code", "execution_count": 4, "id": "fdf4d8fd-2265-4237-ba85-06a3efbf8145", "metadata": {}, "outputs": [], "source": [ "SYSTEM_PROMPT = system_prompts[\"KG_RAG_BASED_TEXT_GENERATION\"]\n", "CONTEXT_VOLUME = int(config_data[\"CONTEXT_VOLUME\"])\n", "QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data[\"QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD\"])\n", "QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data[\"QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY\"])\n", "VECTOR_DB_PATH = config_data[\"VECTOR_DB_PATH\"]\n", "NODE_CONTEXT_PATH = config_data[\"NODE_CONTEXT_PATH\"]\n", "SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL\"]\n", "SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL\"]\n", "TEMPERATURE = config_data[\"LLM_TEMPERATURE\"]\n", "\n", "CHAT_MODEL_ID = LLM_MODEL\n", "EDGE_EVIDENCE = True\n", "\n", "CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID\n", "\n", "\n", "vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)\n", "embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)\n", "node_context_df = pd.read_csv(NODE_CONTEXT_PATH)\n" ] }, { "cell_type": "markdown", "id": "547cf664-8b48-4f19-a232-09a5b2fa4ffa", "metadata": {}, "source": [ "## Load test data" ] }, { "cell_type": "code", "execution_count": 5, "id": "00fa2491-901e-44ea-8109-2a60b23771ba", "metadata": {}, "outputs": [], "source": [ "data = pd.read_csv('data/rag_comparison_data.csv')\n" ] }, { "cell_type": "markdown", "id": "39c207c9-49be-449b-9b70-a92cdf8095d3", "metadata": {}, "source": [ "## Function for chat completion with token usage tracking" ] }, { "cell_type": "code", "execution_count": 6, "id": "8ca41e38-79fb-4f68-aa16-db1785b6551f", "metadata": {}, "outputs": [], "source": [ "\n", "def chat_completion_with_token_usage(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature):\n", " response = openai.ChatCompletion.create(\n", " temperature=temperature,\n", " deployment_id=chat_deployment_id,\n", " model=chat_model_id,\n", " messages=[\n", " {\"role\": \"system\", \"content\": system_prompt},\n", " {\"role\": \"user\", \"content\": instruction}\n", " ]\n", " )\n", " return response['choices'][0]['message']['content'], response.usage.total_tokens\n" ] }, { "cell_type": "markdown", "id": "4b2bbab7-72f6-414b-bdd0-0eab4ed842f2", "metadata": {}, "source": [ "## Run on test data" ] }, { "cell_type": "code", "execution_count": 7, "id": "637671b2-a06c-4fe4-a7a6-855b0ba48fcd", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100it [11:13, 6.74s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3min 37s, sys: 9.86 s, total: 3min 47s\n", "Wall time: 11min 13s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "%%time\n", "\n", "kg_rag_answer = []\n", "total_tokens_used = []\n", "\n", "for index, row in tqdm(data.iterrows()):\n", " question = row['question']\n", " context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE)\n", " enriched_prompt = \"Context: \"+ context + \"\\n\" + \"Question: \" + question\n", " output, token_usage = chat_completion_with_token_usage(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)\n", " kg_rag_answer.append(output)\n", " total_tokens_used.append(token_usage)\n", " \n", "data.loc[:,'kg_rag_answer'] = kg_rag_answer\n", "data.loc[:, 'total_tokens_used'] = total_tokens_used\n" ] }, { "cell_type": "markdown", "id": "18e4b72c-c2a5-4b1a-8100-7ad831eb1401", "metadata": {}, "source": [ "## Run on perturbed test data" ] }, { "cell_type": "code", "execution_count": 8, "id": "8a042aa2-2366-4d49-a694-efd6d7b4616b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100it [09:49, 5.90s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3min 36s, sys: 9.04 s, total: 3min 45s\n", "Wall time: 9min 49s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "%%time\n", "\n", "kg_rag_answer = []\n", "total_tokens_used = []\n", "\n", "for index, row in tqdm(data.iterrows()):\n", " question = row['question']\n", " context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE)\n", " enriched_prompt = \"Context: \"+ context + \"\\n\" + \"Question: \" + question\n", " output, token_usage = chat_completion_with_token_usage(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)\n", " kg_rag_answer.append(output)\n", " total_tokens_used.append(token_usage)\n", " \n", "data.loc[:,'kg_rag_answer_perturbed'] = kg_rag_answer\n", "data.loc[:, 'total_tokens_used_perturbed'] = total_tokens_used\n" ] }, { "cell_type": "markdown", "id": "9c902260-1d9e-4a52-a377-f0c002c91e16", "metadata": {}, "source": [ "## Save the result" ] }, { "cell_type": "code", "execution_count": 9, "id": "d510de56-dd39-4742-8a5a-9bb934690d95", "metadata": {}, "outputs": [], "source": [ "save_path = 'data/results'\n", "os.makedirs(save_path, exist_ok=True)\n", "data.to_csv(os.path.join(save_path, 'kg_rag_output.csv'), index=False)\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }