summaryrefslogtreecommitdiff
path: root/learn_torch/utils/vgg_hook.py
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2022-07-24 20:25:48 +0800
committerzhang <zch921005@126.com>2022-07-24 20:25:48 +0800
commit92d3bc06bad13095df6515111bba45e73f701018 (patch)
tree5730478fe92b39f3b909843546291d0eced774d0 /learn_torch/utils/vgg_hook.py
parente9945ee44d8c46d93d50f023f49e79f3ba532583 (diff)
wordpiece
Diffstat (limited to 'learn_torch/utils/vgg_hook.py')
-rw-r--r--learn_torch/utils/vgg_hook.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/learn_torch/utils/vgg_hook.py b/learn_torch/utils/vgg_hook.py
new file mode 100644
index 0000000..9311f1e
--- /dev/null
+++ b/learn_torch/utils/vgg_hook.py
@@ -0,0 +1,44 @@
+
+import timm
+import torch
+from torch import nn
+
+
+def print_shape(m, i, o):
+ #m: module, i: input, o: output
+ # print(m, i[0].shape, o.shape)
+ print(i[0].shape, m, o.shape)
+
+
+def get_children(model: nn.Module):
+ # get children form model!
+ children = list(model.children())
+ flatt_children = []
+ if children == []:
+ # if model has no children; model is last child! :O
+ return model
+ else:
+ # look for children from children... to the last child!
+ for child in children:
+ try:
+ flatt_children.extend(get_children(child))
+ except TypeError:
+ flatt_children.append(get_children(child))
+ return flatt_children
+
+
+# model_name = 'vgg11'
+model_name = 'resnet34'
+model = timm.create_model(model_name, pretrained=True)
+
+flatt_children = get_children(model)
+for layer in flatt_children:
+ layer.register_forward_hook(print_shape)
+
+# for layer in model.children():
+# layer.register_forward_hook(print_shape)
+
+# 4d: batch*channel*width*height
+batch_input = torch.randn(4, 3, 300, 300)
+
+model(batch_input)