{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "36600735", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T07:42:26.334897Z", "start_time": "2023-02-12T07:42:25.748244Z" } }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from torchvision import datasets, transforms" ] }, { "cell_type": "markdown", "id": "f7333d83", "metadata": {}, "source": [ "## parameters" ] }, { "cell_type": "code", "execution_count": 19, "id": "8d31ddee", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T08:01:39.702046Z", "start_time": "2023-02-12T08:01:39.696278Z" } }, "outputs": [], "source": [ "# dataset\n", "input_shape = 28\n", "num_classes = 10\n", "\n", "# hyper \n", "batch_size = 64\n", "num_epochs = 5\n", "learning_rate = 1e-3\n", "\n", "# gpu\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "code", "execution_count": 15, "id": "72dcd7cf", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T07:59:43.271511Z", "start_time": "2023-02-12T07:59:43.264504Z" } }, "outputs": [ { "data": { "text/plain": [ "device(type='cuda')" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device" ] }, { "cell_type": "markdown", "id": "856f2db1", "metadata": {}, "source": [ "## dataset 与 dataloader" ] }, { "cell_type": "code", "execution_count": 3, "id": "421d5a1d", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T07:42:58.378258Z", "start_time": "2023-02-12T07:42:58.318280Z" } }, "outputs": [], "source": [ "train_dataset = datasets.MNIST(root='../data/', \n", " download=True, \n", " train=True, \n", " transform=transforms.ToTensor())\n", "test_dataset = datasets.MNIST(root='../data/', \n", " download=True, \n", " train=False, \n", " transform=transforms.ToTensor())" ] }, { "cell_type": "code", "execution_count": 4, "id": "c656663b", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T07:43:16.544394Z", "start_time": "2023-02-12T07:43:16.538929Z" } }, "outputs": [], "source": [ "train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, \n", " shuffle=True, \n", " batch_size=batch_size)\n", "test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, \n", " shuffle=False, \n", " batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": 5, "id": "f77ea73d", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T07:43:18.810041Z", "start_time": "2023-02-12T07:43:18.790215Z" } }, "outputs": [], "source": [ "images, labels = next(iter(train_dataloader))" ] }, { "cell_type": "code", "execution_count": 7, "id": "d03b9505", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T07:44:03.036073Z", "start_time": "2023-02-12T07:44:03.029497Z" } }, "outputs": [ { "data": { "text/plain": [ "torch.Size([64, 1, 28, 28])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# batch_size, channels, h, w\n", "images.shape" ] }, { "cell_type": "markdown", "id": "20b8c4b4", "metadata": {}, "source": [ "## model arch" ] }, { "cell_type": "markdown", "id": "6a3182ef", "metadata": {}, "source": [ "- cnn: channel 不断增加,shape 不断减少的过程\n", " - 最好是 *2" ] }, { "cell_type": "code", "execution_count": 23, "id": "521ddf96", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T08:05:01.738399Z", "start_time": "2023-02-12T08:05:01.731637Z" } }, "outputs": [], "source": [ "class CNN(nn.Module):\n", " def __init__(self, input_shape, in_channels, num_classes):\n", " super(CNN, self).__init__()\n", " # conv2d: (b, 1, 28, 28) => (b, 16, 28, 28)\n", " # maxpool2d: (b, 16, 28, 28) => (b, 16, 14, 14)\n", " self.cnn1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=16, \n", " kernel_size=5, padding=2, stride=1), \n", " nn.BatchNorm2d(16), \n", " nn.ReLU(), \n", " nn.MaxPool2d(kernel_size=2, stride=2))\n", " \n", " # conv2d: (b, 16, 14, 14) => (b, 32, 14, 14)\n", " # maxpool2d: (b, 32, 14, 14) => (b, 32, 7, 7)\n", " self.cnn2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, \n", " kernel_size=5, padding=2, stride=1), \n", " nn.BatchNorm2d(32), \n", " nn.ReLU(), \n", " nn.MaxPool2d(kernel_size=2, stride=2))\n", " # (b, 32, 7, 7) => (b, 32*7*7)\n", " # (b, 32*7*7) => (b, 10)\n", " self.fc = nn.Linear(32*(input_shape//4)*(input_shape//4), num_classes)\n", "\n", " \n", " def forward(self, x):\n", " # (b, 1, 28, 28) => (b, 16, 14, 14)\n", " out = self.cnn1(x)\n", " # (b, 16, 14, 14) => (b, 32, 7, 7)\n", " out = self.cnn2(out)\n", " # (b, 32, 7, 7) => (b, 32*7*7)\n", " out = out.reshape(out.size(0), -1)\n", " out = self.fc(out)\n", " return out\n", " " ] }, { "cell_type": "markdown", "id": "6ab15298", "metadata": {}, "source": [ "### torchsummary" ] }, { "cell_type": "code", "execution_count": 9, "id": "4e0b41cd", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T07:57:32.036968Z", "start_time": "2023-02-12T07:57:30.861860Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\r\n", "Requirement already satisfied: torchsummary in /home/chzhang/anaconda3/envs/ldm/lib/python3.8/site-packages (1.5.1)\r\n" ] } ], "source": [ "!pip install torchsummary" ] }, { "cell_type": "code", "execution_count": 10, "id": "78aa128c", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T07:57:38.852830Z", "start_time": "2023-02-12T07:57:38.847705Z" } }, "outputs": [], "source": [ "from torchsummary import summary" ] }, { "cell_type": "code", "execution_count": 24, "id": "306be169", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T08:05:05.639359Z", "start_time": "2023-02-12T08:05:05.623987Z" } }, "outputs": [], "source": [ "model = CNN(input_shape=input_shape, num_classes=num_classes, in_channels=1).to(device)" ] }, { "cell_type": "code", "execution_count": 17, "id": "a4812e64", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T07:59:54.714382Z", "start_time": "2023-02-12T07:59:54.222808Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", " Conv2d-1 [64, 16, 28, 28] 416\n", " BatchNorm2d-2 [64, 16, 28, 28] 32\n", " ReLU-3 [64, 16, 28, 28] 0\n", " MaxPool2d-4 [64, 16, 14, 14] 0\n", " Conv2d-5 [64, 32, 14, 14] 12,832\n", " BatchNorm2d-6 [64, 32, 14, 14] 64\n", " ReLU-7 [64, 32, 14, 14] 0\n", " MaxPool2d-8 [64, 32, 7, 7] 0\n", " Linear-9 [64, 10] 15,690\n", "================================================================\n", "Total params: 29,034\n", "Trainable params: 29,034\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 0.19\n", "Forward/backward pass size (MB): 29.86\n", "Params size (MB): 0.11\n", "Estimated Total Size (MB): 30.17\n", "----------------------------------------------------------------\n" ] } ], "source": [ "summary(model, input_size=(1, 28, 28), batch_size=batch_size)" ] }, { "cell_type": "markdown", "id": "872d454c", "metadata": {}, "source": [ "## model train" ] }, { "cell_type": "code", "execution_count": 25, "id": "3e0babe0", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T08:05:09.085935Z", "start_time": "2023-02-12T08:05:09.080188Z" } }, "outputs": [], "source": [ "criterion = nn.CrossEntropyLoss()\n", "optimzer = torch.optim.Adam(model.parameters(), lr=learning_rate)" ] }, { "cell_type": "code", "execution_count": 21, "id": "bc2ad9e6", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T08:04:31.938635Z", "start_time": "2023-02-12T08:04:31.933653Z" } }, "outputs": [], "source": [ "total_batch = len(train_dataloader)" ] }, { "cell_type": "code", "execution_count": 26, "id": "636c01fb", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T08:05:42.541533Z", "start_time": "2023-02-12T08:05:11.254793Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/5, 100/938: 0.2435\n", "1/5, 200/938: 0.0455\n", "1/5, 300/938: 0.0351\n", "1/5, 400/938: 0.2363\n", "1/5, 500/938: 0.0385\n", "1/5, 600/938: 0.0285\n", "1/5, 700/938: 0.0369\n", "1/5, 800/938: 0.0443\n", "1/5, 900/938: 0.1301\n", "2/5, 100/938: 0.0114\n", "2/5, 200/938: 0.0374\n", "2/5, 300/938: 0.0232\n", "2/5, 400/938: 0.0625\n", "2/5, 500/938: 0.0478\n", "2/5, 600/938: 0.0270\n", "2/5, 700/938: 0.0067\n", "2/5, 800/938: 0.0709\n", "2/5, 900/938: 0.0186\n", "3/5, 100/938: 0.0077\n", "3/5, 200/938: 0.0060\n", "3/5, 300/938: 0.0492\n", "3/5, 400/938: 0.0459\n", "3/5, 500/938: 0.0081\n", "3/5, 600/938: 0.0468\n", "3/5, 700/938: 0.0169\n", "3/5, 800/938: 0.0386\n", "3/5, 900/938: 0.0110\n", "4/5, 100/938: 0.0027\n", "4/5, 200/938: 0.0240\n", "4/5, 300/938: 0.0047\n", "4/5, 400/938: 0.0035\n", "4/5, 500/938: 0.0002\n", "4/5, 600/938: 0.0321\n", "4/5, 700/938: 0.0060\n", "4/5, 800/938: 0.0004\n", "4/5, 900/938: 0.0597\n", "5/5, 100/938: 0.0049\n", "5/5, 200/938: 0.0140\n", "5/5, 300/938: 0.0072\n", "5/5, 400/938: 0.0270\n", "5/5, 500/938: 0.0177\n", "5/5, 600/938: 0.0005\n", "5/5, 700/938: 0.0047\n", "5/5, 800/938: 0.0440\n", "5/5, 900/938: 0.0094\n" ] } ], "source": [ "for epoch in range(num_epochs):\n", " for batch_idx, (images, labels) in enumerate(train_dataloader):\n", " images = images.to(device)\n", " labels = labels.to(device)\n", " \n", " # forward\n", " out = model(images)\n", " loss = criterion(out, labels)\n", " \n", " # backward\n", " optimzer.zero_grad()\n", " loss.backward()\n", " optimzer.step() # 更细 模型参数\n", " \n", " if (batch_idx+1) % 100 == 0:\n", " print(f'{epoch+1}/{num_epochs}, {batch_idx+1}/{total_batch}: {loss.item():.4f}')" ] }, { "cell_type": "markdown", "id": "5e7078dd", "metadata": {}, "source": [ "## model evaluation" ] }, { "cell_type": "code", "execution_count": 27, "id": "9d99e59a", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T08:07:52.577181Z", "start_time": "2023-02-12T08:07:52.228875Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9906/10000=0.9906\n" ] } ], "source": [ "total = 0\n", "correct = 0\n", "for images, labels in test_dataloader:\n", " images = images.to(device)\n", " labels = labels.to(device)\n", " out = model(images)\n", " preds = torch.argmax(out, dim=1)\n", " \n", " total += images.size(0)\n", " correct += (preds == labels).sum().item()\n", "print(f'{correct}/{total}={correct/total}')" ] }, { "cell_type": "markdown", "id": "96272f61", "metadata": {}, "source": [ "## model save" ] }, { "cell_type": "code", "execution_count": 29, "id": "2e68f24f", "metadata": { "ExecuteTime": { "end_time": "2023-02-12T08:08:08.828595Z", "start_time": "2023-02-12T08:08:08.817961Z" } }, "outputs": [], "source": [ "torch.save(model.state_dict(), 'cnn_mnist.ckpt')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": true }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }