summaryrefslogtreecommitdiff
path: root/dl/tutorials/04_vgg_on_cifar10.ipynb
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2023-02-22 00:22:36 +0800
committerzhang <zch921005@126.com>2023-02-22 00:22:36 +0800
commit12ac2c65cb05b15461592333e338feaf98cbe7cb (patch)
tree2245481b2b85dff6643d207d05a1cc32ec572047 /dl/tutorials/04_vgg_on_cifar10.ipynb
parent37506402a98eba9bf9d06760a1010fa17adb39e4 (diff)
vgg
Diffstat (limited to 'dl/tutorials/04_vgg_on_cifar10.ipynb')
-rw-r--r--dl/tutorials/04_vgg_on_cifar10.ipynb887
1 files changed, 887 insertions, 0 deletions
diff --git a/dl/tutorials/04_vgg_on_cifar10.ipynb b/dl/tutorials/04_vgg_on_cifar10.ipynb
new file mode 100644
index 0000000..cbf491c
--- /dev/null
+++ b/dl/tutorials/04_vgg_on_cifar10.ipynb
@@ -0,0 +1,887 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "36600735",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:07:27.625977Z",
+ "start_time": "2023-02-21T16:07:27.623719Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from torch import nn\n",
+ "import torchvision\n",
+ "from torchvision import models\n",
+ "from torchvision import datasets, transforms\n",
+ "from datetime import datetime\n",
+ "from utils import get_mean_and_std"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "daeda546",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:01:51.590225Z",
+ "start_time": "2023-02-21T16:01:51.562576Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1.10.2\n",
+ "0.11.3\n",
+ "True\n",
+ "NVIDIA A10\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(torch.__version__)\n",
+ "print(torchvision.__version__)\n",
+ "print(torch.cuda.is_available())\n",
+ "print(torch.cuda.get_device_name())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "df51273a",
+ "metadata": {},
+ "source": [
+ "## vgg"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "15ffacfc",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:02:34.019695Z",
+ "start_time": "2023-02-21T16:02:33.088531Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "vgg = models.vgg16(pretrained=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "7291eaa0",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:02:38.446266Z",
+ "start_time": "2023-02-21T16:02:38.440109Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "VGG(\n",
+ " (features): Sequential(\n",
+ " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (3): ReLU(inplace=True)\n",
+ " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (6): ReLU(inplace=True)\n",
+ " (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (8): ReLU(inplace=True)\n",
+ " (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (11): ReLU(inplace=True)\n",
+ " (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (13): ReLU(inplace=True)\n",
+ " (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (15): ReLU(inplace=True)\n",
+ " (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (18): ReLU(inplace=True)\n",
+ " (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (20): ReLU(inplace=True)\n",
+ " (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (22): ReLU(inplace=True)\n",
+ " (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (25): ReLU(inplace=True)\n",
+ " (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (27): ReLU(inplace=True)\n",
+ " (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (29): ReLU(inplace=True)\n",
+ " (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " )\n",
+ " (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
+ " (classifier): Sequential(\n",
+ " (0): Linear(in_features=25088, out_features=4096, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.5, inplace=False)\n",
+ " (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
+ " (4): ReLU(inplace=True)\n",
+ " (5): Dropout(p=0.5, inplace=False)\n",
+ " (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "vgg"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "5353f3f1",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:03:11.962435Z",
+ "start_time": "2023-02-21T16:03:11.960524Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from torchsummary import summary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "b2120b07",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:03:30.107291Z",
+ "start_time": "2023-02-21T16:03:29.845682Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "----------------------------------------------------------------\n",
+ " Layer (type) Output Shape Param #\n",
+ "================================================================\n",
+ " Conv2d-1 [-1, 64, 224, 224] 1,792\n",
+ " ReLU-2 [-1, 64, 224, 224] 0\n",
+ " Conv2d-3 [-1, 64, 224, 224] 36,928\n",
+ " ReLU-4 [-1, 64, 224, 224] 0\n",
+ " MaxPool2d-5 [-1, 64, 112, 112] 0\n",
+ " Conv2d-6 [-1, 128, 112, 112] 73,856\n",
+ " ReLU-7 [-1, 128, 112, 112] 0\n",
+ " Conv2d-8 [-1, 128, 112, 112] 147,584\n",
+ " ReLU-9 [-1, 128, 112, 112] 0\n",
+ " MaxPool2d-10 [-1, 128, 56, 56] 0\n",
+ " Conv2d-11 [-1, 256, 56, 56] 295,168\n",
+ " ReLU-12 [-1, 256, 56, 56] 0\n",
+ " Conv2d-13 [-1, 256, 56, 56] 590,080\n",
+ " ReLU-14 [-1, 256, 56, 56] 0\n",
+ " Conv2d-15 [-1, 256, 56, 56] 590,080\n",
+ " ReLU-16 [-1, 256, 56, 56] 0\n",
+ " MaxPool2d-17 [-1, 256, 28, 28] 0\n",
+ " Conv2d-18 [-1, 512, 28, 28] 1,180,160\n",
+ " ReLU-19 [-1, 512, 28, 28] 0\n",
+ " Conv2d-20 [-1, 512, 28, 28] 2,359,808\n",
+ " ReLU-21 [-1, 512, 28, 28] 0\n",
+ " Conv2d-22 [-1, 512, 28, 28] 2,359,808\n",
+ " ReLU-23 [-1, 512, 28, 28] 0\n",
+ " MaxPool2d-24 [-1, 512, 14, 14] 0\n",
+ " Conv2d-25 [-1, 512, 14, 14] 2,359,808\n",
+ " ReLU-26 [-1, 512, 14, 14] 0\n",
+ " Conv2d-27 [-1, 512, 14, 14] 2,359,808\n",
+ " ReLU-28 [-1, 512, 14, 14] 0\n",
+ " Conv2d-29 [-1, 512, 14, 14] 2,359,808\n",
+ " ReLU-30 [-1, 512, 14, 14] 0\n",
+ " MaxPool2d-31 [-1, 512, 7, 7] 0\n",
+ "AdaptiveAvgPool2d-32 [-1, 512, 7, 7] 0\n",
+ " Linear-33 [-1, 4096] 102,764,544\n",
+ " ReLU-34 [-1, 4096] 0\n",
+ " Dropout-35 [-1, 4096] 0\n",
+ " Linear-36 [-1, 4096] 16,781,312\n",
+ " ReLU-37 [-1, 4096] 0\n",
+ " Dropout-38 [-1, 4096] 0\n",
+ " Linear-39 [-1, 1000] 4,097,000\n",
+ "================================================================\n",
+ "Total params: 138,357,544\n",
+ "Trainable params: 138,357,544\n",
+ "Non-trainable params: 0\n",
+ "----------------------------------------------------------------\n",
+ "Input size (MB): 0.57\n",
+ "Forward/backward pass size (MB): 218.78\n",
+ "Params size (MB): 527.79\n",
+ "Estimated Total Size (MB): 747.15\n",
+ "----------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "summary(vgg, input_size=(3, 224, 224), device='cpu')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "88f4b5b0",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:05:36.359811Z",
+ "start_time": "2023-02-21T16:05:36.355617Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "VGG(\n",
+ " (features): Sequential(\n",
+ " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (3): ReLU(inplace=True)\n",
+ " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (6): ReLU(inplace=True)\n",
+ " (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (8): ReLU(inplace=True)\n",
+ " (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (11): ReLU(inplace=True)\n",
+ " (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (13): ReLU(inplace=True)\n",
+ " (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (15): ReLU(inplace=True)\n",
+ " (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (18): ReLU(inplace=True)\n",
+ " (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (20): ReLU(inplace=True)\n",
+ " (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (22): ReLU(inplace=True)\n",
+ " (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (25): ReLU(inplace=True)\n",
+ " (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (27): ReLU(inplace=True)\n",
+ " (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (29): ReLU(inplace=True)\n",
+ " (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " )\n",
+ " (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
+ " (classifier): Sequential(\n",
+ " (0): Linear(in_features=25088, out_features=4096, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.5, inplace=False)\n",
+ " (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
+ " (4): ReLU(inplace=True)\n",
+ " (5): Dropout(p=0.5, inplace=False)\n",
+ " (6): Linear(in_features=4096, out_features=10, bias=True)\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "in_features = vgg.classifier[6].in_features\n",
+ "vgg.classifier[6] = nn.Linear(in_features, 10)\n",
+ "vgg"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "60fb0bd4",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:05:57.728656Z",
+ "start_time": "2023-02-21T16:05:57.577214Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "----------------------------------------------------------------\n",
+ " Layer (type) Output Shape Param #\n",
+ "================================================================\n",
+ " Conv2d-1 [-1, 64, 224, 224] 1,792\n",
+ " ReLU-2 [-1, 64, 224, 224] 0\n",
+ " Conv2d-3 [-1, 64, 224, 224] 36,928\n",
+ " ReLU-4 [-1, 64, 224, 224] 0\n",
+ " MaxPool2d-5 [-1, 64, 112, 112] 0\n",
+ " Conv2d-6 [-1, 128, 112, 112] 73,856\n",
+ " ReLU-7 [-1, 128, 112, 112] 0\n",
+ " Conv2d-8 [-1, 128, 112, 112] 147,584\n",
+ " ReLU-9 [-1, 128, 112, 112] 0\n",
+ " MaxPool2d-10 [-1, 128, 56, 56] 0\n",
+ " Conv2d-11 [-1, 256, 56, 56] 295,168\n",
+ " ReLU-12 [-1, 256, 56, 56] 0\n",
+ " Conv2d-13 [-1, 256, 56, 56] 590,080\n",
+ " ReLU-14 [-1, 256, 56, 56] 0\n",
+ " Conv2d-15 [-1, 256, 56, 56] 590,080\n",
+ " ReLU-16 [-1, 256, 56, 56] 0\n",
+ " MaxPool2d-17 [-1, 256, 28, 28] 0\n",
+ " Conv2d-18 [-1, 512, 28, 28] 1,180,160\n",
+ " ReLU-19 [-1, 512, 28, 28] 0\n",
+ " Conv2d-20 [-1, 512, 28, 28] 2,359,808\n",
+ " ReLU-21 [-1, 512, 28, 28] 0\n",
+ " Conv2d-22 [-1, 512, 28, 28] 2,359,808\n",
+ " ReLU-23 [-1, 512, 28, 28] 0\n",
+ " MaxPool2d-24 [-1, 512, 14, 14] 0\n",
+ " Conv2d-25 [-1, 512, 14, 14] 2,359,808\n",
+ " ReLU-26 [-1, 512, 14, 14] 0\n",
+ " Conv2d-27 [-1, 512, 14, 14] 2,359,808\n",
+ " ReLU-28 [-1, 512, 14, 14] 0\n",
+ " Conv2d-29 [-1, 512, 14, 14] 2,359,808\n",
+ " ReLU-30 [-1, 512, 14, 14] 0\n",
+ " MaxPool2d-31 [-1, 512, 7, 7] 0\n",
+ "AdaptiveAvgPool2d-32 [-1, 512, 7, 7] 0\n",
+ " Linear-33 [-1, 4096] 102,764,544\n",
+ " ReLU-34 [-1, 4096] 0\n",
+ " Dropout-35 [-1, 4096] 0\n",
+ " Linear-36 [-1, 4096] 16,781,312\n",
+ " ReLU-37 [-1, 4096] 0\n",
+ " Dropout-38 [-1, 4096] 0\n",
+ " Linear-39 [-1, 10] 40,970\n",
+ "================================================================\n",
+ "Total params: 134,301,514\n",
+ "Trainable params: 134,301,514\n",
+ "Non-trainable params: 0\n",
+ "----------------------------------------------------------------\n",
+ "Input size (MB): 0.57\n",
+ "Forward/backward pass size (MB): 218.77\n",
+ "Params size (MB): 512.32\n",
+ "Estimated Total Size (MB): 731.67\n",
+ "----------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "summary(vgg, input_size=(3, 224, 224), device='cpu')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f7333d83",
+ "metadata": {},
+ "source": [
+ "## parameters"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "8d31ddee",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:06:28.035219Z",
+ "start_time": "2023-02-21T16:06:28.032859Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# dataset\n",
+ "# input_shape = 32\n",
+ "num_classes = 10\n",
+ "\n",
+ "# hyper \n",
+ "batch_size = 64\n",
+ "num_epochs = 5\n",
+ "learning_rate = 1e-3\n",
+ "\n",
+ "# gpu\n",
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "72dcd7cf",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:06:30.620769Z",
+ "start_time": "2023-02-21T16:06:30.618149Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "device(type='cuda')"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "device"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "856f2db1",
+ "metadata": {},
+ "source": [
+ "## dataset 与 dataloader"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "421d5a1d",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:07:08.617211Z",
+ "start_time": "2023-02-21T16:07:07.313136Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Files already downloaded and verified\n",
+ "Files already downloaded and verified\n"
+ ]
+ }
+ ],
+ "source": [
+ "train_dataset = datasets.CIFAR10(root='../data/', \n",
+ " download=True, \n",
+ " train=True, \n",
+ " transform=transforms.ToTensor())\n",
+ "test_dataset = datasets.CIFAR10(root='../data/', \n",
+ " download=True, \n",
+ " train=False, \n",
+ " transform=transforms.ToTensor())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "1da83915",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:08:15.650602Z",
+ "start_time": "2023-02-21T16:07:37.819122Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "==> Computing mean and std..\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 50000/50000 [00:37<00:00, 1321.94it/s]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "(tensor([0.4914, 0.4822, 0.4465]), tensor([0.2023, 0.1994, 0.2010]))"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "get_mean_and_std(train_dataset)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "f2379f10",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:09:30.085725Z",
+ "start_time": "2023-02-21T16:09:30.083236Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "transform = transforms.Compose([\n",
+ " transforms.Resize(size=(224, 224)),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))\n",
+ "])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "00a568cd",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:09:51.421012Z",
+ "start_time": "2023-02-21T16:09:50.112455Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Files already downloaded and verified\n",
+ "Files already downloaded and verified\n"
+ ]
+ }
+ ],
+ "source": [
+ "train_dataset = datasets.CIFAR10(root='../data/', \n",
+ " download=True, \n",
+ " train=True, \n",
+ " transform=transform)\n",
+ "test_dataset = datasets.CIFAR10(root='../data/', \n",
+ " download=True, \n",
+ " train=False, \n",
+ " transform=transform)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "c656663b",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:09:53.259344Z",
+ "start_time": "2023-02-21T16:09:53.256872Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, \n",
+ " shuffle=True, \n",
+ " batch_size=batch_size)\n",
+ "test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, \n",
+ " shuffle=False, \n",
+ " batch_size=batch_size)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "f77ea73d",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:09:55.520559Z",
+ "start_time": "2023-02-21T16:09:55.462862Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "images, labels = next(iter(train_dataloader))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "d03b9505",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:09:56.669215Z",
+ "start_time": "2023-02-21T16:09:56.666615Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([64, 3, 224, 224])"
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# batch_size, channels, h, w\n",
+ "images.shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "20b8c4b4",
+ "metadata": {},
+ "source": [
+ "## model arch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "d69e7f3e",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:10:48.937581Z",
+ "start_time": "2023-02-21T16:10:44.764451Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "vgg = models.vgg16(pretrained=True)\n",
+ "in_features = vgg.classifier[6].in_features\n",
+ "vgg.classifier[6] = nn.Linear(in_features, 10)\n",
+ "vgg = vgg.to(device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "872d454c",
+ "metadata": {},
+ "source": [
+ "## model train/fine-tune"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "3e0babe0",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:13:24.682178Z",
+ "start_time": "2023-02-21T16:13:24.679390Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "criterion = nn.CrossEntropyLoss()\n",
+ "# optimzier = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
+ "optimizer = torch.optim.SGD(vgg.parameters(), lr = learning_rate, momentum=0.9,weight_decay=5e-4)\n",
+ "total_batch = len(train_dataloader)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "636c01fb",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:20:46.902487Z",
+ "start_time": "2023-02-21T16:13:25.728051Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2023-02-22 00:13:56.818252, 1/5, 100/782: 0.4080, acc: 0.859375\n",
+ "2023-02-22 00:14:28.367012, 1/5, 200/782: 0.4214, acc: 0.90625\n",
+ "2023-02-22 00:15:00.193995, 1/5, 300/782: 0.5746, acc: 0.828125\n",
+ "2023-02-22 00:15:32.204011, 1/5, 400/782: 0.3419, acc: 0.875\n",
+ "2023-02-22 00:16:04.007379, 1/5, 500/782: 0.3597, acc: 0.859375\n",
+ "2023-02-22 00:16:35.717492, 1/5, 600/782: 0.4721, acc: 0.796875\n",
+ "2023-02-22 00:17:07.485007, 1/5, 700/782: 0.1224, acc: 0.953125\n",
+ "2023-02-22 00:18:05.079364, 2/5, 100/782: 0.1760, acc: 0.921875\n",
+ "2023-02-22 00:18:36.859823, 2/5, 200/782: 0.1521, acc: 0.921875\n",
+ "2023-02-22 00:19:08.638372, 2/5, 300/782: 0.2511, acc: 0.90625\n",
+ "2023-02-22 00:19:40.422421, 2/5, 400/782: 0.1399, acc: 0.9375\n",
+ "2023-02-22 00:20:12.218806, 2/5, 500/782: 0.1915, acc: 0.90625\n",
+ "2023-02-22 00:20:44.003433, 2/5, 600/782: 0.1382, acc: 0.9375\n"
+ ]
+ },
+ {
+ "ename": "KeyboardInterrupt",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m<ipython-input-25-9a44dc355816>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;31m# 标准的处理,用 validate data;这个过程是监督训练过程\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mn_corrects\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0macc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mn_corrects\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+ ]
+ }
+ ],
+ "source": [
+ "for epoch in range(num_epochs):\n",
+ " for batch_idx, (images, labels) in enumerate(train_dataloader):\n",
+ " images = images.to(device)\n",
+ " labels = labels.to(device)\n",
+ " \n",
+ " # forward\n",
+ " out = vgg(images)\n",
+ " loss = criterion(out, labels)\n",
+ " \n",
+ " # 标准的处理,用 validate data;这个过程是监督训练过程,用于 early stop\n",
+ " n_corrects = (out.argmax(axis=1) == labels).sum().item()\n",
+ " acc = n_corrects/labels.size(0)\n",
+ " \n",
+ " # backward\n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step() # 更细 模型参数\n",
+ " \n",
+ " if (batch_idx+1) % 100 == 0:\n",
+ " print(f'{datetime.now()}, {epoch+1}/{num_epochs}, {batch_idx+1}/{total_batch}: {loss.item():.4f}, acc: {acc}')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5e7078dd",
+ "metadata": {},
+ "source": [
+ "## model evaluation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "ef80d16f",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:21:06.582989Z",
+ "start_time": "2023-02-21T16:20:55.012423Z"
+ }
+ },
+ "outputs": [
+ {
+ "ename": "KeyboardInterrupt",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m<ipython-input-27-a604fa16a1e8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtotal\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mcorrect\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtest_dataloader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mimages\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 519\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 520\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 521\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 522\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 559\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 560\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 561\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 562\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 563\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torchvision/datasets/cifar.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 121\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 122\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_transform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torchvision/transforms/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1095\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1096\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1097\u001b[0;31m \u001b[0mforward_call\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_tracing_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1098\u001b[0m \u001b[0;31m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[0;31m# this function, and just call forward.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+ ]
+ }
+ ],
+ "source": [
+ "total = 0\n",
+ "correct = 0\n",
+ "for images, labels in test_dataloader:\n",
+ " images = images.to(device)\n",
+ " labels = labels.to(device)\n",
+ " out = vgg(images)\n",
+ " preds = torch.argmax(out, dim=1)\n",
+ " \n",
+ " total += images.size(0)\n",
+ " correct += (preds == labels).sum().item()\n",
+ "print(f'{correct}/{total}={correct/total}')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "96272f61",
+ "metadata": {},
+ "source": [
+ "## model save"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2e68f24f",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-21T16:21:06.587517Z",
+ "start_time": "2023-02-21T16:21:03.590Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "torch.save(vgg.state_dict(), 'cnn_cifar.ckpt')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.13"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": true
+ },
+ "varInspector": {
+ "cols": {
+ "lenName": 16,
+ "lenType": 16,
+ "lenVar": 40
+ },
+ "kernels_config": {
+ "python": {
+ "delete_cmd_postfix": "",
+ "delete_cmd_prefix": "del ",
+ "library": "var_list.py",
+ "varRefreshCmd": "print(var_dic_list())"
+ },
+ "r": {
+ "delete_cmd_postfix": ") ",
+ "delete_cmd_prefix": "rm(",
+ "library": "var_list.r",
+ "varRefreshCmd": "cat(var_dic_list()) "
+ }
+ },
+ "types_to_exclude": [
+ "module",
+ "function",
+ "builtin_function_or_method",
+ "instance",
+ "_Feature"
+ ],
+ "window_display": false
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}