summaryrefslogtreecommitdiff
path: root/learn_torch/basics
diff options
context:
space:
mode:
Diffstat (limited to 'learn_torch/basics')
-rw-r--r--learn_torch/basics/torch.gather.ipynb287
1 files changed, 287 insertions, 0 deletions
diff --git a/learn_torch/basics/torch.gather.ipynb b/learn_torch/basics/torch.gather.ipynb
new file mode 100644
index 0000000..f1d1c57
--- /dev/null
+++ b/learn_torch/basics/torch.gather.ipynb
@@ -0,0 +1,287 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "3e1b8170",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-15T14:53:42.582632Z",
+ "start_time": "2023-03-15T14:53:40.053444Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import torch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "895b4a69",
+ "metadata": {},
+ "source": [
+ "## 基本介绍"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "55041716",
+ "metadata": {},
+ "source": [
+ "- 复杂的接口,一些抽象的接口的深入理解\n",
+ " - 理解复杂的代码;\n",
+ " - 有助于构建复杂的神经网络,且清晰,不出错;\n",
+ " \n",
+ "- 基于 index 在 原始 tensor 上选择(select)\n",
+ " - 行选多个(行选1个是一种特例),提供的是列值(`dim=1`);\n",
+ " - 多个时,是几个必须一致;\n",
+ " - 列选多个(列选1个是一种特例),提供的是行值(`dim=0`);\n",
+ "\n",
+ "- 要求,\n",
+ " - input tensor 与 index 的 dim 相同\n",
+ " - index.shape < input.shape\n",
+ "\n",
+ "- 应用场景\n",
+ " - batch sample 多分类问题的分类输出,获得各个true label 上的 score(logits)/probs(行选1个)\n",
+ " - batch sample: n*c\n",
+ " - index(true label): n*1\n",
+ " - dqn:batch state 的 batch action 的输出,回归问题;\n",
+ " ```\n",
+ " state_action_values = policy_net(state_batch).gather(1, action_batch)\n",
+ " ```\n",
+ " - state_batch: n*d\n",
+ " - action_batch: n*1\n",
+ " - state_action_values: n*1"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "78106c0f",
+ "metadata": {},
+ "source": [
+ "## 官网示例"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "ce2df535",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-15T15:28:55.429980Z",
+ "start_time": "2023-03-15T15:28:55.425003Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[1, 2],\n",
+ " [3, 4]])"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t = torch.tensor([[1, 2], \n",
+ " [3, 4]])\n",
+ "t"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "1d63e435",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-15T15:29:32.016891Z",
+ "start_time": "2023-03-15T15:29:32.011386Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[1, 1],\n",
+ " [4, 3]])"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.gather(t, 1, torch.tensor([[0, 0], \n",
+ " [1, 0]]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "89f41a09",
+ "metadata": {},
+ "source": [
+ "[0, 0]\\\n",
+ "[1, 0] => \n",
+ "> [(0, 0), (0, 0)]\\\n",
+ "> [(1, 1), (1, 0)]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "913c87ac",
+ "metadata": {},
+ "source": [
+ "## 复杂例子"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "0fac917e",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-15T15:30:24.244197Z",
+ "start_time": "2023-03-15T15:30:24.238490Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[1, 2, 3],\n",
+ " [4, 5, 6],\n",
+ " [7, 8, 9]])"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t = torch.arange(1, 10).view(3, 3)\n",
+ "t"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5d90c9d6",
+ "metadata": {},
+ "source": [
+ "- 如果想选择 (1,3), (5, 6)\n",
+ " - 0, 2\n",
+ " - 1, 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "c422f32e",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-15T15:31:16.503161Z",
+ "start_time": "2023-03-15T15:31:16.497969Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[1, 3],\n",
+ " [5, 6]])"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t.gather(1, torch.tensor(((0, 2), (1, 2))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fdfd1c88",
+ "metadata": {},
+ "source": [
+ "- 如果想选择 (1, 4), (2, 8),提供的是行值\n",
+ " - (0, 1), \n",
+ " - (0, 2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "b4199432",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-15T15:32:01.607615Z",
+ "start_time": "2023-03-15T15:32:01.602219Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[1, 2],\n",
+ " [4, 8]])"
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t.gather(0, torch.tensor(((0, 0), (1, 2))))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1a02c334",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.13"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": false
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}