summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cv/pretrained/features.py17
-rw-r--r--learn_torch/utils/hook.py45
2 files changed, 62 insertions, 0 deletions
diff --git a/cv/pretrained/features.py b/cv/pretrained/features.py
new file mode 100644
index 0000000..f479126
--- /dev/null
+++ b/cv/pretrained/features.py
@@ -0,0 +1,17 @@
+
+import timm
+import torch
+from torch import nn
+
+
+model_name = 'xception41'
+# model_name = 'resnet18'
+model = timm.create_model(model_name, pretrained=True)
+
+input = torch.randn(2, 3, 299, 299)
+
+o1 = model(input)
+print(o1.shape)
+
+o2 = model.forward_features(input)
+print(o2.shape) \ No newline at end of file
diff --git a/learn_torch/utils/hook.py b/learn_torch/utils/hook.py
new file mode 100644
index 0000000..a97cee4
--- /dev/null
+++ b/learn_torch/utils/hook.py
@@ -0,0 +1,45 @@
+
+import timm
+import torch
+from torch import nn
+
+
+def print_shape(m, i, o):
+ #m: module, i: input, o: output
+ # print(m, i[0].shape, o.shape)
+ print(i[0].shape, '=>', o.shape)
+
+
+
+
+def get_children(model: nn.Module):
+ # get children form model!
+ children = list(model.children())
+ flatt_children = []
+ if children == []:
+ # if model has no children; model is last child! :O
+ return model
+ else:
+ # look for children from children... to the last child!
+ for child in children:
+ try:
+ flatt_children.extend(get_children(child))
+ except TypeError:
+ flatt_children.append(get_children(child))
+ return flatt_children
+
+
+model_name = 'vgg11'
+model = timm.create_model(model_name, pretrained=True)
+
+flatt_children = get_children(model)
+for layer in flatt_children:
+ layer.register_forward_hook(print_shape)
+
+# for layer in model.children():
+# layer.register_forward_hook(print_shape)
+
+# 4d: batch*channel*width*height
+batch_input = torch.randn(4, 3, 299, 299)
+
+model(batch_input)