summaryrefslogtreecommitdiff
path: root/notebooks/kg_rag_token_usage_tracking.ipynb
diff options
context:
space:
mode:
authormaszhongming <mingz5@illinois.edu>2025-09-16 15:15:29 -0500
committermaszhongming <mingz5@illinois.edu>2025-09-16 15:15:29 -0500
commit73c194f304f827b55081b15524479f82a1b7d94c (patch)
tree5e8660e421915420892c5eca18f1ad680f80a861 /notebooks/kg_rag_token_usage_tracking.ipynb
Initial commit
Diffstat (limited to 'notebooks/kg_rag_token_usage_tracking.ipynb')
-rw-r--r--notebooks/kg_rag_token_usage_tracking.ipynb288
1 files changed, 288 insertions, 0 deletions
diff --git a/notebooks/kg_rag_token_usage_tracking.ipynb b/notebooks/kg_rag_token_usage_tracking.ipynb
new file mode 100644
index 0000000..1840b99
--- /dev/null
+++ b/notebooks/kg_rag_token_usage_tracking.ipynb
@@ -0,0 +1,288 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "b86c2320-71ed-4223-9df7-0b9281cb652c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.chdir('..')\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "8e9dc80f-43a6-4d8d-9d99-343bc6515ff8",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/root/anaconda3/envs/kg_rag_test_2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ }
+ ],
+ "source": [
+ "from kg_rag.utility import *\n",
+ "from tqdm import tqdm\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3b991006-9e91-4db1-9c11-62cbf1d9c356",
+ "metadata": {},
+ "source": [
+ "## Choose the LLM"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "5ae38918-24e1-4a28-b4e5-461eda38002c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "LLM_MODEL = 'gpt-4-32k'\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "db3c5056-15d6-4608-87c8-1e897dc4075e",
+ "metadata": {},
+ "source": [
+ "## Configure KG-RAG"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "fdf4d8fd-2265-4237-ba85-06a3efbf8145",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "SYSTEM_PROMPT = system_prompts[\"KG_RAG_BASED_TEXT_GENERATION\"]\n",
+ "CONTEXT_VOLUME = int(config_data[\"CONTEXT_VOLUME\"])\n",
+ "QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD = float(config_data[\"QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD\"])\n",
+ "QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY = float(config_data[\"QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY\"])\n",
+ "VECTOR_DB_PATH = config_data[\"VECTOR_DB_PATH\"]\n",
+ "NODE_CONTEXT_PATH = config_data[\"NODE_CONTEXT_PATH\"]\n",
+ "SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL\"]\n",
+ "SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL = config_data[\"SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL\"]\n",
+ "TEMPERATURE = config_data[\"LLM_TEMPERATURE\"]\n",
+ "\n",
+ "CHAT_MODEL_ID = LLM_MODEL\n",
+ "EDGE_EVIDENCE = True\n",
+ "\n",
+ "CHAT_DEPLOYMENT_ID = CHAT_MODEL_ID\n",
+ "\n",
+ "\n",
+ "vectorstore = load_chroma(VECTOR_DB_PATH, SENTENCE_EMBEDDING_MODEL_FOR_NODE_RETRIEVAL)\n",
+ "embedding_function_for_context_retrieval = load_sentence_transformer(SENTENCE_EMBEDDING_MODEL_FOR_CONTEXT_RETRIEVAL)\n",
+ "node_context_df = pd.read_csv(NODE_CONTEXT_PATH)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "547cf664-8b48-4f19-a232-09a5b2fa4ffa",
+ "metadata": {},
+ "source": [
+ "## Load test data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "00fa2491-901e-44ea-8109-2a60b23771ba",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data = pd.read_csv('data/rag_comparison_data.csv')\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "39c207c9-49be-449b-9b70-a92cdf8095d3",
+ "metadata": {},
+ "source": [
+ "## Function for chat completion with token usage tracking"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "8ca41e38-79fb-4f68-aa16-db1785b6551f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "def chat_completion_with_token_usage(instruction, system_prompt, chat_model_id, chat_deployment_id, temperature):\n",
+ " response = openai.ChatCompletion.create(\n",
+ " temperature=temperature,\n",
+ " deployment_id=chat_deployment_id,\n",
+ " model=chat_model_id,\n",
+ " messages=[\n",
+ " {\"role\": \"system\", \"content\": system_prompt},\n",
+ " {\"role\": \"user\", \"content\": instruction}\n",
+ " ]\n",
+ " )\n",
+ " return response['choices'][0]['message']['content'], response.usage.total_tokens\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4b2bbab7-72f6-414b-bdd0-0eab4ed842f2",
+ "metadata": {},
+ "source": [
+ "## Run on test data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "637671b2-a06c-4fe4-a7a6-855b0ba48fcd",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100it [11:13, 6.74s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 3min 37s, sys: 9.86 s, total: 3min 47s\n",
+ "Wall time: 11min 13s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "\n",
+ "kg_rag_answer = []\n",
+ "total_tokens_used = []\n",
+ "\n",
+ "for index, row in tqdm(data.iterrows()):\n",
+ " question = row['question']\n",
+ " context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE)\n",
+ " enriched_prompt = \"Context: \"+ context + \"\\n\" + \"Question: \" + question\n",
+ " output, token_usage = chat_completion_with_token_usage(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)\n",
+ " kg_rag_answer.append(output)\n",
+ " total_tokens_used.append(token_usage)\n",
+ " \n",
+ "data.loc[:,'kg_rag_answer'] = kg_rag_answer\n",
+ "data.loc[:, 'total_tokens_used'] = total_tokens_used\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "18e4b72c-c2a5-4b1a-8100-7ad831eb1401",
+ "metadata": {},
+ "source": [
+ "## Run on perturbed test data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "8a042aa2-2366-4d49-a694-efd6d7b4616b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100it [09:49, 5.90s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 3min 36s, sys: 9.04 s, total: 3min 45s\n",
+ "Wall time: 9min 49s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "\n",
+ "kg_rag_answer = []\n",
+ "total_tokens_used = []\n",
+ "\n",
+ "for index, row in tqdm(data.iterrows()):\n",
+ " question = row['question']\n",
+ " context = retrieve_context(question, vectorstore, embedding_function_for_context_retrieval, node_context_df, CONTEXT_VOLUME, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD, QUESTION_VS_CONTEXT_MINIMUM_SIMILARITY, EDGE_EVIDENCE)\n",
+ " enriched_prompt = \"Context: \"+ context + \"\\n\" + \"Question: \" + question\n",
+ " output, token_usage = chat_completion_with_token_usage(enriched_prompt, SYSTEM_PROMPT, CHAT_MODEL_ID, CHAT_DEPLOYMENT_ID, temperature=TEMPERATURE)\n",
+ " kg_rag_answer.append(output)\n",
+ " total_tokens_used.append(token_usage)\n",
+ " \n",
+ "data.loc[:,'kg_rag_answer_perturbed'] = kg_rag_answer\n",
+ "data.loc[:, 'total_tokens_used_perturbed'] = total_tokens_used\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9c902260-1d9e-4a52-a377-f0c002c91e16",
+ "metadata": {},
+ "source": [
+ "## Save the result"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "d510de56-dd39-4742-8a5a-9bb934690d95",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "save_path = 'data/results'\n",
+ "os.makedirs(save_path, exist_ok=True)\n",
+ "data.to_csv(os.path.join(save_path, 'kg_rag_output.csv'), index=False)\n",
+ "\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}