summaryrefslogtreecommitdiff
path: root/learn_torch/utils/hook.py
blob: a97cee438188dc6367532fd88b56fef1722fcbb4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

import timm
import torch
from torch import nn


def print_shape(m, i, o):
    #m: module, i: input, o: output
    # print(m, i[0].shape, o.shape)
    print(i[0].shape, '=>', o.shape)




def get_children(model: nn.Module):
    # get children form model!
    children = list(model.children())
    flatt_children = []
    if children == []:
        # if model has no children; model is last child! :O
        return model
    else:
       # look for children from children... to the last child!
       for child in children:
            try:
                flatt_children.extend(get_children(child))
            except TypeError:
                flatt_children.append(get_children(child))
    return flatt_children


model_name = 'vgg11'
model = timm.create_model(model_name, pretrained=True)

flatt_children = get_children(model)
for layer in flatt_children:
    layer.register_forward_hook(print_shape)

# for layer in model.children():
#     layer.register_forward_hook(print_shape)

# 4d: batch*channel*width*height
batch_input = torch.randn(4, 3, 299, 299)

model(batch_input)