summaryrefslogtreecommitdiff
path: root/cv/pretrained/features.py
blob: f47912634cf0eb96bca0eecff37b4ef707ba9e4c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

import timm
import torch
from torch import nn


model_name = 'xception41'
# model_name = 'resnet18'
model = timm.create_model(model_name, pretrained=True)

input = torch.randn(2, 3, 299, 299)

o1 = model(input)
print(o1.shape)

o2 = model.forward_features(input)
print(o2.shape)