diff options
| author | chunhuizhang <zch921005@126.com> | 2023-02-12 16:20:39 +0800 |
|---|---|---|
| committer | chunhuizhang <zch921005@126.com> | 2023-02-12 16:20:39 +0800 |
| commit | 67595034b5184ce3cfb5b286ae0da23aa736ade6 (patch) | |
| tree | 53e47dd6e44f7a88fc60f3445c12251aa5479c90 /dl/tutorials | |
| parent | 76896af96a12d36bb530eebc4309c29c8eae66b0 (diff) | |
cnn on mnist
Diffstat (limited to 'dl/tutorials')
| -rw-r--r-- | dl/tutorials/02_cnn_on_mnist.ipynb | 590 |
1 files changed, 590 insertions, 0 deletions
diff --git a/dl/tutorials/02_cnn_on_mnist.ipynb b/dl/tutorials/02_cnn_on_mnist.ipynb new file mode 100644 index 0000000..e53672f --- /dev/null +++ b/dl/tutorials/02_cnn_on_mnist.ipynb @@ -0,0 +1,590 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "36600735", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T07:42:26.334897Z", + "start_time": "2023-02-12T07:42:25.748244Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "from torchvision import datasets, transforms" + ] + }, + { + "cell_type": "markdown", + "id": "f7333d83", + "metadata": {}, + "source": [ + "## parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "8d31ddee", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T08:01:39.702046Z", + "start_time": "2023-02-12T08:01:39.696278Z" + } + }, + "outputs": [], + "source": [ + "# dataset\n", + "input_shape = 28\n", + "num_classes = 10\n", + "\n", + "# hyper \n", + "batch_size = 64\n", + "num_epochs = 5\n", + "learning_rate = 1e-3\n", + "\n", + "# gpu\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "72dcd7cf", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T07:59:43.271511Z", + "start_time": "2023-02-12T07:59:43.264504Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cuda')" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device" + ] + }, + { + "cell_type": "markdown", + "id": "856f2db1", + "metadata": {}, + "source": [ + "## dataset 与 dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "421d5a1d", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T07:42:58.378258Z", + "start_time": "2023-02-12T07:42:58.318280Z" + } + }, + "outputs": [], + "source": [ + "train_dataset = datasets.MNIST(root='../data/', \n", + " download=True, \n", + " train=True, \n", + " transform=transforms.ToTensor())\n", + "test_dataset = datasets.MNIST(root='../data/', \n", + " download=True, \n", + " train=False, \n", + " transform=transforms.ToTensor())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c656663b", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T07:43:16.544394Z", + "start_time": "2023-02-12T07:43:16.538929Z" + } + }, + "outputs": [], + "source": [ + "train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, \n", + " shuffle=True, \n", + " batch_size=batch_size)\n", + "test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, \n", + " shuffle=False, \n", + " batch_size=batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f77ea73d", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T07:43:18.810041Z", + "start_time": "2023-02-12T07:43:18.790215Z" + } + }, + "outputs": [], + "source": [ + "images, labels = next(iter(train_dataloader))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d03b9505", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T07:44:03.036073Z", + "start_time": "2023-02-12T07:44:03.029497Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([64, 1, 28, 28])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# batch_size, channels, h, w\n", + "images.shape" + ] + }, + { + "cell_type": "markdown", + "id": "20b8c4b4", + "metadata": {}, + "source": [ + "## model arch" + ] + }, + { + "cell_type": "markdown", + "id": "6a3182ef", + "metadata": {}, + "source": [ + "- cnn: channel 不断增加,shape 不断减少的过程\n", + " - 最好是 *2" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "521ddf96", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T08:05:01.738399Z", + "start_time": "2023-02-12T08:05:01.731637Z" + } + }, + "outputs": [], + "source": [ + "class CNN(nn.Module):\n", + " def __init__(self, input_shape, in_channels, num_classes):\n", + " super(CNN, self).__init__()\n", + " # conv2d: (b, 1, 28, 28) => (b, 16, 28, 28)\n", + " # maxpool2d: (b, 16, 28, 28) => (b, 16, 14, 14)\n", + " self.cnn1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=16, \n", + " kernel_size=5, padding=2, stride=1), \n", + " nn.BatchNorm2d(16), \n", + " nn.ReLU(), \n", + " nn.MaxPool2d(kernel_size=2, stride=2))\n", + " \n", + " # conv2d: (b, 16, 14, 14) => (b, 32, 14, 14)\n", + " # maxpool2d: (b, 32, 14, 14) => (b, 32, 7, 7)\n", + " self.cnn2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, \n", + " kernel_size=5, padding=2, stride=1), \n", + " nn.BatchNorm2d(32), \n", + " nn.ReLU(), \n", + " nn.MaxPool2d(kernel_size=2, stride=2))\n", + " # (b, 32, 7, 7) => (b, 32*7*7)\n", + " # (b, 32*7*7) => (b, 10)\n", + " self.fc = nn.Linear(32*(input_shape//4)*(input_shape//4), num_classes)\n", + "\n", + " \n", + " def forward(self, x):\n", + " # (b, 1, 28, 28) => (b, 16, 14, 14)\n", + " out = self.cnn1(x)\n", + " # (b, 16, 14, 14) => (b, 32, 7, 7)\n", + " out = self.cnn2(out)\n", + " # (b, 32, 7, 7) => (b, 32*7*7)\n", + " out = out.reshape(out.size(0), -1)\n", + " out = self.fc(out)\n", + " return out\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "6ab15298", + "metadata": {}, + "source": [ + "### torchsummary" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "4e0b41cd", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T07:57:32.036968Z", + "start_time": "2023-02-12T07:57:30.861860Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\r\n", + "Requirement already satisfied: torchsummary in /home/chzhang/anaconda3/envs/ldm/lib/python3.8/site-packages (1.5.1)\r\n" + ] + } + ], + "source": [ + "!pip install torchsummary" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "78aa128c", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T07:57:38.852830Z", + "start_time": "2023-02-12T07:57:38.847705Z" + } + }, + "outputs": [], + "source": [ + "from torchsummary import summary" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "306be169", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T08:05:05.639359Z", + "start_time": "2023-02-12T08:05:05.623987Z" + } + }, + "outputs": [], + "source": [ + "model = CNN(input_shape=input_shape, num_classes=num_classes, in_channels=1).to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "a4812e64", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T07:59:54.714382Z", + "start_time": "2023-02-12T07:59:54.222808Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "----------------------------------------------------------------\n", + " Layer (type) Output Shape Param #\n", + "================================================================\n", + " Conv2d-1 [64, 16, 28, 28] 416\n", + " BatchNorm2d-2 [64, 16, 28, 28] 32\n", + " ReLU-3 [64, 16, 28, 28] 0\n", + " MaxPool2d-4 [64, 16, 14, 14] 0\n", + " Conv2d-5 [64, 32, 14, 14] 12,832\n", + " BatchNorm2d-6 [64, 32, 14, 14] 64\n", + " ReLU-7 [64, 32, 14, 14] 0\n", + " MaxPool2d-8 [64, 32, 7, 7] 0\n", + " Linear-9 [64, 10] 15,690\n", + "================================================================\n", + "Total params: 29,034\n", + "Trainable params: 29,034\n", + "Non-trainable params: 0\n", + "----------------------------------------------------------------\n", + "Input size (MB): 0.19\n", + "Forward/backward pass size (MB): 29.86\n", + "Params size (MB): 0.11\n", + "Estimated Total Size (MB): 30.17\n", + "----------------------------------------------------------------\n" + ] + } + ], + "source": [ + "summary(model, input_size=(1, 28, 28), batch_size=batch_size)" + ] + }, + { + "cell_type": "markdown", + "id": "872d454c", + "metadata": {}, + "source": [ + "## model train" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "3e0babe0", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T08:05:09.085935Z", + "start_time": "2023-02-12T08:05:09.080188Z" + } + }, + "outputs": [], + "source": [ + "criterion = nn.CrossEntropyLoss()\n", + "optimzer = torch.optim.Adam(model.parameters(), lr=learning_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "bc2ad9e6", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T08:04:31.938635Z", + "start_time": "2023-02-12T08:04:31.933653Z" + } + }, + "outputs": [], + "source": [ + "total_batch = len(train_dataloader)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "636c01fb", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T08:05:42.541533Z", + "start_time": "2023-02-12T08:05:11.254793Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1/5, 100/938: 0.2435\n", + "1/5, 200/938: 0.0455\n", + "1/5, 300/938: 0.0351\n", + "1/5, 400/938: 0.2363\n", + "1/5, 500/938: 0.0385\n", + "1/5, 600/938: 0.0285\n", + "1/5, 700/938: 0.0369\n", + "1/5, 800/938: 0.0443\n", + "1/5, 900/938: 0.1301\n", + "2/5, 100/938: 0.0114\n", + "2/5, 200/938: 0.0374\n", + "2/5, 300/938: 0.0232\n", + "2/5, 400/938: 0.0625\n", + "2/5, 500/938: 0.0478\n", + "2/5, 600/938: 0.0270\n", + "2/5, 700/938: 0.0067\n", + "2/5, 800/938: 0.0709\n", + "2/5, 900/938: 0.0186\n", + "3/5, 100/938: 0.0077\n", + "3/5, 200/938: 0.0060\n", + "3/5, 300/938: 0.0492\n", + "3/5, 400/938: 0.0459\n", + "3/5, 500/938: 0.0081\n", + "3/5, 600/938: 0.0468\n", + "3/5, 700/938: 0.0169\n", + "3/5, 800/938: 0.0386\n", + "3/5, 900/938: 0.0110\n", + "4/5, 100/938: 0.0027\n", + "4/5, 200/938: 0.0240\n", + "4/5, 300/938: 0.0047\n", + "4/5, 400/938: 0.0035\n", + "4/5, 500/938: 0.0002\n", + "4/5, 600/938: 0.0321\n", + "4/5, 700/938: 0.0060\n", + "4/5, 800/938: 0.0004\n", + "4/5, 900/938: 0.0597\n", + "5/5, 100/938: 0.0049\n", + "5/5, 200/938: 0.0140\n", + "5/5, 300/938: 0.0072\n", + "5/5, 400/938: 0.0270\n", + "5/5, 500/938: 0.0177\n", + "5/5, 600/938: 0.0005\n", + "5/5, 700/938: 0.0047\n", + "5/5, 800/938: 0.0440\n", + "5/5, 900/938: 0.0094\n" + ] + } + ], + "source": [ + "for epoch in range(num_epochs):\n", + " for batch_idx, (images, labels) in enumerate(train_dataloader):\n", + " images = images.to(device)\n", + " labels = labels.to(device)\n", + " \n", + " # forward\n", + " out = model(images)\n", + " loss = criterion(out, labels)\n", + " \n", + " # backward\n", + " optimzer.zero_grad()\n", + " loss.backward()\n", + " optimzer.step() # 更细 模型参数\n", + " \n", + " if (batch_idx+1) % 100 == 0:\n", + " print(f'{epoch+1}/{num_epochs}, {batch_idx+1}/{total_batch}: {loss.item():.4f}')" + ] + }, + { + "cell_type": "markdown", + "id": "5e7078dd", + "metadata": {}, + "source": [ + "## model evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "9d99e59a", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T08:07:52.577181Z", + "start_time": "2023-02-12T08:07:52.228875Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9906/10000=0.9906\n" + ] + } + ], + "source": [ + "total = 0\n", + "correct = 0\n", + "for images, labels in test_dataloader:\n", + " images = images.to(device)\n", + " labels = labels.to(device)\n", + " out = model(images)\n", + " preds = torch.argmax(out, dim=1)\n", + " \n", + " total += images.size(0)\n", + " correct += (preds == labels).sum().item()\n", + "print(f'{correct}/{total}={correct/total}')" + ] + }, + { + "cell_type": "markdown", + "id": "96272f61", + "metadata": {}, + "source": [ + "## model save" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "2e68f24f", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-12T08:08:08.828595Z", + "start_time": "2023-02-12T08:08:08.817961Z" + } + }, + "outputs": [], + "source": [ + "torch.save(model.state_dict(), 'cnn_mnist.ckpt')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |
