{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:29:24.616302Z", "start_time": "2022-07-10T02:29:18.081012Z" } }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from transformers import BertModel, BertTokenizer" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:29:41.979562Z", "start_time": "2022-07-10T02:29:24.618327Z" }, "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] } ], "source": [ "model_name = 'bert-base-uncased'\n", "\n", "tokenizer = BertTokenizer.from_pretrained(model_name)\n", "model = BertModel.from_pretrained(model_name, output_hidden_states=True)" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:23:46.325851Z", "start_time": "2022-07-10T02:23:46.322584Z" } }, "source": [ "### 1. input " ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:43:56.693361Z", "start_time": "2022-07-10T02:43:56.691053Z" } }, "outputs": [], "source": [ "text = \"After stealing money from the bank vault, the bank robber was seen \" \\\n", " \"fishing on the Mississippi river bank.\"" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:44:09.053456Z", "start_time": "2022-07-10T02:44:09.049462Z" } }, "outputs": [], "source": [ "token_input = tokenizer(text, return_tensors='pt')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:44:10.582118Z", "start_time": "2022-07-10T02:44:10.558526Z" } }, "outputs": [ { "data": { "text/plain": [ "{'input_ids': tensor([[ 101, 2044, 11065, 2769, 2013, 1996, 2924, 11632, 1010, 1996,\n", " 2924, 27307, 2001, 2464, 5645, 2006, 1996, 5900, 2314, 2924,\n", " 1012, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "token_input" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:45:01.400030Z", "start_time": "2022-07-10T02:45:01.395465Z" } }, "outputs": [ { "data": { "text/plain": [ "(tensor([[ 101, 2044, 11065, 2769, 2013, 1996, 2924, 11632, 1010, 1996,\n", " 2924, 27307, 2001, 2464, 5645, 2006, 1996, 5900, 2314, 2924,\n", " 1012, 102]]), torch.Size([1, 22]))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "token_input['input_ids'], token_input['input_ids'].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- batch_size = 1, 只有一个句子,序列长度为 22(未 truncate 及 padding)\n", "- " ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:23:36.420705Z", "start_time": "2022-07-10T02:23:36.418239Z" } }, "source": [ "### 2. model forward" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- forward\n", " - embedding => encoder => pooler" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:46:02.848067Z", "start_time": "2022-07-10T02:46:02.703744Z" } }, "outputs": [], "source": [ "model.eval()\n", "with torch.no_grad():\n", " outputs = model(**token_input)" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:23:44.578754Z", "start_time": "2022-07-10T02:23:44.576132Z" } }, "source": [ "### 3. output" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:40:09.069362Z", "start_time": "2022-07-10T02:40:09.065437Z" } }, "source": [ "- len(outputs) == 3\n", "- outputs[0]\n", " - last_hidden_state, shape: batch_size\\*seq_len\\*hidden_size(1\\*22\\*768)\n", "- outputs[1]\n", " - pooler_output, shape: batch_size\\*hidden_size(1\\*768)\n", " - Last layer hidden-state of the first token of the sequence (classification token, [CLS])\n", "- outputs[2] (model.config.output_hidden_states = True) \n", " - type: tuple\n", " - one for the output of the embeddings(1), if the model has an embedding layer(12), + one for the output of each layer) \n", " - (1+12)\\*(batch_size\\*seq_len\\*hidden_size) = 13\\*1\\*22\\*768\n", " \n", " \n", " \n", "- outputs[0] == outputs[2][-1]\n", "- outputs[1] == model.pooler(outputs[2][-1])\n", "- outputs[2][0] == model.embeddings(token_input['input_ids'], token_input['token_type_ids'])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:46:38.864887Z", "start_time": "2022-07-10T02:46:38.861110Z" } }, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(outputs)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:48:44.742132Z", "start_time": "2022-07-10T02:48:44.736806Z" } }, "outputs": [ { "data": { "text/plain": [ "(tuple, 13)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(outputs[2]), len(outputs[2])" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:50:23.068317Z", "start_time": "2022-07-10T02:50:23.059660Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[[True, True, True, ..., True, True, True],\n", " [True, True, True, ..., True, True, True],\n", " [True, True, True, ..., True, True, True],\n", " ...,\n", " [True, True, True, ..., True, True, True],\n", " [True, True, True, ..., True, True, True],\n", " [True, True, True, ..., True, True, True]]])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "outputs[0] == outputs[2][-1]" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:54:20.023961Z", "start_time": "2022-07-10T02:54:20.014364Z" }, "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "tensor([[[True, True, True, ..., True, True, True],\n", " [True, True, True, ..., True, True, True],\n", " [True, True, True, ..., True, True, True],\n", " ...,\n", " [True, True, True, ..., True, True, True],\n", " [True, True, True, ..., True, True, True],\n", " [True, True, True, ..., True, True, True]]])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "outputs[2][0] == model.embeddings(token_input['input_ids'], token_input['token_type_ids'])" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "ExecuteTime": { "end_time": "2022-07-10T02:56:10.311946Z", "start_time": "2022-07-10T02:56:10.307742Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 torch.Size([1, 22, 768])\n", "1 torch.Size([1, 22, 768])\n", "2 torch.Size([1, 22, 768])\n", "3 torch.Size([1, 22, 768])\n", "4 torch.Size([1, 22, 768])\n", "5 torch.Size([1, 22, 768])\n", "6 torch.Size([1, 22, 768])\n", "7 torch.Size([1, 22, 768])\n", "8 torch.Size([1, 22, 768])\n", "9 torch.Size([1, 22, 768])\n", "10 torch.Size([1, 22, 768])\n", "11 torch.Size([1, 22, 768])\n", "12 torch.Size([1, 22, 768])\n" ] } ], "source": [ "for i in range(len(outputs[2])):\n", " print(i, outputs[2][i].shape)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }