summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--hugface/02_tokenizer_encode_plus_token_type_ids.ipynb617
1 files changed, 617 insertions, 0 deletions
diff --git a/hugface/02_tokenizer_encode_plus_token_type_ids.ipynb b/hugface/02_tokenizer_encode_plus_token_type_ids.ipynb
new file mode 100644
index 0000000..06a0a58
--- /dev/null
+++ b/hugface/02_tokenizer_encode_plus_token_type_ids.ipynb
@@ -0,0 +1,617 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1. 加载模型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 80,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T14:56:22.932982Z",
+ "start_time": "2022-06-20T14:56:22.930078Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from transformers import BertTokenizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 81,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T14:56:31.483548Z",
+ "start_time": "2022-06-20T14:56:31.480598Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "model_name = 'bert-base-uncased'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 82,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T14:56:50.395407Z",
+ "start_time": "2022-06-20T14:56:42.028614Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "tokenizer = BertTokenizer.from_pretrained(model_name)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 83,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T14:56:51.186108Z",
+ "start_time": "2022-06-20T14:56:51.181772Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "PreTrainedTokenizer(name_or_path='bert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})"
+ ]
+ },
+ "execution_count": 83,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 85,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T14:57:32.114749Z",
+ "start_time": "2022-06-20T14:57:32.109702Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'unk_token': '[UNK]',\n",
+ " 'sep_token': '[SEP]',\n",
+ " 'pad_token': '[PAD]',\n",
+ " 'cls_token': '[CLS]',\n",
+ " 'mask_token': '[MASK]'}"
+ ]
+ },
+ "execution_count": 85,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.special_tokens_map"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 86,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T14:58:08.894192Z",
+ "start_time": "2022-06-20T14:58:08.891071Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "special_tokens = list(tokenizer.special_tokens_map.values())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 87,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T14:58:11.659126Z",
+ "start_time": "2022-06-20T14:58:11.653540Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']"
+ ]
+ },
+ "execution_count": 87,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "special_tokens"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 88,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T14:58:28.232023Z",
+ "start_time": "2022-06-20T14:58:28.227861Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[100, 102, 0, 101, 103]"
+ ]
+ },
+ "execution_count": 88,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.convert_tokens_to_ids(special_tokens)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 89,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T14:59:20.727082Z",
+ "start_time": "2022-06-20T14:59:20.722891Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[101, 100, 102, 0, 101, 103, 102]"
+ ]
+ },
+ "execution_count": 89,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.encode(special_tokens)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 90,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T14:59:48.802903Z",
+ "start_time": "2022-06-20T14:59:48.798063Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'[CLS] [UNK] [SEP] [PAD] [CLS] [MASK] [SEP]'"
+ ]
+ },
+ "execution_count": 90,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.decode([101, 100, 102, 0, 101, 103, 102])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2. 认识文本语料"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- newsgroups_train.DESCR\n",
+ "- newsgroups_train.data\n",
+ "- newsgroups_train.target\n",
+ "- newsgroups_train.target_names"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 91,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T15:00:35.577160Z",
+ "start_time": "2022-06-20T15:00:35.574035Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from sklearn.datasets import fetch_20newsgroups"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 92,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T15:00:42.921294Z",
+ "start_time": "2022-06-20T15:00:42.653867Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "newsgroups_train = fetch_20newsgroups(subset='train')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 95,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T15:01:21.114905Z",
+ "start_time": "2022-06-20T15:01:21.110446Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "list"
+ ]
+ },
+ "execution_count": 95,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "type(newsgroups_train.data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 96,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T15:01:28.937567Z",
+ "start_time": "2022-06-20T15:01:28.933964Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "11314"
+ ]
+ },
+ "execution_count": 96,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(newsgroups_train.data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 98,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T15:01:40.160373Z",
+ "start_time": "2022-06-20T15:01:40.155860Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "11314"
+ ]
+ },
+ "execution_count": 98,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(newsgroups_train.target) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 99,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T15:01:57.389065Z",
+ "start_time": "2022-06-20T15:01:57.385988Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from collections import Counter"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 100,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T15:02:03.345302Z",
+ "start_time": "2022-06-20T15:02:03.337241Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Counter({7: 594,\n",
+ " 4: 578,\n",
+ " 1: 584,\n",
+ " 14: 593,\n",
+ " 16: 546,\n",
+ " 13: 594,\n",
+ " 3: 590,\n",
+ " 2: 591,\n",
+ " 8: 598,\n",
+ " 19: 377,\n",
+ " 6: 585,\n",
+ " 0: 480,\n",
+ " 12: 591,\n",
+ " 5: 593,\n",
+ " 10: 600,\n",
+ " 9: 597,\n",
+ " 15: 599,\n",
+ " 17: 564,\n",
+ " 18: 465,\n",
+ " 11: 595})"
+ ]
+ },
+ "execution_count": 100,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "Counter(newsgroups_train.target)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 101,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T15:02:20.691128Z",
+ "start_time": "2022-06-20T15:02:20.686558Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['alt.atheism',\n",
+ " 'comp.graphics',\n",
+ " 'comp.os.ms-windows.misc',\n",
+ " 'comp.sys.ibm.pc.hardware',\n",
+ " 'comp.sys.mac.hardware',\n",
+ " 'comp.windows.x',\n",
+ " 'misc.forsale',\n",
+ " 'rec.autos',\n",
+ " 'rec.motorcycles',\n",
+ " 'rec.sport.baseball',\n",
+ " 'rec.sport.hockey',\n",
+ " 'sci.crypt',\n",
+ " 'sci.electronics',\n",
+ " 'sci.med',\n",
+ " 'sci.space',\n",
+ " 'soc.religion.christian',\n",
+ " 'talk.politics.guns',\n",
+ " 'talk.politics.mideast',\n",
+ " 'talk.politics.misc',\n",
+ " 'talk.religion.misc']"
+ ]
+ },
+ "execution_count": 101,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "newsgroups_train.target_names"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3. tokenizer 补充"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T14:34:05.319605Z",
+ "start_time": "2022-06-20T14:34:05.303549Z"
+ }
+ },
+ "source": [
+ "- input_ids, attention_mask\n",
+ " - mask:bert 另外一个预训练任务,mlm;\n",
+ "- encode_plus, token_type_ids\n",
+ " - token_type_ids: 0:表示第一句,1:表示第二句;可以通过 tokenizer()(tokenizer.\\_\\_call\\_\\_:都是0的);也可以通过 encode_plus 生成/返回(前一句为0,后一句为1);\n",
+ " - 句子对,一般使用在 nsp(next sentence predict,bert 预训练任务)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 102,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T15:02:48.909751Z",
+ "start_time": "2022-06-20T15:02:48.906613Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "test_news = newsgroups_train.data[:3]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 107,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T15:03:04.579584Z",
+ "start_time": "2022-06-20T15:03:04.575396Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1981"
+ ]
+ },
+ "execution_count": 107,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(test_news[2])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 115,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T15:07:13.117543Z",
+ "start_time": "2022-06-20T15:07:13.108526Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'input_ids': [101, 2013, 1024, 3393, 2099, 2595, 3367, 1030, 11333, 2213, 1012, 8529, 2094, 1012, 3968, 2226, 1006, 2073, 1005, 1055, 2026, 2518, 1007, 3395, 1024, 2054, 2482, 2003, 2023, 999, 1029, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
+ ]
+ },
+ "execution_count": 115,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# single sentence 级别的\n",
+ "tokenizer(test_news[0], truncation=True, max_length=32)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 118,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T15:08:30.245910Z",
+ "start_time": "2022-06-20T15:08:30.231817Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "{'input_ids': [101, 2013, 1024, 3393, 2099, 2595, 3367, 1030, 11333, 2213, 1012, 8529, 2094, 1012, 3968, 2226, 102, 2013, 1024, 3124, 5283, 2080, 1030, 9806, 1012, 1057, 1012, 2899, 1012, 3968, 2226, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
+ ]
+ },
+ "execution_count": 118,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# sentence pair 级别\n",
+ "tokenizer.encode_plus(text=test_news[0], text_pair=test_news[1], max_length=32, truncation=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 119,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-20T15:08:51.172868Z",
+ "start_time": "2022-06-20T15:08:51.167536Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'[CLS] from : lerxst @ wam. umd. edu [SEP] from : guykuo @ carson. u. washington. edu [SEP]'"
+ ]
+ },
+ "execution_count": 119,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.decode([101, 2013, 1024, 3393, 2099, 2595, 3367, 1030, 11333, 2213, 1012, 8529, 2094, 1012, 3968, 2226, 102, 2013, 1024, 3124, 5283, 2080, 1030, 9806, 1012, 1057, 1012, 2899, 1012, 3968, 2226, 102])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}