In [1]:
import torch
from torch import nn
from torchvision import datasets, transforms

## parameters

In [19]:
# dataset
input_shape = 28
num_classes = 10

# hyper 
batch_size = 64
num_epochs = 5
learning_rate = 1e-3

# gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [15]:
device

device(type='cuda')

## dataset 与 dataloader

In [3]:
train_dataset = datasets.MNIST(root='../data/', 
                               download=True, 
                               train=True, 
                               transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='../data/', 
                               download=True, 
                               train=False, 
                               transform=transforms.ToTensor())

In [4]:
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                               shuffle=True, 
                                               batch_size=batch_size)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                               shuffle=False, 
                                               batch_size=batch_size)

In [5]:
images, labels = next(iter(train_dataloader))

In [7]:
# batch_size, channels, h, w
images.shape

torch.Size([64, 1, 28, 28])

## model arch

- cnn: channel 不断增加，shape 不断减少的过程
    - 最好是 *2

In [23]:
class CNN(nn.Module):
    def __init__(self, input_shape, in_channels, num_classes):
        super(CNN, self).__init__()
        # conv2d: (b, 1, 28, 28) => (b, 16, 28, 28)
        # maxpool2d: (b, 16, 28, 28) => (b, 16, 14, 14)
        self.cnn1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=16, 
                                            kernel_size=5, padding=2, stride=1), 
                                  nn.BatchNorm2d(16), 
                                  nn.ReLU(), 
                                  nn.MaxPool2d(kernel_size=2, stride=2))
        
        # conv2d: (b, 16, 14, 14) => (b, 32, 14, 14)
        # maxpool2d: (b, 32, 14, 14) => (b, 32, 7, 7)
        self.cnn2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, 
                                            kernel_size=5, padding=2, stride=1), 
                                  nn.BatchNorm2d(32), 
                                  nn.ReLU(), 
                                  nn.MaxPool2d(kernel_size=2, stride=2))
        # (b, 32, 7, 7) => (b, 32*7*7)
        # (b, 32*7*7) => (b, 10)
        self.fc = nn.Linear(32*(input_shape//4)*(input_shape//4), num_classes)

    
    def forward(self, x):
        # (b, 1, 28, 28) => (b, 16, 14, 14)
        out = self.cnn1(x)
        # (b, 16, 14, 14) => (b, 32, 7, 7)
        out = self.cnn2(out)
        # (b, 32, 7, 7) => (b, 32*7*7)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out
    

### torchsummary

In [9]:
!pip install torchsummary

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple


In [10]:
from torchsummary import summary

In [24]:
model = CNN(input_shape=input_shape, num_classes=num_classes, in_channels=1).to(device)

In [17]:
summary(model, input_size=(1, 28, 28), batch_size=batch_size)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [64, 16, 28, 28]             416
       BatchNorm2d-2           [64, 16, 28, 28]              32
              ReLU-3           [64, 16, 28, 28]               0
         MaxPool2d-4           [64, 16, 14, 14]               0
            Conv2d-5           [64, 32, 14, 14]          12,832
       BatchNorm2d-6           [64, 32, 14, 14]              64
              ReLU-7           [64, 32, 14, 14]               0
         MaxPool2d-8             [64, 32, 7, 7]               0
            Linear-9                   [64, 10]          15,690
Total params: 29,034
Trainable params: 29,034
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.19
Forward/backward pass size (MB): 29.86
Params size (MB): 0.11
Estimated Total Size (MB): 30.17
-------------------------------------------

## model train

In [25]:
criterion = nn.CrossEntropyLoss()
optimzer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [21]:
total_batch = len(train_dataloader)

In [26]:
for epoch in range(num_epochs):
    for batch_idx, (images, labels) in enumerate(train_dataloader):
        images = images.to(device)
        labels = labels.to(device)
        
        # forward
        out = model(images)
        loss = criterion(out, labels)
        
        # backward
        optimzer.zero_grad()
        loss.backward()
        optimzer.step()   # 更细 模型参数
        
        if (batch_idx+1) % 100 == 0:
            print(f'{epoch+1}/{num_epochs}, {batch_idx+1}/{total_batch}: {loss.item():.4f}')

1/5, 100/938: 0.2435
1/5, 200/938: 0.0455
1/5, 300/938: 0.0351
1/5, 400/938: 0.2363
1/5, 500/938: 0.0385
1/5, 600/938: 0.0285
1/5, 700/938: 0.0369
1/5, 800/938: 0.0443
1/5, 900/938: 0.1301
2/5, 100/938: 0.0114
2/5, 200/938: 0.0374
2/5, 300/938: 0.0232
2/5, 400/938: 0.0625
2/5, 500/938: 0.0478
2/5, 600/938: 0.0270
2/5, 700/938: 0.0067
2/5, 800/938: 0.0709
2/5, 900/938: 0.0186
3/5, 100/938: 0.0077
3/5, 200/938: 0.0060
3/5, 300/938: 0.0492
3/5, 400/938: 0.0459
3/5, 500/938: 0.0081
3/5, 600/938: 0.0468
3/5, 700/938: 0.0169
3/5, 800/938: 0.0386
3/5, 900/938: 0.0110
4/5, 100/938: 0.0027
4/5, 200/938: 0.0240
4/5, 300/938: 0.0047
4/5, 400/938: 0.0035
4/5, 500/938: 0.0002
4/5, 600/938: 0.0321
4/5, 700/938: 0.0060
4/5, 800/938: 0.0004
4/5, 900/938: 0.0597
5/5, 100/938: 0.0049
5/5, 200/938: 0.0140
5/5, 300/938: 0.0072
5/5, 400/938: 0.0270
5/5, 500/938: 0.0177
5/5, 600/938: 0.0005
5/5, 700/938: 0.0047
5/5, 800/938: 0.0440
5/5, 900/938: 0.0094


## model evaluation

In [27]:
total = 0
correct = 0
for images, labels in test_dataloader:
    images = images.to(device)
    labels = labels.to(device)
    out = model(images)
    preds = torch.argmax(out, dim=1)
    
    total += images.size(0)
    correct += (preds == labels).sum().item()
print(f'{correct}/{total}={correct/total}')

9906/10000=0.9906


## model save

In [29]:
torch.save(model.state_dict(), 'cnn_mnist.ckpt')