diff options
| author | zhang <zch921005@126.com> | 2023-02-22 00:22:36 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2023-02-22 00:22:36 +0800 |
| commit | 12ac2c65cb05b15461592333e338feaf98cbe7cb (patch) | |
| tree | 2245481b2b85dff6643d207d05a1cc32ec572047 /dl/tutorials/04_vgg_on_cifar10.ipynb | |
| parent | 37506402a98eba9bf9d06760a1010fa17adb39e4 (diff) | |
vgg
Diffstat (limited to 'dl/tutorials/04_vgg_on_cifar10.ipynb')
| -rw-r--r-- | dl/tutorials/04_vgg_on_cifar10.ipynb | 887 |
1 files changed, 887 insertions, 0 deletions
diff --git a/dl/tutorials/04_vgg_on_cifar10.ipynb b/dl/tutorials/04_vgg_on_cifar10.ipynb new file mode 100644 index 0000000..cbf491c --- /dev/null +++ b/dl/tutorials/04_vgg_on_cifar10.ipynb @@ -0,0 +1,887 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 13, + "id": "36600735", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:07:27.625977Z", + "start_time": "2023-02-21T16:07:27.623719Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "import torchvision\n", + "from torchvision import models\n", + "from torchvision import datasets, transforms\n", + "from datetime import datetime\n", + "from utils import get_mean_and_std" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "daeda546", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:01:51.590225Z", + "start_time": "2023-02-21T16:01:51.562576Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.10.2\n", + "0.11.3\n", + "True\n", + "NVIDIA A10\n" + ] + } + ], + "source": [ + "print(torch.__version__)\n", + "print(torchvision.__version__)\n", + "print(torch.cuda.is_available())\n", + "print(torch.cuda.get_device_name())" + ] + }, + { + "cell_type": "markdown", + "id": "df51273a", + "metadata": {}, + "source": [ + "## vgg" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "15ffacfc", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:02:34.019695Z", + "start_time": "2023-02-21T16:02:33.088531Z" + } + }, + "outputs": [], + "source": [ + "vgg = models.vgg16(pretrained=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7291eaa0", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:02:38.446266Z", + "start_time": "2023-02-21T16:02:38.440109Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "VGG(\n", + " (features): Sequential(\n", + " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (3): ReLU(inplace=True)\n", + " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (6): ReLU(inplace=True)\n", + " (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (8): ReLU(inplace=True)\n", + " (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (11): ReLU(inplace=True)\n", + " (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (13): ReLU(inplace=True)\n", + " (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (15): ReLU(inplace=True)\n", + " (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (18): ReLU(inplace=True)\n", + " (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (20): ReLU(inplace=True)\n", + " (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (22): ReLU(inplace=True)\n", + " (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (25): ReLU(inplace=True)\n", + " (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (27): ReLU(inplace=True)\n", + " (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (29): ReLU(inplace=True)\n", + " (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n", + " (classifier): Sequential(\n", + " (0): Linear(in_features=25088, out_features=4096, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.5, inplace=False)\n", + " (3): Linear(in_features=4096, out_features=4096, bias=True)\n", + " (4): ReLU(inplace=True)\n", + " (5): Dropout(p=0.5, inplace=False)\n", + " (6): Linear(in_features=4096, out_features=1000, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vgg" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "5353f3f1", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:03:11.962435Z", + "start_time": "2023-02-21T16:03:11.960524Z" + } + }, + "outputs": [], + "source": [ + "from torchsummary import summary" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b2120b07", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:03:30.107291Z", + "start_time": "2023-02-21T16:03:29.845682Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================\n", + " Conv2d-1 [-1, 64, 224, 224] 1,792\n", + " ReLU-2 [-1, 64, 224, 224] 0\n", + " Conv2d-3 [-1, 64, 224, 224] 36,928\n", + " ReLU-4 [-1, 64, 224, 224] 0\n", + " MaxPool2d-5 [-1, 64, 112, 112] 0\n", + " Conv2d-6 [-1, 128, 112, 112] 73,856\n", + " ReLU-7 [-1, 128, 112, 112] 0\n", + " Conv2d-8 [-1, 128, 112, 112] 147,584\n", + " ReLU-9 [-1, 128, 112, 112] 0\n", + " MaxPool2d-10 [-1, 128, 56, 56] 0\n", + " Conv2d-11 [-1, 256, 56, 56] 295,168\n", + " ReLU-12 [-1, 256, 56, 56] 0\n", + " Conv2d-13 [-1, 256, 56, 56] 590,080\n", + " ReLU-14 [-1, 256, 56, 56] 0\n", + " Conv2d-15 [-1, 256, 56, 56] 590,080\n", + " ReLU-16 [-1, 256, 56, 56] 0\n", + " MaxPool2d-17 [-1, 256, 28, 28] 0\n", + " Conv2d-18 [-1, 512, 28, 28] 1,180,160\n", + " ReLU-19 [-1, 512, 28, 28] 0\n", + " Conv2d-20 [-1, 512, 28, 28] 2,359,808\n", + " ReLU-21 [-1, 512, 28, 28] 0\n", + " Conv2d-22 [-1, 512, 28, 28] 2,359,808\n", + " ReLU-23 [-1, 512, 28, 28] 0\n", + " MaxPool2d-24 [-1, 512, 14, 14] 0\n", + " Conv2d-25 [-1, 512, 14, 14] 2,359,808\n", + " ReLU-26 [-1, 512, 14, 14] 0\n", + " Conv2d-27 [-1, 512, 14, 14] 2,359,808\n", + " ReLU-28 [-1, 512, 14, 14] 0\n", + " Conv2d-29 [-1, 512, 14, 14] 2,359,808\n", + " ReLU-30 [-1, 512, 14, 14] 0\n", + " MaxPool2d-31 [-1, 512, 7, 7] 0\n", + "AdaptiveAvgPool2d-32 [-1, 512, 7, 7] 0\n", + " Linear-33 [-1, 4096] 102,764,544\n", + " ReLU-34 [-1, 4096] 0\n", + " Dropout-35 [-1, 4096] 0\n", + " Linear-36 [-1, 4096] 16,781,312\n", + " ReLU-37 [-1, 4096] 0\n", + " Dropout-38 [-1, 4096] 0\n", + " Linear-39 [-1, 1000] 4,097,000\n", + "================================================================\n", + "Total params: 138,357,544\n", + "Trainable params: 138,357,544\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------\n", + "Input size (MB): 0.57\n", + "Forward/backward pass size (MB): 218.78\n", + "Params size (MB): 527.79\n", + "Estimated Total Size (MB): 747.15\n", + "----------------------------------------------------------------\n" + ] + } + ], + "source": [ + "summary(vgg, input_size=(3, 224, 224), device='cpu')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "88f4b5b0", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:05:36.359811Z", + "start_time": "2023-02-21T16:05:36.355617Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "VGG(\n", + " (features): Sequential(\n", + " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (3): ReLU(inplace=True)\n", + " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (6): ReLU(inplace=True)\n", + " (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (8): ReLU(inplace=True)\n", + " (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (11): ReLU(inplace=True)\n", + " (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (13): ReLU(inplace=True)\n", + " (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (15): ReLU(inplace=True)\n", + " (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (18): ReLU(inplace=True)\n", + " (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (20): ReLU(inplace=True)\n", + " (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (22): ReLU(inplace=True)\n", + " (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (25): ReLU(inplace=True)\n", + " (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (27): ReLU(inplace=True)\n", + " (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (29): ReLU(inplace=True)\n", + " (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " )\n", + " (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n", + " (classifier): Sequential(\n", + " (0): Linear(in_features=25088, out_features=4096, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.5, inplace=False)\n", + " (3): Linear(in_features=4096, out_features=4096, bias=True)\n", + " (4): ReLU(inplace=True)\n", + " (5): Dropout(p=0.5, inplace=False)\n", + " (6): Linear(in_features=4096, out_features=10, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "in_features = vgg.classifier[6].in_features\n", + "vgg.classifier[6] = nn.Linear(in_features, 10)\n", + "vgg" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "60fb0bd4", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:05:57.728656Z", + "start_time": "2023-02-21T16:05:57.577214Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================\n", + " Conv2d-1 [-1, 64, 224, 224] 1,792\n", + " ReLU-2 [-1, 64, 224, 224] 0\n", + " Conv2d-3 [-1, 64, 224, 224] 36,928\n", + " ReLU-4 [-1, 64, 224, 224] 0\n", + " MaxPool2d-5 [-1, 64, 112, 112] 0\n", + " Conv2d-6 [-1, 128, 112, 112] 73,856\n", + " ReLU-7 [-1, 128, 112, 112] 0\n", + " Conv2d-8 [-1, 128, 112, 112] 147,584\n", + " ReLU-9 [-1, 128, 112, 112] 0\n", + " MaxPool2d-10 [-1, 128, 56, 56] 0\n", + " Conv2d-11 [-1, 256, 56, 56] 295,168\n", + " ReLU-12 [-1, 256, 56, 56] 0\n", + " Conv2d-13 [-1, 256, 56, 56] 590,080\n", + " ReLU-14 [-1, 256, 56, 56] 0\n", + " Conv2d-15 [-1, 256, 56, 56] 590,080\n", + " ReLU-16 [-1, 256, 56, 56] 0\n", + " MaxPool2d-17 [-1, 256, 28, 28] 0\n", + " Conv2d-18 [-1, 512, 28, 28] 1,180,160\n", + " ReLU-19 [-1, 512, 28, 28] 0\n", + " Conv2d-20 [-1, 512, 28, 28] 2,359,808\n", + " ReLU-21 [-1, 512, 28, 28] 0\n", + " Conv2d-22 [-1, 512, 28, 28] 2,359,808\n", + " ReLU-23 [-1, 512, 28, 28] 0\n", + " MaxPool2d-24 [-1, 512, 14, 14] 0\n", + " Conv2d-25 [-1, 512, 14, 14] 2,359,808\n", + " ReLU-26 [-1, 512, 14, 14] 0\n", + " Conv2d-27 [-1, 512, 14, 14] 2,359,808\n", + " ReLU-28 [-1, 512, 14, 14] 0\n", + " Conv2d-29 [-1, 512, 14, 14] 2,359,808\n", + " ReLU-30 [-1, 512, 14, 14] 0\n", + " MaxPool2d-31 [-1, 512, 7, 7] 0\n", + "AdaptiveAvgPool2d-32 [-1, 512, 7, 7] 0\n", + " Linear-33 [-1, 4096] 102,764,544\n", + " ReLU-34 [-1, 4096] 0\n", + " Dropout-35 [-1, 4096] 0\n", + " Linear-36 [-1, 4096] 16,781,312\n", + " ReLU-37 [-1, 4096] 0\n", + " Dropout-38 [-1, 4096] 0\n", + " Linear-39 [-1, 10] 40,970\n", + "================================================================\n", + "Total params: 134,301,514\n", + "Trainable params: 134,301,514\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------\n", + "Input size (MB): 0.57\n", + "Forward/backward pass size (MB): 218.77\n", + "Params size (MB): 512.32\n", + "Estimated Total Size (MB): 731.67\n", + "----------------------------------------------------------------\n" + ] + } + ], + "source": [ + "summary(vgg, input_size=(3, 224, 224), device='cpu')" + ] + }, + { + "cell_type": "markdown", + "id": "f7333d83", + "metadata": {}, + "source": [ + "## parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8d31ddee", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:06:28.035219Z", + "start_time": "2023-02-21T16:06:28.032859Z" + } + }, + "outputs": [], + "source": [ + "# dataset\n", + "# input_shape = 32\n", + "num_classes = 10\n", + "\n", + "# hyper \n", + "batch_size = 64\n", + "num_epochs = 5\n", + "learning_rate = 1e-3\n", + "\n", + "# gpu\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "72dcd7cf", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:06:30.620769Z", + "start_time": "2023-02-21T16:06:30.618149Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cuda')" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device" + ] + }, + { + "cell_type": "markdown", + "id": "856f2db1", + "metadata": {}, + "source": [ + "## dataset 与 dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "421d5a1d", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:07:08.617211Z", + "start_time": "2023-02-21T16:07:07.313136Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files already downloaded and verified\n", + "Files already downloaded and verified\n" + ] + } + ], + "source": [ + "train_dataset = datasets.CIFAR10(root='../data/', \n", + " download=True, \n", + " train=True, \n", + " transform=transforms.ToTensor())\n", + "test_dataset = datasets.CIFAR10(root='../data/', \n", + " download=True, \n", + " train=False, \n", + " transform=transforms.ToTensor())" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "1da83915", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:08:15.650602Z", + "start_time": "2023-02-21T16:07:37.819122Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==> Computing mean and std..\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 50000/50000 [00:37<00:00, 1321.94it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "(tensor([0.4914, 0.4822, 0.4465]), tensor([0.2023, 0.1994, 0.2010]))" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_mean_and_std(train_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "f2379f10", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:09:30.085725Z", + "start_time": "2023-02-21T16:09:30.083236Z" + } + }, + "outputs": [], + "source": [ + "transform = transforms.Compose([\n", + " transforms.Resize(size=(224, 224)),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "00a568cd", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:09:51.421012Z", + "start_time": "2023-02-21T16:09:50.112455Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files already downloaded and verified\n", + "Files already downloaded and verified\n" + ] + } + ], + "source": [ + "train_dataset = datasets.CIFAR10(root='../data/', \n", + " download=True, \n", + " train=True, \n", + " transform=transform)\n", + "test_dataset = datasets.CIFAR10(root='../data/', \n", + " download=True, \n", + " train=False, \n", + " transform=transform)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "c656663b", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:09:53.259344Z", + "start_time": "2023-02-21T16:09:53.256872Z" + } + }, + "outputs": [], + "source": [ + "train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, \n", + " shuffle=True, \n", + " batch_size=batch_size)\n", + "test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, \n", + " shuffle=False, \n", + " batch_size=batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "f77ea73d", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:09:55.520559Z", + "start_time": "2023-02-21T16:09:55.462862Z" + } + }, + "outputs": [], + "source": [ + "images, labels = next(iter(train_dataloader))" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "d03b9505", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:09:56.669215Z", + "start_time": "2023-02-21T16:09:56.666615Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([64, 3, 224, 224])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# batch_size, channels, h, w\n", + "images.shape" + ] + }, + { + "cell_type": "markdown", + "id": "20b8c4b4", + "metadata": {}, + "source": [ + "## model arch" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "d69e7f3e", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:10:48.937581Z", + "start_time": "2023-02-21T16:10:44.764451Z" + } + }, + "outputs": [], + "source": [ + "vgg = models.vgg16(pretrained=True)\n", + "in_features = vgg.classifier[6].in_features\n", + "vgg.classifier[6] = nn.Linear(in_features, 10)\n", + "vgg = vgg.to(device)" + ] + }, + { + "cell_type": "markdown", + "id": "872d454c", + "metadata": {}, + "source": [ + "## model train/fine-tune" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "3e0babe0", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:13:24.682178Z", + "start_time": "2023-02-21T16:13:24.679390Z" + } + }, + "outputs": [], + "source": [ + "criterion = nn.CrossEntropyLoss()\n", + "# optimzier = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", + "optimizer = torch.optim.SGD(vgg.parameters(), lr = learning_rate, momentum=0.9,weight_decay=5e-4)\n", + "total_batch = len(train_dataloader)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "636c01fb", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:20:46.902487Z", + "start_time": "2023-02-21T16:13:25.728051Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-02-22 00:13:56.818252, 1/5, 100/782: 0.4080, acc: 0.859375\n", + "2023-02-22 00:14:28.367012, 1/5, 200/782: 0.4214, acc: 0.90625\n", + "2023-02-22 00:15:00.193995, 1/5, 300/782: 0.5746, acc: 0.828125\n", + "2023-02-22 00:15:32.204011, 1/5, 400/782: 0.3419, acc: 0.875\n", + "2023-02-22 00:16:04.007379, 1/5, 500/782: 0.3597, acc: 0.859375\n", + "2023-02-22 00:16:35.717492, 1/5, 600/782: 0.4721, acc: 0.796875\n", + "2023-02-22 00:17:07.485007, 1/5, 700/782: 0.1224, acc: 0.953125\n", + "2023-02-22 00:18:05.079364, 2/5, 100/782: 0.1760, acc: 0.921875\n", + "2023-02-22 00:18:36.859823, 2/5, 200/782: 0.1521, acc: 0.921875\n", + "2023-02-22 00:19:08.638372, 2/5, 300/782: 0.2511, acc: 0.90625\n", + "2023-02-22 00:19:40.422421, 2/5, 400/782: 0.1399, acc: 0.9375\n", + "2023-02-22 00:20:12.218806, 2/5, 500/782: 0.1915, acc: 0.90625\n", + "2023-02-22 00:20:44.003433, 2/5, 600/782: 0.1382, acc: 0.9375\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-25-9a44dc355816>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;31m# 标准的处理,用 validate data;这个过程是监督训练过程\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mn_corrects\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0macc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mn_corrects\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "for epoch in range(num_epochs):\n", + " for batch_idx, (images, labels) in enumerate(train_dataloader):\n", + " images = images.to(device)\n", + " labels = labels.to(device)\n", + " \n", + " # forward\n", + " out = vgg(images)\n", + " loss = criterion(out, labels)\n", + " \n", + " # 标准的处理,用 validate data;这个过程是监督训练过程,用于 early stop\n", + " n_corrects = (out.argmax(axis=1) == labels).sum().item()\n", + " acc = n_corrects/labels.size(0)\n", + " \n", + " # backward\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step() # 更细 模型参数\n", + " \n", + " if (batch_idx+1) % 100 == 0:\n", + " print(f'{datetime.now()}, {epoch+1}/{num_epochs}, {batch_idx+1}/{total_batch}: {loss.item():.4f}, acc: {acc}')" + ] + }, + { + "cell_type": "markdown", + "id": "5e7078dd", + "metadata": {}, + "source": [ + "## model evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "ef80d16f", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:21:06.582989Z", + "start_time": "2023-02-21T16:20:55.012423Z" + } + }, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-27-a604fa16a1e8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtotal\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mcorrect\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtest_dataloader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mimages\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 519\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 520\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 521\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 522\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 559\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 560\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 561\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 562\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 563\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torchvision/datasets/cifar.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 121\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 122\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_transform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torchvision/transforms/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1095\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1096\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1097\u001b[0;31m \u001b[0mforward_call\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_tracing_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1098\u001b[0m \u001b[0;31m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[0;31m# this function, and just call forward.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "total = 0\n", + "correct = 0\n", + "for images, labels in test_dataloader:\n", + " images = images.to(device)\n", + " labels = labels.to(device)\n", + " out = vgg(images)\n", + " preds = torch.argmax(out, dim=1)\n", + " \n", + " total += images.size(0)\n", + " correct += (preds == labels).sum().item()\n", + "print(f'{correct}/{total}={correct/total}')" + ] + }, + { + "cell_type": "markdown", + "id": "96272f61", + "metadata": {}, + "source": [ + "## model save" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e68f24f", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-21T16:21:06.587517Z", + "start_time": "2023-02-21T16:21:03.590Z" + } + }, + "outputs": [], + "source": [ + "torch.save(vgg.state_dict(), 'cnn_cifar.ckpt')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.13" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |
