{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "3e1b8170", "metadata": { "ExecuteTime": { "end_time": "2023-03-15T14:53:42.582632Z", "start_time": "2023-03-15T14:53:40.053444Z" } }, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "markdown", "id": "895b4a69", "metadata": {}, "source": [ "## 基本介绍" ] }, { "cell_type": "markdown", "id": "55041716", "metadata": {}, "source": [ "- 复杂的接口,一些抽象的接口的深入理解\n", " - 理解复杂的代码;\n", " - 有助于构建复杂的神经网络,且清晰,不出错;\n", " \n", "- 基于 index 在 原始 tensor 上选择(select)\n", " - 行选多个(行选1个是一种特例),提供的是列值(`dim=1`);\n", " - 多个时,是几个必须一致;\n", " - 列选多个(列选1个是一种特例),提供的是行值(`dim=0`);\n", "\n", "- 要求,\n", " - input tensor 与 index 的 dim 相同\n", " - index.shape < input.shape\n", "\n", "- 应用场景\n", " - batch sample 多分类问题的分类输出,获得各个true label 上的 score(logits)/probs(行选1个)\n", " - batch sample: n*c\n", " - index(true label): n*1\n", " - dqn:batch state 的 batch action 的输出,回归问题;\n", " ```\n", " state_action_values = policy_net(state_batch).gather(1, action_batch)\n", " ```\n", " - state_batch: n*d\n", " - action_batch: n*1\n", " - state_action_values: n*1" ] }, { "cell_type": "markdown", "id": "78106c0f", "metadata": {}, "source": [ "## 官网示例" ] }, { "cell_type": "code", "execution_count": 16, "id": "ce2df535", "metadata": { "ExecuteTime": { "end_time": "2023-03-15T15:28:55.429980Z", "start_time": "2023-03-15T15:28:55.425003Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[1, 2],\n", " [3, 4]])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t = torch.tensor([[1, 2], \n", " [3, 4]])\n", "t" ] }, { "cell_type": "code", "execution_count": 17, "id": "1d63e435", "metadata": { "ExecuteTime": { "end_time": "2023-03-15T15:29:32.016891Z", "start_time": "2023-03-15T15:29:32.011386Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[1, 1],\n", " [4, 3]])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.gather(t, 1, torch.tensor([[0, 0], \n", " [1, 0]]))" ] }, { "cell_type": "markdown", "id": "89f41a09", "metadata": {}, "source": [ "[0, 0]\\\n", "[1, 0] => \n", "> [(0, 0), (0, 0)]\\\n", "> [(1, 1), (1, 0)]" ] }, { "cell_type": "markdown", "id": "913c87ac", "metadata": {}, "source": [ "## 复杂例子" ] }, { "cell_type": "code", "execution_count": 18, "id": "0fac917e", "metadata": { "ExecuteTime": { "end_time": "2023-03-15T15:30:24.244197Z", "start_time": "2023-03-15T15:30:24.238490Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[1, 2, 3],\n", " [4, 5, 6],\n", " [7, 8, 9]])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t = torch.arange(1, 10).view(3, 3)\n", "t" ] }, { "cell_type": "markdown", "id": "5d90c9d6", "metadata": {}, "source": [ "- 如果想选择 (1,3), (5, 6)\n", " - 0, 2\n", " - 1, 2" ] }, { "cell_type": "code", "execution_count": 19, "id": "c422f32e", "metadata": { "ExecuteTime": { "end_time": "2023-03-15T15:31:16.503161Z", "start_time": "2023-03-15T15:31:16.497969Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[1, 3],\n", " [5, 6]])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.gather(1, torch.tensor(((0, 2), (1, 2))))" ] }, { "cell_type": "markdown", "id": "fdfd1c88", "metadata": {}, "source": [ "- 如果想选择 (1, 4), (2, 8),提供的是行值\n", " - (0, 1), \n", " - (0, 2)" ] }, { "cell_type": "code", "execution_count": 20, "id": "b4199432", "metadata": { "ExecuteTime": { "end_time": "2023-03-15T15:32:01.607615Z", "start_time": "2023-03-15T15:32:01.602219Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[1, 2],\n", " [4, 8]])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t.gather(0, torch.tensor(((0, 0), (1, 2))))" ] }, { "cell_type": "code", "execution_count": null, "id": "1a02c334", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }