In [27]:
import torch
from torch import nn
from torchvision import datasets, transforms
from datetime import datetime

## parameters

In [8]:
# dataset
input_shape = 32
num_classes = 10

# hyper 
batch_size = 64
num_epochs = 5
learning_rate = 1e-3

# gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
device

device(type='cuda')

## dataset 与 dataloader

In [4]:
train_dataset = datasets.CIFAR10(root='../data/', 
                               download=True, 
                               train=True, 
                               transform=transforms.ToTensor())
test_dataset = datasets.CIFAR10(root='../data/', 
                               download=True, 
                               train=False, 
                               transform=transforms.ToTensor())

Files already downloaded and verified
Files already downloaded and verified


In [5]:
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                               shuffle=True, 
                                               batch_size=batch_size)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                               shuffle=False, 
                                               batch_size=batch_size)

In [6]:
images, labels = next(iter(train_dataloader))

In [7]:
# batch_size, channels, h, w
images.shape

torch.Size([64, 3, 32, 32])

## model arch

- cnn: channel 不断增加，shape 不断减少的过程
    - 最好是 *2

In [30]:
class CNN(nn.Module):
    def __init__(self, input_shape, in_channels, num_classes):
        super(CNN, self).__init__()
        # conv2d: (b, 1, 28, 28) => (b, 16, 28, 28)
        # maxpool2d: (b, 16, 28, 28) => (b, 16, 14, 14)
        self.cnn1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=16, 
                                            kernel_size=5, padding=2, stride=1), 
                                  nn.BatchNorm2d(16), 
                                  nn.ReLU(), 
                                  nn.MaxPool2d(kernel_size=2, stride=2))
        
        # conv2d: (b, 16, 14, 14) => (b, 32, 14, 14)
        # maxpool2d: (b, 32, 14, 14) => (b, 32, 7, 7)
        self.cnn2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, 
                                            kernel_size=5, padding=2, stride=1), 
                                  nn.BatchNorm2d(32), 
                                  nn.ReLU(), 
                                  nn.MaxPool2d(kernel_size=2, stride=2))
        
        self.cnn3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, 
                                    kernel_size=5, padding=2, stride=1), 
                          nn.BatchNorm2d(64), 
                          nn.ReLU(), 
                          nn.MaxPool2d(kernel_size=2, stride=2))
        self.cnn4 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, 
                            kernel_size=5, padding=2, stride=1), 
                  nn.BatchNorm2d(128), 
                  nn.ReLU(), 
                  nn.MaxPool2d(kernel_size=2, stride=2))
        self.cnn5 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, 
                            kernel_size=5, padding=2, stride=1), 
                  nn.BatchNorm2d(256), 
                  nn.ReLU(), 
                  nn.MaxPool2d(kernel_size=2, stride=2))
        # (b, 32, 7, 7) => (b, 32*7*7)
        # (b, 32*7*7) => (b, 10)
        self.fc = nn.Linear(256*(input_shape//32)*(input_shape//32), num_classes)

    
    def forward(self, x):
        # (b, 1, 28, 28) => (b, 16, 14, 14)
        out = self.cnn1(x)
        # (b, 16, 14, 14) => (b, 32, 7, 7)
        out = self.cnn2(out)
        # (b, 32, 7, 7) => (b, 32*7*7)
        out = self.cnn3(out)
        out = self.cnn4(out)
        out = self.cnn5(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out
    

### torchsummary

In [9]:
!pip install torchsummary

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple


In [10]:
from torchsummary import summary

In [31]:
model = CNN(input_shape=input_shape, num_classes=num_classes, in_channels=3).to(device)

In [32]:
summary(model, input_size=(3, 32, 32), batch_size=batch_size)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [64, 16, 32, 32]           1,216
       BatchNorm2d-2           [64, 16, 32, 32]              32
              ReLU-3           [64, 16, 32, 32]               0
         MaxPool2d-4           [64, 16, 16, 16]               0
            Conv2d-5           [64, 32, 16, 16]          12,832
       BatchNorm2d-6           [64, 32, 16, 16]              64
              ReLU-7           [64, 32, 16, 16]               0
         MaxPool2d-8             [64, 32, 8, 8]               0
            Conv2d-9             [64, 64, 8, 8]          51,264
      BatchNorm2d-10             [64, 64, 8, 8]             128
             ReLU-11             [64, 64, 8, 8]               0
        MaxPool2d-12             [64, 64, 4, 4]               0
           Conv2d-13            [64, 128, 4, 4]         204,928
      BatchNorm2d-14            [64, 12

## model train

In [33]:
criterion = nn.CrossEntropyLoss()
optimzer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [14]:
total_batch = len(train_dataloader)

In [15]:
total_batch

782

In [16]:
len(train_dataset)//batch_size

781

In [34]:
for epoch in range(num_epochs):
    for batch_idx, (images, labels) in enumerate(train_dataloader):
        images = images.to(device)
        labels = labels.to(device)
        
        # forward
        out = model(images)
        loss = criterion(out, labels)
        
        # backward
        optimzer.zero_grad()
        loss.backward()
        optimzer.step()   # 更细 模型参数
        
        if (batch_idx+1) % 100 == 0:
            print(f'{datetime.now()}, {epoch+1}/{num_epochs}, {batch_idx+1}/{total_batch}: {loss.item():.4f}')

2023-02-13 22:25:13.614828, 1/5, 100/782: 1.6673
2023-02-13 22:25:15.118278, 1/5, 200/782: 1.3886
2023-02-13 22:25:16.620148, 1/5, 300/782: 1.4714
2023-02-13 22:25:18.121508, 1/5, 400/782: 0.9283
2023-02-13 22:25:19.624122, 1/5, 500/782: 1.1564
2023-02-13 22:25:21.128540, 1/5, 600/782: 0.9533
2023-02-13 22:25:22.634274, 1/5, 700/782: 1.3402
2023-02-13 22:25:25.374337, 2/5, 100/782: 1.0992
2023-02-13 22:25:26.100186, 2/5, 200/782: 0.8210
2023-02-13 22:25:26.750454, 2/5, 300/782: 0.8591
2023-02-13 22:25:27.398051, 2/5, 400/782: 0.7915
2023-02-13 22:25:28.047762, 2/5, 500/782: 0.9805
2023-02-13 22:25:28.697176, 2/5, 600/782: 0.7892
2023-02-13 22:25:29.345190, 2/5, 700/782: 0.7435
2023-02-13 22:25:30.524457, 3/5, 100/782: 0.8694
2023-02-13 22:25:31.172259, 3/5, 200/782: 0.6262
2023-02-13 22:25:31.820744, 3/5, 300/782: 0.7870
2023-02-13 22:25:32.471617, 3/5, 400/782: 0.8567
2023-02-13 22:25:33.119088, 3/5, 500/782: 0.7772
2023-02-13 22:25:33.768611, 3/5, 600/782: 0.7215
2023-02-13 22:25:34.

## model evaluation

In [35]:
total = 0
correct = 0
for images, labels in test_dataloader:
    images = images.to(device)
    labels = labels.to(device)
    out = model(images)
    preds = torch.argmax(out, dim=1)
    
    total += images.size(0)
    correct += (preds == labels).sum().item()
print(f'{correct}/{total}={correct/total}')

7399/10000=0.7399


## model save

In [29]:
torch.save(model.state_dict(), 'cnn_mnist.ckpt')