summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--llm/tutorials/05_openai_tokenizer_logit_bias_logprobs.ipynb1229
1 files changed, 1229 insertions, 0 deletions
diff --git a/llm/tutorials/05_openai_tokenizer_logit_bias_logprobs.ipynb b/llm/tutorials/05_openai_tokenizer_logit_bias_logprobs.ipynb
new file mode 100644
index 0000000..cedcb66
--- /dev/null
+++ b/llm/tutorials/05_openai_tokenizer_logit_bias_logprobs.ipynb
@@ -0,0 +1,1229 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "6da657b7",
+ "metadata": {},
+ "source": [
+ "## Counting Tokens "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 95,
+ "id": "a02cfa31",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T14:52:24.046989Z",
+ "start_time": "2023-03-23T14:52:05.424260Z"
+ },
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
+ "To disable this warning, you can either:\n",
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
+ "Looking in indexes: http://mirrors.aliyun.com/pypi/simple/\n",
+ "Requirement already satisfied: tiktoken in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (0.3.0)\n",
+ "Requirement already satisfied: blobfile>=2 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from tiktoken) (2.0.1)\n",
+ "Requirement already satisfied: regex>=2022.1.18 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from tiktoken) (2022.7.9)\n",
+ "Requirement already satisfied: requests>=2.26.0 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from tiktoken) (2.28.1)\n",
+ "Requirement already satisfied: urllib3<3,>=1.25.3 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from blobfile>=2->tiktoken) (1.26.11)\n",
+ "Requirement already satisfied: lxml~=4.9 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from blobfile>=2->tiktoken) (4.9.1)\n",
+ "Requirement already satisfied: pycryptodomex~=3.8 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from blobfile>=2->tiktoken) (3.17)\n",
+ "Requirement already satisfied: filelock~=3.0 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from blobfile>=2->tiktoken) (3.6.0)\n",
+ "Requirement already satisfied: charset-normalizer<3,>=2 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (2.0.4)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (3.3)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (2022.9.24)\n",
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
+ "To disable this warning, you can either:\n",
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
+ "Looking in indexes: http://mirrors.aliyun.com/pypi/simple/\n",
+ "Requirement already satisfied: torch in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (1.13.1)\n",
+ "Requirement already satisfied: typing_extensions in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from torch) (4.3.0)\n",
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
+ "To disable this warning, you can either:\n",
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
+ "Looking in indexes: http://mirrors.aliyun.com/pypi/simple/\n",
+ "Requirement already satisfied: transformers in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (4.27.2)\n",
+ "Requirement already satisfied: numpy>=1.17 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (1.21.5)\n",
+ "Requirement already satisfied: filelock in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (3.6.0)\n",
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (0.13.2)\n",
+ "Requirement already satisfied: tqdm>=4.27 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (4.64.1)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (6.0)\n",
+ "Requirement already satisfied: packaging>=20.0 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (21.3)\n",
+ "Requirement already satisfied: regex!=2019.12.17 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (2022.7.9)\n",
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (0.13.2)\n",
+ "Requirement already satisfied: requests in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from transformers) (2.28.1)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.3.0)\n",
+ "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from packaging>=20.0->transformers) (3.0.9)\n",
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers) (1.26.11)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers) (2022.9.24)\n",
+ "Requirement already satisfied: charset-normalizer<3,>=2 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers) (2.0.4)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests->transformers) (3.3)\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install tiktoken\n",
+ "!pip install torch\n",
+ "!pip install transformers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 96,
+ "id": "7e43a64e",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T14:52:26.002426Z",
+ "start_time": "2023-03-23T14:52:25.996155Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from transformers import AutoTokenizer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4b8f1876",
+ "metadata": {},
+ "source": [
+ "### transformers(huggingface) autotokenizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 97,
+ "id": "171acb93",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T14:53:03.861367Z",
+ "start_time": "2023-03-23T14:52:58.546570Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "id": "89f66c9a",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T13:32:47.367309Z",
+ "start_time": "2023-03-23T13:32:47.364534Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "text = \"\"\"Many words map to one token, but some don't: indivisible.\n",
+ "\n",
+ "Unicode characters like emojis may be split into many tokens containing the underlying bytes: 🤚🏾\n",
+ "\n",
+ "Sequences of characters commonly found next to each other may be grouped together: 1234567890\"\"\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 98,
+ "id": "14c5765f",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T14:54:39.465273Z",
+ "start_time": "2023-03-23T14:54:39.458465Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([ 7085, 2456, 3975, 284, 530, 11241, 11, 475, 617, 836,\n",
+ " 470, 25, 773, 452, 12843, 13, 198, 198, 3118, 291,\n",
+ " 1098, 3435, 588, 795, 13210, 271, 743, 307, 6626, 656,\n",
+ " 867, 16326, 7268, 262, 10238, 9881, 25, 12520, 97, 248,\n",
+ " 8582, 237, 122, 198, 198, 44015, 3007, 286, 3435, 8811,\n",
+ " 1043, 1306, 284, 1123, 584, 743, 307, 32824, 1978, 25,\n",
+ " 17031, 2231, 30924, 3829])"
+ ]
+ },
+ "execution_count": 98,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "input_ids = torch.tensor(tokenizer.encode(text))\n",
+ "input_ids"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 100,
+ "id": "f143ea77",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T14:55:13.289227Z",
+ "start_time": "2023-03-23T14:55:13.284166Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([64])"
+ ]
+ },
+ "execution_count": 100,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "input_ids.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 101,
+ "id": "178abc80",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T14:55:34.672382Z",
+ "start_time": "2023-03-23T14:55:34.662809Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 64])"
+ ]
+ },
+ "execution_count": 101,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)\n",
+ "input_ids.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 102,
+ "id": "7b4391ec",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T14:55:46.349903Z",
+ "start_time": "2023-03-23T14:55:46.343341Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['Many',\n",
+ " 'Ġwords',\n",
+ " 'Ġmap',\n",
+ " 'Ġto',\n",
+ " 'Ġone',\n",
+ " 'Ġtoken',\n",
+ " ',',\n",
+ " 'Ġbut',\n",
+ " 'Ġsome',\n",
+ " 'Ġdon',\n",
+ " \"'t\",\n",
+ " ':',\n",
+ " 'Ġind',\n",
+ " 'iv',\n",
+ " 'isible',\n",
+ " '.',\n",
+ " 'Ċ',\n",
+ " 'Ċ',\n",
+ " 'Un',\n",
+ " 'ic',\n",
+ " 'ode',\n",
+ " 'Ġcharacters',\n",
+ " 'Ġlike',\n",
+ " 'Ġem',\n",
+ " 'oj',\n",
+ " 'is',\n",
+ " 'Ġmay',\n",
+ " 'Ġbe',\n",
+ " 'Ġsplit',\n",
+ " 'Ġinto',\n",
+ " 'Ġmany',\n",
+ " 'Ġtokens',\n",
+ " 'Ġcontaining',\n",
+ " 'Ġthe',\n",
+ " 'Ġunderlying',\n",
+ " 'Ġbytes',\n",
+ " ':',\n",
+ " 'ĠðŁ',\n",
+ " '¤',\n",
+ " 'ļ',\n",
+ " 'ðŁ',\n",
+ " 'ı',\n",
+ " '¾',\n",
+ " 'Ċ',\n",
+ " 'Ċ',\n",
+ " 'Sequ',\n",
+ " 'ences',\n",
+ " 'Ġof',\n",
+ " 'Ġcharacters',\n",
+ " 'Ġcommonly',\n",
+ " 'Ġfound',\n",
+ " 'Ġnext',\n",
+ " 'Ġto',\n",
+ " 'Ġeach',\n",
+ " 'Ġother',\n",
+ " 'Ġmay',\n",
+ " 'Ġbe',\n",
+ " 'Ġgrouped',\n",
+ " 'Ġtogether',\n",
+ " ':',\n",
+ " 'Ġ123',\n",
+ " '45',\n",
+ " '678',\n",
+ " '90']"
+ ]
+ },
+ "execution_count": 102,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.convert_ids_to_tokens(input_ids.tolist()[0])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c45b3612",
+ "metadata": {},
+ "source": [
+ "### tiktoken"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 103,
+ "id": "e8b21268",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T14:56:46.742017Z",
+ "start_time": "2023-03-23T14:56:46.739355Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import tiktoken"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 105,
+ "id": "971fcb64",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T14:57:30.301936Z",
+ "start_time": "2023-03-23T14:57:30.297321Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "<Encoding 'gpt2'>"
+ ]
+ },
+ "execution_count": 105,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "encoding = tiktoken.get_encoding('gpt2')\n",
+ "encoding"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 106,
+ "id": "c427ace3",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T14:57:50.210515Z",
+ "start_time": "2023-03-23T14:57:50.206666Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def num_tokens_from_string(string: str, encoding_name='gpt2') -> int:\n",
+ " \"\"\"Returns the number of tokens in a text string.\"\"\"\n",
+ " encoding = tiktoken.get_encoding(encoding_name)\n",
+ " token_ids = encoding.encode(string)\n",
+ " return token_ids"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 109,
+ "id": "3e36c32f",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T14:58:17.624225Z",
+ "start_time": "2023-03-23T14:58:17.617847Z"
+ },
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([ 7085, 2456, 3975, 284, 530, 11241, 11, 475, 617, 836,\n",
+ " 470, 25, 773, 452, 12843, 13, 198, 198, 3118, 291,\n",
+ " 1098, 3435, 588, 795, 13210, 271, 743, 307, 6626, 656,\n",
+ " 867, 16326, 7268, 262, 10238, 9881, 25, 12520, 97, 248,\n",
+ " 8582, 237, 122, 198, 198, 44015, 3007, 286, 3435, 8811,\n",
+ " 1043, 1306, 284, 1123, 584, 743, 307, 32824, 1978, 25,\n",
+ " 17031, 2231, 30924, 3829])"
+ ]
+ },
+ "execution_count": 109,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "token_ids = torch.tensor(num_tokens_from_string(text))\n",
+ "token_ids"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 110,
+ "id": "6c7d988c",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T14:58:31.666008Z",
+ "start_time": "2023-03-23T14:58:31.660228Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 64])"
+ ]
+ },
+ "execution_count": 110,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "token_ids = torch.tensor(num_tokens_from_string(text)).unsqueeze(0)\n",
+ "token_ids.shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "16410ec8",
+ "metadata": {},
+ "source": [
+ "## Tokenizing Text"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e94cc37a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# chunks = break_up_file_to_chunks(filename)\n",
+ "# for i, chunk in enumerate(chunks):\n",
+ "# prompt_request = \"Summarize this meeting transcript: \" + tokenizer.decode(chunks[i])\n",
+ "# response = openai.Completion.create(\n",
+ "# model=\"text-davinci-003\",\n",
+ "# prompt=prompt_request,\n",
+ "# temperature=.5,\n",
+ "# max_tokens=500,\n",
+ "# top_p=1,\n",
+ "# frequency_penalty=0,\n",
+ "# presence_penalty=0\n",
+ "# )\n",
+ " \n",
+ "# prompt_response.append(response[\"choices\"][0][\"text\"])\n",
+ "# prompt_tokens.append(response[\"usage\"][\"total_tokens\"])\n",
+ "\n",
+ "# total_usage = sum(prompt_tokens)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "52535acb",
+ "metadata": {},
+ "source": [
+ "## test"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 114,
+ "id": "ecc0fa7a",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T15:02:20.550454Z",
+ "start_time": "2023-03-23T15:02:20.546001Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[3666, 4004, 3124, 318, 2266, 13]\n",
+ "[3666, 4004, 3124, 318, 2266, 13]\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(tokenizer.encode('My favorite color is red.'))\n",
+ "print(tiktoken.encoding_for_model('gpt2').encode('My favorite color is red.'))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 115,
+ "id": "51f21bb3",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T15:03:17.395349Z",
+ "start_time": "2023-03-23T15:03:17.390482Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[7738, 318, 616, 4004, 3124, 13]\n",
+ "[7738, 318, 616, 4004, 3124, 13]\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(tokenizer.encode('Red is my favorite color.'))\n",
+ "print(tiktoken.encoding_for_model('gpt2').encode('Red is my favorite color.'))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "id": "95ba81ab",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T13:46:45.621813Z",
+ "start_time": "2023-03-23T13:46:45.617543Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['My', 'Ġfavorite', 'Ġcolor', 'Ġis', 'Ġred', '.']"
+ ]
+ },
+ "execution_count": 45,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.convert_ids_to_tokens(tokenizer.encode('My favorite color is red.'))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "id": "3e2762b5",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T13:47:57.125761Z",
+ "start_time": "2023-03-23T13:47:57.121320Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[2266]"
+ ]
+ },
+ "execution_count": 46,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.encode(' red')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "id": "5c57e75e",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T13:48:21.284562Z",
+ "start_time": "2023-03-23T13:48:21.280295Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[2297]"
+ ]
+ },
+ "execution_count": 47,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.encode(' Red')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "id": "7f9e6948",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T13:48:36.829697Z",
+ "start_time": "2023-03-23T13:48:36.825399Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[7738]"
+ ]
+ },
+ "execution_count": 48,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.encode('Red')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d2c9ce11",
+ "metadata": {},
+ "source": [
+ "## logit_bias (logprobs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 71,
+ "id": "24ff6772",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T14:16:49.514101Z",
+ "start_time": "2023-03-23T14:16:49.511267Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 117,
+ "id": "38b92b96",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T15:07:33.484099Z",
+ "start_time": "2023-03-23T15:07:33.479060Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[6342]"
+ ]
+ },
+ "execution_count": 117,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.encode(' Paris')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 116,
+ "id": "18091fde",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T15:07:15.532849Z",
+ "start_time": "2023-03-23T15:07:15.527278Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['ĠParis']"
+ ]
+ },
+ "execution_count": 116,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.convert_ids_to_tokens([6342])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 118,
+ "id": "69a64160",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T15:07:47.432779Z",
+ "start_time": "2023-03-23T15:07:47.427456Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[40313]"
+ ]
+ },
+ "execution_count": 118,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.encode('Paris')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 53,
+ "id": "6443a649",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T14:06:56.563477Z",
+ "start_time": "2023-03-23T14:06:55.687258Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import openai\n",
+ "openai.api_key = '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'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 93,
+ "id": "390b9b4e",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T14:32:47.398402Z",
+ "start_time": "2023-03-23T14:32:45.812646Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "resp = openai.Completion.create(engine='davinci', \n",
+ " prompt='q: What is the capital of france?\\na: ', \n",
+ " logprobs=5, \n",
+ " stop='\\n', \n",
+ " temperature=0, \n",
+ " logit_bias={6342:-100, 40313:-100}\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 119,
+ "id": "95495ef1",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T15:08:56.899012Z",
+ "start_time": "2023-03-23T15:08:56.893837Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "<OpenAIObject text_completion id=cmpl-6xGHetKXoitkQOoJAoxknDvTuvkdI at 0x144472ea0> JSON: {\n",
+ " \"choices\": [\n",
+ " {\n",
+ " \"finish_reason\": \"stop\",\n",
+ " \"index\": 0,\n",
+ " \"logprobs\": {\n",
+ " \"text_offset\": [\n",
+ " 37,\n",
+ " 38,\n",
+ " 44,\n",
+ " 47,\n",
+ " 49,\n",
+ " 57,\n",
+ " 58,\n",
+ " 62,\n",
+ " 64,\n",
+ " 69\n",
+ " ],\n",
+ " \"token_logprobs\": [\n",
+ " -0.47720006,\n",
+ " -2.6506393,\n",
+ " -1.2077931,\n",
+ " -0.84418315,\n",
+ " -0.5668232,\n",
+ " -1.1204532,\n",
+ " -0.5845836,\n",
+ " -0.01427887,\n",
+ " -0.0745326,\n",
+ " -0.38981625\n",
+ " ],\n",
+ " \"tokens\": [\n",
+ " \"\\u00a0\",\n",
+ " \"France\",\n",
+ " \" is\",\n",
+ " \" a\",\n",
+ " \" country\",\n",
+ " \",\",\n",
+ " \" not\",\n",
+ " \" a\",\n",
+ " \" city\",\n",
+ " \".\"\n",
+ " ],\n",
+ " \"top_logprobs\": [\n",
+ " {\n",
+ " \"\\t\": -3.8812888,\n",
+ " \"\\n\": -3.4891934,\n",
+ " \"________\": -2.3949227,\n",
+ " \"________________________________\": -3.8888516,\n",
+ " \"\\u00a0\": -0.47720006\n",
+ " },\n",
+ " {\n",
+ " \"B\": -3.7005355,\n",
+ " \"France\": -2.6506393,\n",
+ " \"Q\": -3.373843,\n",
+ " \"The\": -3.3061914,\n",
+ " \"par\": -2.7321594\n",
+ " },\n",
+ " {\n",
+ " \"\\n\": -2.514796,\n",
+ " \" does\": -2.5521305,\n",
+ " \" doesn\": -2.3279827,\n",
+ " \" has\": -1.6931255,\n",
+ " \" is\": -1.2077931\n",
+ " },\n",
+ " {\n",
+ " \" a\": -0.84418315,\n",
+ " \" divided\": -2.6524067,\n",
+ " \" not\": -1.3749524,\n",
+ " \" one\": -3.5888045,\n",
+ " \" the\": -2.4301894\n",
+ " },\n",
+ " {\n",
+ " \" capital\": -3.908288,\n",
+ " \" city\": -3.1317036,\n",
+ " \" country\": -0.5668232,\n",
+ " \" republic\": -2.3707287,\n",
+ " \"\\u00a0\": -3.943017\n",
+ " },\n",
+ " {\n",
+ " \" and\": -2.1657336,\n",
+ " \" in\": -1.8537258,\n",
+ " \" not\": -3.2263339,\n",
+ " \",\": -1.1204532,\n",
+ " \".\": -1.7335285\n",
+ " },\n",
+ " {\n",
+ " \" and\": -2.314171,\n",
+ " \" it\": -3.8719406,\n",
+ " \" not\": -0.5845836,\n",
+ " \" so\": -1.8761828,\n",
+ " \" the\": -3.3711283\n",
+ " },\n",
+ " {\n",
+ " \" a\": -0.01427887,\n",
+ " \" an\": -5.497788,\n",
+ " \" capital\": -5.5760226,\n",
+ " \" its\": -7.0022845,\n",
+ " \" the\": -6.8639107\n",
+ " },\n",
+ " {\n",
+ " \" capital\": -3.229291,\n",
+ " \" city\": -0.0745326,\n",
+ " \" person\": -4.1651354,\n",
+ " \" place\": -5.017827,\n",
+ " \" state\": -6.3108892\n",
+ " },\n",
+ " {\n",
+ " \"\\n\": -2.8642426,\n",
+ " \" so\": -4.4286876,\n",
+ " \"!\": -3.5753324,\n",
+ " \",\": -1.9522477,\n",
+ " \".\": -0.38981625\n",
+ " }\n",
+ " ]\n",
+ " },\n",
+ " \"text\": \"\\u00a0France is a country, not a city.\"\n",
+ " }\n",
+ " ],\n",
+ " \"created\": 1679581966,\n",
+ " \"id\": \"cmpl-6xGHetKXoitkQOoJAoxknDvTuvkdI\",\n",
+ " \"model\": \"davinci\",\n",
+ " \"object\": \"text_completion\",\n",
+ " \"usage\": {\n",
+ " \"completion_tokens\": 10,\n",
+ " \"prompt_tokens\": 14,\n",
+ " \"total_tokens\": 24\n",
+ " }\n",
+ "}"
+ ]
+ },
+ "execution_count": 119,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "resp"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 120,
+ "id": "f9a95d7e",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T15:10:51.710578Z",
+ "start_time": "2023-03-23T15:10:51.702828Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'\\xa0France is a country, not a city.'"
+ ]
+ },
+ "execution_count": 120,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "resp[\"choices\"][0]['text']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 121,
+ "id": "66e8d1ec",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T15:11:10.180884Z",
+ "start_time": "2023-03-23T15:11:10.164436Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>tokens</th>\n",
+ " <th>token_logprobs</th>\n",
+ " <th>top_logprobs</th>\n",
+ " <th>text_offset</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>0</th>\n",
+ " <td></td>\n",
+ " <td>-0.477200</td>\n",
+ " <td>{'________________________________': -3.888851...</td>\n",
+ " <td>37</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>1</th>\n",
+ " <td>France</td>\n",
+ " <td>-2.650639</td>\n",
+ " <td>{'par': -2.7321594, 'France': -2.6506393, 'B':...</td>\n",
+ " <td>38</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>2</th>\n",
+ " <td>is</td>\n",
+ " <td>-1.207793</td>\n",
+ " <td>{' doesn': -2.3279827, '\n",
+ "': -2.514796, ' is': ...</td>\n",
+ " <td>44</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>3</th>\n",
+ " <td>a</td>\n",
+ " <td>-0.844183</td>\n",
+ " <td>{' a': -0.84418315, ' the': -2.4301894, ' not'...</td>\n",
+ " <td>47</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>4</th>\n",
+ " <td>country</td>\n",
+ " <td>-0.566823</td>\n",
+ " <td>{' country': -0.5668232, ' republic': -2.37072...</td>\n",
+ " <td>49</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>5</th>\n",
+ " <td>,</td>\n",
+ " <td>-1.120453</td>\n",
+ " <td>{',': -1.1204532, '.': -1.7335285, ' in': -1.8...</td>\n",
+ " <td>57</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>6</th>\n",
+ " <td>not</td>\n",
+ " <td>-0.584584</td>\n",
+ " <td>{' the': -3.3711283, ' and': -2.314171, ' it':...</td>\n",
+ " <td>58</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>7</th>\n",
+ " <td>a</td>\n",
+ " <td>-0.014279</td>\n",
+ " <td>{' a': -0.01427887, ' the': -6.8639107, ' an':...</td>\n",
+ " <td>62</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>8</th>\n",
+ " <td>city</td>\n",
+ " <td>-0.074533</td>\n",
+ " <td>{' person': -4.1651354, ' state': -6.3108892, ...</td>\n",
+ " <td>64</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>9</th>\n",
+ " <td>.</td>\n",
+ " <td>-0.389816</td>\n",
+ " <td>{'!': -3.5753324, ',': -1.9522477, '.': -0.389...</td>\n",
+ " <td>69</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " tokens token_logprobs \\\n",
+ "0   -0.477200 \n",
+ "1 France -2.650639 \n",
+ "2 is -1.207793 \n",
+ "3 a -0.844183 \n",
+ "4 country -0.566823 \n",
+ "5 , -1.120453 \n",
+ "6 not -0.584584 \n",
+ "7 a -0.014279 \n",
+ "8 city -0.074533 \n",
+ "9 . -0.389816 \n",
+ "\n",
+ " top_logprobs text_offset \n",
+ "0 {'________________________________': -3.888851... 37 \n",
+ "1 {'par': -2.7321594, 'France': -2.6506393, 'B':... 38 \n",
+ "2 {' doesn': -2.3279827, '\n",
+ "': -2.514796, ' is': ... 44 \n",
+ "3 {' a': -0.84418315, ' the': -2.4301894, ' not'... 47 \n",
+ "4 {' country': -0.5668232, ' republic': -2.37072... 49 \n",
+ "5 {',': -1.1204532, '.': -1.7335285, ' in': -1.8... 57 \n",
+ "6 {' the': -3.3711283, ' and': -2.314171, ' it':... 58 \n",
+ "7 {' a': -0.01427887, ' the': -6.8639107, ' an':... 62 \n",
+ "8 {' person': -4.1651354, ' state': -6.3108892, ... 64 \n",
+ "9 {'!': -3.5753324, ',': -1.9522477, '.': -0.389... 69 "
+ ]
+ },
+ "execution_count": 121,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pd.DataFrame(resp[\"choices\"][0][\"logprobs\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 122,
+ "id": "2dd6bb17",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-23T15:12:00.288365Z",
+ "start_time": "2023-03-23T15:12:00.277073Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>logprob</th>\n",
+ " <th>%</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>par</th>\n",
+ " <td>-2.732159</td>\n",
+ " <td>6.507861</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>France</th>\n",
+ " <td>-2.650639</td>\n",
+ " <td>7.060606</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>B</th>\n",
+ " <td>-3.700535</td>\n",
+ " <td>2.471029</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>The</th>\n",
+ " <td>-3.306191</td>\n",
+ " <td>3.665551</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>Q</th>\n",
+ " <td>-3.373843</td>\n",
+ " <td>3.425773</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " logprob %\n",
+ "par -2.732159 6.507861\n",
+ "France -2.650639 7.060606\n",
+ "B -3.700535 2.471029\n",
+ "The -3.306191 3.665551\n",
+ "Q -3.373843 3.425773"
+ ]
+ },
+ "execution_count": 122,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "scores = pd.DataFrame([resp[\"choices\"][0][\"logprobs\"][\"top_logprobs\"][1]]).T\n",
+ "scores.columns = [\"logprob\"]\n",
+ "scores[\"%\"] = scores[\"logprob\"].apply(lambda x: 100*np.e**x)\n",
+ "scores"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.13"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": true
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}