summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--basics/python_basics.py10
-rw-r--r--fun_tools/bar_animation.py35
-rw-r--r--learn_torch/basics/crossentropy_loss.py10
-rw-r--r--learn_torch/utils/vgg_hook.py44
-rw-r--r--pretrained/cnn/timm_vgg.py3
5 files changed, 102 insertions, 0 deletions
diff --git a/basics/python_basics.py b/basics/python_basics.py
new file mode 100644
index 0000000..bf780cc
--- /dev/null
+++ b/basics/python_basics.py
@@ -0,0 +1,10 @@
+
+def check_num_id(i):
+ a = i
+ b = i
+ print(i, id(a), id(b))
+
+
+if __name__ == '__main__':
+ for i in range(200, 260):
+ check_num_id(i)
diff --git a/fun_tools/bar_animation.py b/fun_tools/bar_animation.py
new file mode 100644
index 0000000..bff49d7
--- /dev/null
+++ b/fun_tools/bar_animation.py
@@ -0,0 +1,35 @@
+import bar_chart_race as bcr
+df = bcr.load_dataset('covid19_tutorial')
+bcr.bar_chart_race(
+ df=df,
+ filename='./images/covid19_horiz.mp4',
+ orientation='v',
+ sort='desc',
+ n_bars=8,
+ fixed_order=False,
+ fixed_max=True,
+ steps_per_period=20,
+ period_length=500,
+ # end_period_pause=0,
+ interpolate_period=False,
+ period_label={'x': .98, 'y': .3, 'ha': 'right', 'va': 'center'},
+ # period_template='%B %d, %Y',
+ period_summary_func=lambda v, r: {'x': .98, 'y': .2,
+ 's': f'Total deaths: {v.sum():,.0f}',
+ 'ha': 'right', 'size': 11},
+ perpendicular_bar_func='median',
+ # colors='dark12',
+ title='COVID-19 Deaths by Country',
+ bar_size=.95,
+ # bar_textposition='inside',
+ # bar_texttemplate='{x:,.0f}',
+ # bar_label_font=7,
+ # tick_label_font=7,
+ # tick_template='{x:,.0f}',
+ shared_fontdict=None,
+ scale='linear',
+ fig=None,
+ writer=None,
+ bar_kwargs={'alpha': .7},
+ # fig_kwargs={'figsize': (6, 3.5), 'dpi': 144},
+ filter_column_colors=False) \ No newline at end of file
diff --git a/learn_torch/basics/crossentropy_loss.py b/learn_torch/basics/crossentropy_loss.py
new file mode 100644
index 0000000..67ccfa5
--- /dev/null
+++ b/learn_torch/basics/crossentropy_loss.py
@@ -0,0 +1,10 @@
+
+from torch import nn
+import torch
+
+if __name__ == '__main__':
+ input = torch.randn(3, 5)
+ target = torch.empty(3, dtype=torch.long).random_(5)
+ loss = nn.CrossEntropyLoss()
+ output = loss(input, target)
+ print(output) \ No newline at end of file
diff --git a/learn_torch/utils/vgg_hook.py b/learn_torch/utils/vgg_hook.py
new file mode 100644
index 0000000..9311f1e
--- /dev/null
+++ b/learn_torch/utils/vgg_hook.py
@@ -0,0 +1,44 @@
+
+import timm
+import torch
+from torch import nn
+
+
+def print_shape(m, i, o):
+ #m: module, i: input, o: output
+ # print(m, i[0].shape, o.shape)
+ print(i[0].shape, m, o.shape)
+
+
+def get_children(model: nn.Module):
+ # get children form model!
+ children = list(model.children())
+ flatt_children = []
+ if children == []:
+ # if model has no children; model is last child! :O
+ return model
+ else:
+ # look for children from children... to the last child!
+ for child in children:
+ try:
+ flatt_children.extend(get_children(child))
+ except TypeError:
+ flatt_children.append(get_children(child))
+ return flatt_children
+
+
+# model_name = 'vgg11'
+model_name = 'resnet34'
+model = timm.create_model(model_name, pretrained=True)
+
+flatt_children = get_children(model)
+for layer in flatt_children:
+ layer.register_forward_hook(print_shape)
+
+# for layer in model.children():
+# layer.register_forward_hook(print_shape)
+
+# 4d: batch*channel*width*height
+batch_input = torch.randn(4, 3, 300, 300)
+
+model(batch_input)
diff --git a/pretrained/cnn/timm_vgg.py b/pretrained/cnn/timm_vgg.py
new file mode 100644
index 0000000..4b0c791
--- /dev/null
+++ b/pretrained/cnn/timm_vgg.py
@@ -0,0 +1,3 @@
+
+import timm
+