diff options
| -rw-r--r-- | hugface/01_tokenizer_sentiment_analysis.ipynb | 485 | ||||
| -rw-r--r-- | hugface/basics.py | 33 |
2 files changed, 518 insertions, 0 deletions
diff --git a/hugface/01_tokenizer_sentiment_analysis.ipynb b/hugface/01_tokenizer_sentiment_analysis.ipynb new file mode 100644 index 0000000..5d3d36f --- /dev/null +++ b/hugface/01_tokenizer_sentiment_analysis.ipynb @@ -0,0 +1,485 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. tokenizer, 构造输入" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- tokenizer, model: 相匹配,tokenizer outputs => model input\n", + "- Auto\\*Tokenizer, AutoModel\\*:Generic type\n", + "- tokenizer:服务于 model input\n", + " - len(input_ids) == len(attention_mask)\n", + " - tokenizer(test_senteces[0], ): tokenizer.\\_\\_call\\_\\_:encode\n", + " - tokenizer.encode == tokenizer.tokenize + tokenizer.convert_tokens_to_ids\n", + " - tokenizer.decode\n", + " - tokenizer 工作的原理其实就是 tokenizer.vocab:字典,存储了 token => id 的映射关系\n", + " - tokenizer.special_tokens_map\n", + " - attention mask 与 padding 相匹配;" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:35:23.828856Z", + "start_time": "2022-06-19T02:35:23.825866Z" + } + }, + "outputs": [], + "source": [ + "test_senteces = ['today is not that bad', 'today is so bad', 'so good']\n", + "model_name = 'distilbert-base-uncased-finetuned-sst-2-english'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:20:03.633366Z", + "start_time": "2022-06-19T02:19:58.789769Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/chunhuizhang/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n", + " return f(*args, **kwds)\n" + ] + } + ], + "source": [ + "from transformers import AutoTokenizer, AutoModelForSequenceClassification" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:21:21.136712Z", + "start_time": "2022-06-19T02:21:07.860869Z" + } + }, + "outputs": [], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + "model = AutoModelForSequenceClassification.from_pretrained(model_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:23:20.499775Z", + "start_time": "2022-06-19T02:23:20.495749Z" + } + }, + "outputs": [], + "source": [ + "batch_input = tokenizer(test_senteces, truncation=True, padding=True, return_tensors='pt')" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:23:51.629797Z", + "start_time": "2022-06-19T02:23:51.624465Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input_ids': tensor([[ 101, 2651, 2003, 2025, 2008, 2919, 102],\n", + " [ 101, 2651, 2003, 2061, 2919, 102, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 0]])}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_input" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:24:27.711057Z", + "start_time": "2022-06-19T02:24:27.705957Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input_ids': [101, 2651, 2003, 2025, 2008, 2919, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer(test_senteces[0], )" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:26:02.232367Z", + "start_time": "2022-06-19T02:26:02.227981Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[101, 2651, 2003, 2025, 2008, 2919, 102]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.encode(test_senteces[0], )" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:27:01.802959Z", + "start_time": "2022-06-19T02:27:01.798355Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[2651, 2003, 2025, 2008, 2919]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.convert_tokens_to_ids(tokenizer.tokenize(test_senteces[0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:28:38.186096Z", + "start_time": "2022-06-19T02:28:38.180739Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'[CLS] today is not that bad [SEP]'" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.decode([101, 2651, 2003, 2025, 2008, 2919, 102])" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:31:37.843855Z", + "start_time": "2022-06-19T02:31:37.838167Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_values(['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'])" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.special_tokens_map.values()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:31:17.531347Z", + "start_time": "2022-06-19T02:31:17.525170Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[100, 102, 0, 101, 103]" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.convert_tokens_to_ids([special for special in tokenizer.special_tokens_map.values()])" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:37:00.609769Z", + "start_time": "2022-06-19T02:37:00.605463Z" + } + }, + "outputs": [], + "source": [ + "batch_input = tokenizer(test_senteces, truncation=True, padding=True, return_tensors='pt')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. model,调用模型" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:38:19.070764Z", + "start_time": "2022-06-19T02:38:19.068147Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:39:56.465031Z", + "start_time": "2022-06-19T02:39:56.459505Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "DistilBertConfig {\n", + " \"_name_or_path\": \"distilbert-base-uncased-finetuned-sst-2-english\",\n", + " \"activation\": \"gelu\",\n", + " \"architectures\": [\n", + " \"DistilBertForSequenceClassification\"\n", + " ],\n", + " \"attention_dropout\": 0.1,\n", + " \"dim\": 768,\n", + " \"dropout\": 0.1,\n", + " \"finetuning_task\": \"sst-2\",\n", + " \"hidden_dim\": 3072,\n", + " \"id2label\": {\n", + " \"0\": \"NEGATIVE\",\n", + " \"1\": \"POSITIVE\"\n", + " },\n", + " \"initializer_range\": 0.02,\n", + " \"label2id\": {\n", + " \"NEGATIVE\": 0,\n", + " \"POSITIVE\": 1\n", + " },\n", + " \"max_position_embeddings\": 512,\n", + " \"model_type\": \"distilbert\",\n", + " \"n_heads\": 12,\n", + " \"n_layers\": 6,\n", + " \"output_past\": true,\n", + " \"pad_token_id\": 0,\n", + " \"qa_dropout\": 0.1,\n", + " \"seq_classif_dropout\": 0.2,\n", + " \"sinusoidal_pos_embds\": false,\n", + " \"tie_weights_\": true,\n", + " \"transformers_version\": \"4.11.2\",\n", + " \"vocab_size\": 30522\n", + "}" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.config" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:40:46.821090Z", + "start_time": "2022-06-19T02:40:46.723297Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SequenceClassifierOutput(loss=None, logits=tensor([[-3.4620, 3.6118],\n", + " [ 4.7508, -3.7899],\n", + " [-4.1938, 4.5566]]), hidden_states=None, attentions=None)\n", + "tensor([[8.4632e-04, 9.9915e-01],\n", + " [9.9980e-01, 1.9531e-04],\n", + " [1.5837e-04, 9.9984e-01]])\n", + "tensor([1, 0, 1])\n", + "['POSITIVE', 'NEGATIVE', 'POSITIVE']\n" + ] + } + ], + "source": [ + "with torch.no_grad():\n", + " outputs = model(**batch_input)\n", + " print(outputs)\n", + " scores = F.softmax(outputs.logits, dim=1)\n", + " print(scores)\n", + " labels = torch.argmax(scores, dim=1)\n", + " print(labels)\n", + " labels = [model.config.id2label[id] for id in labels.tolist()]\n", + " print(labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:10:28.942776Z", + "start_time": "2022-06-19T02:10:28.939489Z" + } + }, + "source": [ + "### 3. parse output,输出解析" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-19T02:40:57.429808Z", + "start_time": "2022-06-19T02:40:57.328151Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SequenceClassifierOutput(loss=None, logits=tensor([[-3.4620, 3.6118],\n", + " [ 4.7508, -3.7899],\n", + " [-4.1938, 4.5566]]), hidden_states=None, attentions=None)\n", + "tensor([[8.4632e-04, 9.9915e-01],\n", + " [9.9980e-01, 1.9531e-04],\n", + " [1.5837e-04, 9.9984e-01]])\n", + "tensor([1, 0, 1])\n", + "['POSITIVE', 'NEGATIVE', 'POSITIVE']\n" + ] + } + ], + "source": [ + "with torch.no_grad():\n", + " outputs = model(**batch_input)\n", + " print(outputs)\n", + " scores = F.softmax(outputs.logits, dim=1)\n", + " print(scores)\n", + " labels = torch.argmax(scores, dim=1)\n", + " print(labels)\n", + " labels = [model.config.id2label[id] for id in labels.tolist()]\n", + " print(labels)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/hugface/basics.py b/hugface/basics.py new file mode 100644 index 0000000..acba6ff --- /dev/null +++ b/hugface/basics.py @@ -0,0 +1,33 @@ + +import transformers +from transformers import pipeline +import torch.nn.functional as F +import torch + +from transformers import AutoTokenizer, AutoModelForSequenceClassification + +model_name = 'distilbert-base-uncased-finetuned-sst-2-english' +# model_name = 'bert-base-uncased' + +model = AutoModelForSequenceClassification.from_pretrained(model_name) +tokenizer = AutoTokenizer.from_pretrained(model_name) + +# clf = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer) +# +# test_sentence = 'today is not that bad' +test_sentences = ['today is not that bad', 'today is so bad'] +# res = clf(test_sentences) +# print(res) +# + + +batch = tokenizer(test_sentences, padding='max_length', truncation=True, max_length=512, return_tensors='pt') + +with torch.no_grad(): + # print(**batch) + outputs = model(**batch) + print(outputs) + scores = F.softmax(outputs.logits, dim=1) + labels = torch.argmax(scores, dim=1) + labels = [model.config.id2label[id] for id in labels.tolist()] + print(labels) |
