{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. tokenizer, 构造输入" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- tokenizer, model: 相匹配,tokenizer outputs => model input\n", "- Auto\\*Tokenizer, AutoModel\\*:Generic type\n", "- tokenizer:服务于 model input\n", " - len(input_ids) == len(attention_mask)\n", " - tokenizer(test_senteces[0], ): tokenizer.\\_\\_call\\_\\_:encode\n", " - tokenizer.encode == tokenizer.tokenize + tokenizer.convert_tokens_to_ids\n", " - tokenizer.decode\n", " - tokenizer 工作的原理其实就是 tokenizer.vocab:字典,存储了 token => id 的映射关系\n", " - tokenizer.special_tokens_map\n", " - attention mask 与 padding 相匹配;" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:35:23.828856Z", "start_time": "2022-06-19T02:35:23.825866Z" } }, "outputs": [], "source": [ "test_senteces = ['today is not that bad', 'today is so bad', 'so good']\n", "model_name = 'distilbert-base-uncased-finetuned-sst-2-english'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:20:03.633366Z", "start_time": "2022-06-19T02:19:58.789769Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/chunhuizhang/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n", " return f(*args, **kwds)\n" ] } ], "source": [ "from transformers import AutoTokenizer, AutoModelForSequenceClassification" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:21:21.136712Z", "start_time": "2022-06-19T02:21:07.860869Z" } }, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "model = AutoModelForSequenceClassification.from_pretrained(model_name)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:23:20.499775Z", "start_time": "2022-06-19T02:23:20.495749Z" } }, "outputs": [], "source": [ "batch_input = tokenizer(test_senteces, truncation=True, padding=True, return_tensors='pt')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:23:51.629797Z", "start_time": "2022-06-19T02:23:51.624465Z" } }, "outputs": [ { "data": { "text/plain": [ "{'input_ids': tensor([[ 101, 2651, 2003, 2025, 2008, 2919, 102],\n", " [ 101, 2651, 2003, 2061, 2919, 102, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],\n", " [1, 1, 1, 1, 1, 1, 0]])}" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch_input" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:24:27.711057Z", "start_time": "2022-06-19T02:24:27.705957Z" } }, "outputs": [ { "data": { "text/plain": [ "{'input_ids': [101, 2651, 2003, 2025, 2008, 2919, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer(test_senteces[0], )" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:26:02.232367Z", "start_time": "2022-06-19T02:26:02.227981Z" } }, "outputs": [ { "data": { "text/plain": [ "[101, 2651, 2003, 2025, 2008, 2919, 102]" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.encode(test_senteces[0], )" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:27:01.802959Z", "start_time": "2022-06-19T02:27:01.798355Z" } }, "outputs": [ { "data": { "text/plain": [ "[2651, 2003, 2025, 2008, 2919]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.convert_tokens_to_ids(tokenizer.tokenize(test_senteces[0]))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:28:38.186096Z", "start_time": "2022-06-19T02:28:38.180739Z" } }, "outputs": [ { "data": { "text/plain": [ "'[CLS] today is not that bad [SEP]'" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.decode([101, 2651, 2003, 2025, 2008, 2919, 102])" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:31:37.843855Z", "start_time": "2022-06-19T02:31:37.838167Z" } }, "outputs": [ { "data": { "text/plain": [ "dict_values(['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'])" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.special_tokens_map.values()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:31:17.531347Z", "start_time": "2022-06-19T02:31:17.525170Z" } }, "outputs": [ { "data": { "text/plain": [ "[100, 102, 0, 101, 103]" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.convert_tokens_to_ids([special for special in tokenizer.special_tokens_map.values()])" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:37:00.609769Z", "start_time": "2022-06-19T02:37:00.605463Z" } }, "outputs": [], "source": [ "batch_input = tokenizer(test_senteces, truncation=True, padding=True, return_tensors='pt')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. model,调用模型" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:38:19.070764Z", "start_time": "2022-06-19T02:38:19.068147Z" } }, "outputs": [], "source": [ "import torch\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:39:56.465031Z", "start_time": "2022-06-19T02:39:56.459505Z" } }, "outputs": [ { "data": { "text/plain": [ "DistilBertConfig {\n", " \"_name_or_path\": \"distilbert-base-uncased-finetuned-sst-2-english\",\n", " \"activation\": \"gelu\",\n", " \"architectures\": [\n", " \"DistilBertForSequenceClassification\"\n", " ],\n", " \"attention_dropout\": 0.1,\n", " \"dim\": 768,\n", " \"dropout\": 0.1,\n", " \"finetuning_task\": \"sst-2\",\n", " \"hidden_dim\": 3072,\n", " \"id2label\": {\n", " \"0\": \"NEGATIVE\",\n", " \"1\": \"POSITIVE\"\n", " },\n", " \"initializer_range\": 0.02,\n", " \"label2id\": {\n", " \"NEGATIVE\": 0,\n", " \"POSITIVE\": 1\n", " },\n", " \"max_position_embeddings\": 512,\n", " \"model_type\": \"distilbert\",\n", " \"n_heads\": 12,\n", " \"n_layers\": 6,\n", " \"output_past\": true,\n", " \"pad_token_id\": 0,\n", " \"qa_dropout\": 0.1,\n", " \"seq_classif_dropout\": 0.2,\n", " \"sinusoidal_pos_embds\": false,\n", " \"tie_weights_\": true,\n", " \"transformers_version\": \"4.11.2\",\n", " \"vocab_size\": 30522\n", "}" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.config" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:40:46.821090Z", "start_time": "2022-06-19T02:40:46.723297Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SequenceClassifierOutput(loss=None, logits=tensor([[-3.4620, 3.6118],\n", " [ 4.7508, -3.7899],\n", " [-4.1938, 4.5566]]), hidden_states=None, attentions=None)\n", "tensor([[8.4632e-04, 9.9915e-01],\n", " [9.9980e-01, 1.9531e-04],\n", " [1.5837e-04, 9.9984e-01]])\n", "tensor([1, 0, 1])\n", "['POSITIVE', 'NEGATIVE', 'POSITIVE']\n" ] } ], "source": [ "with torch.no_grad():\n", " outputs = model(**batch_input)\n", " print(outputs)\n", " scores = F.softmax(outputs.logits, dim=1)\n", " print(scores)\n", " labels = torch.argmax(scores, dim=1)\n", " print(labels)\n", " labels = [model.config.id2label[id] for id in labels.tolist()]\n", " print(labels)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:10:28.942776Z", "start_time": "2022-06-19T02:10:28.939489Z" } }, "source": [ "### 3. parse output,输出解析" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "ExecuteTime": { "end_time": "2022-06-19T02:40:57.429808Z", "start_time": "2022-06-19T02:40:57.328151Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SequenceClassifierOutput(loss=None, logits=tensor([[-3.4620, 3.6118],\n", " [ 4.7508, -3.7899],\n", " [-4.1938, 4.5566]]), hidden_states=None, attentions=None)\n", "tensor([[8.4632e-04, 9.9915e-01],\n", " [9.9980e-01, 1.9531e-04],\n", " [1.5837e-04, 9.9984e-01]])\n", "tensor([1, 0, 1])\n", "['POSITIVE', 'NEGATIVE', 'POSITIVE']\n" ] } ], "source": [ "with torch.no_grad():\n", " outputs = model(**batch_input)\n", " print(outputs)\n", " scores = F.softmax(outputs.logits, dim=1)\n", " print(scores)\n", " labels = torch.argmax(scores, dim=1)\n", " print(labels)\n", " labels = [model.config.id2label[id] for id in labels.tolist()]\n", " print(labels)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }