From 441c21f3ca4b949395e3460032ae5c6cc0f06342 Mon Sep 17 00:00:00 2001 From: chzhang Date: Wed, 15 Mar 2023 23:37:06 +0800 Subject: torch.gather --- learn_torch/basics/torch.gather.ipynb | 287 ++++++++++++++++++++++++++++++++++ 1 file changed, 287 insertions(+) create mode 100644 learn_torch/basics/torch.gather.ipynb (limited to 'learn_torch/basics') diff --git a/learn_torch/basics/torch.gather.ipynb b/learn_torch/basics/torch.gather.ipynb new file mode 100644 index 0000000..f1d1c57 --- /dev/null +++ b/learn_torch/basics/torch.gather.ipynb @@ -0,0 +1,287 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "3e1b8170", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-15T14:53:42.582632Z", + "start_time": "2023-03-15T14:53:40.053444Z" + } + }, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "markdown", + "id": "895b4a69", + "metadata": {}, + "source": [ + "## 基本介绍" + ] + }, + { + "cell_type": "markdown", + "id": "55041716", + "metadata": {}, + "source": [ + "- 复杂的接口,一些抽象的接口的深入理解\n", + " - 理解复杂的代码;\n", + " - 有助于构建复杂的神经网络,且清晰,不出错;\n", + " \n", + "- 基于 index 在 原始 tensor 上选择(select)\n", + " - 行选多个(行选1个是一种特例),提供的是列值(`dim=1`);\n", + " - 多个时,是几个必须一致;\n", + " - 列选多个(列选1个是一种特例),提供的是行值(`dim=0`);\n", + "\n", + "- 要求,\n", + " - input tensor 与 index 的 dim 相同\n", + " - index.shape < input.shape\n", + "\n", + "- 应用场景\n", + " - batch sample 多分类问题的分类输出,获得各个true label 上的 score(logits)/probs(行选1个)\n", + " - batch sample: n*c\n", + " - index(true label): n*1\n", + " - dqn:batch state 的 batch action 的输出,回归问题;\n", + " ```\n", + " state_action_values = policy_net(state_batch).gather(1, action_batch)\n", + " ```\n", + " - state_batch: n*d\n", + " - action_batch: n*1\n", + " - state_action_values: n*1" + ] + }, + { + "cell_type": "markdown", + "id": "78106c0f", + "metadata": {}, + "source": [ + "## 官网示例" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "ce2df535", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-15T15:28:55.429980Z", + "start_time": "2023-03-15T15:28:55.425003Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1, 2],\n", + " [3, 4]])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t = torch.tensor([[1, 2], \n", + " [3, 4]])\n", + "t" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "1d63e435", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-15T15:29:32.016891Z", + "start_time": "2023-03-15T15:29:32.011386Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1, 1],\n", + " [4, 3]])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.gather(t, 1, torch.tensor([[0, 0], \n", + " [1, 0]]))" + ] + }, + { + "cell_type": "markdown", + "id": "89f41a09", + "metadata": {}, + "source": [ + "[0, 0]\\\n", + "[1, 0] => \n", + "> [(0, 0), (0, 0)]\\\n", + "> [(1, 1), (1, 0)]" + ] + }, + { + "cell_type": "markdown", + "id": "913c87ac", + "metadata": {}, + "source": [ + "## 复杂例子" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "0fac917e", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-15T15:30:24.244197Z", + "start_time": "2023-03-15T15:30:24.238490Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1, 2, 3],\n", + " [4, 5, 6],\n", + " [7, 8, 9]])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t = torch.arange(1, 10).view(3, 3)\n", + "t" + ] + }, + { + "cell_type": "markdown", + "id": "5d90c9d6", + "metadata": {}, + "source": [ + "- 如果想选择 (1,3), (5, 6)\n", + " - 0, 2\n", + " - 1, 2" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "c422f32e", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-15T15:31:16.503161Z", + "start_time": "2023-03-15T15:31:16.497969Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1, 3],\n", + " [5, 6]])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t.gather(1, torch.tensor(((0, 2), (1, 2))))" + ] + }, + { + "cell_type": "markdown", + "id": "fdfd1c88", + "metadata": {}, + "source": [ + "- 如果想选择 (1, 4), (2, 8),提供的是行值\n", + " - (0, 1), \n", + " - (0, 2)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "b4199432", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-15T15:32:01.607615Z", + "start_time": "2023-03-15T15:32:01.602219Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1, 2],\n", + " [4, 8]])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t.gather(0, torch.tensor(((0, 0), (1, 2))))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a02c334", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} -- cgit v1.2.3