summaryrefslogtreecommitdiff
path: root/learn_torch/seq/test_rnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'learn_torch/seq/test_rnn.py')
-rw-r--r--learn_torch/seq/test_rnn.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/learn_torch/seq/test_rnn.py b/learn_torch/seq/test_rnn.py
new file mode 100644
index 0000000..5a7baf2
--- /dev/null
+++ b/learn_torch/seq/test_rnn.py
@@ -0,0 +1,8 @@
+import torch
+from torch import nn
+
+if __name__ == '__main__':
+ rnn = nn.RNN(10, 20, 2)
+ input = torch.randn(5, 3, 10)
+ h0 = torch.randn(2, 3, 10)
+ rnn(input, h0)