summaryrefslogtreecommitdiff
path: root/pretrained/swin/demo.py
diff options
context:
space:
mode:
Diffstat (limited to 'pretrained/swin/demo.py')
-rw-r--r--pretrained/swin/demo.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/pretrained/swin/demo.py b/pretrained/swin/demo.py
new file mode 100644
index 0000000..6f8d87b
--- /dev/null
+++ b/pretrained/swin/demo.py
@@ -0,0 +1,36 @@
+import timm
+import torch
+import torchvision
+from PIL import Image
+from torchvision import transforms as T
+import json
+
+if __name__ == '__main__':
+ trans_ = T.Compose(
+ [
+ T.Resize(256),
+ T.CenterCrop(224),
+ T.ToTensor(),
+ T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD)]
+ )
+
+ image = Image.open('./house_finch.jpg')
+ transformed = trans_(image)
+ print(transformed.shape)
+ batch_input = transformed.unsqueeze(0)
+ print(batch_input.shape)
+
+ model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
+
+ with torch.no_grad():
+ output = model(batch_input)
+
+ print(output)
+ print(output.shape)
+ labels = json.load(open('./imagenet_labels.json'))
+ class_ = output.argmax(dim=1)
+ print(class_)
+ print(class_.data)
+ print(output.sum(dim=1))
+ print(labels[class_])
+