diff options
| author | zhang <zch921005@126.com> | 2022-05-21 14:23:49 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2022-05-21 14:23:49 +0800 |
| commit | 678fab50280b647d95213a9695d07c49542696f2 (patch) | |
| tree | 74ca60de14311a8a2ff58dbf82d9b7574c9cd3ef /learn_torch/learn_nn/custom_module.py | |
| parent | 2180c68999eb8dc0c7bcec015b2703f5b8b20223 (diff) | |
0521
Diffstat (limited to 'learn_torch/learn_nn/custom_module.py')
| -rw-r--r-- | learn_torch/learn_nn/custom_module.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/learn_torch/learn_nn/custom_module.py b/learn_torch/learn_nn/custom_module.py new file mode 100644 index 0000000..7052b14 --- /dev/null +++ b/learn_torch/learn_nn/custom_module.py @@ -0,0 +1,21 @@ +import torch +from torch import nn + + +class MySeq(torch.nn.Module): + def __init__(self, *args): + super().__init__() + for block in args: + self._modules[block] = block + + def forward(self, X): + for block in self._modules.values(): + X = block(X) + return X + + +if __name__ == '__main__': + X = torch.rand(2, 20) + net = MySeq(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10)) + net(X) + |
