From 7b1fe580f089f55dd04319088a7e805eb1e38fed Mon Sep 17 00:00:00 2001 From: chzhang Date: Thu, 23 Mar 2023 23:14:42 +0800 Subject: openai tokenizer & logit bias --- .../05_openai_tokenizer_logit_bias_logprobs.ipynb | 1229 ++++++++++++++++++++ 1 file changed, 1229 insertions(+) create mode 100644 llm/tutorials/05_openai_tokenizer_logit_bias_logprobs.ipynb (limited to 'llm/tutorials') diff --git a/llm/tutorials/05_openai_tokenizer_logit_bias_logprobs.ipynb b/llm/tutorials/05_openai_tokenizer_logit_bias_logprobs.ipynb new file mode 100644 index 0000000..cedcb66 --- /dev/null +++ b/llm/tutorials/05_openai_tokenizer_logit_bias_logprobs.ipynb @@ -0,0 +1,1229 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6da657b7", + "metadata": {}, + "source": [ + "## Counting Tokens " + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "id": "a02cfa31", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T14:52:24.046989Z", + "start_time": "2023-03-23T14:52:05.424260Z" + }, + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "Looking in indexes: http://mirrors.aliyun.com/pypi/simple/\n", + "Requirement already satisfied: tiktoken in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (0.3.0)\n", + "Requirement already satisfied: blobfile>=2 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from tiktoken) (2.0.1)\n", + "Requirement already satisfied: regex>=2022.1.18 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from tiktoken) (2022.7.9)\n", + "Requirement already satisfied: requests>=2.26.0 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from tiktoken) (2.28.1)\n", + "Requirement already satisfied: urllib3<3,>=1.25.3 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from blobfile>=2->tiktoken) (1.26.11)\n", + "Requirement already satisfied: lxml~=4.9 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from blobfile>=2->tiktoken) (4.9.1)\n", + "Requirement already satisfied: pycryptodomex~=3.8 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from blobfile>=2->tiktoken) (3.17)\n", + "Requirement already satisfied: filelock~=3.0 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from blobfile>=2->tiktoken) (3.6.0)\n", + "Requirement already satisfied: charset-normalizer<3,>=2 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (3.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (2022.9.24)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "Looking in indexes: http://mirrors.aliyun.com/pypi/simple/\n", + "Requirement already satisfied: torch in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (1.13.1)\n", + "Requirement already satisfied: typing_extensions in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from torch) (4.3.0)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "Looking in indexes: http://mirrors.aliyun.com/pypi/simple/\n", + "Requirement already satisfied: transformers in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (4.27.2)\n", + "Requirement already satisfied: numpy>=1.17 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (1.21.5)\n", + "Requirement already satisfied: filelock in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (3.6.0)\n", + "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (0.13.2)\n", + "Requirement already satisfied: tqdm>=4.27 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (4.64.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (6.0)\n", + "Requirement already satisfied: packaging>=20.0 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (21.3)\n", + "Requirement already satisfied: regex!=2019.12.17 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (2022.7.9)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (0.13.2)\n", + "Requirement already satisfied: requests in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (2.28.1)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.3.0)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from packaging>=20.0->transformers) (3.0.9)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers) (1.26.11)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers) (2022.9.24)\n", + "Requirement already satisfied: charset-normalizer<3,>=2 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers) (3.3)\n" + ] + } + ], + "source": [ + "!pip install tiktoken\n", + "!pip install torch\n", + "!pip install transformers" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "id": "7e43a64e", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T14:52:26.002426Z", + "start_time": "2023-03-23T14:52:25.996155Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoTokenizer" + ] + }, + { + "cell_type": "markdown", + "id": "4b8f1876", + "metadata": {}, + "source": [ + "### transformers(huggingface) autotokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "id": "171acb93", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T14:53:03.861367Z", + "start_time": "2023-03-23T14:52:58.546570Z" + } + }, + "outputs": [], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "89f66c9a", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T13:32:47.367309Z", + "start_time": "2023-03-23T13:32:47.364534Z" + } + }, + "outputs": [], + "source": [ + "text = \"\"\"Many words map to one token, but some don't: indivisible.\n", + "\n", + "Unicode characters like emojis may be split into many tokens containing the underlying bytes: 🤚🏾\n", + "\n", + "Sequences of characters commonly found next to each other may be grouped together: 1234567890\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "id": "14c5765f", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T14:54:39.465273Z", + "start_time": "2023-03-23T14:54:39.458465Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 7085, 2456, 3975, 284, 530, 11241, 11, 475, 617, 836,\n", + " 470, 25, 773, 452, 12843, 13, 198, 198, 3118, 291,\n", + " 1098, 3435, 588, 795, 13210, 271, 743, 307, 6626, 656,\n", + " 867, 16326, 7268, 262, 10238, 9881, 25, 12520, 97, 248,\n", + " 8582, 237, 122, 198, 198, 44015, 3007, 286, 3435, 8811,\n", + " 1043, 1306, 284, 1123, 584, 743, 307, 32824, 1978, 25,\n", + " 17031, 2231, 30924, 3829])" + ] + }, + "execution_count": 98, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_ids = torch.tensor(tokenizer.encode(text))\n", + "input_ids" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "id": "f143ea77", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T14:55:13.289227Z", + "start_time": "2023-03-23T14:55:13.284166Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([64])" + ] + }, + "execution_count": 100, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_ids.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "id": "178abc80", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T14:55:34.672382Z", + "start_time": "2023-03-23T14:55:34.662809Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 64])" + ] + }, + "execution_count": 101, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)\n", + "input_ids.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "id": "7b4391ec", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T14:55:46.349903Z", + "start_time": "2023-03-23T14:55:46.343341Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['Many',\n", + " 'Ġwords',\n", + " 'Ġmap',\n", + " 'Ġto',\n", + " 'Ġone',\n", + " 'Ġtoken',\n", + " ',',\n", + " 'Ġbut',\n", + " 'Ġsome',\n", + " 'Ġdon',\n", + " \"'t\",\n", + " ':',\n", + " 'Ġind',\n", + " 'iv',\n", + " 'isible',\n", + " '.',\n", + " 'Ċ',\n", + " 'Ċ',\n", + " 'Un',\n", + " 'ic',\n", + " 'ode',\n", + " 'Ġcharacters',\n", + " 'Ġlike',\n", + " 'Ġem',\n", + " 'oj',\n", + " 'is',\n", + " 'Ġmay',\n", + " 'Ġbe',\n", + " 'Ġsplit',\n", + " 'Ġinto',\n", + " 'Ġmany',\n", + " 'Ġtokens',\n", + " 'Ġcontaining',\n", + " 'Ġthe',\n", + " 'Ġunderlying',\n", + " 'Ġbytes',\n", + " ':',\n", + " 'ĠðŁ',\n", + " '¤',\n", + " 'ļ',\n", + " 'ðŁ',\n", + " 'ı',\n", + " '¾',\n", + " 'Ċ',\n", + " 'Ċ',\n", + " 'Sequ',\n", + " 'ences',\n", + " 'Ġof',\n", + " 'Ġcharacters',\n", + " 'Ġcommonly',\n", + " 'Ġfound',\n", + " 'Ġnext',\n", + " 'Ġto',\n", + " 'Ġeach',\n", + " 'Ġother',\n", + " 'Ġmay',\n", + " 'Ġbe',\n", + " 'Ġgrouped',\n", + " 'Ġtogether',\n", + " ':',\n", + " 'Ġ123',\n", + " '45',\n", + " '678',\n", + " '90']" + ] + }, + "execution_count": 102, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.convert_ids_to_tokens(input_ids.tolist()[0])" + ] + }, + { + "cell_type": "markdown", + "id": "c45b3612", + "metadata": {}, + "source": [ + "### tiktoken" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "id": "e8b21268", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T14:56:46.742017Z", + "start_time": "2023-03-23T14:56:46.739355Z" + } + }, + "outputs": [], + "source": [ + "import tiktoken" + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "id": "971fcb64", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T14:57:30.301936Z", + "start_time": "2023-03-23T14:57:30.297321Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 105, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "encoding = tiktoken.get_encoding('gpt2')\n", + "encoding" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "id": "c427ace3", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T14:57:50.210515Z", + "start_time": "2023-03-23T14:57:50.206666Z" + } + }, + "outputs": [], + "source": [ + "def num_tokens_from_string(string: str, encoding_name='gpt2') -> int:\n", + " \"\"\"Returns the number of tokens in a text string.\"\"\"\n", + " encoding = tiktoken.get_encoding(encoding_name)\n", + " token_ids = encoding.encode(string)\n", + " return token_ids" + ] + }, + { + "cell_type": "code", + "execution_count": 109, + "id": "3e36c32f", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T14:58:17.624225Z", + "start_time": "2023-03-23T14:58:17.617847Z" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 7085, 2456, 3975, 284, 530, 11241, 11, 475, 617, 836,\n", + " 470, 25, 773, 452, 12843, 13, 198, 198, 3118, 291,\n", + " 1098, 3435, 588, 795, 13210, 271, 743, 307, 6626, 656,\n", + " 867, 16326, 7268, 262, 10238, 9881, 25, 12520, 97, 248,\n", + " 8582, 237, 122, 198, 198, 44015, 3007, 286, 3435, 8811,\n", + " 1043, 1306, 284, 1123, 584, 743, 307, 32824, 1978, 25,\n", + " 17031, 2231, 30924, 3829])" + ] + }, + "execution_count": 109, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "token_ids = torch.tensor(num_tokens_from_string(text))\n", + "token_ids" + ] + }, + { + "cell_type": "code", + "execution_count": 110, + "id": "6c7d988c", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T14:58:31.666008Z", + "start_time": "2023-03-23T14:58:31.660228Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 64])" + ] + }, + "execution_count": 110, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "token_ids = torch.tensor(num_tokens_from_string(text)).unsqueeze(0)\n", + "token_ids.shape" + ] + }, + { + "cell_type": "markdown", + "id": "16410ec8", + "metadata": {}, + "source": [ + "## Tokenizing Text" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e94cc37a", + "metadata": {}, + "outputs": [], + "source": [ + "# chunks = break_up_file_to_chunks(filename)\n", + "# for i, chunk in enumerate(chunks):\n", + "# prompt_request = \"Summarize this meeting transcript: \" + tokenizer.decode(chunks[i])\n", + "# response = openai.Completion.create(\n", + "# model=\"text-davinci-003\",\n", + "# prompt=prompt_request,\n", + "# temperature=.5,\n", + "# max_tokens=500,\n", + "# top_p=1,\n", + "# frequency_penalty=0,\n", + "# presence_penalty=0\n", + "# )\n", + " \n", + "# prompt_response.append(response[\"choices\"][0][\"text\"])\n", + "# prompt_tokens.append(response[\"usage\"][\"total_tokens\"])\n", + "\n", + "# total_usage = sum(prompt_tokens)" + ] + }, + { + "cell_type": "markdown", + "id": "52535acb", + "metadata": {}, + "source": [ + "## test" + ] + }, + { + "cell_type": "code", + "execution_count": 114, + "id": "ecc0fa7a", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T15:02:20.550454Z", + "start_time": "2023-03-23T15:02:20.546001Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[3666, 4004, 3124, 318, 2266, 13]\n", + "[3666, 4004, 3124, 318, 2266, 13]\n" + ] + } + ], + "source": [ + "print(tokenizer.encode('My favorite color is red.'))\n", + "print(tiktoken.encoding_for_model('gpt2').encode('My favorite color is red.'))" + ] + }, + { + "cell_type": "code", + "execution_count": 115, + "id": "51f21bb3", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T15:03:17.395349Z", + "start_time": "2023-03-23T15:03:17.390482Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[7738, 318, 616, 4004, 3124, 13]\n", + "[7738, 318, 616, 4004, 3124, 13]\n" + ] + } + ], + "source": [ + "print(tokenizer.encode('Red is my favorite color.'))\n", + "print(tiktoken.encoding_for_model('gpt2').encode('Red is my favorite color.'))" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "95ba81ab", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T13:46:45.621813Z", + "start_time": "2023-03-23T13:46:45.617543Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['My', 'Ġfavorite', 'Ġcolor', 'Ġis', 'Ġred', '.']" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.convert_ids_to_tokens(tokenizer.encode('My favorite color is red.'))" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "3e2762b5", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T13:47:57.125761Z", + "start_time": "2023-03-23T13:47:57.121320Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[2266]" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.encode(' red')" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "5c57e75e", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T13:48:21.284562Z", + "start_time": "2023-03-23T13:48:21.280295Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[2297]" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.encode(' Red')" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "7f9e6948", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T13:48:36.829697Z", + "start_time": "2023-03-23T13:48:36.825399Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[7738]" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.encode('Red')" + ] + }, + { + "cell_type": "markdown", + "id": "d2c9ce11", + "metadata": {}, + "source": [ + "## logit_bias (logprobs)" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "24ff6772", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T14:16:49.514101Z", + "start_time": "2023-03-23T14:16:49.511267Z" + } + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 117, + "id": "38b92b96", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T15:07:33.484099Z", + "start_time": "2023-03-23T15:07:33.479060Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[6342]" + ] + }, + "execution_count": 117, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.encode(' Paris')" + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "id": "18091fde", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T15:07:15.532849Z", + "start_time": "2023-03-23T15:07:15.527278Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['ĠParis']" + ] + }, + "execution_count": 116, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.convert_ids_to_tokens([6342])" + ] + }, + { + "cell_type": "code", + "execution_count": 118, + "id": "69a64160", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T15:07:47.432779Z", + "start_time": "2023-03-23T15:07:47.427456Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[40313]" + ] + }, + "execution_count": 118, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.encode('Paris')" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "6443a649", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T14:06:56.563477Z", + "start_time": "2023-03-23T14:06:55.687258Z" + } + }, + "outputs": [], + "source": [ + "import openai\n", + "openai.api_key = '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'" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "390b9b4e", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T14:32:47.398402Z", + "start_time": "2023-03-23T14:32:45.812646Z" + } + }, + "outputs": [], + "source": [ + "resp = openai.Completion.create(engine='davinci', \n", + " prompt='q: What is the capital of france?\\na: ', \n", + " logprobs=5, \n", + " stop='\\n', \n", + " temperature=0, \n", + " logit_bias={6342:-100, 40313:-100}\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 119, + "id": "95495ef1", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T15:08:56.899012Z", + "start_time": "2023-03-23T15:08:56.893837Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + " JSON: {\n", + " \"choices\": [\n", + " {\n", + " \"finish_reason\": \"stop\",\n", + " \"index\": 0,\n", + " \"logprobs\": {\n", + " \"text_offset\": [\n", + " 37,\n", + " 38,\n", + " 44,\n", + " 47,\n", + " 49,\n", + " 57,\n", + " 58,\n", + " 62,\n", + " 64,\n", + " 69\n", + " ],\n", + " \"token_logprobs\": [\n", + " -0.47720006,\n", + " -2.6506393,\n", + " -1.2077931,\n", + " -0.84418315,\n", + " -0.5668232,\n", + " -1.1204532,\n", + " -0.5845836,\n", + " -0.01427887,\n", + " -0.0745326,\n", + " -0.38981625\n", + " ],\n", + " \"tokens\": [\n", + " \"\\u00a0\",\n", + " \"France\",\n", + " \" is\",\n", + " \" a\",\n", + " \" country\",\n", + " \",\",\n", + " \" not\",\n", + " \" a\",\n", + " \" city\",\n", + " \".\"\n", + " ],\n", + " \"top_logprobs\": [\n", + " {\n", + " \"\\t\": -3.8812888,\n", + " \"\\n\": -3.4891934,\n", + " \"________\": -2.3949227,\n", + " \"________________________________\": -3.8888516,\n", + " \"\\u00a0\": -0.47720006\n", + " },\n", + " {\n", + " \"B\": -3.7005355,\n", + " \"France\": -2.6506393,\n", + " \"Q\": -3.373843,\n", + " \"The\": -3.3061914,\n", + " \"par\": -2.7321594\n", + " },\n", + " {\n", + " \"\\n\": -2.514796,\n", + " \" does\": -2.5521305,\n", + " \" doesn\": -2.3279827,\n", + " \" has\": -1.6931255,\n", + " \" is\": -1.2077931\n", + " },\n", + " {\n", + " \" a\": -0.84418315,\n", + " \" divided\": -2.6524067,\n", + " \" not\": -1.3749524,\n", + " \" one\": -3.5888045,\n", + " \" the\": -2.4301894\n", + " },\n", + " {\n", + " \" capital\": -3.908288,\n", + " \" city\": -3.1317036,\n", + " \" country\": -0.5668232,\n", + " \" republic\": -2.3707287,\n", + " \"\\u00a0\": -3.943017\n", + " },\n", + " {\n", + " \" and\": -2.1657336,\n", + " \" in\": -1.8537258,\n", + " \" not\": -3.2263339,\n", + " \",\": -1.1204532,\n", + " \".\": -1.7335285\n", + " },\n", + " {\n", + " \" and\": -2.314171,\n", + " \" it\": -3.8719406,\n", + " \" not\": -0.5845836,\n", + " \" so\": -1.8761828,\n", + " \" the\": -3.3711283\n", + " },\n", + " {\n", + " \" a\": -0.01427887,\n", + " \" an\": -5.497788,\n", + " \" capital\": -5.5760226,\n", + " \" its\": -7.0022845,\n", + " \" the\": -6.8639107\n", + " },\n", + " {\n", + " \" capital\": -3.229291,\n", + " \" city\": -0.0745326,\n", + " \" person\": -4.1651354,\n", + " \" place\": -5.017827,\n", + " \" state\": -6.3108892\n", + " },\n", + " {\n", + " \"\\n\": -2.8642426,\n", + " \" so\": -4.4286876,\n", + " \"!\": -3.5753324,\n", + " \",\": -1.9522477,\n", + " \".\": -0.38981625\n", + " }\n", + " ]\n", + " },\n", + " \"text\": \"\\u00a0France is a country, not a city.\"\n", + " }\n", + " ],\n", + " \"created\": 1679581966,\n", + " \"id\": \"cmpl-6xGHetKXoitkQOoJAoxknDvTuvkdI\",\n", + " \"model\": \"davinci\",\n", + " \"object\": \"text_completion\",\n", + " \"usage\": {\n", + " \"completion_tokens\": 10,\n", + " \"prompt_tokens\": 14,\n", + " \"total_tokens\": 24\n", + " }\n", + "}" + ] + }, + "execution_count": 119, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "resp" + ] + }, + { + "cell_type": "code", + "execution_count": 120, + "id": "f9a95d7e", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T15:10:51.710578Z", + "start_time": "2023-03-23T15:10:51.702828Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'\\xa0France is a country, not a city.'" + ] + }, + "execution_count": 120, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "resp[\"choices\"][0]['text']" + ] + }, + { + "cell_type": "code", + "execution_count": 121, + "id": "66e8d1ec", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T15:11:10.180884Z", + "start_time": "2023-03-23T15:11:10.164436Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
tokenstoken_logprobstop_logprobstext_offset
0-0.477200{'________________________________': -3.888851...37
1France-2.650639{'par': -2.7321594, 'France': -2.6506393, 'B':...38
2is-1.207793{' doesn': -2.3279827, '\n", + "': -2.514796, ' is': ...44
3a-0.844183{' a': -0.84418315, ' the': -2.4301894, ' not'...47
4country-0.566823{' country': -0.5668232, ' republic': -2.37072...49
5,-1.120453{',': -1.1204532, '.': -1.7335285, ' in': -1.8...57
6not-0.584584{' the': -3.3711283, ' and': -2.314171, ' it':...58
7a-0.014279{' a': -0.01427887, ' the': -6.8639107, ' an':...62
8city-0.074533{' person': -4.1651354, ' state': -6.3108892, ...64
9.-0.389816{'!': -3.5753324, ',': -1.9522477, '.': -0.389...69
\n", + "
" + ], + "text/plain": [ + " tokens token_logprobs \\\n", + "0   -0.477200 \n", + "1 France -2.650639 \n", + "2 is -1.207793 \n", + "3 a -0.844183 \n", + "4 country -0.566823 \n", + "5 , -1.120453 \n", + "6 not -0.584584 \n", + "7 a -0.014279 \n", + "8 city -0.074533 \n", + "9 . -0.389816 \n", + "\n", + " top_logprobs text_offset \n", + "0 {'________________________________': -3.888851... 37 \n", + "1 {'par': -2.7321594, 'France': -2.6506393, 'B':... 38 \n", + "2 {' doesn': -2.3279827, '\n", + "': -2.514796, ' is': ... 44 \n", + "3 {' a': -0.84418315, ' the': -2.4301894, ' not'... 47 \n", + "4 {' country': -0.5668232, ' republic': -2.37072... 49 \n", + "5 {',': -1.1204532, '.': -1.7335285, ' in': -1.8... 57 \n", + "6 {' the': -3.3711283, ' and': -2.314171, ' it':... 58 \n", + "7 {' a': -0.01427887, ' the': -6.8639107, ' an':... 62 \n", + "8 {' person': -4.1651354, ' state': -6.3108892, ... 64 \n", + "9 {'!': -3.5753324, ',': -1.9522477, '.': -0.389... 69 " + ] + }, + "execution_count": 121, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame(resp[\"choices\"][0][\"logprobs\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 122, + "id": "2dd6bb17", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-23T15:12:00.288365Z", + "start_time": "2023-03-23T15:12:00.277073Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
logprob%
par-2.7321596.507861
France-2.6506397.060606
B-3.7005352.471029
The-3.3061913.665551
Q-3.3738433.425773
\n", + "
" + ], + "text/plain": [ + " logprob %\n", + "par -2.732159 6.507861\n", + "France -2.650639 7.060606\n", + "B -3.700535 2.471029\n", + "The -3.306191 3.665551\n", + "Q -3.373843 3.425773" + ] + }, + "execution_count": 122, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scores = pd.DataFrame([resp[\"choices\"][0][\"logprobs\"][\"top_logprobs\"][1]]).T\n", + "scores.columns = [\"logprob\"]\n", + "scores[\"%\"] = scores[\"logprob\"].apply(lambda x: 100*np.e**x)\n", + "scores" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} -- cgit v1.2.3