diff options
Diffstat (limited to 'learn_torch/seq/test_rnn.py')
| -rw-r--r-- | learn_torch/seq/test_rnn.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/learn_torch/seq/test_rnn.py b/learn_torch/seq/test_rnn.py new file mode 100644 index 0000000..5a7baf2 --- /dev/null +++ b/learn_torch/seq/test_rnn.py @@ -0,0 +1,8 @@ +import torch +from torch import nn + +if __name__ == '__main__': + rnn = nn.RNN(10, 20, 2) + input = torch.randn(5, 3, 10) + h0 = torch.randn(2, 3, 10) + rnn(input, h0) |
