summaryrefslogtreecommitdiff
path: root/fine_tune
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2022-07-04 21:44:58 +0800
committerzhang <zch921005@126.com>2022-07-04 21:44:58 +0800
commitb116d9f9a5e0aa2bcf48d86fdd9bba92ecf1b526 (patch)
tree5b4fcdd51a71d5f9763796d9a646d5d73ddc4b60 /fine_tune
parente3c114a02cb5e704324459a7a0eb2601fdd937e2 (diff)
torch.no_grad vs. requires_grad
Diffstat (limited to 'fine_tune')
-rw-r--r--fine_tune/bert/bert_fine_tune.ipynb850
1 files changed, 850 insertions, 0 deletions
diff --git a/fine_tune/bert/bert_fine_tune.ipynb b/fine_tune/bert/bert_fine_tune.ipynb
new file mode 100644
index 0000000..6554ba0
--- /dev/null
+++ b/fine_tune/bert/bert_fine_tune.ipynb
@@ -0,0 +1,850 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T08:55:39.339410Z",
+ "start_time": "2022-06-25T08:55:38.522277Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import re\n",
+ "from tqdm import tqdm\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T08:55:45.550205Z",
+ "start_time": "2022-06-25T08:55:45.521302Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ " # Load data and set labels\n",
+ "data_complaint = pd.read_csv('data/complaint1700.csv')\n",
+ "data_complaint['label'] = 0\n",
+ "data_non_complaint = pd.read_csv('data/noncomplaint1700.csv')\n",
+ "data_non_complaint['label'] = 1\n",
+ "\n",
+ "# Concatenate complaining and non-complaining data\n",
+ "data = pd.concat([data_complaint, data_non_complaint], axis=0).reset_index(drop=True)\n",
+ "\n",
+ "# Drop 'airline' column\n",
+ "data.drop(['airline'], inplace=True, axis=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T08:55:53.310710Z",
+ "start_time": "2022-06-25T08:55:53.295841Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>id</th>\n",
+ " <th>tweet</th>\n",
+ " <th>label</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>2579</th>\n",
+ " <td>82091</td>\n",
+ " <td>@AlaskaAir @RSherman_25 Thank you so much!! Ca...</td>\n",
+ " <td>1</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>657</th>\n",
+ " <td>147575</td>\n",
+ " <td>@DeltaAssist hi. I lost my sunglasses on a fli...</td>\n",
+ " <td>0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>1971</th>\n",
+ " <td>23890</td>\n",
+ " <td>Flights to #PuertoRico booked on @JetBlue! Can...</td>\n",
+ " <td>1</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>3312</th>\n",
+ " <td>160070</td>\n",
+ " <td>@united Do you offer open-ended tickets? CLT-B...</td>\n",
+ " <td>1</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>1034</th>\n",
+ " <td>63946</td>\n",
+ " <td>So @AmericanAir I'm going to need you all to g...</td>\n",
+ " <td>0</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " id tweet label\n",
+ "2579 82091 @AlaskaAir @RSherman_25 Thank you so much!! Ca... 1\n",
+ "657 147575 @DeltaAssist hi. I lost my sunglasses on a fli... 0\n",
+ "1971 23890 Flights to #PuertoRico booked on @JetBlue! Can... 1\n",
+ "3312 160070 @united Do you offer open-ended tickets? CLT-B... 1\n",
+ "1034 63946 So @AmericanAir I'm going to need you all to g... 0"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data.sample(5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T08:56:10.130909Z",
+ "start_time": "2022-06-25T08:56:08.960476Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "No GPU available, using the CPU instead.\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "\n",
+ "if torch.cuda.is_available(): \n",
+ " device = torch.device(\"cuda\")\n",
+ " print(f'There are {torch.cuda.device_count()} GPU(s) available.')\n",
+ " print('Device name:', torch.cuda.get_device_name(0))\n",
+ "\n",
+ "else:\n",
+ " print('No GPU available, using the CPU instead.')\n",
+ " device = torch.device(\"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T08:56:37.054745Z",
+ "start_time": "2022-06-25T08:56:37.051239Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def text_preprocessing(text):\n",
+ " \"\"\"\n",
+ " - Remove entity mentions (eg. '@united')\n",
+ " - Correct errors (eg. '&amp;' to '&')\n",
+ " @param text (str): a string to be processed.\n",
+ " @return text (Str): the processed string.\n",
+ " \"\"\"\n",
+ " # Remove '@name'\n",
+ " text = re.sub(r'(@.*?)[\\s]', ' ', text)\n",
+ "\n",
+ " # Replace '&amp;' with '&'\n",
+ " text = re.sub(r'&amp;', '&', text)\n",
+ "\n",
+ " # Remove trailing whitespace\n",
+ " text = re.sub(r'\\s+', ' ', text).strip()\n",
+ "\n",
+ " return text"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T08:57:01.039579Z",
+ "start_time": "2022-06-25T08:57:00.483491Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from sklearn.model_selection import train_test_split\n",
+ "\n",
+ "X = data.tweet.values\n",
+ "y = data.label.values\n",
+ "\n",
+ "X_train, X_val, y_train, y_val =\\\n",
+ " train_test_split(X, y, test_size=0.1, random_state=2020)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T08:59:17.779189Z",
+ "start_time": "2022-06-25T08:59:17.759147Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Load test data\n",
+ "test_data = pd.read_csv('data/test_data.csv')\n",
+ "\n",
+ "# Keep important columns\n",
+ "test_data = test_data[['id', 'tweet']]\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T09:00:01.125406Z",
+ "start_time": "2022-06-25T08:59:48.736061Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from transformers import BertTokenizer\n",
+ "\n",
+ "# Load the BERT tokenizer\n",
+ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T09:00:06.381689Z",
+ "start_time": "2022-06-25T09:00:01.928757Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Max length: 68\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Concatenate train data and test data\n",
+ "all_tweets = np.concatenate([data.tweet.values, test_data.tweet.values])\n",
+ "\n",
+ "# Encode our concatenated data\n",
+ "encoded_tweets = [tokenizer.encode(sent, add_special_tokens=True) for sent in all_tweets]\n",
+ "\n",
+ "# Find the maximum length\n",
+ "max_len = max([len(sent) for sent in encoded_tweets])\n",
+ "print('Max length: ', max_len)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T09:00:18.264259Z",
+ "start_time": "2022-06-25T09:00:18.261607Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "max_len = 64"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T09:00:34.288131Z",
+ "start_time": "2022-06-25T09:00:34.282880Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Create a function to tokenize a set of texts\n",
+ "def preprocessing_for_bert(data):\n",
+ " \"\"\"Perform required preprocessing steps for pretrained BERT.\n",
+ " @param data (np.array): Array of texts to be processed.\n",
+ " @return input_ids (torch.Tensor): Tensor of token ids to be fed to a model.\n",
+ " @return attention_masks (torch.Tensor): Tensor of indices specifying which\n",
+ " tokens should be attended to by the model.\n",
+ " \"\"\"\n",
+ " # Create empty lists to store outputs\n",
+ " input_ids = []\n",
+ " attention_masks = []\n",
+ "\n",
+ " # For every sentence...\n",
+ " for sent in data:\n",
+ " # `encode_plus` will:\n",
+ " # (1) Tokenize the sentence\n",
+ " # (2) Add the `[CLS]` and `[SEP]` token to the start and end\n",
+ " # (3) Truncate/Pad sentence to max length\n",
+ " # (4) Map tokens to their IDs\n",
+ " # (5) Create attention mask\n",
+ " # (6) Return a dictionary of outputs\n",
+ " encoded_sent = tokenizer.encode_plus(\n",
+ " text=text_preprocessing(sent), # Preprocess sentence\n",
+ " add_special_tokens=True, # Add `[CLS]` and `[SEP]`\n",
+ " max_length = max_len, # Max length to truncate/pad\n",
+ " pad_to_max_length=True, # Pad sentence to max length\n",
+ " #return_tensors='pt', # Return PyTorch tensor\n",
+ " return_attention_mask=True # Return attention mask\n",
+ " )\n",
+ " \n",
+ " # Add the outputs to the lists\n",
+ " input_ids.append(encoded_sent.get('input_ids'))\n",
+ " attention_masks.append(encoded_sent.get('attention_mask'))\n",
+ "\n",
+ " # Convert lists to tensors\n",
+ " input_ids = torch.tensor(input_ids)\n",
+ " attention_masks = torch.tensor(attention_masks)\n",
+ "\n",
+ " return input_ids, attention_masks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T09:00:52.948339Z",
+ "start_time": "2022-06-25T09:00:51.094279Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n",
+ "/Users/chunhuizhang/anaconda3/lib/python3.6/site-packages/transformers/tokenization_utils_base.py:2217: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
+ " FutureWarning,\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Original: @united I'm having issues. Yesterday I rebooked for 24 hours after I was supposed to fly, now I can't log on &amp; check in. Can you help?\n",
+ "Token IDs: [101, 1045, 1005, 1049, 2383, 3314, 1012, 7483, 1045, 2128, 8654, 2098, 2005, 2484, 2847, 2044, 1045, 2001, 4011, 2000, 4875, 1010, 2085, 1045, 2064, 1005, 1056, 8833, 2006, 1004, 4638, 1999, 1012, 2064, 2017, 2393, 1029, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
+ "Tokenizing data...\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Print sentence 0 and its encoded token ids\n",
+ "token_ids = list(preprocessing_for_bert([X[0]])[0].squeeze().numpy())\n",
+ "print('Original: ', X[0])\n",
+ "print('Token IDs: ', token_ids)\n",
+ "\n",
+ "# Run function `preprocessing_for_bert` on the train set and the validation set\n",
+ "print('Tokenizing data...')\n",
+ "train_inputs, train_masks = preprocessing_for_bert(X_train)\n",
+ "val_inputs, val_masks = preprocessing_for_bert(X_val)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T09:01:08.103694Z",
+ "start_time": "2022-06-25T09:01:08.088666Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 101, 2054, 1005, ..., 0, 0, 0],\n",
+ " [ 101, 2054, 1996, ..., 0, 0, 0],\n",
+ " [ 101, 3524, 2054, ..., 0, 0, 0],\n",
+ " ...,\n",
+ " [ 101, 3294, 17203, ..., 0, 0, 0],\n",
+ " [ 101, 1045, 5223, ..., 0, 0, 0],\n",
+ " [ 101, 1998, 2009, ..., 0, 0, 0]])"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_inputs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T09:01:14.520761Z",
+ "start_time": "2022-06-25T09:01:14.516264Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[1, 1, 1, ..., 0, 0, 0],\n",
+ " [1, 1, 1, ..., 0, 0, 0],\n",
+ " [1, 1, 1, ..., 0, 0, 0],\n",
+ " ...,\n",
+ " [1, 1, 1, ..., 0, 0, 0],\n",
+ " [1, 1, 1, ..., 0, 0, 0],\n",
+ " [1, 1, 1, ..., 0, 0, 0]])"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_masks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T09:01:26.298148Z",
+ "start_time": "2022-06-25T09:01:26.291846Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler\n",
+ "\n",
+ "# Convert other data types to torch.Tensor\n",
+ "train_labels = torch.tensor(y_train)\n",
+ "val_labels = torch.tensor(y_val)\n",
+ "\n",
+ "# For fine-tuning BERT, the authors recommend a batch size of 16 or 32.\n",
+ "batch_size = 32\n",
+ "\n",
+ "# Create the DataLoader for our training set\n",
+ "train_data = TensorDataset(train_inputs, train_masks, train_labels)\n",
+ "train_sampler = RandomSampler(train_data)\n",
+ "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n",
+ "\n",
+ "# Create the DataLoader for our validation set\n",
+ "val_data = TensorDataset(val_inputs, val_masks, val_labels)\n",
+ "val_sampler = SequentialSampler(val_data)\n",
+ "val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=batch_size)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T09:01:39.358697Z",
+ "start_time": "2022-06-25T09:01:39.336610Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from transformers import BertModel\n",
+ "\n",
+ "# Create the BertClassfier class\n",
+ "class BertClassifier(nn.Module):\n",
+ " \"\"\"Bert Model for Classification Tasks.\n",
+ " \"\"\"\n",
+ " def __init__(self, freeze_bert=False):\n",
+ " \"\"\"\n",
+ " @param bert: a BertModel object\n",
+ " @param classifier: a torch.nn.Module classifier\n",
+ " @param freeze_bert (bool): Set `False` to fine-tune the BERT model\n",
+ " \"\"\"\n",
+ " super(BertClassifier, self).__init__()\n",
+ " # Specify hidden size of BERT, hidden size of our classifier, and number of labels\n",
+ " D_in, H, D_out = 768, 50, 2\n",
+ "\n",
+ " # Instantiate BERT model\n",
+ " self.bert = BertModel.from_pretrained('bert-base-uncased')\n",
+ "\n",
+ " # Instantiate an one-layer feed-forward classifier\n",
+ " self.classifier = nn.Sequential(\n",
+ " nn.Linear(D_in, H),\n",
+ " nn.ReLU(),\n",
+ " #nn.Dropout(0.5),\n",
+ " nn.Linear(H, D_out)\n",
+ " )\n",
+ "\n",
+ " # Freeze the BERT model\n",
+ " if freeze_bert:\n",
+ " for param in self.bert.parameters():\n",
+ " param.requires_grad = False\n",
+ " \n",
+ " def forward(self, input_ids, attention_mask):\n",
+ " \"\"\"\n",
+ " Feed input to BERT and the classifier to compute logits.\n",
+ " @param input_ids (torch.Tensor): an input tensor with shape (batch_size,\n",
+ " max_length)\n",
+ " @param attention_mask (torch.Tensor): a tensor that hold attention mask\n",
+ " information with shape (batch_size, max_length)\n",
+ " @return logits (torch.Tensor): an output tensor with shape (batch_size,\n",
+ " num_labels)\n",
+ " \"\"\"\n",
+ " # Feed input to BERT\n",
+ " outputs = self.bert(input_ids=input_ids,\n",
+ " attention_mask=attention_mask)\n",
+ " \n",
+ " # Extract the last hidden state of the token `[CLS]` for classification task\n",
+ " last_hidden_state_cls = outputs[0][:, 0, :]\n",
+ "\n",
+ " # Feed input to classifier to compute logits\n",
+ " logits = self.classifier(last_hidden_state_cls)\n",
+ "\n",
+ " return logits"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T09:02:30.655449Z",
+ "start_time": "2022-06-25T09:02:30.644656Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from transformers import AdamW, get_linear_schedule_with_warmup\n",
+ "\n",
+ "def initialize_model(epochs=4):\n",
+ " \"\"\"Initialize the Bert Classifier, the optimizer and the learning rate scheduler.\n",
+ " \"\"\"\n",
+ " # Instantiate Bert Classifier\n",
+ " bert_classifier = BertClassifier(freeze_bert=False)\n",
+ "\n",
+ " # Tell PyTorch to run the model on GPU\n",
+ " bert_classifier.to(device)\n",
+ "\n",
+ " # Create the optimizer\n",
+ " optimizer = AdamW(bert_classifier.parameters(),\n",
+ " lr=5e-5, # Default learning rate\n",
+ " eps=1e-8 # Default epsilon value\n",
+ " )\n",
+ "\n",
+ " # Total number of training steps\n",
+ " total_steps = len(train_dataloader) * epochs\n",
+ "\n",
+ " # Set up the learning rate scheduler\n",
+ " scheduler = get_linear_schedule_with_warmup(optimizer,\n",
+ " num_warmup_steps=0, # Default value\n",
+ " num_training_steps=total_steps)\n",
+ " return bert_classifier, optimizer, scheduler"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T09:12:53.857467Z",
+ "start_time": "2022-06-25T09:12:53.844730Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import random\n",
+ "import time\n",
+ "\n",
+ "# Specify loss function\n",
+ "loss_fn = nn.CrossEntropyLoss()\n",
+ "\n",
+ "def set_seed(seed_value=42):\n",
+ " \"\"\"Set seed for reproducibility.\n",
+ " \"\"\"\n",
+ " random.seed(seed_value)\n",
+ " np.random.seed(seed_value)\n",
+ " torch.manual_seed(seed_value)\n",
+ " torch.cuda.manual_seed_all(seed_value)\n",
+ "\n",
+ "def train(model, train_dataloader, val_dataloader=None, epochs=4, evaluation=False):\n",
+ " \"\"\"Train the BertClassifier model.\n",
+ " \"\"\"\n",
+ " # Start training loop\n",
+ " print(\"Start training...\\n\")\n",
+ " for epoch_i in range(epochs):\n",
+ " # =======================================\n",
+ " # Training\n",
+ " # =======================================\n",
+ " # Print the header of the result table\n",
+ " print(f\"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}\")\n",
+ " print(\"-\"*70)\n",
+ "\n",
+ " # Measure the elapsed time of each epoch\n",
+ " t0_epoch, t0_batch = time.time(), time.time()\n",
+ "\n",
+ " # Reset tracking variables at the beginning of each epoch\n",
+ " total_loss, batch_loss, batch_counts = 0, 0, 0\n",
+ "\n",
+ " # Put the model into the training mode\n",
+ " model.train()\n",
+ "\n",
+ " # For each batch of training data...\n",
+ " for step, batch in enumerate(train_dataloader):\n",
+ " batch_counts +=1\n",
+ " # Load batch to GPU\n",
+ " b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)\n",
+ "\n",
+ " # Zero out any previously calculated gradients\n",
+ " model.zero_grad()\n",
+ "\n",
+ " # Perform a forward pass. This will return logits.\n",
+ " logits = model(b_input_ids, b_attn_mask)\n",
+ "\n",
+ " # Compute loss and accumulate the loss values\n",
+ " loss = loss_fn(logits, b_labels)\n",
+ " batch_loss += loss.item()\n",
+ " total_loss += loss.item()\n",
+ "\n",
+ " # Perform a backward pass to calculate gradients\n",
+ " loss.backward()\n",
+ "\n",
+ " # Clip the norm of the gradients to 1.0 to prevent \"exploding gradients\"\n",
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
+ "\n",
+ " # Update parameters and the learning rate\n",
+ " optimizer.step()\n",
+ " scheduler.step()\n",
+ "\n",
+ " # Print the loss values and time elapsed for every 20 batches\n",
+ " if (step % 5 == 0 and step != 0) or (step == len(train_dataloader) - 1):\n",
+ " # Calculate time elapsed for 20 batches\n",
+ " time_elapsed = time.time() - t0_batch\n",
+ "\n",
+ " # Print training results\n",
+ " print(f\"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^10} | {'-':^9} | {time_elapsed:^9.2f}\")\n",
+ "\n",
+ " # Reset batch tracking variables\n",
+ " batch_loss, batch_counts = 0, 0\n",
+ " t0_batch = time.time()\n",
+ "\n",
+ " # Calculate the average loss over the entire training data\n",
+ " avg_train_loss = total_loss / len(train_dataloader)\n",
+ "\n",
+ " print(\"-\"*70)\n",
+ " # =======================================\n",
+ " # Evaluation\n",
+ " # =======================================\n",
+ " if evaluation == True:\n",
+ " # After the completion of each training epoch, measure the model's performance\n",
+ " # on our validation set.\n",
+ " val_loss, val_accuracy = evaluate(model, val_dataloader)\n",
+ "\n",
+ " # Print performance over the entire training data\n",
+ " time_elapsed = time.time() - t0_epoch\n",
+ " \n",
+ " print(f\"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}\")\n",
+ " print(\"-\"*70)\n",
+ " print(\"\\n\")\n",
+ " \n",
+ " print(\"Training complete!\")\n",
+ "\n",
+ "\n",
+ "def evaluate(model, val_dataloader):\n",
+ " \"\"\"After the completion of each training epoch, measure the model's performance\n",
+ " on our validation set.\n",
+ " \"\"\"\n",
+ " # Put the model into the evaluation mode. The dropout layers are disabled during\n",
+ " # the test time.\n",
+ " model.eval()\n",
+ "\n",
+ " # Tracking variables\n",
+ " val_accuracy = []\n",
+ " val_loss = []\n",
+ "\n",
+ " # For each batch in our validation set...\n",
+ " for batch in val_dataloader:\n",
+ " # Load batch to GPU\n",
+ " b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)\n",
+ "\n",
+ " # Compute logits\n",
+ " with torch.no_grad():\n",
+ " logits = model(b_input_ids, b_attn_mask)\n",
+ "\n",
+ " # Compute loss\n",
+ " loss = loss_fn(logits, b_labels)\n",
+ " val_loss.append(loss.item())\n",
+ "\n",
+ " # Get the predictions\n",
+ " preds = torch.argmax(logits, dim=1).flatten()\n",
+ "\n",
+ " # Calculate the accuracy rate\n",
+ " accuracy = (preds == b_labels).cpu().numpy().mean() * 100\n",
+ " val_accuracy.append(accuracy)\n",
+ "\n",
+ " # Compute the average accuracy and loss over the validation set.\n",
+ " val_loss = np.mean(val_loss)\n",
+ " val_accuracy = np.mean(val_accuracy)\n",
+ "\n",
+ " return val_loss, val_accuracy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-25T09:54:42.915823Z",
+ "start_time": "2022-06-25T09:12:58.201795Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias']\n",
+ "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Start training...\n",
+ "\n",
+ " Epoch | Batch | Train Loss | Val Loss | Val Acc | Elapsed \n",
+ "----------------------------------------------------------------------\n",
+ " 1 | 5 | 0.685621 | - | - | 79.83 \n",
+ " 1 | 10 | 0.669755 | - | - | 60.98 \n",
+ " 1 | 15 | 0.636234 | - | - | 60.81 \n",
+ " 1 | 20 | 0.603634 | - | - | 60.93 \n",
+ " 1 | 25 | 0.573587 | - | - | 60.90 \n",
+ " 1 | 30 | 0.594417 | - | - | 61.13 \n",
+ " 1 | 35 | 0.577772 | - | - | 60.90 \n",
+ " 1 | 40 | 0.564313 | - | - | 61.48 \n",
+ " 1 | 45 | 0.467151 | - | - | 61.49 \n",
+ " 1 | 50 | 0.477319 | - | - | 61.96 \n",
+ " 1 | 55 | 0.487433 | - | - | 61.16 \n",
+ " 1 | 60 | 0.407353 | - | - | 60.69 \n",
+ " 1 | 65 | 0.532023 | - | - | 60.32 \n",
+ " 1 | 70 | 0.529829 | - | - | 60.40 \n",
+ " 1 | 75 | 0.443143 | - | - | 60.19 \n",
+ " 1 | 80 | 0.432639 | - | - | 60.50 \n",
+ " 1 | 85 | 0.506726 | - | - | 60.32 \n",
+ " 1 | 90 | 0.442358 | - | - | 60.45 \n",
+ " 1 | 95 | 0.539845 | - | - | 55.92 \n",
+ "----------------------------------------------------------------------\n",
+ " 1 | - | 0.536889 | 0.426069 | 80.68 | 1208.44 \n",
+ "----------------------------------------------------------------------\n",
+ "\n",
+ "\n",
+ " Epoch | Batch | Train Loss | Val Loss | Val Acc | Elapsed \n",
+ "----------------------------------------------------------------------\n",
+ " 2 | 5 | 0.325858 | - | - | 82.61 \n",
+ " 2 | 10 | 0.307421 | - | - | 66.54 \n",
+ " 2 | 15 | 0.317675 | - | - | 64.03 \n",
+ " 2 | 20 | 0.326934 | - | - | 63.01 \n",
+ " 2 | 25 | 0.325949 | - | - | 63.21 \n",
+ " 2 | 30 | 0.341778 | - | - | 63.01 \n",
+ " 2 | 35 | 0.273992 | - | - | 63.04 \n",
+ " 2 | 40 | 0.247418 | - | - | 64.86 \n",
+ " 2 | 45 | 0.264468 | - | - | 64.09 \n",
+ " 2 | 50 | 0.370117 | - | - | 63.30 \n",
+ " 2 | 55 | 0.256397 | - | - | 66.72 \n",
+ " 2 | 60 | 0.298844 | - | - | 65.71 \n",
+ " 2 | 65 | 0.371179 | - | - | 70.41 \n",
+ " 2 | 70 | 0.191519 | - | - | 66.28 \n",
+ " 2 | 75 | 0.354108 | - | - | 66.15 \n",
+ " 2 | 80 | 0.326162 | - | - | 66.60 \n",
+ " 2 | 85 | 0.273768 | - | - | 67.04 \n",
+ " 2 | 90 | 0.270890 | - | - | 65.44 \n",
+ " 2 | 95 | 0.294375 | - | - | 60.52 \n",
+ "----------------------------------------------------------------------\n",
+ " 2 | - | 0.302293 | 0.425262 | 81.70 | 1293.09 \n",
+ "----------------------------------------------------------------------\n",
+ "\n",
+ "\n",
+ "Training complete!\n"
+ ]
+ }
+ ],
+ "source": [
+ "set_seed(42) # Set seed for reproducibility\n",
+ "bert_classifier, optimizer, scheduler = initialize_model(epochs=2)\n",
+ "train(bert_classifier, train_dataloader, val_dataloader, epochs=2, evaluation=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}