{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. 图像数据处理" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "ExecuteTime": { "end_time": "2022-06-22T15:20:04.751148Z", "start_time": "2022-06-22T15:20:03.148807Z" } }, "outputs": [], "source": [ "from PIL import Image\n", "from torchvision import transforms as T\n", "import timm" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2022-06-22T15:17:08.148080Z", "start_time": "2022-06-22T15:17:08.142540Z" } }, "outputs": [], "source": [ "img1 = Image.open('./imgs/cat.jpg')\n", "img2 = Image.open('./imgs/cat2.jpg')\n", "img3 = Image.open('./imgs/face.jpg')" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "ExecuteTime": { "end_time": "2022-06-22T15:20:48.808680Z", "start_time": "2022-06-22T15:20:48.804760Z" } }, "outputs": [], "source": [ "trans = T.Compose([\n", " T.Resize((256, 256)),\n", " T.ToTensor(),\n", " T.Normalize(mean=timm.data.IMAGENET_DEFAULT_MEAN, std=timm.data.IMAGENET_DEFAULT_STD)\n", "])" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "ExecuteTime": { "end_time": "2022-06-22T15:21:36.222222Z", "start_time": "2022-06-22T15:21:36.112742Z" } }, "outputs": [], "source": [ "t_img1 = trans(img1).unsqueeze(0)\n", "t_img2 = trans(img2).unsqueeze(0)\n", "t_img3 = trans(img3).unsqueeze(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. 加载模型" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "ExecuteTime": { "end_time": "2022-06-22T15:25:37.025421Z", "start_time": "2022-06-22T15:25:37.021239Z" } }, "outputs": [], "source": [ "from torchvision import models\n", "import torch\n", "from torch import nn" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2022-06-22T15:18:02.323997Z", "start_time": "2022-06-22T15:18:00.711753Z" } }, "outputs": [], "source": [ "model = models.resnet152(pretrained=True)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2022-06-22T15:19:05.325328Z", "start_time": "2022-06-22T15:19:05.321574Z" } }, "outputs": [], "source": [ "layer = model._modules['avgpool']" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "ExecuteTime": { "end_time": "2022-06-22T15:25:15.871839Z", "start_time": "2022-06-22T15:25:15.110680Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 2048, 1, 1])\n" ] } ], "source": [ "model.eval()\n", "with torch.no_grad():\n", " feature_vec3 = torch.zeros(2048)\n", " def copy(m, i, o):\n", " print(o.shape)\n", " feature_vec3.copy_(o.reshape(o.shape[1]))\n", " h = layer.register_forward_hook(copy)\n", " model(t_img3) # forward\n", " h.remove()" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "ExecuteTime": { "end_time": "2022-06-22T15:25:27.967597Z", "start_time": "2022-06-22T15:25:27.962756Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([0.1978, 0.4018, 0.0321, ..., 0.0981, 0.2331, 0.2016])" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "feature_vec3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. 输出处理" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "ExecuteTime": { "end_time": "2022-06-22T15:25:48.350136Z", "start_time": "2022-06-22T15:25:48.347187Z" } }, "outputs": [], "source": [ "cos_sim = nn.CosineSimilarity(dim=1)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "ExecuteTime": { "end_time": "2022-06-22T15:26:16.628814Z", "start_time": "2022-06-22T15:26:16.623559Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([0.6978])" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cos_sim(feature_vec.unsqueeze(0), feature_vec2.unsqueeze(0))" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "ExecuteTime": { "end_time": "2022-06-22T15:26:43.703193Z", "start_time": "2022-06-22T15:26:43.698088Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([0.5045])" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cos_sim(feature_vec.unsqueeze(0), feature_vec3.unsqueeze(0))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }