summaryrefslogtreecommitdiff
path: root/dl
diff options
context:
space:
mode:
Diffstat (limited to 'dl')
-rw-r--r--dl/tutorials/02_cnn_on_mnist.ipynb590
1 files changed, 590 insertions, 0 deletions
diff --git a/dl/tutorials/02_cnn_on_mnist.ipynb b/dl/tutorials/02_cnn_on_mnist.ipynb
new file mode 100644
index 0000000..e53672f
--- /dev/null
+++ b/dl/tutorials/02_cnn_on_mnist.ipynb
@@ -0,0 +1,590 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "36600735",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T07:42:26.334897Z",
+ "start_time": "2023-02-12T07:42:25.748244Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from torch import nn\n",
+ "from torchvision import datasets, transforms"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f7333d83",
+ "metadata": {},
+ "source": [
+ "## parameters"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "8d31ddee",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T08:01:39.702046Z",
+ "start_time": "2023-02-12T08:01:39.696278Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# dataset\n",
+ "input_shape = 28\n",
+ "num_classes = 10\n",
+ "\n",
+ "# hyper \n",
+ "batch_size = 64\n",
+ "num_epochs = 5\n",
+ "learning_rate = 1e-3\n",
+ "\n",
+ "# gpu\n",
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "72dcd7cf",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T07:59:43.271511Z",
+ "start_time": "2023-02-12T07:59:43.264504Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "device(type='cuda')"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "device"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "856f2db1",
+ "metadata": {},
+ "source": [
+ "## dataset 与 dataloader"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "421d5a1d",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T07:42:58.378258Z",
+ "start_time": "2023-02-12T07:42:58.318280Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "train_dataset = datasets.MNIST(root='../data/', \n",
+ " download=True, \n",
+ " train=True, \n",
+ " transform=transforms.ToTensor())\n",
+ "test_dataset = datasets.MNIST(root='../data/', \n",
+ " download=True, \n",
+ " train=False, \n",
+ " transform=transforms.ToTensor())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "c656663b",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T07:43:16.544394Z",
+ "start_time": "2023-02-12T07:43:16.538929Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, \n",
+ " shuffle=True, \n",
+ " batch_size=batch_size)\n",
+ "test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, \n",
+ " shuffle=False, \n",
+ " batch_size=batch_size)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "f77ea73d",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T07:43:18.810041Z",
+ "start_time": "2023-02-12T07:43:18.790215Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "images, labels = next(iter(train_dataloader))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "d03b9505",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T07:44:03.036073Z",
+ "start_time": "2023-02-12T07:44:03.029497Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([64, 1, 28, 28])"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# batch_size, channels, h, w\n",
+ "images.shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "20b8c4b4",
+ "metadata": {},
+ "source": [
+ "## model arch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6a3182ef",
+ "metadata": {},
+ "source": [
+ "- cnn: channel 不断增加,shape 不断减少的过程\n",
+ " - 最好是 *2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "521ddf96",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T08:05:01.738399Z",
+ "start_time": "2023-02-12T08:05:01.731637Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "class CNN(nn.Module):\n",
+ " def __init__(self, input_shape, in_channels, num_classes):\n",
+ " super(CNN, self).__init__()\n",
+ " # conv2d: (b, 1, 28, 28) => (b, 16, 28, 28)\n",
+ " # maxpool2d: (b, 16, 28, 28) => (b, 16, 14, 14)\n",
+ " self.cnn1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=16, \n",
+ " kernel_size=5, padding=2, stride=1), \n",
+ " nn.BatchNorm2d(16), \n",
+ " nn.ReLU(), \n",
+ " nn.MaxPool2d(kernel_size=2, stride=2))\n",
+ " \n",
+ " # conv2d: (b, 16, 14, 14) => (b, 32, 14, 14)\n",
+ " # maxpool2d: (b, 32, 14, 14) => (b, 32, 7, 7)\n",
+ " self.cnn2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, \n",
+ " kernel_size=5, padding=2, stride=1), \n",
+ " nn.BatchNorm2d(32), \n",
+ " nn.ReLU(), \n",
+ " nn.MaxPool2d(kernel_size=2, stride=2))\n",
+ " # (b, 32, 7, 7) => (b, 32*7*7)\n",
+ " # (b, 32*7*7) => (b, 10)\n",
+ " self.fc = nn.Linear(32*(input_shape//4)*(input_shape//4), num_classes)\n",
+ "\n",
+ " \n",
+ " def forward(self, x):\n",
+ " # (b, 1, 28, 28) => (b, 16, 14, 14)\n",
+ " out = self.cnn1(x)\n",
+ " # (b, 16, 14, 14) => (b, 32, 7, 7)\n",
+ " out = self.cnn2(out)\n",
+ " # (b, 32, 7, 7) => (b, 32*7*7)\n",
+ " out = out.reshape(out.size(0), -1)\n",
+ " out = self.fc(out)\n",
+ " return out\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6ab15298",
+ "metadata": {},
+ "source": [
+ "### torchsummary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "4e0b41cd",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T07:57:32.036968Z",
+ "start_time": "2023-02-12T07:57:30.861860Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\r\n",
+ "Requirement already satisfied: torchsummary in /home/chzhang/anaconda3/envs/ldm/lib/python3.8/site-packages (1.5.1)\r\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install torchsummary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "78aa128c",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T07:57:38.852830Z",
+ "start_time": "2023-02-12T07:57:38.847705Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from torchsummary import summary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "306be169",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T08:05:05.639359Z",
+ "start_time": "2023-02-12T08:05:05.623987Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "model = CNN(input_shape=input_shape, num_classes=num_classes, in_channels=1).to(device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "a4812e64",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T07:59:54.714382Z",
+ "start_time": "2023-02-12T07:59:54.222808Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "----------------------------------------------------------------\n",
+ " Layer (type) Output Shape Param #\n",
+ "================================================================\n",
+ " Conv2d-1 [64, 16, 28, 28] 416\n",
+ " BatchNorm2d-2 [64, 16, 28, 28] 32\n",
+ " ReLU-3 [64, 16, 28, 28] 0\n",
+ " MaxPool2d-4 [64, 16, 14, 14] 0\n",
+ " Conv2d-5 [64, 32, 14, 14] 12,832\n",
+ " BatchNorm2d-6 [64, 32, 14, 14] 64\n",
+ " ReLU-7 [64, 32, 14, 14] 0\n",
+ " MaxPool2d-8 [64, 32, 7, 7] 0\n",
+ " Linear-9 [64, 10] 15,690\n",
+ "================================================================\n",
+ "Total params: 29,034\n",
+ "Trainable params: 29,034\n",
+ "Non-trainable params: 0\n",
+ "----------------------------------------------------------------\n",
+ "Input size (MB): 0.19\n",
+ "Forward/backward pass size (MB): 29.86\n",
+ "Params size (MB): 0.11\n",
+ "Estimated Total Size (MB): 30.17\n",
+ "----------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "summary(model, input_size=(1, 28, 28), batch_size=batch_size)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "872d454c",
+ "metadata": {},
+ "source": [
+ "## model train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "3e0babe0",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T08:05:09.085935Z",
+ "start_time": "2023-02-12T08:05:09.080188Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "criterion = nn.CrossEntropyLoss()\n",
+ "optimzer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "bc2ad9e6",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T08:04:31.938635Z",
+ "start_time": "2023-02-12T08:04:31.933653Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "total_batch = len(train_dataloader)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "636c01fb",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T08:05:42.541533Z",
+ "start_time": "2023-02-12T08:05:11.254793Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1/5, 100/938: 0.2435\n",
+ "1/5, 200/938: 0.0455\n",
+ "1/5, 300/938: 0.0351\n",
+ "1/5, 400/938: 0.2363\n",
+ "1/5, 500/938: 0.0385\n",
+ "1/5, 600/938: 0.0285\n",
+ "1/5, 700/938: 0.0369\n",
+ "1/5, 800/938: 0.0443\n",
+ "1/5, 900/938: 0.1301\n",
+ "2/5, 100/938: 0.0114\n",
+ "2/5, 200/938: 0.0374\n",
+ "2/5, 300/938: 0.0232\n",
+ "2/5, 400/938: 0.0625\n",
+ "2/5, 500/938: 0.0478\n",
+ "2/5, 600/938: 0.0270\n",
+ "2/5, 700/938: 0.0067\n",
+ "2/5, 800/938: 0.0709\n",
+ "2/5, 900/938: 0.0186\n",
+ "3/5, 100/938: 0.0077\n",
+ "3/5, 200/938: 0.0060\n",
+ "3/5, 300/938: 0.0492\n",
+ "3/5, 400/938: 0.0459\n",
+ "3/5, 500/938: 0.0081\n",
+ "3/5, 600/938: 0.0468\n",
+ "3/5, 700/938: 0.0169\n",
+ "3/5, 800/938: 0.0386\n",
+ "3/5, 900/938: 0.0110\n",
+ "4/5, 100/938: 0.0027\n",
+ "4/5, 200/938: 0.0240\n",
+ "4/5, 300/938: 0.0047\n",
+ "4/5, 400/938: 0.0035\n",
+ "4/5, 500/938: 0.0002\n",
+ "4/5, 600/938: 0.0321\n",
+ "4/5, 700/938: 0.0060\n",
+ "4/5, 800/938: 0.0004\n",
+ "4/5, 900/938: 0.0597\n",
+ "5/5, 100/938: 0.0049\n",
+ "5/5, 200/938: 0.0140\n",
+ "5/5, 300/938: 0.0072\n",
+ "5/5, 400/938: 0.0270\n",
+ "5/5, 500/938: 0.0177\n",
+ "5/5, 600/938: 0.0005\n",
+ "5/5, 700/938: 0.0047\n",
+ "5/5, 800/938: 0.0440\n",
+ "5/5, 900/938: 0.0094\n"
+ ]
+ }
+ ],
+ "source": [
+ "for epoch in range(num_epochs):\n",
+ " for batch_idx, (images, labels) in enumerate(train_dataloader):\n",
+ " images = images.to(device)\n",
+ " labels = labels.to(device)\n",
+ " \n",
+ " # forward\n",
+ " out = model(images)\n",
+ " loss = criterion(out, labels)\n",
+ " \n",
+ " # backward\n",
+ " optimzer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimzer.step() # 更细 模型参数\n",
+ " \n",
+ " if (batch_idx+1) % 100 == 0:\n",
+ " print(f'{epoch+1}/{num_epochs}, {batch_idx+1}/{total_batch}: {loss.item():.4f}')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5e7078dd",
+ "metadata": {},
+ "source": [
+ "## model evaluation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "9d99e59a",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T08:07:52.577181Z",
+ "start_time": "2023-02-12T08:07:52.228875Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "9906/10000=0.9906\n"
+ ]
+ }
+ ],
+ "source": [
+ "total = 0\n",
+ "correct = 0\n",
+ "for images, labels in test_dataloader:\n",
+ " images = images.to(device)\n",
+ " labels = labels.to(device)\n",
+ " out = model(images)\n",
+ " preds = torch.argmax(out, dim=1)\n",
+ " \n",
+ " total += images.size(0)\n",
+ " correct += (preds == labels).sum().item()\n",
+ "print(f'{correct}/{total}={correct/total}')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "96272f61",
+ "metadata": {},
+ "source": [
+ "## model save"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "2e68f24f",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-02-12T08:08:08.828595Z",
+ "start_time": "2023-02-12T08:08:08.817961Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "torch.save(model.state_dict(), 'cnn_mnist.ckpt')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.5"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": true
+ },
+ "varInspector": {
+ "cols": {
+ "lenName": 16,
+ "lenType": 16,
+ "lenVar": 40
+ },
+ "kernels_config": {
+ "python": {
+ "delete_cmd_postfix": "",
+ "delete_cmd_prefix": "del ",
+ "library": "var_list.py",
+ "varRefreshCmd": "print(var_dic_list())"
+ },
+ "r": {
+ "delete_cmd_postfix": ") ",
+ "delete_cmd_prefix": "rm(",
+ "library": "var_list.r",
+ "varRefreshCmd": "cat(var_dic_list()) "
+ }
+ },
+ "types_to_exclude": [
+ "module",
+ "function",
+ "builtin_function_or_method",
+ "instance",
+ "_Feature"
+ ],
+ "window_display": false
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}