summaryrefslogtreecommitdiff
path: root/cv/img2vec/image_feature_extractor_similarity.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'cv/img2vec/image_feature_extractor_similarity.ipynb')
-rw-r--r--cv/img2vec/image_feature_extractor_similarity.ipynb275
1 files changed, 275 insertions, 0 deletions
diff --git a/cv/img2vec/image_feature_extractor_similarity.ipynb b/cv/img2vec/image_feature_extractor_similarity.ipynb
new file mode 100644
index 0000000..a9fc867
--- /dev/null
+++ b/cv/img2vec/image_feature_extractor_similarity.ipynb
@@ -0,0 +1,275 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1. 图像数据处理"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-22T15:20:04.751148Z",
+ "start_time": "2022-06-22T15:20:03.148807Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from PIL import Image\n",
+ "from torchvision import transforms as T\n",
+ "import timm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-22T15:17:08.148080Z",
+ "start_time": "2022-06-22T15:17:08.142540Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "img1 = Image.open('./imgs/cat.jpg')\n",
+ "img2 = Image.open('./imgs/cat2.jpg')\n",
+ "img3 = Image.open('./imgs/face.jpg')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-22T15:20:48.808680Z",
+ "start_time": "2022-06-22T15:20:48.804760Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "trans = T.Compose([\n",
+ " T.Resize((256, 256)),\n",
+ " T.ToTensor(),\n",
+ " T.Normalize(mean=timm.data.IMAGENET_DEFAULT_MEAN, std=timm.data.IMAGENET_DEFAULT_STD)\n",
+ "])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-22T15:21:36.222222Z",
+ "start_time": "2022-06-22T15:21:36.112742Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "t_img1 = trans(img1).unsqueeze(0)\n",
+ "t_img2 = trans(img2).unsqueeze(0)\n",
+ "t_img3 = trans(img3).unsqueeze(0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2. 加载模型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-22T15:25:37.025421Z",
+ "start_time": "2022-06-22T15:25:37.021239Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from torchvision import models\n",
+ "import torch\n",
+ "from torch import nn"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-22T15:18:02.323997Z",
+ "start_time": "2022-06-22T15:18:00.711753Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "model = models.resnet152(pretrained=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-22T15:19:05.325328Z",
+ "start_time": "2022-06-22T15:19:05.321574Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "layer = model._modules['avgpool']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-22T15:25:15.871839Z",
+ "start_time": "2022-06-22T15:25:15.110680Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([1, 2048, 1, 1])\n"
+ ]
+ }
+ ],
+ "source": [
+ "model.eval()\n",
+ "with torch.no_grad():\n",
+ " feature_vec3 = torch.zeros(2048)\n",
+ " def copy(m, i, o):\n",
+ " print(o.shape)\n",
+ " feature_vec3.copy_(o.reshape(o.shape[1]))\n",
+ " h = layer.register_forward_hook(copy)\n",
+ " model(t_img3) # forward\n",
+ " h.remove()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-22T15:25:27.967597Z",
+ "start_time": "2022-06-22T15:25:27.962756Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([0.1978, 0.4018, 0.0321, ..., 0.0981, 0.2331, 0.2016])"
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "feature_vec3"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3. 输出处理"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-22T15:25:48.350136Z",
+ "start_time": "2022-06-22T15:25:48.347187Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "cos_sim = nn.CosineSimilarity(dim=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-22T15:26:16.628814Z",
+ "start_time": "2022-06-22T15:26:16.623559Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([0.6978])"
+ ]
+ },
+ "execution_count": 37,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "cos_sim(feature_vec.unsqueeze(0), feature_vec2.unsqueeze(0))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-22T15:26:43.703193Z",
+ "start_time": "2022-06-22T15:26:43.698088Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([0.5045])"
+ ]
+ },
+ "execution_count": 38,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "cos_sim(feature_vec.unsqueeze(0), feature_vec3.unsqueeze(0))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}