summaryrefslogtreecommitdiff
path: root/dl/tutorials/03_cnn_on_cifar10.ipynb
diff options
context:
space:
mode:
authorchunhuizhang <zch921005@126.com>2023-02-13 22:33:25 +0800
committerchunhuizhang <zch921005@126.com>2023-02-13 22:33:25 +0800
commit13d8981e9b949046d7e457d7c654802a426eab94 (patch)
tree827b5aec2f44e4b3bf6250c9420d0e33ca0ae0a3 /dl/tutorials/03_cnn_on_cifar10.ipynb
parent67595034b5184ce3cfb5b286ae0da23aa736ade6 (diff)
cnn on cifar10
Diffstat (limited to 'dl/tutorials/03_cnn_on_cifar10.ipynb')
-rw-r--r--dl/tutorials/03_cnn_on_cifar10.ipynb673
1 files changed, 673 insertions, 0 deletions
diff --git a/dl/tutorials/03_cnn_on_cifar10.ipynb b/dl/tutorials/03_cnn_on_cifar10.ipynb
new file mode 100644
index 0000000..ca9bc2b
--- /dev/null
+++ b/dl/tutorials/03_cnn_on_cifar10.ipynb
@@ -0,0 +1,673 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "36600735",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:22:25.389365Z",
+ "start_time": "2023-02-13T14:22:25.385967Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from torch import nn\n",
+ "from torchvision import datasets, transforms\n",
+ "from datetime import datetime"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f7333d83",
+ "metadata": {},
+ "source": [
+ "## parameters"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "8d31ddee",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:16:31.908271Z",
+ "start_time": "2023-02-13T14:16:31.901670Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# dataset\n",
+ "input_shape = 32\n",
+ "num_classes = 10\n",
+ "\n",
+ "# hyper \n",
+ "batch_size = 64\n",
+ "num_epochs = 5\n",
+ "learning_rate = 1e-3\n",
+ "\n",
+ "# gpu\n",
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "72dcd7cf",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:15:42.457815Z",
+ "start_time": "2023-02-13T14:15:42.443734Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "device(type='cuda')"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "device"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "856f2db1",
+ "metadata": {},
+ "source": [
+ "## dataset 与 dataloader"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "421d5a1d",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:16:04.373037Z",
+ "start_time": "2023-02-13T14:16:02.822820Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Files already downloaded and verified\n",
+ "Files already downloaded and verified\n"
+ ]
+ }
+ ],
+ "source": [
+ "train_dataset = datasets.CIFAR10(root='../data/', \n",
+ " download=True, \n",
+ " train=True, \n",
+ " transform=transforms.ToTensor())\n",
+ "test_dataset = datasets.CIFAR10(root='../data/', \n",
+ " download=True, \n",
+ " train=False, \n",
+ " transform=transforms.ToTensor())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "c656663b",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:16:11.223316Z",
+ "start_time": "2023-02-13T14:16:11.218212Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, \n",
+ " shuffle=True, \n",
+ " batch_size=batch_size)\n",
+ "test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, \n",
+ " shuffle=False, \n",
+ " batch_size=batch_size)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "f77ea73d",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:16:15.712110Z",
+ "start_time": "2023-02-13T14:16:15.691181Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "images, labels = next(iter(train_dataloader))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "d03b9505",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:16:20.441587Z",
+ "start_time": "2023-02-13T14:16:20.434654Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([64, 3, 32, 32])"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# batch_size, channels, h, w\n",
+ "images.shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "20b8c4b4",
+ "metadata": {},
+ "source": [
+ "## model arch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6a3182ef",
+ "metadata": {},
+ "source": [
+ "- cnn: channel 不断增加,shape 不断减少的过程\n",
+ " - 最好是 *2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "521ddf96",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:24:52.447992Z",
+ "start_time": "2023-02-13T14:24:52.432526Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "class CNN(nn.Module):\n",
+ " def __init__(self, input_shape, in_channels, num_classes):\n",
+ " super(CNN, self).__init__()\n",
+ " # conv2d: (b, 1, 28, 28) => (b, 16, 28, 28)\n",
+ " # maxpool2d: (b, 16, 28, 28) => (b, 16, 14, 14)\n",
+ " self.cnn1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=16, \n",
+ " kernel_size=5, padding=2, stride=1), \n",
+ " nn.BatchNorm2d(16), \n",
+ " nn.ReLU(), \n",
+ " nn.MaxPool2d(kernel_size=2, stride=2))\n",
+ " \n",
+ " # conv2d: (b, 16, 14, 14) => (b, 32, 14, 14)\n",
+ " # maxpool2d: (b, 32, 14, 14) => (b, 32, 7, 7)\n",
+ " self.cnn2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, \n",
+ " kernel_size=5, padding=2, stride=1), \n",
+ " nn.BatchNorm2d(32), \n",
+ " nn.ReLU(), \n",
+ " nn.MaxPool2d(kernel_size=2, stride=2))\n",
+ " \n",
+ " self.cnn3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, \n",
+ " kernel_size=5, padding=2, stride=1), \n",
+ " nn.BatchNorm2d(64), \n",
+ " nn.ReLU(), \n",
+ " nn.MaxPool2d(kernel_size=2, stride=2))\n",
+ " self.cnn4 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, \n",
+ " kernel_size=5, padding=2, stride=1), \n",
+ " nn.BatchNorm2d(128), \n",
+ " nn.ReLU(), \n",
+ " nn.MaxPool2d(kernel_size=2, stride=2))\n",
+ " self.cnn5 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, \n",
+ " kernel_size=5, padding=2, stride=1), \n",
+ " nn.BatchNorm2d(256), \n",
+ " nn.ReLU(), \n",
+ " nn.MaxPool2d(kernel_size=2, stride=2))\n",
+ " # (b, 32, 7, 7) => (b, 32*7*7)\n",
+ " # (b, 32*7*7) => (b, 10)\n",
+ " self.fc = nn.Linear(256*(input_shape//32)*(input_shape//32), num_classes)\n",
+ "\n",
+ " \n",
+ " def forward(self, x):\n",
+ " # (b, 1, 28, 28) => (b, 16, 14, 14)\n",
+ " out = self.cnn1(x)\n",
+ " # (b, 16, 14, 14) => (b, 32, 7, 7)\n",
+ " out = self.cnn2(out)\n",
+ " # (b, 32, 7, 7) => (b, 32*7*7)\n",
+ " out = self.cnn3(out)\n",
+ " out = self.cnn4(out)\n",
+ " out = self.cnn5(out)\n",
+ " out = out.reshape(out.size(0), -1)\n",
+ " out = self.fc(out)\n",
+ " return out\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6ab15298",
+ "metadata": {},
+ "source": [
+ "### torchsummary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "4e0b41cd",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T07:57:32.036968Z",
+ "start_time": "2023-02-12T07:57:30.861860Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\r\n",
+ "Requirement already satisfied: torchsummary in /home/chzhang/anaconda3/envs/ldm/lib/python3.8/site-packages (1.5.1)\r\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install torchsummary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "78aa128c",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:17:07.748715Z",
+ "start_time": "2023-02-13T14:17:07.742454Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from torchsummary import summary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "306be169",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:24:57.381789Z",
+ "start_time": "2023-02-13T14:24:57.352963Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "model = CNN(input_shape=input_shape, num_classes=num_classes, in_channels=3).to(device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "a4812e64",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:25:00.958163Z",
+ "start_time": "2023-02-13T14:25:00.937024Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "----------------------------------------------------------------\n",
+ " Layer (type) Output Shape Param #\n",
+ "================================================================\n",
+ " Conv2d-1 [64, 16, 32, 32] 1,216\n",
+ " BatchNorm2d-2 [64, 16, 32, 32] 32\n",
+ " ReLU-3 [64, 16, 32, 32] 0\n",
+ " MaxPool2d-4 [64, 16, 16, 16] 0\n",
+ " Conv2d-5 [64, 32, 16, 16] 12,832\n",
+ " BatchNorm2d-6 [64, 32, 16, 16] 64\n",
+ " ReLU-7 [64, 32, 16, 16] 0\n",
+ " MaxPool2d-8 [64, 32, 8, 8] 0\n",
+ " Conv2d-9 [64, 64, 8, 8] 51,264\n",
+ " BatchNorm2d-10 [64, 64, 8, 8] 128\n",
+ " ReLU-11 [64, 64, 8, 8] 0\n",
+ " MaxPool2d-12 [64, 64, 4, 4] 0\n",
+ " Conv2d-13 [64, 128, 4, 4] 204,928\n",
+ " BatchNorm2d-14 [64, 128, 4, 4] 256\n",
+ " ReLU-15 [64, 128, 4, 4] 0\n",
+ " MaxPool2d-16 [64, 128, 2, 2] 0\n",
+ " Conv2d-17 [64, 256, 2, 2] 819,456\n",
+ " BatchNorm2d-18 [64, 256, 2, 2] 512\n",
+ " ReLU-19 [64, 256, 2, 2] 0\n",
+ " MaxPool2d-20 [64, 256, 1, 1] 0\n",
+ " Linear-21 [64, 10] 2,570\n",
+ "================================================================\n",
+ "Total params: 1,093,258\n",
+ "Trainable params: 1,093,258\n",
+ "Non-trainable params: 0\n",
+ "----------------------------------------------------------------\n",
+ "Input size (MB): 0.75\n",
+ "Forward/backward pass size (MB): 50.38\n",
+ "Params size (MB): 4.17\n",
+ "Estimated Total Size (MB): 55.30\n",
+ "----------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "summary(model, input_size=(3, 32, 32), batch_size=batch_size)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "872d454c",
+ "metadata": {},
+ "source": [
+ "## model train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "3e0babe0",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:25:08.929629Z",
+ "start_time": "2023-02-13T14:25:08.924328Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "criterion = nn.CrossEntropyLoss()\n",
+ "optimzer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "bc2ad9e6",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:17:49.220427Z",
+ "start_time": "2023-02-13T14:17:49.214834Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "total_batch = len(train_dataloader)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "1a018d2b",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:17:57.621018Z",
+ "start_time": "2023-02-13T14:17:57.613265Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "782"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "total_batch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "3bd485d4",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:18:14.619380Z",
+ "start_time": "2023-02-13T14:18:14.611993Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "781"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(train_dataset)//batch_size"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "id": "636c01fb",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:25:45.077613Z",
+ "start_time": "2023-02-13T14:25:12.098487Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2023-02-13 22:25:13.614828, 1/5, 100/782: 1.6673\n",
+ "2023-02-13 22:25:15.118278, 1/5, 200/782: 1.3886\n",
+ "2023-02-13 22:25:16.620148, 1/5, 300/782: 1.4714\n",
+ "2023-02-13 22:25:18.121508, 1/5, 400/782: 0.9283\n",
+ "2023-02-13 22:25:19.624122, 1/5, 500/782: 1.1564\n",
+ "2023-02-13 22:25:21.128540, 1/5, 600/782: 0.9533\n",
+ "2023-02-13 22:25:22.634274, 1/5, 700/782: 1.3402\n",
+ "2023-02-13 22:25:25.374337, 2/5, 100/782: 1.0992\n",
+ "2023-02-13 22:25:26.100186, 2/5, 200/782: 0.8210\n",
+ "2023-02-13 22:25:26.750454, 2/5, 300/782: 0.8591\n",
+ "2023-02-13 22:25:27.398051, 2/5, 400/782: 0.7915\n",
+ "2023-02-13 22:25:28.047762, 2/5, 500/782: 0.9805\n",
+ "2023-02-13 22:25:28.697176, 2/5, 600/782: 0.7892\n",
+ "2023-02-13 22:25:29.345190, 2/5, 700/782: 0.7435\n",
+ "2023-02-13 22:25:30.524457, 3/5, 100/782: 0.8694\n",
+ "2023-02-13 22:25:31.172259, 3/5, 200/782: 0.6262\n",
+ "2023-02-13 22:25:31.820744, 3/5, 300/782: 0.7870\n",
+ "2023-02-13 22:25:32.471617, 3/5, 400/782: 0.8567\n",
+ "2023-02-13 22:25:33.119088, 3/5, 500/782: 0.7772\n",
+ "2023-02-13 22:25:33.768611, 3/5, 600/782: 0.7215\n",
+ "2023-02-13 22:25:34.417291, 3/5, 700/782: 0.6598\n",
+ "2023-02-13 22:25:35.598934, 4/5, 100/782: 0.5161\n",
+ "2023-02-13 22:25:36.245626, 4/5, 200/782: 0.5731\n",
+ "2023-02-13 22:25:36.892254, 4/5, 300/782: 0.4769\n",
+ "2023-02-13 22:25:37.541522, 4/5, 400/782: 0.5950\n",
+ "2023-02-13 22:25:38.188024, 4/5, 500/782: 0.6995\n",
+ "2023-02-13 22:25:38.835657, 4/5, 600/782: 0.6877\n",
+ "2023-02-13 22:25:39.483413, 4/5, 700/782: 0.7518\n",
+ "2023-02-13 22:25:40.665808, 5/5, 100/782: 0.3128\n",
+ "2023-02-13 22:25:41.311349, 5/5, 200/782: 0.4721\n",
+ "2023-02-13 22:25:41.958491, 5/5, 300/782: 0.5731\n",
+ "2023-02-13 22:25:42.605772, 5/5, 400/782: 0.5555\n",
+ "2023-02-13 22:25:43.251503, 5/5, 500/782: 0.4549\n",
+ "2023-02-13 22:25:43.899441, 5/5, 600/782: 0.7287\n",
+ "2023-02-13 22:25:44.544882, 5/5, 700/782: 0.4048\n"
+ ]
+ }
+ ],
+ "source": [
+ "for epoch in range(num_epochs):\n",
+ " for batch_idx, (images, labels) in enumerate(train_dataloader):\n",
+ " images = images.to(device)\n",
+ " labels = labels.to(device)\n",
+ " \n",
+ " # forward\n",
+ " out = model(images)\n",
+ " loss = criterion(out, labels)\n",
+ " \n",
+ " # backward\n",
+ " optimzer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimzer.step() # 更细 模型参数\n",
+ " \n",
+ " if (batch_idx+1) % 100 == 0:\n",
+ " print(f'{datetime.now()}, {epoch+1}/{num_epochs}, {batch_idx+1}/{total_batch}: {loss.item():.4f}')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5e7078dd",
+ "metadata": {},
+ "source": [
+ "## model evaluation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "9d99e59a",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-13T14:26:04.834781Z",
+ "start_time": "2023-02-13T14:26:04.219699Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "7399/10000=0.7399\n"
+ ]
+ }
+ ],
+ "source": [
+ "total = 0\n",
+ "correct = 0\n",
+ "for images, labels in test_dataloader:\n",
+ " images = images.to(device)\n",
+ " labels = labels.to(device)\n",
+ " out = model(images)\n",
+ " preds = torch.argmax(out, dim=1)\n",
+ " \n",
+ " total += images.size(0)\n",
+ " correct += (preds == labels).sum().item()\n",
+ "print(f'{correct}/{total}={correct/total}')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "96272f61",
+ "metadata": {},
+ "source": [
+ "## model save"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "2e68f24f",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T08:08:08.828595Z",
+ "start_time": "2023-02-12T08:08:08.817961Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "torch.save(model.state_dict(), 'cnn_mnist.ckpt')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.5"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": true
+ },
+ "varInspector": {
+ "cols": {
+ "lenName": 16,
+ "lenType": 16,
+ "lenVar": 40
+ },
+ "kernels_config": {
+ "python": {
+ "delete_cmd_postfix": "",
+ "delete_cmd_prefix": "del ",
+ "library": "var_list.py",
+ "varRefreshCmd": "print(var_dic_list())"
+ },
+ "r": {
+ "delete_cmd_postfix": ") ",
+ "delete_cmd_prefix": "rm(",
+ "library": "var_list.r",
+ "varRefreshCmd": "cat(var_dic_list()) "
+ }
+ },
+ "types_to_exclude": [
+ "module",
+ "function",
+ "builtin_function_or_method",
+ "instance",
+ "_Feature"
+ ],
+ "window_display": false
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}