diff options
| -rw-r--r-- | basics/python_basics.py | 10 | ||||
| -rw-r--r-- | fun_tools/bar_animation.py | 35 | ||||
| -rw-r--r-- | learn_torch/basics/crossentropy_loss.py | 10 | ||||
| -rw-r--r-- | learn_torch/utils/vgg_hook.py | 44 | ||||
| -rw-r--r-- | pretrained/cnn/timm_vgg.py | 3 |
5 files changed, 102 insertions, 0 deletions
diff --git a/basics/python_basics.py b/basics/python_basics.py new file mode 100644 index 0000000..bf780cc --- /dev/null +++ b/basics/python_basics.py @@ -0,0 +1,10 @@ + +def check_num_id(i): + a = i + b = i + print(i, id(a), id(b)) + + +if __name__ == '__main__': + for i in range(200, 260): + check_num_id(i) diff --git a/fun_tools/bar_animation.py b/fun_tools/bar_animation.py new file mode 100644 index 0000000..bff49d7 --- /dev/null +++ b/fun_tools/bar_animation.py @@ -0,0 +1,35 @@ +import bar_chart_race as bcr +df = bcr.load_dataset('covid19_tutorial') +bcr.bar_chart_race( + df=df, + filename='./images/covid19_horiz.mp4', + orientation='v', + sort='desc', + n_bars=8, + fixed_order=False, + fixed_max=True, + steps_per_period=20, + period_length=500, + # end_period_pause=0, + interpolate_period=False, + period_label={'x': .98, 'y': .3, 'ha': 'right', 'va': 'center'}, + # period_template='%B %d, %Y', + period_summary_func=lambda v, r: {'x': .98, 'y': .2, + 's': f'Total deaths: {v.sum():,.0f}', + 'ha': 'right', 'size': 11}, + perpendicular_bar_func='median', + # colors='dark12', + title='COVID-19 Deaths by Country', + bar_size=.95, + # bar_textposition='inside', + # bar_texttemplate='{x:,.0f}', + # bar_label_font=7, + # tick_label_font=7, + # tick_template='{x:,.0f}', + shared_fontdict=None, + scale='linear', + fig=None, + writer=None, + bar_kwargs={'alpha': .7}, + # fig_kwargs={'figsize': (6, 3.5), 'dpi': 144}, + filter_column_colors=False)
\ No newline at end of file diff --git a/learn_torch/basics/crossentropy_loss.py b/learn_torch/basics/crossentropy_loss.py new file mode 100644 index 0000000..67ccfa5 --- /dev/null +++ b/learn_torch/basics/crossentropy_loss.py @@ -0,0 +1,10 @@ + +from torch import nn +import torch + +if __name__ == '__main__': + input = torch.randn(3, 5) + target = torch.empty(3, dtype=torch.long).random_(5) + loss = nn.CrossEntropyLoss() + output = loss(input, target) + print(output)
\ No newline at end of file diff --git a/learn_torch/utils/vgg_hook.py b/learn_torch/utils/vgg_hook.py new file mode 100644 index 0000000..9311f1e --- /dev/null +++ b/learn_torch/utils/vgg_hook.py @@ -0,0 +1,44 @@ + +import timm +import torch +from torch import nn + + +def print_shape(m, i, o): + #m: module, i: input, o: output + # print(m, i[0].shape, o.shape) + print(i[0].shape, m, o.shape) + + +def get_children(model: nn.Module): + # get children form model! + children = list(model.children()) + flatt_children = [] + if children == []: + # if model has no children; model is last child! :O + return model + else: + # look for children from children... to the last child! + for child in children: + try: + flatt_children.extend(get_children(child)) + except TypeError: + flatt_children.append(get_children(child)) + return flatt_children + + +# model_name = 'vgg11' +model_name = 'resnet34' +model = timm.create_model(model_name, pretrained=True) + +flatt_children = get_children(model) +for layer in flatt_children: + layer.register_forward_hook(print_shape) + +# for layer in model.children(): +# layer.register_forward_hook(print_shape) + +# 4d: batch*channel*width*height +batch_input = torch.randn(4, 3, 300, 300) + +model(batch_input) diff --git a/pretrained/cnn/timm_vgg.py b/pretrained/cnn/timm_vgg.py new file mode 100644 index 0000000..4b0c791 --- /dev/null +++ b/pretrained/cnn/timm_vgg.py @@ -0,0 +1,3 @@ + +import timm + |
