summaryrefslogtreecommitdiff
path: root/learn_torch/seq/test_rnn.py
blob: 5a7baf249cd5d4f9439937ff005b7f7a04ef0c8d (plain)
1
2
3
4
5
6
7
8
import torch
from torch import nn

if __name__ == '__main__':
    rnn = nn.RNN(10, 20, 2)
    input = torch.randn(5, 3, 10)
    h0 = torch.randn(2, 3, 10)
    rnn(input, h0)