diff options
Diffstat (limited to 'learn_torch/learn_nn/custom_module.py')
| -rw-r--r-- | learn_torch/learn_nn/custom_module.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/learn_torch/learn_nn/custom_module.py b/learn_torch/learn_nn/custom_module.py new file mode 100644 index 0000000..7052b14 --- /dev/null +++ b/learn_torch/learn_nn/custom_module.py @@ -0,0 +1,21 @@ +import torch +from torch import nn + + +class MySeq(torch.nn.Module): + def __init__(self, *args): + super().__init__() + for block in args: + self._modules[block] = block + + def forward(self, X): + for block in self._modules.values(): + X = block(X) + return X + + +if __name__ == '__main__': + X = torch.rand(2, 20) + net = MySeq(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10)) + net(X) + |
