From 73c194f304f827b55081b15524479f82a1b7d94c Mon Sep 17 00:00:00 2001 From: maszhongming Date: Tue, 16 Sep 2025 15:15:29 -0500 Subject: Initial commit --- notebooks/kg_rag_token_usage_tracking.ipynb | 288 ++++++++++++++++++++++++++++ 1 file changed, 288 insertions(+) create mode 100644 notebooks/kg_rag_token_usage_tracking.ipynb (limited to 'notebooks/kg_rag_token_usage_tracking.ipynb') diff --git a/notebooks/kg_rag_token_usage_tracking.ipynb b/notebooks/kg_rag_token_usage_tracking.ipynb new file mode 100644 index 0000000..1840b99 --- /dev/null +++ b/notebooks/kg_rag_token_usage_tracking.ipynb @@ -0,0 +1,288 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "b86c2320-71ed-4223-9df7-0b9281cb652c", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.chdir('..')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8e9dc80f-43a6-4d8d-9d99-343bc6515ff8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/root/anaconda3/envs/kg_rag_test_2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from kg_rag.utility import *\n", + "from tqdm import tqdm\n" + ] + }, + { + "cell_type": "markdown", + "id": "3b991006-9e91-4db1-9c11-62cbf1d9c356", + "metadata": {}, + "source": [ + "## Choose the LLM" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5ae38918-24e1-4a28-b4e5-461eda38002c", + "metadata": {}, + "outputs": [], + "source": [ + "LLM_MODEL = 'gpt-4-32k'\n" + ] + }, + { + "cell_type": "markdown", + "id": "db3c5056-15d6-4608-87c8-1e897dc4075e", + "metadata": {}, + "source": [ + "## Configure KG-RAG" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "fdf4d8fd-2265-4237-ba85-06a3efbf8145", + "metadata": {}, + "outputs": [], + "source": [ + "SYSTEM_PROMPT = system_prompts[\"KG_RAG_BASED_TEXT_GENERATION\"]\n", + "CONTEXT_VOLUME = int(config_data[\"CONTEXT_VOLUME\"])\n", + "QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data[\"QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD\"])\n", + "QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data[\"QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY\"])\n", + "VECTOR_DB_PATH = config_data[\"VECTOR_DB_PATH\"]\n", + "NODE_CONTEXT_PATH = config_data[\"NODE_CONTEXT_PATH\"]\n", + "SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL\"]\n", + "SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL\"]\n", + "TEMPERATURE = config_data[\"LLM_TEMPERATURE\"]\n", + "\n", + "CHAT_MODEL_ID = LLM_MODEL\n", + "EDGE_EVIDENCE = True\n", + "\n", + "CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID\n", + "\n", + "\n", + "vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)\n", + "embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)\n", + "node_context_df = pd.read_csv(NODE_CONTEXT_PATH)\n" + ] + }, + { + "cell_type": "markdown", + "id": "547cf664-8b48-4f19-a232-09a5b2fa4ffa", + "metadata": {}, + "source": [ + "## Load test data" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "00fa2491-901e-44ea-8109-2a60b23771ba", + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.read_csv('data/rag_comparison_data.csv')\n" + ] + }, + { + "cell_type": "markdown", + "id": "39c207c9-49be-449b-9b70-a92cdf8095d3", + "metadata": {}, + "source": [ + "## Function for chat completion with token usage tracking" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8ca41e38-79fb-4f68-aa16-db1785b6551f", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def chat_completion_with_token_usage(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature):\n", + " response = openai.ChatCompletion.create(\n", + " temperature=temperature,\n", + " deployment_id=chat_deployment_id,\n", + " model=chat_model_id,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": instruction}\n", + " ]\n", + " )\n", + " return response['choices'][0]['message']['content'], response.usage.total_tokens\n" + ] + }, + { + "cell_type": "markdown", + "id": "4b2bbab7-72f6-414b-bdd0-0eab4ed842f2", + "metadata": {}, + "source": [ + "## Run on test data" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "637671b2-a06c-4fe4-a7a6-855b0ba48fcd", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100it [11:13, 6.74s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3min 37s, sys: 9.86 s, total: 3min 47s\n", + "Wall time: 11min 13s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "kg_rag_answer = []\n", + "total_tokens_used = []\n", + "\n", + "for index, row in tqdm(data.iterrows()):\n", + " question = row['question']\n", + " context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE)\n", + " enriched_prompt = \"Context: \"+ context + \"\\n\" + \"Question: \" + question\n", + " output, token_usage = chat_completion_with_token_usage(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)\n", + " kg_rag_answer.append(output)\n", + " total_tokens_used.append(token_usage)\n", + " \n", + "data.loc[:,'kg_rag_answer'] = kg_rag_answer\n", + "data.loc[:, 'total_tokens_used'] = total_tokens_used\n" + ] + }, + { + "cell_type": "markdown", + "id": "18e4b72c-c2a5-4b1a-8100-7ad831eb1401", + "metadata": {}, + "source": [ + "## Run on perturbed test data" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8a042aa2-2366-4d49-a694-efd6d7b4616b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100it [09:49, 5.90s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3min 36s, sys: 9.04 s, total: 3min 45s\n", + "Wall time: 9min 49s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "kg_rag_answer = []\n", + "total_tokens_used = []\n", + "\n", + "for index, row in tqdm(data.iterrows()):\n", + " question = row['question']\n", + " context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE)\n", + " enriched_prompt = \"Context: \"+ context + \"\\n\" + \"Question: \" + question\n", + " output, token_usage = chat_completion_with_token_usage(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)\n", + " kg_rag_answer.append(output)\n", + " total_tokens_used.append(token_usage)\n", + " \n", + "data.loc[:,'kg_rag_answer_perturbed'] = kg_rag_answer\n", + "data.loc[:, 'total_tokens_used_perturbed'] = total_tokens_used\n" + ] + }, + { + "cell_type": "markdown", + "id": "9c902260-1d9e-4a52-a377-f0c002c91e16", + "metadata": {}, + "source": [ + "## Save the result" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d510de56-dd39-4742-8a5a-9bb934690d95", + "metadata": {}, + "outputs": [], + "source": [ + "save_path = 'data/results'\n", + "os.makedirs(save_path, exist_ok=True)\n", + "data.to_csv(os.path.join(save_path, 'kg_rag_output.csv'), index=False)\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} -- cgit v1.2.3