summaryrefslogtreecommitdiff
path: root/learn_torch/learn_nn/custom_module.py
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2022-05-21 14:23:49 +0800
committerzhang <zch921005@126.com>2022-05-21 14:23:49 +0800
commit678fab50280b647d95213a9695d07c49542696f2 (patch)
tree74ca60de14311a8a2ff58dbf82d9b7574c9cd3ef /learn_torch/learn_nn/custom_module.py
parent2180c68999eb8dc0c7bcec015b2703f5b8b20223 (diff)
0521
Diffstat (limited to 'learn_torch/learn_nn/custom_module.py')
-rw-r--r--learn_torch/learn_nn/custom_module.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/learn_torch/learn_nn/custom_module.py b/learn_torch/learn_nn/custom_module.py
new file mode 100644
index 0000000..7052b14
--- /dev/null
+++ b/learn_torch/learn_nn/custom_module.py
@@ -0,0 +1,21 @@
+import torch
+from torch import nn
+
+
+class MySeq(torch.nn.Module):
+ def __init__(self, *args):
+ super().__init__()
+ for block in args:
+ self._modules[block] = block
+
+ def forward(self, X):
+ for block in self._modules.values():
+ X = block(X)
+ return X
+
+
+if __name__ == '__main__':
+ X = torch.rand(2, 20)
+ net = MySeq(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))
+ net(X)
+