diff options
| author | zhang <zch921005@126.com> | 2022-06-23 23:11:56 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2022-06-23 23:11:56 +0800 |
| commit | 462bd2c944beed1667df30aa7fa626b972a5dfc9 (patch) | |
| tree | d4000b7217cada045bf27853787eb278d13b29f1 /learn_torch/utils | |
| parent | b2432d57f626a37ee790a83483bcc960048b0dac (diff) | |
pytorch hook
Diffstat (limited to 'learn_torch/utils')
| -rw-r--r-- | learn_torch/utils/hook.py | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/learn_torch/utils/hook.py b/learn_torch/utils/hook.py new file mode 100644 index 0000000..a97cee4 --- /dev/null +++ b/learn_torch/utils/hook.py @@ -0,0 +1,45 @@ + +import timm +import torch +from torch import nn + + +def print_shape(m, i, o): + #m: module, i: input, o: output + # print(m, i[0].shape, o.shape) + print(i[0].shape, '=>', o.shape) + + + + +def get_children(model: nn.Module): + # get children form model! + children = list(model.children()) + flatt_children = [] + if children == []: + # if model has no children; model is last child! :O + return model + else: + # look for children from children... to the last child! + for child in children: + try: + flatt_children.extend(get_children(child)) + except TypeError: + flatt_children.append(get_children(child)) + return flatt_children + + +model_name = 'vgg11' +model = timm.create_model(model_name, pretrained=True) + +flatt_children = get_children(model) +for layer in flatt_children: + layer.register_forward_hook(print_shape) + +# for layer in model.children(): +# layer.register_forward_hook(print_shape) + +# 4d: batch*channel*width*height +batch_input = torch.randn(4, 3, 299, 299) + +model(batch_input) |
