summaryrefslogtreecommitdiff
path: root/learn_torch/learn_nn/custom_module.py
blob: 7052b14b9faf052b5bd4e089f7cb43433a46456a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from torch import nn


class MySeq(torch.nn.Module):
    def __init__(self, *args):
        super().__init__()
        for block in args:
            self._modules[block] = block

    def forward(self, X):
        for block in self._modules.values():
            X = block(X)
        return X


if __name__ == '__main__':
    X = torch.rand(2, 20)
    net = MySeq(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))
    net(X)