From 92d3bc06bad13095df6515111bba45e73f701018 Mon Sep 17 00:00:00 2001 From: zhang Date: Sun, 24 Jul 2022 20:25:48 +0800 Subject: wordpiece --- learn_torch/utils/vgg_hook.py | 44 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 learn_torch/utils/vgg_hook.py (limited to 'learn_torch/utils') diff --git a/learn_torch/utils/vgg_hook.py b/learn_torch/utils/vgg_hook.py new file mode 100644 index 0000000..9311f1e --- /dev/null +++ b/learn_torch/utils/vgg_hook.py @@ -0,0 +1,44 @@ + +import timm +import torch +from torch import nn + + +def print_shape(m, i, o): + #m: module, i: input, o: output + # print(m, i[0].shape, o.shape) + print(i[0].shape, m, o.shape) + + +def get_children(model: nn.Module): + # get children form model! + children = list(model.children()) + flatt_children = [] + if children == []: + # if model has no children; model is last child! :O + return model + else: + # look for children from children... to the last child! + for child in children: + try: + flatt_children.extend(get_children(child)) + except TypeError: + flatt_children.append(get_children(child)) + return flatt_children + + +# model_name = 'vgg11' +model_name = 'resnet34' +model = timm.create_model(model_name, pretrained=True) + +flatt_children = get_children(model) +for layer in flatt_children: + layer.register_forward_hook(print_shape) + +# for layer in model.children(): +# layer.register_forward_hook(print_shape) + +# 4d: batch*channel*width*height +batch_input = torch.randn(4, 3, 300, 300) + +model(batch_input) -- cgit v1.2.3