{ "cells": [ { "cell_type": "code", "execution_count": 13, "id": "36600735", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:07:27.625977Z", "start_time": "2023-02-21T16:07:27.623719Z" } }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import torchvision\n", "from torchvision import models\n", "from torchvision import datasets, transforms\n", "from datetime import datetime\n", "from utils import get_mean_and_std" ] }, { "cell_type": "code", "execution_count": 2, "id": "daeda546", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:01:51.590225Z", "start_time": "2023-02-21T16:01:51.562576Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.10.2\n", "0.11.3\n", "True\n", "NVIDIA A10\n" ] } ], "source": [ "print(torch.__version__)\n", "print(torchvision.__version__)\n", "print(torch.cuda.is_available())\n", "print(torch.cuda.get_device_name())" ] }, { "cell_type": "markdown", "id": "df51273a", "metadata": {}, "source": [ "## vgg" ] }, { "cell_type": "code", "execution_count": 3, "id": "15ffacfc", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:02:34.019695Z", "start_time": "2023-02-21T16:02:33.088531Z" } }, "outputs": [], "source": [ "vgg = models.vgg16(pretrained=True)" ] }, { "cell_type": "code", "execution_count": 4, "id": "7291eaa0", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:02:38.446266Z", "start_time": "2023-02-21T16:02:38.440109Z" } }, "outputs": [ { "data": { "text/plain": [ "VGG(\n", " (features): Sequential(\n", " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (3): ReLU(inplace=True)\n", " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (6): ReLU(inplace=True)\n", " (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (8): ReLU(inplace=True)\n", " (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (11): ReLU(inplace=True)\n", " (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (13): ReLU(inplace=True)\n", " (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (15): ReLU(inplace=True)\n", " (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (18): ReLU(inplace=True)\n", " (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (20): ReLU(inplace=True)\n", " (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (22): ReLU(inplace=True)\n", " (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (25): ReLU(inplace=True)\n", " (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (27): ReLU(inplace=True)\n", " (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (29): ReLU(inplace=True)\n", " (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n", " (classifier): Sequential(\n", " (0): Linear(in_features=25088, out_features=4096, bias=True)\n", " (1): ReLU(inplace=True)\n", " (2): Dropout(p=0.5, inplace=False)\n", " (3): Linear(in_features=4096, out_features=4096, bias=True)\n", " (4): ReLU(inplace=True)\n", " (5): Dropout(p=0.5, inplace=False)\n", " (6): Linear(in_features=4096, out_features=1000, bias=True)\n", " )\n", ")" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vgg" ] }, { "cell_type": "code", "execution_count": 5, "id": "5353f3f1", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:03:11.962435Z", "start_time": "2023-02-21T16:03:11.960524Z" } }, "outputs": [], "source": [ "from torchsummary import summary" ] }, { "cell_type": "code", "execution_count": 6, "id": "b2120b07", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:03:30.107291Z", "start_time": "2023-02-21T16:03:29.845682Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", " Conv2d-1 [-1, 64, 224, 224] 1,792\n", " ReLU-2 [-1, 64, 224, 224] 0\n", " Conv2d-3 [-1, 64, 224, 224] 36,928\n", " ReLU-4 [-1, 64, 224, 224] 0\n", " MaxPool2d-5 [-1, 64, 112, 112] 0\n", " Conv2d-6 [-1, 128, 112, 112] 73,856\n", " ReLU-7 [-1, 128, 112, 112] 0\n", " Conv2d-8 [-1, 128, 112, 112] 147,584\n", " ReLU-9 [-1, 128, 112, 112] 0\n", " MaxPool2d-10 [-1, 128, 56, 56] 0\n", " Conv2d-11 [-1, 256, 56, 56] 295,168\n", " ReLU-12 [-1, 256, 56, 56] 0\n", " Conv2d-13 [-1, 256, 56, 56] 590,080\n", " ReLU-14 [-1, 256, 56, 56] 0\n", " Conv2d-15 [-1, 256, 56, 56] 590,080\n", " ReLU-16 [-1, 256, 56, 56] 0\n", " MaxPool2d-17 [-1, 256, 28, 28] 0\n", " Conv2d-18 [-1, 512, 28, 28] 1,180,160\n", " ReLU-19 [-1, 512, 28, 28] 0\n", " Conv2d-20 [-1, 512, 28, 28] 2,359,808\n", " ReLU-21 [-1, 512, 28, 28] 0\n", " Conv2d-22 [-1, 512, 28, 28] 2,359,808\n", " ReLU-23 [-1, 512, 28, 28] 0\n", " MaxPool2d-24 [-1, 512, 14, 14] 0\n", " Conv2d-25 [-1, 512, 14, 14] 2,359,808\n", " ReLU-26 [-1, 512, 14, 14] 0\n", " Conv2d-27 [-1, 512, 14, 14] 2,359,808\n", " ReLU-28 [-1, 512, 14, 14] 0\n", " Conv2d-29 [-1, 512, 14, 14] 2,359,808\n", " ReLU-30 [-1, 512, 14, 14] 0\n", " MaxPool2d-31 [-1, 512, 7, 7] 0\n", "AdaptiveAvgPool2d-32 [-1, 512, 7, 7] 0\n", " Linear-33 [-1, 4096] 102,764,544\n", " ReLU-34 [-1, 4096] 0\n", " Dropout-35 [-1, 4096] 0\n", " Linear-36 [-1, 4096] 16,781,312\n", " ReLU-37 [-1, 4096] 0\n", " Dropout-38 [-1, 4096] 0\n", " Linear-39 [-1, 1000] 4,097,000\n", "================================================================\n", "Total params: 138,357,544\n", "Trainable params: 138,357,544\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 0.57\n", "Forward/backward pass size (MB): 218.78\n", "Params size (MB): 527.79\n", "Estimated Total Size (MB): 747.15\n", "----------------------------------------------------------------\n" ] } ], "source": [ "summary(vgg, input_size=(3, 224, 224), device='cpu')" ] }, { "cell_type": "code", "execution_count": 8, "id": "88f4b5b0", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:05:36.359811Z", "start_time": "2023-02-21T16:05:36.355617Z" } }, "outputs": [ { "data": { "text/plain": [ "VGG(\n", " (features): Sequential(\n", " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (3): ReLU(inplace=True)\n", " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (6): ReLU(inplace=True)\n", " (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (8): ReLU(inplace=True)\n", " (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (11): ReLU(inplace=True)\n", " (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (13): ReLU(inplace=True)\n", " (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (15): ReLU(inplace=True)\n", " (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (18): ReLU(inplace=True)\n", " (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (20): ReLU(inplace=True)\n", " (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (22): ReLU(inplace=True)\n", " (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (25): ReLU(inplace=True)\n", " (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (27): ReLU(inplace=True)\n", " (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (29): ReLU(inplace=True)\n", " (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n", " (classifier): Sequential(\n", " (0): Linear(in_features=25088, out_features=4096, bias=True)\n", " (1): ReLU(inplace=True)\n", " (2): Dropout(p=0.5, inplace=False)\n", " (3): Linear(in_features=4096, out_features=4096, bias=True)\n", " (4): ReLU(inplace=True)\n", " (5): Dropout(p=0.5, inplace=False)\n", " (6): Linear(in_features=4096, out_features=10, bias=True)\n", " )\n", ")" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "in_features = vgg.classifier[6].in_features\n", "vgg.classifier[6] = nn.Linear(in_features, 10)\n", "vgg" ] }, { "cell_type": "code", "execution_count": 9, "id": "60fb0bd4", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:05:57.728656Z", "start_time": "2023-02-21T16:05:57.577214Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", " Conv2d-1 [-1, 64, 224, 224] 1,792\n", " ReLU-2 [-1, 64, 224, 224] 0\n", " Conv2d-3 [-1, 64, 224, 224] 36,928\n", " ReLU-4 [-1, 64, 224, 224] 0\n", " MaxPool2d-5 [-1, 64, 112, 112] 0\n", " Conv2d-6 [-1, 128, 112, 112] 73,856\n", " ReLU-7 [-1, 128, 112, 112] 0\n", " Conv2d-8 [-1, 128, 112, 112] 147,584\n", " ReLU-9 [-1, 128, 112, 112] 0\n", " MaxPool2d-10 [-1, 128, 56, 56] 0\n", " Conv2d-11 [-1, 256, 56, 56] 295,168\n", " ReLU-12 [-1, 256, 56, 56] 0\n", " Conv2d-13 [-1, 256, 56, 56] 590,080\n", " ReLU-14 [-1, 256, 56, 56] 0\n", " Conv2d-15 [-1, 256, 56, 56] 590,080\n", " ReLU-16 [-1, 256, 56, 56] 0\n", " MaxPool2d-17 [-1, 256, 28, 28] 0\n", " Conv2d-18 [-1, 512, 28, 28] 1,180,160\n", " ReLU-19 [-1, 512, 28, 28] 0\n", " Conv2d-20 [-1, 512, 28, 28] 2,359,808\n", " ReLU-21 [-1, 512, 28, 28] 0\n", " Conv2d-22 [-1, 512, 28, 28] 2,359,808\n", " ReLU-23 [-1, 512, 28, 28] 0\n", " MaxPool2d-24 [-1, 512, 14, 14] 0\n", " Conv2d-25 [-1, 512, 14, 14] 2,359,808\n", " ReLU-26 [-1, 512, 14, 14] 0\n", " Conv2d-27 [-1, 512, 14, 14] 2,359,808\n", " ReLU-28 [-1, 512, 14, 14] 0\n", " Conv2d-29 [-1, 512, 14, 14] 2,359,808\n", " ReLU-30 [-1, 512, 14, 14] 0\n", " MaxPool2d-31 [-1, 512, 7, 7] 0\n", "AdaptiveAvgPool2d-32 [-1, 512, 7, 7] 0\n", " Linear-33 [-1, 4096] 102,764,544\n", " ReLU-34 [-1, 4096] 0\n", " Dropout-35 [-1, 4096] 0\n", " Linear-36 [-1, 4096] 16,781,312\n", " ReLU-37 [-1, 4096] 0\n", " Dropout-38 [-1, 4096] 0\n", " Linear-39 [-1, 10] 40,970\n", "================================================================\n", "Total params: 134,301,514\n", "Trainable params: 134,301,514\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 0.57\n", "Forward/backward pass size (MB): 218.77\n", "Params size (MB): 512.32\n", "Estimated Total Size (MB): 731.67\n", "----------------------------------------------------------------\n" ] } ], "source": [ "summary(vgg, input_size=(3, 224, 224), device='cpu')" ] }, { "cell_type": "markdown", "id": "f7333d83", "metadata": {}, "source": [ "## parameters" ] }, { "cell_type": "code", "execution_count": 10, "id": "8d31ddee", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:06:28.035219Z", "start_time": "2023-02-21T16:06:28.032859Z" } }, "outputs": [], "source": [ "# dataset\n", "# input_shape = 32\n", "num_classes = 10\n", "\n", "# hyper \n", "batch_size = 64\n", "num_epochs = 5\n", "learning_rate = 1e-3\n", "\n", "# gpu\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "code", "execution_count": 11, "id": "72dcd7cf", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:06:30.620769Z", "start_time": "2023-02-21T16:06:30.618149Z" } }, "outputs": [ { "data": { "text/plain": [ "device(type='cuda')" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device" ] }, { "cell_type": "markdown", "id": "856f2db1", "metadata": {}, "source": [ "## dataset 与 dataloader" ] }, { "cell_type": "code", "execution_count": 12, "id": "421d5a1d", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:07:08.617211Z", "start_time": "2023-02-21T16:07:07.313136Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] } ], "source": [ "train_dataset = datasets.CIFAR10(root='../data/', \n", " download=True, \n", " train=True, \n", " transform=transforms.ToTensor())\n", "test_dataset = datasets.CIFAR10(root='../data/', \n", " download=True, \n", " train=False, \n", " transform=transforms.ToTensor())" ] }, { "cell_type": "code", "execution_count": 14, "id": "1da83915", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:08:15.650602Z", "start_time": "2023-02-21T16:07:37.819122Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==> Computing mean and std..\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50000/50000 [00:37<00:00, 1321.94it/s]\n" ] }, { "data": { "text/plain": [ "(tensor([0.4914, 0.4822, 0.4465]), tensor([0.2023, 0.1994, 0.2010]))" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_mean_and_std(train_dataset)" ] }, { "cell_type": "code", "execution_count": 16, "id": "f2379f10", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:09:30.085725Z", "start_time": "2023-02-21T16:09:30.083236Z" } }, "outputs": [], "source": [ "transform = transforms.Compose([\n", " transforms.Resize(size=(224, 224)),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))\n", "])" ] }, { "cell_type": "code", "execution_count": 17, "id": "00a568cd", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:09:51.421012Z", "start_time": "2023-02-21T16:09:50.112455Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] } ], "source": [ "train_dataset = datasets.CIFAR10(root='../data/', \n", " download=True, \n", " train=True, \n", " transform=transform)\n", "test_dataset = datasets.CIFAR10(root='../data/', \n", " download=True, \n", " train=False, \n", " transform=transform)" ] }, { "cell_type": "code", "execution_count": 18, "id": "c656663b", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:09:53.259344Z", "start_time": "2023-02-21T16:09:53.256872Z" } }, "outputs": [], "source": [ "train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, \n", " shuffle=True, \n", " batch_size=batch_size)\n", "test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, \n", " shuffle=False, \n", " batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": 19, "id": "f77ea73d", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:09:55.520559Z", "start_time": "2023-02-21T16:09:55.462862Z" } }, "outputs": [], "source": [ "images, labels = next(iter(train_dataloader))" ] }, { "cell_type": "code", "execution_count": 20, "id": "d03b9505", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:09:56.669215Z", "start_time": "2023-02-21T16:09:56.666615Z" } }, "outputs": [ { "data": { "text/plain": [ "torch.Size([64, 3, 224, 224])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# batch_size, channels, h, w\n", "images.shape" ] }, { "cell_type": "markdown", "id": "20b8c4b4", "metadata": {}, "source": [ "## model arch" ] }, { "cell_type": "code", "execution_count": 21, "id": "d69e7f3e", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:10:48.937581Z", "start_time": "2023-02-21T16:10:44.764451Z" } }, "outputs": [], "source": [ "vgg = models.vgg16(pretrained=True)\n", "in_features = vgg.classifier[6].in_features\n", "vgg.classifier[6] = nn.Linear(in_features, 10)\n", "vgg = vgg.to(device)" ] }, { "cell_type": "markdown", "id": "872d454c", "metadata": {}, "source": [ "## model train/fine-tune" ] }, { "cell_type": "code", "execution_count": 24, "id": "3e0babe0", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:13:24.682178Z", "start_time": "2023-02-21T16:13:24.679390Z" } }, "outputs": [], "source": [ "criterion = nn.CrossEntropyLoss()\n", "# optimzier = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", "optimizer = torch.optim.SGD(vgg.parameters(), lr = learning_rate, momentum=0.9,weight_decay=5e-4)\n", "total_batch = len(train_dataloader)" ] }, { "cell_type": "code", "execution_count": 25, "id": "636c01fb", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:20:46.902487Z", "start_time": "2023-02-21T16:13:25.728051Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2023-02-22 00:13:56.818252, 1/5, 100/782: 0.4080, acc: 0.859375\n", "2023-02-22 00:14:28.367012, 1/5, 200/782: 0.4214, acc: 0.90625\n", "2023-02-22 00:15:00.193995, 1/5, 300/782: 0.5746, acc: 0.828125\n", "2023-02-22 00:15:32.204011, 1/5, 400/782: 0.3419, acc: 0.875\n", "2023-02-22 00:16:04.007379, 1/5, 500/782: 0.3597, acc: 0.859375\n", "2023-02-22 00:16:35.717492, 1/5, 600/782: 0.4721, acc: 0.796875\n", "2023-02-22 00:17:07.485007, 1/5, 700/782: 0.1224, acc: 0.953125\n", "2023-02-22 00:18:05.079364, 2/5, 100/782: 0.1760, acc: 0.921875\n", "2023-02-22 00:18:36.859823, 2/5, 200/782: 0.1521, acc: 0.921875\n", "2023-02-22 00:19:08.638372, 2/5, 300/782: 0.2511, acc: 0.90625\n", "2023-02-22 00:19:40.422421, 2/5, 400/782: 0.1399, acc: 0.9375\n", "2023-02-22 00:20:12.218806, 2/5, 500/782: 0.1915, acc: 0.90625\n", "2023-02-22 00:20:44.003433, 2/5, 600/782: 0.1382, acc: 0.9375\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;31m# 标准的处理,用 validate data;这个过程是监督训练过程\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mn_corrects\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0macc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mn_corrects\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "for epoch in range(num_epochs):\n", " for batch_idx, (images, labels) in enumerate(train_dataloader):\n", " images = images.to(device)\n", " labels = labels.to(device)\n", " \n", " # forward\n", " out = vgg(images)\n", " loss = criterion(out, labels)\n", " \n", " # 标准的处理,用 validate data;这个过程是监督训练过程,用于 early stop\n", " n_corrects = (out.argmax(axis=1) == labels).sum().item()\n", " acc = n_corrects/labels.size(0)\n", " \n", " # backward\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step() # 更细 模型参数\n", " \n", " if (batch_idx+1) % 100 == 0:\n", " print(f'{datetime.now()}, {epoch+1}/{num_epochs}, {batch_idx+1}/{total_batch}: {loss.item():.4f}, acc: {acc}')" ] }, { "cell_type": "markdown", "id": "5e7078dd", "metadata": {}, "source": [ "## model evaluation" ] }, { "cell_type": "code", "execution_count": 27, "id": "ef80d16f", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:21:06.582989Z", "start_time": "2023-02-21T16:20:55.012423Z" } }, "outputs": [ { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtotal\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mcorrect\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtest_dataloader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mimages\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 519\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 520\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 521\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 522\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 559\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 560\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 561\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 562\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 563\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torchvision/datasets/cifar.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 121\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 122\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_transform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torchvision/transforms/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/torch/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1095\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1096\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1097\u001b[0;31m \u001b[0mforward_call\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_tracing_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1098\u001b[0m \u001b[0;31m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[0;31m# this function, and just call forward.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "total = 0\n", "correct = 0\n", "for images, labels in test_dataloader:\n", " images = images.to(device)\n", " labels = labels.to(device)\n", " out = vgg(images)\n", " preds = torch.argmax(out, dim=1)\n", " \n", " total += images.size(0)\n", " correct += (preds == labels).sum().item()\n", "print(f'{correct}/{total}={correct/total}')" ] }, { "cell_type": "markdown", "id": "96272f61", "metadata": {}, "source": [ "## model save" ] }, { "cell_type": "code", "execution_count": null, "id": "2e68f24f", "metadata": { "ExecuteTime": { "end_time": "2023-02-21T16:21:06.587517Z", "start_time": "2023-02-21T16:21:03.590Z" } }, "outputs": [], "source": [ "torch.save(vgg.state_dict(), 'cnn_cifar.ckpt')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": true }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }