summaryrefslogtreecommitdiff
path: root/dl/tutorials/02_rnn_on_images.py
diff options
context:
space:
mode:
authorchzhang <zch921005@126.com>2023-02-07 22:59:47 +0800
committerchzhang <zch921005@126.com>2023-02-07 22:59:47 +0800
commit6c759a4574475f3ba6fc31ce67fc2d32ba0343bd (patch)
treea5e10bc3b6b7fc4c0233a9089f273d241996e7cc /dl/tutorials/02_rnn_on_images.py
parent3d9e845cee12d5f4cf45292f8cd5c9f0134ca70e (diff)
rnn on images
Diffstat (limited to 'dl/tutorials/02_rnn_on_images.py')
-rw-r--r--dl/tutorials/02_rnn_on_images.py102
1 files changed, 102 insertions, 0 deletions
diff --git a/dl/tutorials/02_rnn_on_images.py b/dl/tutorials/02_rnn_on_images.py
new file mode 100644
index 0000000..93d2651
--- /dev/null
+++ b/dl/tutorials/02_rnn_on_images.py
@@ -0,0 +1,102 @@
+import torch
+import torch.nn as nn
+import torchvision
+import torchvision.transforms as transforms
+
+# Device configuration
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+# Hyper-parameters
+sequence_length = 28
+input_size = 28
+hidden_size = 128
+num_layers = 2
+num_classes = 10
+batch_size = 100
+num_epochs = 2
+learning_rate = 0.003
+
+# MNIST dataset
+train_dataset = torchvision.datasets.MNIST(root='../data/',
+ train=True,
+ transform=transforms.ToTensor(),
+ download=True)
+
+test_dataset = torchvision.datasets.MNIST(root='../data/',
+ train=False,
+ transform=transforms.ToTensor())
+
+# Data loader
+train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
+ batch_size=batch_size,
+ shuffle=True)
+
+test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
+ batch_size=batch_size,
+ shuffle=False)
+
+
+# Bidirectional recurrent neural network (many-to-one)
+class BiRNN(nn.Module):
+ def __init__(self, input_size, hidden_size, num_layers, num_classes):
+ super(BiRNN, self).__init__()
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
+ self.fc = nn.Linear(hidden_size * 2, num_classes) # 2 for bidirection
+
+ def forward(self, x):
+ # Set initial states
+ h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device) # 2 for bidirection
+ c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)
+
+ # Forward propagate LSTM
+ out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2)
+
+ # Decode the hidden state of the last time step
+ out = self.fc(out[:, -1, :])
+ return out
+
+
+model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)
+
+# Loss and optimizer
+criterion = nn.CrossEntropyLoss()
+optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
+
+# Train the model
+total_step = len(train_loader)
+for epoch in range(num_epochs):
+ for i, (images, labels) in enumerate(train_loader):
+ images = images.reshape(-1, sequence_length, input_size).to(device)
+ labels = labels.to(device)
+
+ # Forward pass
+ outputs = model(images)
+ loss = criterion(outputs, labels)
+
+ # Backward and optimize
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ if (i + 1) % 100 == 0:
+ print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
+ .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
+
+# Test the model
+with torch.no_grad():
+ correct = 0
+ total = 0
+ for images, labels in test_loader:
+ images = images.reshape(-1, sequence_length, input_size).to(device)
+ labels = labels.to(device)
+ outputs = model(images)
+ _, predicted = torch.max(outputs.data, 1)
+ total += labels.size(0)
+ correct += (predicted == labels).sum().item()
+
+ print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
+
+# Save the model checkpoint
+torch.save(model.state_dict(), 'model.ckpt') \ No newline at end of file