summaryrefslogtreecommitdiff
path: root/hugface
diff options
context:
space:
mode:
Diffstat (limited to 'hugface')
-rw-r--r--hugface/01_tokenizer_sentiment_analysis.ipynb485
-rw-r--r--hugface/basics.py33
2 files changed, 518 insertions, 0 deletions
diff --git a/hugface/01_tokenizer_sentiment_analysis.ipynb b/hugface/01_tokenizer_sentiment_analysis.ipynb
new file mode 100644
index 0000000..5d3d36f
--- /dev/null
+++ b/hugface/01_tokenizer_sentiment_analysis.ipynb
@@ -0,0 +1,485 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1. tokenizer, 构造输入"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- tokenizer, model: 相匹配,tokenizer outputs => model input\n",
+ "- Auto\\*Tokenizer, AutoModel\\*:Generic type\n",
+ "- tokenizer:服务于 model input\n",
+ " - len(input_ids) == len(attention_mask)\n",
+ " - tokenizer(test_senteces[0], ): tokenizer.\\_\\_call\\_\\_:encode\n",
+ " - tokenizer.encode == tokenizer.tokenize + tokenizer.convert_tokens_to_ids\n",
+ " - tokenizer.decode\n",
+ " - tokenizer 工作的原理其实就是 tokenizer.vocab:字典,存储了 token => id 的映射关系\n",
+ " - tokenizer.special_tokens_map\n",
+ " - attention mask 与 padding 相匹配;"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:35:23.828856Z",
+ "start_time": "2022-06-19T02:35:23.825866Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "test_senteces = ['today is not that bad', 'today is so bad', 'so good']\n",
+ "model_name = 'distilbert-base-uncased-finetuned-sst-2-english'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:20:03.633366Z",
+ "start_time": "2022-06-19T02:19:58.789769Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/chunhuizhang/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
+ " return f(*args, **kwds)\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformers import AutoTokenizer, AutoModelForSequenceClassification"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:21:21.136712Z",
+ "start_time": "2022-06-19T02:21:07.860869Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
+ "model = AutoModelForSequenceClassification.from_pretrained(model_name)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:23:20.499775Z",
+ "start_time": "2022-06-19T02:23:20.495749Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "batch_input = tokenizer(test_senteces, truncation=True, padding=True, return_tensors='pt')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:23:51.629797Z",
+ "start_time": "2022-06-19T02:23:51.624465Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'input_ids': tensor([[ 101, 2651, 2003, 2025, 2008, 2919, 102],\n",
+ " [ 101, 2651, 2003, 2061, 2919, 102, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],\n",
+ " [1, 1, 1, 1, 1, 1, 0]])}"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batch_input"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:24:27.711057Z",
+ "start_time": "2022-06-19T02:24:27.705957Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'input_ids': [101, 2651, 2003, 2025, 2008, 2919, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer(test_senteces[0], )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:26:02.232367Z",
+ "start_time": "2022-06-19T02:26:02.227981Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[101, 2651, 2003, 2025, 2008, 2919, 102]"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.encode(test_senteces[0], )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:27:01.802959Z",
+ "start_time": "2022-06-19T02:27:01.798355Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[2651, 2003, 2025, 2008, 2919]"
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.convert_tokens_to_ids(tokenizer.tokenize(test_senteces[0]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:28:38.186096Z",
+ "start_time": "2022-06-19T02:28:38.180739Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'[CLS] today is not that bad [SEP]'"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.decode([101, 2651, 2003, 2025, 2008, 2919, 102])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:31:37.843855Z",
+ "start_time": "2022-06-19T02:31:37.838167Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "dict_values(['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'])"
+ ]
+ },
+ "execution_count": 28,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.special_tokens_map.values()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:31:17.531347Z",
+ "start_time": "2022-06-19T02:31:17.525170Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[100, 102, 0, 101, 103]"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.convert_tokens_to_ids([special for special in tokenizer.special_tokens_map.values()])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:37:00.609769Z",
+ "start_time": "2022-06-19T02:37:00.605463Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "batch_input = tokenizer(test_senteces, truncation=True, padding=True, return_tensors='pt')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2. model,调用模型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:38:19.070764Z",
+ "start_time": "2022-06-19T02:38:19.068147Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn.functional as F"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:39:56.465031Z",
+ "start_time": "2022-06-19T02:39:56.459505Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "DistilBertConfig {\n",
+ " \"_name_or_path\": \"distilbert-base-uncased-finetuned-sst-2-english\",\n",
+ " \"activation\": \"gelu\",\n",
+ " \"architectures\": [\n",
+ " \"DistilBertForSequenceClassification\"\n",
+ " ],\n",
+ " \"attention_dropout\": 0.1,\n",
+ " \"dim\": 768,\n",
+ " \"dropout\": 0.1,\n",
+ " \"finetuning_task\": \"sst-2\",\n",
+ " \"hidden_dim\": 3072,\n",
+ " \"id2label\": {\n",
+ " \"0\": \"NEGATIVE\",\n",
+ " \"1\": \"POSITIVE\"\n",
+ " },\n",
+ " \"initializer_range\": 0.02,\n",
+ " \"label2id\": {\n",
+ " \"NEGATIVE\": 0,\n",
+ " \"POSITIVE\": 1\n",
+ " },\n",
+ " \"max_position_embeddings\": 512,\n",
+ " \"model_type\": \"distilbert\",\n",
+ " \"n_heads\": 12,\n",
+ " \"n_layers\": 6,\n",
+ " \"output_past\": true,\n",
+ " \"pad_token_id\": 0,\n",
+ " \"qa_dropout\": 0.1,\n",
+ " \"seq_classif_dropout\": 0.2,\n",
+ " \"sinusoidal_pos_embds\": false,\n",
+ " \"tie_weights_\": true,\n",
+ " \"transformers_version\": \"4.11.2\",\n",
+ " \"vocab_size\": 30522\n",
+ "}"
+ ]
+ },
+ "execution_count": 46,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:40:46.821090Z",
+ "start_time": "2022-06-19T02:40:46.723297Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "SequenceClassifierOutput(loss=None, logits=tensor([[-3.4620, 3.6118],\n",
+ " [ 4.7508, -3.7899],\n",
+ " [-4.1938, 4.5566]]), hidden_states=None, attentions=None)\n",
+ "tensor([[8.4632e-04, 9.9915e-01],\n",
+ " [9.9980e-01, 1.9531e-04],\n",
+ " [1.5837e-04, 9.9984e-01]])\n",
+ "tensor([1, 0, 1])\n",
+ "['POSITIVE', 'NEGATIVE', 'POSITIVE']\n"
+ ]
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " outputs = model(**batch_input)\n",
+ " print(outputs)\n",
+ " scores = F.softmax(outputs.logits, dim=1)\n",
+ " print(scores)\n",
+ " labels = torch.argmax(scores, dim=1)\n",
+ " print(labels)\n",
+ " labels = [model.config.id2label[id] for id in labels.tolist()]\n",
+ " print(labels)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:10:28.942776Z",
+ "start_time": "2022-06-19T02:10:28.939489Z"
+ }
+ },
+ "source": [
+ "### 3. parse output,输出解析"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-19T02:40:57.429808Z",
+ "start_time": "2022-06-19T02:40:57.328151Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "SequenceClassifierOutput(loss=None, logits=tensor([[-3.4620, 3.6118],\n",
+ " [ 4.7508, -3.7899],\n",
+ " [-4.1938, 4.5566]]), hidden_states=None, attentions=None)\n",
+ "tensor([[8.4632e-04, 9.9915e-01],\n",
+ " [9.9980e-01, 1.9531e-04],\n",
+ " [1.5837e-04, 9.9984e-01]])\n",
+ "tensor([1, 0, 1])\n",
+ "['POSITIVE', 'NEGATIVE', 'POSITIVE']\n"
+ ]
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " outputs = model(**batch_input)\n",
+ " print(outputs)\n",
+ " scores = F.softmax(outputs.logits, dim=1)\n",
+ " print(scores)\n",
+ " labels = torch.argmax(scores, dim=1)\n",
+ " print(labels)\n",
+ " labels = [model.config.id2label[id] for id in labels.tolist()]\n",
+ " print(labels)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/hugface/basics.py b/hugface/basics.py
new file mode 100644
index 0000000..acba6ff
--- /dev/null
+++ b/hugface/basics.py
@@ -0,0 +1,33 @@
+
+import transformers
+from transformers import pipeline
+import torch.nn.functional as F
+import torch
+
+from transformers import AutoTokenizer, AutoModelForSequenceClassification
+
+model_name = 'distilbert-base-uncased-finetuned-sst-2-english'
+# model_name = 'bert-base-uncased'
+
+model = AutoModelForSequenceClassification.from_pretrained(model_name)
+tokenizer = AutoTokenizer.from_pretrained(model_name)
+
+# clf = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
+#
+# test_sentence = 'today is not that bad'
+test_sentences = ['today is not that bad', 'today is so bad']
+# res = clf(test_sentences)
+# print(res)
+#
+
+
+batch = tokenizer(test_sentences, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
+
+with torch.no_grad():
+ # print(**batch)
+ outputs = model(**batch)
+ print(outputs)
+ scores = F.softmax(outputs.logits, dim=1)
+ labels = torch.argmax(scores, dim=1)
+ labels = [model.config.id2label[id] for id in labels.tolist()]
+ print(labels)