summaryrefslogtreecommitdiff
path: root/learn_torch
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2022-06-23 23:11:56 +0800
committerzhang <zch921005@126.com>2022-06-23 23:11:56 +0800
commit462bd2c944beed1667df30aa7fa626b972a5dfc9 (patch)
treed4000b7217cada045bf27853787eb278d13b29f1 /learn_torch
parentb2432d57f626a37ee790a83483bcc960048b0dac (diff)
pytorch hook
Diffstat (limited to 'learn_torch')
-rw-r--r--learn_torch/utils/hook.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/learn_torch/utils/hook.py b/learn_torch/utils/hook.py
new file mode 100644
index 0000000..a97cee4
--- /dev/null
+++ b/learn_torch/utils/hook.py
@@ -0,0 +1,45 @@
+
+import timm
+import torch
+from torch import nn
+
+
+def print_shape(m, i, o):
+ #m: module, i: input, o: output
+ # print(m, i[0].shape, o.shape)
+ print(i[0].shape, '=>', o.shape)
+
+
+
+
+def get_children(model: nn.Module):
+ # get children form model!
+ children = list(model.children())
+ flatt_children = []
+ if children == []:
+ # if model has no children; model is last child! :O
+ return model
+ else:
+ # look for children from children... to the last child!
+ for child in children:
+ try:
+ flatt_children.extend(get_children(child))
+ except TypeError:
+ flatt_children.append(get_children(child))
+ return flatt_children
+
+
+model_name = 'vgg11'
+model = timm.create_model(model_name, pretrained=True)
+
+flatt_children = get_children(model)
+for layer in flatt_children:
+ layer.register_forward_hook(print_shape)
+
+# for layer in model.children():
+# layer.register_forward_hook(print_shape)
+
+# 4d: batch*channel*width*height
+batch_input = torch.randn(4, 3, 299, 299)
+
+model(batch_input)