diff options
Diffstat (limited to 'fine_tune/bert')
| -rw-r--r-- | fine_tune/bert/bert_fine_tune.ipynb | 850 |
1 files changed, 850 insertions, 0 deletions
diff --git a/fine_tune/bert/bert_fine_tune.ipynb b/fine_tune/bert/bert_fine_tune.ipynb new file mode 100644 index 0000000..6554ba0 --- /dev/null +++ b/fine_tune/bert/bert_fine_tune.ipynb @@ -0,0 +1,850 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T08:55:39.339410Z", + "start_time": "2022-06-25T08:55:38.522277Z" + } + }, + "outputs": [], + "source": [ + "import os\n", + "import re\n", + "from tqdm import tqdm\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T08:55:45.550205Z", + "start_time": "2022-06-25T08:55:45.521302Z" + } + }, + "outputs": [], + "source": [ + " # Load data and set labels\n", + "data_complaint = pd.read_csv('data/complaint1700.csv')\n", + "data_complaint['label'] = 0\n", + "data_non_complaint = pd.read_csv('data/noncomplaint1700.csv')\n", + "data_non_complaint['label'] = 1\n", + "\n", + "# Concatenate complaining and non-complaining data\n", + "data = pd.concat([data_complaint, data_non_complaint], axis=0).reset_index(drop=True)\n", + "\n", + "# Drop 'airline' column\n", + "data.drop(['airline'], inplace=True, axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T08:55:53.310710Z", + "start_time": "2022-06-25T08:55:53.295841Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>id</th>\n", + " <th>tweet</th>\n", + " <th>label</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>2579</th>\n", + " <td>82091</td>\n", + " <td>@AlaskaAir @RSherman_25 Thank you so much!! Ca...</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>657</th>\n", + " <td>147575</td>\n", + " <td>@DeltaAssist hi. I lost my sunglasses on a fli...</td>\n", + " <td>0</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1971</th>\n", + " <td>23890</td>\n", + " <td>Flights to #PuertoRico booked on @JetBlue! Can...</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3312</th>\n", + " <td>160070</td>\n", + " <td>@united Do you offer open-ended tickets? CLT-B...</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1034</th>\n", + " <td>63946</td>\n", + " <td>So @AmericanAir I'm going to need you all to g...</td>\n", + " <td>0</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " id tweet label\n", + "2579 82091 @AlaskaAir @RSherman_25 Thank you so much!! Ca... 1\n", + "657 147575 @DeltaAssist hi. I lost my sunglasses on a fli... 0\n", + "1971 23890 Flights to #PuertoRico booked on @JetBlue! Can... 1\n", + "3312 160070 @united Do you offer open-ended tickets? CLT-B... 1\n", + "1034 63946 So @AmericanAir I'm going to need you all to g... 0" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.sample(5)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T08:56:10.130909Z", + "start_time": "2022-06-25T08:56:08.960476Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No GPU available, using the CPU instead.\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "if torch.cuda.is_available(): \n", + " device = torch.device(\"cuda\")\n", + " print(f'There are {torch.cuda.device_count()} GPU(s) available.')\n", + " print('Device name:', torch.cuda.get_device_name(0))\n", + "\n", + "else:\n", + " print('No GPU available, using the CPU instead.')\n", + " device = torch.device(\"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T08:56:37.054745Z", + "start_time": "2022-06-25T08:56:37.051239Z" + } + }, + "outputs": [], + "source": [ + "def text_preprocessing(text):\n", + " \"\"\"\n", + " - Remove entity mentions (eg. '@united')\n", + " - Correct errors (eg. '&' to '&')\n", + " @param text (str): a string to be processed.\n", + " @return text (Str): the processed string.\n", + " \"\"\"\n", + " # Remove '@name'\n", + " text = re.sub(r'(@.*?)[\\s]', ' ', text)\n", + "\n", + " # Replace '&' with '&'\n", + " text = re.sub(r'&', '&', text)\n", + "\n", + " # Remove trailing whitespace\n", + " text = re.sub(r'\\s+', ' ', text).strip()\n", + "\n", + " return text" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T08:57:01.039579Z", + "start_time": "2022-06-25T08:57:00.483491Z" + } + }, + "outputs": [], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "\n", + "X = data.tweet.values\n", + "y = data.label.values\n", + "\n", + "X_train, X_val, y_train, y_val =\\\n", + " train_test_split(X, y, test_size=0.1, random_state=2020)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T08:59:17.779189Z", + "start_time": "2022-06-25T08:59:17.759147Z" + } + }, + "outputs": [], + "source": [ + "# Load test data\n", + "test_data = pd.read_csv('data/test_data.csv')\n", + "\n", + "# Keep important columns\n", + "test_data = test_data[['id', 'tweet']]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T09:00:01.125406Z", + "start_time": "2022-06-25T08:59:48.736061Z" + } + }, + "outputs": [], + "source": [ + "from transformers import BertTokenizer\n", + "\n", + "# Load the BERT tokenizer\n", + "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T09:00:06.381689Z", + "start_time": "2022-06-25T09:00:01.928757Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Max length: 68\n" + ] + } + ], + "source": [ + "# Concatenate train data and test data\n", + "all_tweets = np.concatenate([data.tweet.values, test_data.tweet.values])\n", + "\n", + "# Encode our concatenated data\n", + "encoded_tweets = [tokenizer.encode(sent, add_special_tokens=True) for sent in all_tweets]\n", + "\n", + "# Find the maximum length\n", + "max_len = max([len(sent) for sent in encoded_tweets])\n", + "print('Max length: ', max_len)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T09:00:18.264259Z", + "start_time": "2022-06-25T09:00:18.261607Z" + } + }, + "outputs": [], + "source": [ + "max_len = 64" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T09:00:34.288131Z", + "start_time": "2022-06-25T09:00:34.282880Z" + } + }, + "outputs": [], + "source": [ + "# Create a function to tokenize a set of texts\n", + "def preprocessing_for_bert(data):\n", + " \"\"\"Perform required preprocessing steps for pretrained BERT.\n", + " @param data (np.array): Array of texts to be processed.\n", + " @return input_ids (torch.Tensor): Tensor of token ids to be fed to a model.\n", + " @return attention_masks (torch.Tensor): Tensor of indices specifying which\n", + " tokens should be attended to by the model.\n", + " \"\"\"\n", + " # Create empty lists to store outputs\n", + " input_ids = []\n", + " attention_masks = []\n", + "\n", + " # For every sentence...\n", + " for sent in data:\n", + " # `encode_plus` will:\n", + " # (1) Tokenize the sentence\n", + " # (2) Add the `[CLS]` and `[SEP]` token to the start and end\n", + " # (3) Truncate/Pad sentence to max length\n", + " # (4) Map tokens to their IDs\n", + " # (5) Create attention mask\n", + " # (6) Return a dictionary of outputs\n", + " encoded_sent = tokenizer.encode_plus(\n", + " text=text_preprocessing(sent), # Preprocess sentence\n", + " add_special_tokens=True, # Add `[CLS]` and `[SEP]`\n", + " max_length = max_len, # Max length to truncate/pad\n", + " pad_to_max_length=True, # Pad sentence to max length\n", + " #return_tensors='pt', # Return PyTorch tensor\n", + " return_attention_mask=True # Return attention mask\n", + " )\n", + " \n", + " # Add the outputs to the lists\n", + " input_ids.append(encoded_sent.get('input_ids'))\n", + " attention_masks.append(encoded_sent.get('attention_mask'))\n", + "\n", + " # Convert lists to tensors\n", + " input_ids = torch.tensor(input_ids)\n", + " attention_masks = torch.tensor(attention_masks)\n", + "\n", + " return input_ids, attention_masks" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T09:00:52.948339Z", + "start_time": "2022-06-25T09:00:51.094279Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n", + "/Users/chunhuizhang/anaconda3/lib/python3.6/site-packages/transformers/tokenization_utils_base.py:2217: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n", + " FutureWarning,\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original: @united I'm having issues. Yesterday I rebooked for 24 hours after I was supposed to fly, now I can't log on & check in. Can you help?\n", + "Token IDs: [101, 1045, 1005, 1049, 2383, 3314, 1012, 7483, 1045, 2128, 8654, 2098, 2005, 2484, 2847, 2044, 1045, 2001, 4011, 2000, 4875, 1010, 2085, 1045, 2064, 1005, 1056, 8833, 2006, 1004, 4638, 1999, 1012, 2064, 2017, 2393, 1029, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n", + "Tokenizing data...\n" + ] + } + ], + "source": [ + "# Print sentence 0 and its encoded token ids\n", + "token_ids = list(preprocessing_for_bert([X[0]])[0].squeeze().numpy())\n", + "print('Original: ', X[0])\n", + "print('Token IDs: ', token_ids)\n", + "\n", + "# Run function `preprocessing_for_bert` on the train set and the validation set\n", + "print('Tokenizing data...')\n", + "train_inputs, train_masks = preprocessing_for_bert(X_train)\n", + "val_inputs, val_masks = preprocessing_for_bert(X_val)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T09:01:08.103694Z", + "start_time": "2022-06-25T09:01:08.088666Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 101, 2054, 1005, ..., 0, 0, 0],\n", + " [ 101, 2054, 1996, ..., 0, 0, 0],\n", + " [ 101, 3524, 2054, ..., 0, 0, 0],\n", + " ...,\n", + " [ 101, 3294, 17203, ..., 0, 0, 0],\n", + " [ 101, 1045, 5223, ..., 0, 0, 0],\n", + " [ 101, 1998, 2009, ..., 0, 0, 0]])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T09:01:14.520761Z", + "start_time": "2022-06-25T09:01:14.516264Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1, 1, 1, ..., 0, 0, 0],\n", + " [1, 1, 1, ..., 0, 0, 0],\n", + " [1, 1, 1, ..., 0, 0, 0],\n", + " ...,\n", + " [1, 1, 1, ..., 0, 0, 0],\n", + " [1, 1, 1, ..., 0, 0, 0],\n", + " [1, 1, 1, ..., 0, 0, 0]])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_masks" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T09:01:26.298148Z", + "start_time": "2022-06-25T09:01:26.291846Z" + } + }, + "outputs": [], + "source": [ + "from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler\n", + "\n", + "# Convert other data types to torch.Tensor\n", + "train_labels = torch.tensor(y_train)\n", + "val_labels = torch.tensor(y_val)\n", + "\n", + "# For fine-tuning BERT, the authors recommend a batch size of 16 or 32.\n", + "batch_size = 32\n", + "\n", + "# Create the DataLoader for our training set\n", + "train_data = TensorDataset(train_inputs, train_masks, train_labels)\n", + "train_sampler = RandomSampler(train_data)\n", + "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n", + "\n", + "# Create the DataLoader for our validation set\n", + "val_data = TensorDataset(val_inputs, val_masks, val_labels)\n", + "val_sampler = SequentialSampler(val_data)\n", + "val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T09:01:39.358697Z", + "start_time": "2022-06-25T09:01:39.336610Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from transformers import BertModel\n", + "\n", + "# Create the BertClassfier class\n", + "class BertClassifier(nn.Module):\n", + " \"\"\"Bert Model for Classification Tasks.\n", + " \"\"\"\n", + " def __init__(self, freeze_bert=False):\n", + " \"\"\"\n", + " @param bert: a BertModel object\n", + " @param classifier: a torch.nn.Module classifier\n", + " @param freeze_bert (bool): Set `False` to fine-tune the BERT model\n", + " \"\"\"\n", + " super(BertClassifier, self).__init__()\n", + " # Specify hidden size of BERT, hidden size of our classifier, and number of labels\n", + " D_in, H, D_out = 768, 50, 2\n", + "\n", + " # Instantiate BERT model\n", + " self.bert = BertModel.from_pretrained('bert-base-uncased')\n", + "\n", + " # Instantiate an one-layer feed-forward classifier\n", + " self.classifier = nn.Sequential(\n", + " nn.Linear(D_in, H),\n", + " nn.ReLU(),\n", + " #nn.Dropout(0.5),\n", + " nn.Linear(H, D_out)\n", + " )\n", + "\n", + " # Freeze the BERT model\n", + " if freeze_bert:\n", + " for param in self.bert.parameters():\n", + " param.requires_grad = False\n", + " \n", + " def forward(self, input_ids, attention_mask):\n", + " \"\"\"\n", + " Feed input to BERT and the classifier to compute logits.\n", + " @param input_ids (torch.Tensor): an input tensor with shape (batch_size,\n", + " max_length)\n", + " @param attention_mask (torch.Tensor): a tensor that hold attention mask\n", + " information with shape (batch_size, max_length)\n", + " @return logits (torch.Tensor): an output tensor with shape (batch_size,\n", + " num_labels)\n", + " \"\"\"\n", + " # Feed input to BERT\n", + " outputs = self.bert(input_ids=input_ids,\n", + " attention_mask=attention_mask)\n", + " \n", + " # Extract the last hidden state of the token `[CLS]` for classification task\n", + " last_hidden_state_cls = outputs[0][:, 0, :]\n", + "\n", + " # Feed input to classifier to compute logits\n", + " logits = self.classifier(last_hidden_state_cls)\n", + "\n", + " return logits" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T09:02:30.655449Z", + "start_time": "2022-06-25T09:02:30.644656Z" + } + }, + "outputs": [], + "source": [ + "from transformers import AdamW, get_linear_schedule_with_warmup\n", + "\n", + "def initialize_model(epochs=4):\n", + " \"\"\"Initialize the Bert Classifier, the optimizer and the learning rate scheduler.\n", + " \"\"\"\n", + " # Instantiate Bert Classifier\n", + " bert_classifier = BertClassifier(freeze_bert=False)\n", + "\n", + " # Tell PyTorch to run the model on GPU\n", + " bert_classifier.to(device)\n", + "\n", + " # Create the optimizer\n", + " optimizer = AdamW(bert_classifier.parameters(),\n", + " lr=5e-5, # Default learning rate\n", + " eps=1e-8 # Default epsilon value\n", + " )\n", + "\n", + " # Total number of training steps\n", + " total_steps = len(train_dataloader) * epochs\n", + "\n", + " # Set up the learning rate scheduler\n", + " scheduler = get_linear_schedule_with_warmup(optimizer,\n", + " num_warmup_steps=0, # Default value\n", + " num_training_steps=total_steps)\n", + " return bert_classifier, optimizer, scheduler" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T09:12:53.857467Z", + "start_time": "2022-06-25T09:12:53.844730Z" + } + }, + "outputs": [], + "source": [ + "import random\n", + "import time\n", + "\n", + "# Specify loss function\n", + "loss_fn = nn.CrossEntropyLoss()\n", + "\n", + "def set_seed(seed_value=42):\n", + " \"\"\"Set seed for reproducibility.\n", + " \"\"\"\n", + " random.seed(seed_value)\n", + " np.random.seed(seed_value)\n", + " torch.manual_seed(seed_value)\n", + " torch.cuda.manual_seed_all(seed_value)\n", + "\n", + "def train(model, train_dataloader, val_dataloader=None, epochs=4, evaluation=False):\n", + " \"\"\"Train the BertClassifier model.\n", + " \"\"\"\n", + " # Start training loop\n", + " print(\"Start training...\\n\")\n", + " for epoch_i in range(epochs):\n", + " # =======================================\n", + " # Training\n", + " # =======================================\n", + " # Print the header of the result table\n", + " print(f\"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}\")\n", + " print(\"-\"*70)\n", + "\n", + " # Measure the elapsed time of each epoch\n", + " t0_epoch, t0_batch = time.time(), time.time()\n", + "\n", + " # Reset tracking variables at the beginning of each epoch\n", + " total_loss, batch_loss, batch_counts = 0, 0, 0\n", + "\n", + " # Put the model into the training mode\n", + " model.train()\n", + "\n", + " # For each batch of training data...\n", + " for step, batch in enumerate(train_dataloader):\n", + " batch_counts +=1\n", + " # Load batch to GPU\n", + " b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)\n", + "\n", + " # Zero out any previously calculated gradients\n", + " model.zero_grad()\n", + "\n", + " # Perform a forward pass. This will return logits.\n", + " logits = model(b_input_ids, b_attn_mask)\n", + "\n", + " # Compute loss and accumulate the loss values\n", + " loss = loss_fn(logits, b_labels)\n", + " batch_loss += loss.item()\n", + " total_loss += loss.item()\n", + "\n", + " # Perform a backward pass to calculate gradients\n", + " loss.backward()\n", + "\n", + " # Clip the norm of the gradients to 1.0 to prevent \"exploding gradients\"\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", + "\n", + " # Update parameters and the learning rate\n", + " optimizer.step()\n", + " scheduler.step()\n", + "\n", + " # Print the loss values and time elapsed for every 20 batches\n", + " if (step % 5 == 0 and step != 0) or (step == len(train_dataloader) - 1):\n", + " # Calculate time elapsed for 20 batches\n", + " time_elapsed = time.time() - t0_batch\n", + "\n", + " # Print training results\n", + " print(f\"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^10} | {'-':^9} | {time_elapsed:^9.2f}\")\n", + "\n", + " # Reset batch tracking variables\n", + " batch_loss, batch_counts = 0, 0\n", + " t0_batch = time.time()\n", + "\n", + " # Calculate the average loss over the entire training data\n", + " avg_train_loss = total_loss / len(train_dataloader)\n", + "\n", + " print(\"-\"*70)\n", + " # =======================================\n", + " # Evaluation\n", + " # =======================================\n", + " if evaluation == True:\n", + " # After the completion of each training epoch, measure the model's performance\n", + " # on our validation set.\n", + " val_loss, val_accuracy = evaluate(model, val_dataloader)\n", + "\n", + " # Print performance over the entire training data\n", + " time_elapsed = time.time() - t0_epoch\n", + " \n", + " print(f\"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}\")\n", + " print(\"-\"*70)\n", + " print(\"\\n\")\n", + " \n", + " print(\"Training complete!\")\n", + "\n", + "\n", + "def evaluate(model, val_dataloader):\n", + " \"\"\"After the completion of each training epoch, measure the model's performance\n", + " on our validation set.\n", + " \"\"\"\n", + " # Put the model into the evaluation mode. The dropout layers are disabled during\n", + " # the test time.\n", + " model.eval()\n", + "\n", + " # Tracking variables\n", + " val_accuracy = []\n", + " val_loss = []\n", + "\n", + " # For each batch in our validation set...\n", + " for batch in val_dataloader:\n", + " # Load batch to GPU\n", + " b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)\n", + "\n", + " # Compute logits\n", + " with torch.no_grad():\n", + " logits = model(b_input_ids, b_attn_mask)\n", + "\n", + " # Compute loss\n", + " loss = loss_fn(logits, b_labels)\n", + " val_loss.append(loss.item())\n", + "\n", + " # Get the predictions\n", + " preds = torch.argmax(logits, dim=1).flatten()\n", + "\n", + " # Calculate the accuracy rate\n", + " accuracy = (preds == b_labels).cpu().numpy().mean() * 100\n", + " val_accuracy.append(accuracy)\n", + "\n", + " # Compute the average accuracy and loss over the validation set.\n", + " val_loss = np.mean(val_loss)\n", + " val_accuracy = np.mean(val_accuracy)\n", + "\n", + " return val_loss, val_accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-25T09:54:42.915823Z", + "start_time": "2022-06-25T09:12:58.201795Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias']\n", + "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Start training...\n", + "\n", + " Epoch | Batch | Train Loss | Val Loss | Val Acc | Elapsed \n", + "----------------------------------------------------------------------\n", + " 1 | 5 | 0.685621 | - | - | 79.83 \n", + " 1 | 10 | 0.669755 | - | - | 60.98 \n", + " 1 | 15 | 0.636234 | - | - | 60.81 \n", + " 1 | 20 | 0.603634 | - | - | 60.93 \n", + " 1 | 25 | 0.573587 | - | - | 60.90 \n", + " 1 | 30 | 0.594417 | - | - | 61.13 \n", + " 1 | 35 | 0.577772 | - | - | 60.90 \n", + " 1 | 40 | 0.564313 | - | - | 61.48 \n", + " 1 | 45 | 0.467151 | - | - | 61.49 \n", + " 1 | 50 | 0.477319 | - | - | 61.96 \n", + " 1 | 55 | 0.487433 | - | - | 61.16 \n", + " 1 | 60 | 0.407353 | - | - | 60.69 \n", + " 1 | 65 | 0.532023 | - | - | 60.32 \n", + " 1 | 70 | 0.529829 | - | - | 60.40 \n", + " 1 | 75 | 0.443143 | - | - | 60.19 \n", + " 1 | 80 | 0.432639 | - | - | 60.50 \n", + " 1 | 85 | 0.506726 | - | - | 60.32 \n", + " 1 | 90 | 0.442358 | - | - | 60.45 \n", + " 1 | 95 | 0.539845 | - | - | 55.92 \n", + "----------------------------------------------------------------------\n", + " 1 | - | 0.536889 | 0.426069 | 80.68 | 1208.44 \n", + "----------------------------------------------------------------------\n", + "\n", + "\n", + " Epoch | Batch | Train Loss | Val Loss | Val Acc | Elapsed \n", + "----------------------------------------------------------------------\n", + " 2 | 5 | 0.325858 | - | - | 82.61 \n", + " 2 | 10 | 0.307421 | - | - | 66.54 \n", + " 2 | 15 | 0.317675 | - | - | 64.03 \n", + " 2 | 20 | 0.326934 | - | - | 63.01 \n", + " 2 | 25 | 0.325949 | - | - | 63.21 \n", + " 2 | 30 | 0.341778 | - | - | 63.01 \n", + " 2 | 35 | 0.273992 | - | - | 63.04 \n", + " 2 | 40 | 0.247418 | - | - | 64.86 \n", + " 2 | 45 | 0.264468 | - | - | 64.09 \n", + " 2 | 50 | 0.370117 | - | - | 63.30 \n", + " 2 | 55 | 0.256397 | - | - | 66.72 \n", + " 2 | 60 | 0.298844 | - | - | 65.71 \n", + " 2 | 65 | 0.371179 | - | - | 70.41 \n", + " 2 | 70 | 0.191519 | - | - | 66.28 \n", + " 2 | 75 | 0.354108 | - | - | 66.15 \n", + " 2 | 80 | 0.326162 | - | - | 66.60 \n", + " 2 | 85 | 0.273768 | - | - | 67.04 \n", + " 2 | 90 | 0.270890 | - | - | 65.44 \n", + " 2 | 95 | 0.294375 | - | - | 60.52 \n", + "----------------------------------------------------------------------\n", + " 2 | - | 0.302293 | 0.425262 | 81.70 | 1293.09 \n", + "----------------------------------------------------------------------\n", + "\n", + "\n", + "Training complete!\n" + ] + } + ], + "source": [ + "set_seed(42) # Set seed for reproducibility\n", + "bert_classifier, optimizer, scheduler = initialize_model(epochs=2)\n", + "train(bert_classifier, train_dataloader, val_dataloader, epochs=2, evaluation=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} |
