diff options
| author | chzhang <zch921005@126.com> | 2023-03-13 23:04:13 +0800 |
|---|---|---|
| committer | chzhang <zch921005@126.com> | 2023-03-13 23:04:13 +0800 |
| commit | f7409be04df00e8b10aede52c491cadeada2a491 (patch) | |
| tree | c4c766ff2e4fe5abfc16b93093fda3d7c04e7a22 /llm/tutorials/03_embedding_search.ipynb | |
| parent | 132fc2142a9c493134fce834d14f4772ebe846b4 (diff) | |
openai embedding
Diffstat (limited to 'llm/tutorials/03_embedding_search.ipynb')
| -rw-r--r-- | llm/tutorials/03_embedding_search.ipynb | 600 |
1 files changed, 600 insertions, 0 deletions
diff --git a/llm/tutorials/03_embedding_search.ipynb b/llm/tutorials/03_embedding_search.ipynb new file mode 100644 index 0000000..5225d75 --- /dev/null +++ b/llm/tutorials/03_embedding_search.ipynb @@ -0,0 +1,600 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "id": "12b9c4db", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-13T13:56:17.829502Z", + "start_time": "2023-03-13T13:56:03.465027Z" + }, + "collapsed": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: http://mirrors.aliyun.com/pypi/simple/\n", + "Collecting tiktoken\n", + " Downloading http://mirrors.aliyun.com/pypi/packages/5c/76/03b8286cd264f9f5550229fe21f72abc89d431a9a3c887fc365763acc5a4/tiktoken-0.3.0-cp39-cp39-macosx_10_9_x86_64.whl (735 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m735.4/735.4 kB\u001b[0m \u001b[31m256.0 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: requests>=2.26.0 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from tiktoken) (2.28.1)\n", + "Requirement already satisfied: regex>=2022.1.18 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from tiktoken) (2022.7.9)\n", + "Collecting blobfile>=2\n", + " Downloading http://mirrors.aliyun.com/pypi/packages/c1/35/6b92aa0d86f26f0a8ab6959dd29ac4c7e96d5c1d948d4347bba12e07695a/blobfile-2.0.1-py3-none-any.whl (73 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m73.5/73.5 kB\u001b[0m \u001b[31m214.4 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: urllib3<3,>=1.25.3 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from blobfile>=2->tiktoken) (1.26.11)\n", + "Requirement already satisfied: filelock~=3.0 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from blobfile>=2->tiktoken) (3.6.0)\n", + "Collecting pycryptodomex~=3.8\n", + " Downloading http://mirrors.aliyun.com/pypi/packages/78/db/ec162a8fa1c7c8e03488616a01de59bb752b985f1c507ffb127b40b9d456/pycryptodomex-3.17-cp35-abi3-macosx_10_9_x86_64.whl (1.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m272.2 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: lxml~=4.9 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from blobfile>=2->tiktoken) (4.9.1)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (2022.9.24)\n", + "Requirement already satisfied: charset-normalizer<3,>=2 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (3.3)\n", + "Installing collected packages: pycryptodomex, blobfile, tiktoken\n", + "Successfully installed blobfile-2.0.1 pycryptodomex-3.17 tiktoken-0.3.0\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "pip install tiktoken" + ] + }, + { + "cell_type": "markdown", + "id": "f56bbe1c", + "metadata": {}, + "source": [ + "## 认识数据集" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "id": "76150440", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-13T14:51:31.301676Z", + "start_time": "2023-03-13T14:51:31.297972Z" + } + }, + "outputs": [], + "source": [ + "# imports\n", + "import pandas as pd\n", + "import tiktoken\n", + "import openai\n", + "from openai.embeddings_utils import get_embedding\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fa311d89", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-13T13:57:09.514327Z", + "start_time": "2023-03-13T13:57:09.510859Z" + } + }, + "outputs": [], + "source": [ + "# embedding model parameters\n", + "embedding_model = \"text-embedding-ada-002\"\n", + "embedding_encoding = \"cl100k_base\" # this the encoding for text-embedding-ada-002\n", + "max_tokens = 8191 # the maximum for text-embedding-ada-002 is 8191" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "id": "63c73803", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-13T14:52:36.612466Z", + "start_time": "2023-03-13T14:52:36.609031Z" + } + }, + "outputs": [], + "source": [ + "input_file = './data/fine_food_reviews_1k.csv'" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "id": "b4a220f1", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-13T14:54:11.245064Z", + "start_time": "2023-03-13T14:54:11.210401Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1000, 6)\n", + "(762, 6)\n" + ] + } + ], + "source": [ + "df = pd.read_csv(input_file, index_col=0)\n", + "df = df[[\"Time\", \"ProductId\", \"UserId\", \"Score\", \"Summary\", \"Text\"]]\n", + "df = df.sort_values('Time')\n", + "df.dropna(inplace=True)\n", + "print(df.shape)\n", + "df.drop_duplicates(subset=['Summary', 'Text'], keep='last', inplace=True)\n", + "print(df.shape)\n", + "df['Combined'] = 'Title: ' + df.Summary.str.strip() + '; Content: ' + df.Text.str.strip()" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "id": "a6c7bbd8", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-13T14:55:05.038778Z", + "start_time": "2023-03-13T14:55:04.891622Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "100" + ] + }, + "execution_count": 89, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "top_n = 100\n", + "encoding = tiktoken.get_encoding(embedding_encoding)\n", + "# omit reviews that are too long to embed\n", + "df[\"n_tokens\"] = df.Combined.apply(lambda x: len(encoding.encode(x)))\n", + "df = df[df.n_tokens <= max_tokens].tail(top_n)\n", + "len(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "id": "262994d4", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-13T14:56:38.163617Z", + "start_time": "2023-03-13T14:55:27.448765Z" + } + }, + "outputs": [], + "source": [ + "openai.api_key = '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'\n", + "df['embedding'] = df.Combined.apply(lambda x: get_embedding(x, engine=embedding_model))" + ] + }, + { + "cell_type": "markdown", + "id": "9510f2b6", + "metadata": {}, + "source": [ + "## embedding" + ] + }, + { + "cell_type": "markdown", + "id": "e918cedc", + "metadata": {}, + "source": [ + "- dimension\n", + "- norm" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "id": "92261327", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-13T14:57:45.646492Z", + "start_time": "2023-03-13T14:57:45.620730Z" + } + }, + "outputs": [], + "source": [ + "df['embed_len'] = df.embedding.apply(lambda x: len(x))\n", + "df['embed_norm'] = df.embedding.apply(lambda x: np.linalg.norm(x))" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "id": "eaba5054", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-13T14:57:46.858110Z", + "start_time": "2023-03-13T14:57:46.795263Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>Time</th>\n", + " <th>ProductId</th>\n", + " <th>UserId</th>\n", + " <th>Score</th>\n", + " <th>Summary</th>\n", + " <th>Text</th>\n", + " <th>Combined</th>\n", + " <th>n_tokens</th>\n", + " <th>embedding</th>\n", + " <th>embed_len</th>\n", + " <th>embed_norm</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>650</th>\n", + " <td>1351209600</td>\n", + " <td>B0051O6P36</td>\n", + " <td>A1VC6419THHIET</td>\n", + " <td>5</td>\n", + " <td>Good for all cats.</td>\n", + " <td>I just got these treats last week and they're ...</td>\n", + " <td>Title: Good for all cats.; Content: I just got...</td>\n", + " <td>81</td>\n", + " <td>[-0.02040177956223488, -0.022390257567167282, ...</td>\n", + " <td>1536</td>\n", + " <td>1.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>651</th>\n", + " <td>1351209600</td>\n", + " <td>B001EO5RSQ</td>\n", + " <td>A33W5JAFGHYRQZ</td>\n", + " <td>5</td>\n", + " <td>Love this Cereal!</td>\n", + " <td>There is nothing else like this on the market....</td>\n", + " <td>Title: Love this Cereal!; Content: There is no...</td>\n", + " <td>55</td>\n", + " <td>[-0.012976857833564281, -0.008588296361267567,...</td>\n", + " <td>1536</td>\n", + " <td>1.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>652</th>\n", + " <td>1351209600</td>\n", + " <td>B0045H264C</td>\n", + " <td>A3IYSIAKYOMKTO</td>\n", + " <td>5</td>\n", + " <td>Wild Honey</td>\n", + " <td>This really is unfiltered honey made from wild...</td>\n", + " <td>Title: Wild Honey; Content: This really is unf...</td>\n", + " <td>107</td>\n", + " <td>[0.002022168133407831, -0.010228604078292847, ...</td>\n", + " <td>1536</td>\n", + " <td>1.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>679</th>\n", + " <td>1351209600</td>\n", + " <td>B000UBD88A</td>\n", + " <td>AWRFQYLG7LQKJ</td>\n", + " <td>2</td>\n", + " <td>Not very strong</td>\n", + " <td>Not as strong as the regular dark coffee. Dis...</td>\n", + " <td>Title: Not very strong; Content: Not as strong...</td>\n", + " <td>45</td>\n", + " <td>[-0.0016124029643833637, -0.026590621098876, 0...</td>\n", + " <td>1536</td>\n", + " <td>1.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>654</th>\n", + " <td>1351209600</td>\n", + " <td>B001XWRMAU</td>\n", + " <td>A1KWVBDHBG50VZ</td>\n", + " <td>5</td>\n", + " <td>Outstanding product!.....</td>\n", + " <td>Great flavor.....lotsa &#34;heat&#34;....I use...</td>\n", + " <td>Title: Outstanding product!.....; Content: Gre...</td>\n", + " <td>43</td>\n", + " <td>[-0.00573874544352293, 0.007031316868960857, 0...</td>\n", + " <td>1536</td>\n", + " <td>1.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>...</th>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " <td>...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>623</th>\n", + " <td>1351209600</td>\n", + " <td>B0000CFXYA</td>\n", + " <td>A3GS4GWPIBV0NT</td>\n", + " <td>1</td>\n", + " <td>Strange inflammation response</td>\n", + " <td>Truthfully wasn't crazy about the taste of the...</td>\n", + " <td>Title: Strange inflammation response; Content:...</td>\n", + " <td>110</td>\n", + " <td>[0.00011091353371739388, -0.00466986745595932,...</td>\n", + " <td>1536</td>\n", + " <td>1.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>624</th>\n", + " <td>1351209600</td>\n", + " <td>B0001BH5YM</td>\n", + " <td>A1BZ3HMAKK0NC</td>\n", + " <td>5</td>\n", + " <td>My favorite and only MUSTARD</td>\n", + " <td>You've just got to experience this mustard... ...</td>\n", + " <td>Title: My favorite and only MUSTARD; Content:...</td>\n", + " <td>80</td>\n", + " <td>[-0.020869314670562744, -0.013138455338776112,...</td>\n", + " <td>1536</td>\n", + " <td>1.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>625</th>\n", + " <td>1351209600</td>\n", + " <td>B0009ET7TC</td>\n", + " <td>A2FSDQY5AI6TNX</td>\n", + " <td>5</td>\n", + " <td>My furbabies LOVE these!</td>\n", + " <td>Shake the container and they come running. Eve...</td>\n", + " <td>Title: My furbabies LOVE these!; Content: Shak...</td>\n", + " <td>47</td>\n", + " <td>[-0.009749102406203747, -0.0068712360225617886...</td>\n", + " <td>1536</td>\n", + " <td>1.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>619</th>\n", + " <td>1351209600</td>\n", + " <td>B007PA32L2</td>\n", + " <td>A15FF2P7RPKH6G</td>\n", + " <td>5</td>\n", + " <td>got this for the daughter</td>\n", + " <td>all i have heard since she got a kuerig is why...</td>\n", + " <td>Title: got this for the daughter; Content: all...</td>\n", + " <td>50</td>\n", + " <td>[-0.005320307798683643, 0.0009131018887273967,...</td>\n", + " <td>1536</td>\n", + " <td>1.0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>999</th>\n", + " <td>1351209600</td>\n", + " <td>B001EQ5GEO</td>\n", + " <td>A3VYU0VO6DYV6I</td>\n", + " <td>5</td>\n", + " <td>I love Maui Coffee!</td>\n", + " <td>My first experience with Maui Coffee was bring...</td>\n", + " <td>Title: I love Maui Coffee!; Content: My first ...</td>\n", + " <td>118</td>\n", + " <td>[-0.006057822611182928, -0.015015840530395508,...</td>\n", + " <td>1536</td>\n", + " <td>1.0</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>100 rows × 11 columns</p>\n", + "</div>" + ], + "text/plain": [ + " Time ProductId UserId Score \\\n", + "650 1351209600 B0051O6P36 A1VC6419THHIET 5 \n", + "651 1351209600 B001EO5RSQ A33W5JAFGHYRQZ 5 \n", + "652 1351209600 B0045H264C A3IYSIAKYOMKTO 5 \n", + "679 1351209600 B000UBD88A AWRFQYLG7LQKJ 2 \n", + "654 1351209600 B001XWRMAU A1KWVBDHBG50VZ 5 \n", + ".. ... ... ... ... \n", + "623 1351209600 B0000CFXYA A3GS4GWPIBV0NT 1 \n", + "624 1351209600 B0001BH5YM A1BZ3HMAKK0NC 5 \n", + "625 1351209600 B0009ET7TC A2FSDQY5AI6TNX 5 \n", + "619 1351209600 B007PA32L2 A15FF2P7RPKH6G 5 \n", + "999 1351209600 B001EQ5GEO A3VYU0VO6DYV6I 5 \n", + "\n", + " Summary \\\n", + "650 Good for all cats. \n", + "651 Love this Cereal! \n", + "652 Wild Honey \n", + "679 Not very strong \n", + "654 Outstanding product!..... \n", + ".. ... \n", + "623 Strange inflammation response \n", + "624 My favorite and only MUSTARD \n", + "625 My furbabies LOVE these! \n", + "619 got this for the daughter \n", + "999 I love Maui Coffee! \n", + "\n", + " Text \\\n", + "650 I just got these treats last week and they're ... \n", + "651 There is nothing else like this on the market.... \n", + "652 This really is unfiltered honey made from wild... \n", + "679 Not as strong as the regular dark coffee. Dis... \n", + "654 Great flavor.....lotsa "heat"....I use... \n", + ".. ... \n", + "623 Truthfully wasn't crazy about the taste of the... \n", + "624 You've just got to experience this mustard... ... \n", + "625 Shake the container and they come running. Eve... \n", + "619 all i have heard since she got a kuerig is why... \n", + "999 My first experience with Maui Coffee was bring... \n", + "\n", + " Combined n_tokens \\\n", + "650 Title: Good for all cats.; Content: I just got... 81 \n", + "651 Title: Love this Cereal!; Content: There is no... 55 \n", + "652 Title: Wild Honey; Content: This really is unf... 107 \n", + "679 Title: Not very strong; Content: Not as strong... 45 \n", + "654 Title: Outstanding product!.....; Content: Gre... 43 \n", + ".. ... ... \n", + "623 Title: Strange inflammation response; Content:... 110 \n", + "624 Title: My favorite and only MUSTARD; Content:... 80 \n", + "625 Title: My furbabies LOVE these!; Content: Shak... 47 \n", + "619 Title: got this for the daughter; Content: all... 50 \n", + "999 Title: I love Maui Coffee!; Content: My first ... 118 \n", + "\n", + " embedding embed_len embed_norm \n", + "650 [-0.02040177956223488, -0.022390257567167282, ... 1536 1.0 \n", + "651 [-0.012976857833564281, -0.008588296361267567,... 1536 1.0 \n", + "652 [0.002022168133407831, -0.010228604078292847, ... 1536 1.0 \n", + "679 [-0.0016124029643833637, -0.026590621098876, 0... 1536 1.0 \n", + "654 [-0.00573874544352293, 0.007031316868960857, 0... 1536 1.0 \n", + ".. ... ... ... \n", + "623 [0.00011091353371739388, -0.00466986745595932,... 1536 1.0 \n", + "624 [-0.020869314670562744, -0.013138455338776112,... 1536 1.0 \n", + "625 [-0.009749102406203747, -0.0068712360225617886... 1536 1.0 \n", + "619 [-0.005320307798683643, 0.0009131018887273967,... 1536 1.0 \n", + "999 [-0.006057822611182928, -0.015015840530395508,... 1536 1.0 \n", + "\n", + "[100 rows x 11 columns]" + ] + }, + "execution_count": 92, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df" + ] + }, + { + "cell_type": "markdown", + "id": "b80cea9c", + "metadata": {}, + "source": [ + "## semantic search base text embedding" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "b9280f2b", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-13T14:59:32.786418Z", + "start_time": "2023-03-13T14:59:31.781336Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "super coffee: Great coffee and so easy to brew. This coffee has great aroma and is good to the last drop. I actually like all the brands. This is the way coffee should taste!!\n", + "Delicious!!!!: A coffee treat. Now that my husband and I drink this coffee, there is no going back to the plain stuff ;).\n", + "Full- bodied without a bitter after-taste: This is my everyday coffee choice...a good all around crowd pleaser. Green mountain Sumatra would be my back-up-for-a-change-of-pace second choice...nice t\n" + ] + } + ], + "source": [ + "from openai.embeddings_utils import get_embedding, cosine_similarity\n", + "\n", + "# search through the reviews for a specific product\n", + "def search_reviews(df, query, n=3, pprint=True):\n", + " query_embed = get_embedding(\n", + " query,\n", + " engine=embedding_model\n", + " )\n", + " df[\"similarity\"] = df.embedding.apply(lambda x: cosine_similarity(x, query_embed))\n", + "\n", + " results = (\n", + " df.sort_values(\"similarity\", ascending=False)\n", + " .head(n)\n", + " .Combined.str.replace(\"Title: \", \"\")\n", + " .str.replace(\"; Content:\", \": \")\n", + " )\n", + " if pprint:\n", + " for r in results:\n", + " print(r[:200])\n", + " return results\n", + "\n", + "\n", + "results = search_reviews(df, \"good coffee\", n=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "717a8311", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |
