summaryrefslogtreecommitdiff
path: root/llm/tutorials/03_embedding_search.ipynb
diff options
context:
space:
mode:
authorchzhang <zch921005@126.com>2023-03-13 23:04:13 +0800
committerchzhang <zch921005@126.com>2023-03-13 23:04:13 +0800
commitf7409be04df00e8b10aede52c491cadeada2a491 (patch)
treec4c766ff2e4fe5abfc16b93093fda3d7c04e7a22 /llm/tutorials/03_embedding_search.ipynb
parent132fc2142a9c493134fce834d14f4772ebe846b4 (diff)
openai embedding
Diffstat (limited to 'llm/tutorials/03_embedding_search.ipynb')
-rw-r--r--llm/tutorials/03_embedding_search.ipynb600
1 files changed, 600 insertions, 0 deletions
diff --git a/llm/tutorials/03_embedding_search.ipynb b/llm/tutorials/03_embedding_search.ipynb
new file mode 100644
index 0000000..5225d75
--- /dev/null
+++ b/llm/tutorials/03_embedding_search.ipynb
@@ -0,0 +1,600 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "12b9c4db",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-13T13:56:17.829502Z",
+ "start_time": "2023-03-13T13:56:03.465027Z"
+ },
+ "collapsed": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Looking in indexes: http://mirrors.aliyun.com/pypi/simple/\n",
+ "Collecting tiktoken\n",
+ " Downloading http://mirrors.aliyun.com/pypi/packages/5c/76/03b8286cd264f9f5550229fe21f72abc89d431a9a3c887fc365763acc5a4/tiktoken-0.3.0-cp39-cp39-macosx_10_9_x86_64.whl (735 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m735.4/735.4 kB\u001b[0m \u001b[31m256.0 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: requests>=2.26.0 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from tiktoken) (2.28.1)\n",
+ "Requirement already satisfied: regex>=2022.1.18 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from tiktoken) (2022.7.9)\n",
+ "Collecting blobfile>=2\n",
+ " Downloading http://mirrors.aliyun.com/pypi/packages/c1/35/6b92aa0d86f26f0a8ab6959dd29ac4c7e96d5c1d948d4347bba12e07695a/blobfile-2.0.1-py3-none-any.whl (73 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m73.5/73.5 kB\u001b[0m \u001b[31m214.4 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: urllib3<3,>=1.25.3 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from blobfile>=2->tiktoken) (1.26.11)\n",
+ "Requirement already satisfied: filelock~=3.0 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from blobfile>=2->tiktoken) (3.6.0)\n",
+ "Collecting pycryptodomex~=3.8\n",
+ " Downloading http://mirrors.aliyun.com/pypi/packages/78/db/ec162a8fa1c7c8e03488616a01de59bb752b985f1c507ffb127b40b9d456/pycryptodomex-3.17-cp35-abi3-macosx_10_9_x86_64.whl (1.6 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m272.2 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: lxml~=4.9 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from blobfile>=2->tiktoken) (4.9.1)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (2022.9.24)\n",
+ "Requirement already satisfied: charset-normalizer<3,>=2 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (2.0.4)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /Users/chunhuizhang/opt/anaconda3/lib/python3.9/site-packages (from requests>=2.26.0->tiktoken) (3.3)\n",
+ "Installing collected packages: pycryptodomex, blobfile, tiktoken\n",
+ "Successfully installed blobfile-2.0.1 pycryptodomex-3.17 tiktoken-0.3.0\n",
+ "Note: you may need to restart the kernel to use updated packages.\n"
+ ]
+ }
+ ],
+ "source": [
+ "pip install tiktoken"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f56bbe1c",
+ "metadata": {},
+ "source": [
+ "## 认识数据集"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 84,
+ "id": "76150440",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-13T14:51:31.301676Z",
+ "start_time": "2023-03-13T14:51:31.297972Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# imports\n",
+ "import pandas as pd\n",
+ "import tiktoken\n",
+ "import openai\n",
+ "from openai.embeddings_utils import get_embedding\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "fa311d89",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-13T13:57:09.514327Z",
+ "start_time": "2023-03-13T13:57:09.510859Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# embedding model parameters\n",
+ "embedding_model = \"text-embedding-ada-002\"\n",
+ "embedding_encoding = \"cl100k_base\" # this the encoding for text-embedding-ada-002\n",
+ "max_tokens = 8191 # the maximum for text-embedding-ada-002 is 8191"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 85,
+ "id": "63c73803",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-13T14:52:36.612466Z",
+ "start_time": "2023-03-13T14:52:36.609031Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "input_file = './data/fine_food_reviews_1k.csv'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 88,
+ "id": "b4a220f1",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-13T14:54:11.245064Z",
+ "start_time": "2023-03-13T14:54:11.210401Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(1000, 6)\n",
+ "(762, 6)\n"
+ ]
+ }
+ ],
+ "source": [
+ "df = pd.read_csv(input_file, index_col=0)\n",
+ "df = df[[\"Time\", \"ProductId\", \"UserId\", \"Score\", \"Summary\", \"Text\"]]\n",
+ "df = df.sort_values('Time')\n",
+ "df.dropna(inplace=True)\n",
+ "print(df.shape)\n",
+ "df.drop_duplicates(subset=['Summary', 'Text'], keep='last', inplace=True)\n",
+ "print(df.shape)\n",
+ "df['Combined'] = 'Title: ' + df.Summary.str.strip() + '; Content: ' + df.Text.str.strip()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 89,
+ "id": "a6c7bbd8",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-13T14:55:05.038778Z",
+ "start_time": "2023-03-13T14:55:04.891622Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "100"
+ ]
+ },
+ "execution_count": 89,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "top_n = 100\n",
+ "encoding = tiktoken.get_encoding(embedding_encoding)\n",
+ "# omit reviews that are too long to embed\n",
+ "df[\"n_tokens\"] = df.Combined.apply(lambda x: len(encoding.encode(x)))\n",
+ "df = df[df.n_tokens <= max_tokens].tail(top_n)\n",
+ "len(df)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 90,
+ "id": "262994d4",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-13T14:56:38.163617Z",
+ "start_time": "2023-03-13T14:55:27.448765Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "openai.api_key = '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'\n",
+ "df['embedding'] = df.Combined.apply(lambda x: get_embedding(x, engine=embedding_model))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9510f2b6",
+ "metadata": {},
+ "source": [
+ "## embedding"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e918cedc",
+ "metadata": {},
+ "source": [
+ "- dimension\n",
+ "- norm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 91,
+ "id": "92261327",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-13T14:57:45.646492Z",
+ "start_time": "2023-03-13T14:57:45.620730Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "df['embed_len'] = df.embedding.apply(lambda x: len(x))\n",
+ "df['embed_norm'] = df.embedding.apply(lambda x: np.linalg.norm(x))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 92,
+ "id": "eaba5054",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-13T14:57:46.858110Z",
+ "start_time": "2023-03-13T14:57:46.795263Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>Time</th>\n",
+ " <th>ProductId</th>\n",
+ " <th>UserId</th>\n",
+ " <th>Score</th>\n",
+ " <th>Summary</th>\n",
+ " <th>Text</th>\n",
+ " <th>Combined</th>\n",
+ " <th>n_tokens</th>\n",
+ " <th>embedding</th>\n",
+ " <th>embed_len</th>\n",
+ " <th>embed_norm</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>650</th>\n",
+ " <td>1351209600</td>\n",
+ " <td>B0051O6P36</td>\n",
+ " <td>A1VC6419THHIET</td>\n",
+ " <td>5</td>\n",
+ " <td>Good for all cats.</td>\n",
+ " <td>I just got these treats last week and they're ...</td>\n",
+ " <td>Title: Good for all cats.; Content: I just got...</td>\n",
+ " <td>81</td>\n",
+ " <td>[-0.02040177956223488, -0.022390257567167282, ...</td>\n",
+ " <td>1536</td>\n",
+ " <td>1.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>651</th>\n",
+ " <td>1351209600</td>\n",
+ " <td>B001EO5RSQ</td>\n",
+ " <td>A33W5JAFGHYRQZ</td>\n",
+ " <td>5</td>\n",
+ " <td>Love this Cereal!</td>\n",
+ " <td>There is nothing else like this on the market....</td>\n",
+ " <td>Title: Love this Cereal!; Content: There is no...</td>\n",
+ " <td>55</td>\n",
+ " <td>[-0.012976857833564281, -0.008588296361267567,...</td>\n",
+ " <td>1536</td>\n",
+ " <td>1.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>652</th>\n",
+ " <td>1351209600</td>\n",
+ " <td>B0045H264C</td>\n",
+ " <td>A3IYSIAKYOMKTO</td>\n",
+ " <td>5</td>\n",
+ " <td>Wild Honey</td>\n",
+ " <td>This really is unfiltered honey made from wild...</td>\n",
+ " <td>Title: Wild Honey; Content: This really is unf...</td>\n",
+ " <td>107</td>\n",
+ " <td>[0.002022168133407831, -0.010228604078292847, ...</td>\n",
+ " <td>1536</td>\n",
+ " <td>1.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>679</th>\n",
+ " <td>1351209600</td>\n",
+ " <td>B000UBD88A</td>\n",
+ " <td>AWRFQYLG7LQKJ</td>\n",
+ " <td>2</td>\n",
+ " <td>Not very strong</td>\n",
+ " <td>Not as strong as the regular dark coffee. Dis...</td>\n",
+ " <td>Title: Not very strong; Content: Not as strong...</td>\n",
+ " <td>45</td>\n",
+ " <td>[-0.0016124029643833637, -0.026590621098876, 0...</td>\n",
+ " <td>1536</td>\n",
+ " <td>1.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>654</th>\n",
+ " <td>1351209600</td>\n",
+ " <td>B001XWRMAU</td>\n",
+ " <td>A1KWVBDHBG50VZ</td>\n",
+ " <td>5</td>\n",
+ " <td>Outstanding product!.....</td>\n",
+ " <td>Great flavor.....lotsa &amp;#34;heat&amp;#34;....I use...</td>\n",
+ " <td>Title: Outstanding product!.....; Content: Gre...</td>\n",
+ " <td>43</td>\n",
+ " <td>[-0.00573874544352293, 0.007031316868960857, 0...</td>\n",
+ " <td>1536</td>\n",
+ " <td>1.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>...</th>\n",
+ " <td>...</td>\n",
+ " <td>...</td>\n",
+ " <td>...</td>\n",
+ " <td>...</td>\n",
+ " <td>...</td>\n",
+ " <td>...</td>\n",
+ " <td>...</td>\n",
+ " <td>...</td>\n",
+ " <td>...</td>\n",
+ " <td>...</td>\n",
+ " <td>...</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>623</th>\n",
+ " <td>1351209600</td>\n",
+ " <td>B0000CFXYA</td>\n",
+ " <td>A3GS4GWPIBV0NT</td>\n",
+ " <td>1</td>\n",
+ " <td>Strange inflammation response</td>\n",
+ " <td>Truthfully wasn't crazy about the taste of the...</td>\n",
+ " <td>Title: Strange inflammation response; Content:...</td>\n",
+ " <td>110</td>\n",
+ " <td>[0.00011091353371739388, -0.00466986745595932,...</td>\n",
+ " <td>1536</td>\n",
+ " <td>1.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>624</th>\n",
+ " <td>1351209600</td>\n",
+ " <td>B0001BH5YM</td>\n",
+ " <td>A1BZ3HMAKK0NC</td>\n",
+ " <td>5</td>\n",
+ " <td>My favorite and only MUSTARD</td>\n",
+ " <td>You've just got to experience this mustard... ...</td>\n",
+ " <td>Title: My favorite and only MUSTARD; Content:...</td>\n",
+ " <td>80</td>\n",
+ " <td>[-0.020869314670562744, -0.013138455338776112,...</td>\n",
+ " <td>1536</td>\n",
+ " <td>1.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>625</th>\n",
+ " <td>1351209600</td>\n",
+ " <td>B0009ET7TC</td>\n",
+ " <td>A2FSDQY5AI6TNX</td>\n",
+ " <td>5</td>\n",
+ " <td>My furbabies LOVE these!</td>\n",
+ " <td>Shake the container and they come running. Eve...</td>\n",
+ " <td>Title: My furbabies LOVE these!; Content: Shak...</td>\n",
+ " <td>47</td>\n",
+ " <td>[-0.009749102406203747, -0.0068712360225617886...</td>\n",
+ " <td>1536</td>\n",
+ " <td>1.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>619</th>\n",
+ " <td>1351209600</td>\n",
+ " <td>B007PA32L2</td>\n",
+ " <td>A15FF2P7RPKH6G</td>\n",
+ " <td>5</td>\n",
+ " <td>got this for the daughter</td>\n",
+ " <td>all i have heard since she got a kuerig is why...</td>\n",
+ " <td>Title: got this for the daughter; Content: all...</td>\n",
+ " <td>50</td>\n",
+ " <td>[-0.005320307798683643, 0.0009131018887273967,...</td>\n",
+ " <td>1536</td>\n",
+ " <td>1.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>999</th>\n",
+ " <td>1351209600</td>\n",
+ " <td>B001EQ5GEO</td>\n",
+ " <td>A3VYU0VO6DYV6I</td>\n",
+ " <td>5</td>\n",
+ " <td>I love Maui Coffee!</td>\n",
+ " <td>My first experience with Maui Coffee was bring...</td>\n",
+ " <td>Title: I love Maui Coffee!; Content: My first ...</td>\n",
+ " <td>118</td>\n",
+ " <td>[-0.006057822611182928, -0.015015840530395508,...</td>\n",
+ " <td>1536</td>\n",
+ " <td>1.0</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "<p>100 rows × 11 columns</p>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " Time ProductId UserId Score \\\n",
+ "650 1351209600 B0051O6P36 A1VC6419THHIET 5 \n",
+ "651 1351209600 B001EO5RSQ A33W5JAFGHYRQZ 5 \n",
+ "652 1351209600 B0045H264C A3IYSIAKYOMKTO 5 \n",
+ "679 1351209600 B000UBD88A AWRFQYLG7LQKJ 2 \n",
+ "654 1351209600 B001XWRMAU A1KWVBDHBG50VZ 5 \n",
+ ".. ... ... ... ... \n",
+ "623 1351209600 B0000CFXYA A3GS4GWPIBV0NT 1 \n",
+ "624 1351209600 B0001BH5YM A1BZ3HMAKK0NC 5 \n",
+ "625 1351209600 B0009ET7TC A2FSDQY5AI6TNX 5 \n",
+ "619 1351209600 B007PA32L2 A15FF2P7RPKH6G 5 \n",
+ "999 1351209600 B001EQ5GEO A3VYU0VO6DYV6I 5 \n",
+ "\n",
+ " Summary \\\n",
+ "650 Good for all cats. \n",
+ "651 Love this Cereal! \n",
+ "652 Wild Honey \n",
+ "679 Not very strong \n",
+ "654 Outstanding product!..... \n",
+ ".. ... \n",
+ "623 Strange inflammation response \n",
+ "624 My favorite and only MUSTARD \n",
+ "625 My furbabies LOVE these! \n",
+ "619 got this for the daughter \n",
+ "999 I love Maui Coffee! \n",
+ "\n",
+ " Text \\\n",
+ "650 I just got these treats last week and they're ... \n",
+ "651 There is nothing else like this on the market.... \n",
+ "652 This really is unfiltered honey made from wild... \n",
+ "679 Not as strong as the regular dark coffee. Dis... \n",
+ "654 Great flavor.....lotsa &#34;heat&#34;....I use... \n",
+ ".. ... \n",
+ "623 Truthfully wasn't crazy about the taste of the... \n",
+ "624 You've just got to experience this mustard... ... \n",
+ "625 Shake the container and they come running. Eve... \n",
+ "619 all i have heard since she got a kuerig is why... \n",
+ "999 My first experience with Maui Coffee was bring... \n",
+ "\n",
+ " Combined n_tokens \\\n",
+ "650 Title: Good for all cats.; Content: I just got... 81 \n",
+ "651 Title: Love this Cereal!; Content: There is no... 55 \n",
+ "652 Title: Wild Honey; Content: This really is unf... 107 \n",
+ "679 Title: Not very strong; Content: Not as strong... 45 \n",
+ "654 Title: Outstanding product!.....; Content: Gre... 43 \n",
+ ".. ... ... \n",
+ "623 Title: Strange inflammation response; Content:... 110 \n",
+ "624 Title: My favorite and only MUSTARD; Content:... 80 \n",
+ "625 Title: My furbabies LOVE these!; Content: Shak... 47 \n",
+ "619 Title: got this for the daughter; Content: all... 50 \n",
+ "999 Title: I love Maui Coffee!; Content: My first ... 118 \n",
+ "\n",
+ " embedding embed_len embed_norm \n",
+ "650 [-0.02040177956223488, -0.022390257567167282, ... 1536 1.0 \n",
+ "651 [-0.012976857833564281, -0.008588296361267567,... 1536 1.0 \n",
+ "652 [0.002022168133407831, -0.010228604078292847, ... 1536 1.0 \n",
+ "679 [-0.0016124029643833637, -0.026590621098876, 0... 1536 1.0 \n",
+ "654 [-0.00573874544352293, 0.007031316868960857, 0... 1536 1.0 \n",
+ ".. ... ... ... \n",
+ "623 [0.00011091353371739388, -0.00466986745595932,... 1536 1.0 \n",
+ "624 [-0.020869314670562744, -0.013138455338776112,... 1536 1.0 \n",
+ "625 [-0.009749102406203747, -0.0068712360225617886... 1536 1.0 \n",
+ "619 [-0.005320307798683643, 0.0009131018887273967,... 1536 1.0 \n",
+ "999 [-0.006057822611182928, -0.015015840530395508,... 1536 1.0 \n",
+ "\n",
+ "[100 rows x 11 columns]"
+ ]
+ },
+ "execution_count": 92,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b80cea9c",
+ "metadata": {},
+ "source": [
+ "## semantic search base text embedding"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 93,
+ "id": "b9280f2b",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-13T14:59:32.786418Z",
+ "start_time": "2023-03-13T14:59:31.781336Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "super coffee: Great coffee and so easy to brew. This coffee has great aroma and is good to the last drop. I actually like all the brands. This is the way coffee should taste!!\n",
+ "Delicious!!!!: A coffee treat. Now that my husband and I drink this coffee, there is no going back to the plain stuff ;).\n",
+ "Full- bodied without a bitter after-taste: This is my everyday coffee choice...a good all around crowd pleaser. Green mountain Sumatra would be my back-up-for-a-change-of-pace second choice...nice t\n"
+ ]
+ }
+ ],
+ "source": [
+ "from openai.embeddings_utils import get_embedding, cosine_similarity\n",
+ "\n",
+ "# search through the reviews for a specific product\n",
+ "def search_reviews(df, query, n=3, pprint=True):\n",
+ " query_embed = get_embedding(\n",
+ " query,\n",
+ " engine=embedding_model\n",
+ " )\n",
+ " df[\"similarity\"] = df.embedding.apply(lambda x: cosine_similarity(x, query_embed))\n",
+ "\n",
+ " results = (\n",
+ " df.sort_values(\"similarity\", ascending=False)\n",
+ " .head(n)\n",
+ " .Combined.str.replace(\"Title: \", \"\")\n",
+ " .str.replace(\"; Content:\", \": \")\n",
+ " )\n",
+ " if pprint:\n",
+ " for r in results:\n",
+ " print(r[:200])\n",
+ " return results\n",
+ "\n",
+ "\n",
+ "results = search_reviews(df, \"good coffee\", n=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "717a8311",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.13"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": true
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}