From 37506402a98eba9bf9d06760a1010fa17adb39e4 Mon Sep 17 00:00:00 2001 From: chzhang Date: Mon, 20 Feb 2023 23:43:44 +0800 Subject: sarsa --- rl/tutorials/05_value_iteration_sarsa.ipynb | 2034 +++++++++++++++++++++++++++ 1 file changed, 2034 insertions(+) create mode 100644 rl/tutorials/05_value_iteration_sarsa.ipynb (limited to 'rl/tutorials') diff --git a/rl/tutorials/05_value_iteration_sarsa.ipynb b/rl/tutorials/05_value_iteration_sarsa.ipynb new file mode 100644 index 0000000..2f28a1b --- /dev/null +++ b/rl/tutorials/05_value_iteration_sarsa.ipynb @@ -0,0 +1,2034 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:09:34.179156Z", + "start_time": "2023-02-20T15:09:33.559589Z" + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import gym\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:09:36.809994Z", + "start_time": "2023-02-20T15:09:36.373146Z" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig = plt.figure(figsize=(5, 5))\n", + "ax = plt.gca()\n", + "ax.set_xlim(0, 3)\n", + "ax.set_ylim(0, 3)\n", + "\n", + "# plt.plot([1, 1], [0, 1], color='red', linewidth=2)\n", + "# plt.plot([1, 2], [2, 2], color='red', linewidth=2)\n", + "# plt.plot([2, 2], [2, 1], color='red', linewidth=2)\n", + "# plt.plot([2, 3], [1, 1], color='red', linewidth=2)\n", + "\n", + "plt.plot([2, 3], [1, 1], color='red', linewidth=2)\n", + "plt.plot([0, 1], [1, 1], color='red', linewidth=2)\n", + "plt.plot([1, 1], [1, 2], color='red', linewidth=2)\n", + "plt.plot([1, 2], [2, 2], color='red', linewidth=2)\n", + "\n", + "plt.text(0.5, 2.5, 'S0', size=14, ha='center')\n", + "plt.text(1.5, 2.5, 'S1', size=14, ha='center')\n", + "plt.text(2.5, 2.5, 'S2', size=14, ha='center')\n", + "plt.text(0.5, 1.5, 'S3', size=14, ha='center')\n", + "plt.text(1.5, 1.5, 'S4', size=14, ha='center')\n", + "plt.text(2.5, 1.5, 'S5', size=14, ha='center')\n", + "plt.text(0.5, 0.5, 'S6', size=14, ha='center')\n", + "plt.text(1.5, 0.5, 'S7', size=14, ha='center')\n", + "plt.text(2.5, 0.5, 'S8', size=14, ha='center')\n", + "plt.text(0.5, 2.3, 'START', ha='center')\n", + "plt.text(2.5, 0.3, 'GOAL', ha='center')\n", + "# plt.axis('off')\n", + "plt.tick_params(axis='both', which='both', \n", + " bottom=False, top=False, \n", + " right=False, left=False,\n", + " labelbottom=False, labelleft=False\n", + " )\n", + "line, = ax.plot([0.5], [2.5], marker='o', color='g', markersize=60)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-19T01:21:55.938956Z", + "start_time": "2023-02-19T01:21:55.936601Z" + } + }, + "source": [ + "## basics" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- 基本概念及术语\n", + " - algorithm\n", + " - policy iteration\n", + " - policy gradient\n", + " - value iteration\n", + " - sarsa: state, action, reward, state, action\n", + " - Q-table:待学习(learning/iteration update)\n", + " - row index: state; column index: action;\n", + " - 不是概率分布,是 value\n", + " - q-learning\n", + " - reward\n", + " - 特定时间 $t$ 给到的奖励 $R_t$ 称为即时奖励(immediate reward)\n", + " - 未来的总奖励 $G_t$\n", + " - $G_t=R_{t+1}+R_{t+2}+R_{t+3}+\\cdots$\n", + " - $G_t=R_{t+1}+\\gamma R_{t+2}+\\gamma^2R_{t+3} + \\cdots + \\gamma^kR_{t+k+1}\\cdots$\n", + " - 举例\n", + " - $Q_\\pi(s=7,a=1)=R_{t+1}=1$\n", + " - $Q_\\pi(s=7,a=0)=\\gamma^2$\n", + " - action value,state value\n", + " - bellman equation\n", + " - 适用于状态价值函数(state value function),也适用于动作价值函数(action value function)\n", + " - mdp:markov decision process\n", + " - 马尔可夫性\n", + " - $p(s_{t+1}|s_t)=p(s_{t+1}|{s_1,s_2,s_3,\\cdots,s_t})$\n", + " - bellman equation 成立的前提条件" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-19T01:22:33.918570Z", + "start_time": "2023-02-19T01:22:33.812489Z" + } + }, + "source": [ + "- $R_t$\n", + "- $Q_{\\pi}(s,a)$:state action value function\n", + " - Q table\n", + " - 通过 Sarsa 算法迭代更新 $Q_{\\pi}(s,a)$\n", + "- 对于强化学习而言\n", + " - state, action, reward 都是需要精心设计的" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sarsa(state action reward state action)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:20:32.328550Z", + "start_time": "2023-02-20T15:20:32.322841Z" + } + }, + "outputs": [], + "source": [ + "# border & barrier\n", + "# ↑, →, ↓, ←(顺时针)\n", + "# row index: given state\n", + "# col index: posible action\n", + "# (state, action) matrix\n", + "# 跟环境对齐\n", + "theta_0 = np.asarray([[np.nan, 1, 1, np.nan], # s0\n", + " [np.nan, 1, np.nan, 1], # s1\n", + " [np.nan, np.nan, 1, 1], # s2\n", + " [1, np.nan, np.nan, np.nan], # s3 \n", + " [np.nan, 1, 1, np.nan], # s4\n", + " [1, np.nan, np.nan, 1], # s5\n", + " [np.nan, 1, np.nan, np.nan], # s6 \n", + " [1, 1, np.nan, 1]] # s7\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:20:33.445898Z", + "start_time": "2023-02-20T15:20:33.443279Z" + } + }, + "outputs": [], + "source": [ + "n_states, n_actions = theta_0.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:22:28.342788Z", + "start_time": "2023-02-20T15:22:28.340225Z" + } + }, + "outputs": [], + "source": [ + "# Q table, 状态是离散的(s0-s7),动作也是离散的(上右下左)\n", + "Q = np.random.rand(n_states, n_actions) * theta_0" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:22:29.712906Z", + "start_time": "2023-02-20T15:22:29.702001Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ nan, 0.05711349, 0.7487308 , nan],\n", + " [ nan, 0.21055329, nan, 0.34236513],\n", + " [ nan, nan, 0.58900823, 0.12580755],\n", + " [0.19742577, nan, nan, nan],\n", + " [ nan, 0.46973656, 0.03069312, nan],\n", + " [0.90690458, nan, nan, 0.67280829],\n", + " [ nan, 0.18346137, nan, nan],\n", + " [0.40679226, 0.88075109, nan, 0.38732487]])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Q" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### $\\epsilon$-greedy (explore, exploit)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:23:51.690714Z", + "start_time": "2023-02-20T15:23:51.684199Z" + } + }, + "outputs": [], + "source": [ + "# 基于占比, 最 naive 的概率化方式\n", + "def cvt_theta_0_to_pi(theta):\n", + " m, n = theta.shape\n", + " pi = np.zeros((m, n))\n", + " for r in range(m):\n", + " pi[r, :] = theta[r, :] / np.nansum(theta[r, :])\n", + " return np.nan_to_num(pi)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:23:53.973666Z", + "start_time": "2023-02-20T15:23:53.970002Z" + } + }, + "outputs": [], + "source": [ + "pi_0 = cvt_theta_0_to_pi(theta_0)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:23:59.386828Z", + "start_time": "2023-02-20T15:23:59.382946Z" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0. , 0.5 , 0.5 , 0. ],\n", + " [0. , 0.5 , 0. , 0.5 ],\n", + " [0. , 0. , 0.5 , 0.5 ],\n", + " [1. , 0. , 0. , 0. ],\n", + " [0. , 0.5 , 0.5 , 0. ],\n", + " [0.5 , 0. , 0. , 0.5 ],\n", + " [0. , 1. , 0. , 0. ],\n", + " [0.33333333, 0.33333333, 0. , 0.33333333]])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pi_0" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:24:58.768881Z", + "start_time": "2023-02-20T15:24:58.763948Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ nan, 0.05711349, 0.7487308 , nan],\n", + " [ nan, 0.21055329, nan, 0.34236513],\n", + " [ nan, nan, 0.58900823, 0.12580755],\n", + " [0.19742577, nan, nan, nan],\n", + " [ nan, 0.46973656, 0.03069312, nan],\n", + " [0.90690458, nan, nan, 0.67280829],\n", + " [ nan, 0.18346137, nan, nan],\n", + " [0.40679226, 0.88075109, nan, 0.38732487]])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Q" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-19T01:33:27.896249Z", + "start_time": "2023-02-19T01:33:27.893096Z" + } + }, + "outputs": [], + "source": [ + "# epsilon-greedy\n", + "def get_action(s, Q, eps, pi_0):\n", + " action_space = list(range(4))\n", + " # eps, explore\n", + " if np.random.rand() < eps:\n", + " action = np.random.choice(action_space, p=pi_0[s, :])\n", + " else:\n", + " # 1-eps, exploit\n", + " action = np.nanargmax(Q[s, :])\n", + " return action" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Sarsa (update $Q_\\pi(s,a)$)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "理想情况下:\n", + "$$\n", + "Q(s_t,a_t) = R_{t+1}+\\gamma Q(s_{t+1}, a_{t+1})\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- td (temporal difference error)\n", + " - $R_{t+1}+\\gamma Q(s_{t+1},a_{t+1})-Q(s_t,a_t)$\n", + "- final update equation\n", + " - $Q(s_t,a_t)=Q(s_t,a_t)+\\eta\\cdot(R_{t+1}+\\gamma Q(s_{t+1},a_{t+1})-Q(s_t,a_t))$\n", + " - $s_t,a_t,r_{t+1},s_{t+1},a_{t+1}$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- 折扣(discount factor,$\\gamma$)\n", + " - 有助于缩短步数(更快地结束任务);" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-19T02:15:30.250510Z", + "start_time": "2023-02-19T02:15:30.247289Z" + } + }, + "outputs": [], + "source": [ + "def sarsa(s, a, r, s_next, a_next, Q, eta, gamma):\n", + " if s_next == 8:\n", + " Q[s, a] = Q[s, a] + eta * (r - Q[s, a])\n", + " else:\n", + " Q[s, a] = Q[s, a] + eta * (r + gamma * Q[s_next, a_next] - Q[s, a])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 解决 maze 问题" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:28:38.124516Z", + "start_time": "2023-02-20T15:28:38.116963Z" + } + }, + "outputs": [], + "source": [ + "# 维护着状态,以及 step 函数的返回\n", + "class MazeEnv(gym.Env):\n", + " def __init__(self):\n", + " self.state = 0\n", + " pass\n", + " \n", + " def reset(self):\n", + " self.state = 0\n", + " return self.state\n", + " \n", + " def step(self, action):\n", + " if action == 0:\n", + " self.state -= 3\n", + " elif action == 1:\n", + " self.state += 1\n", + " elif action == 2:\n", + " self.state += 3\n", + " elif action == 3:\n", + " self.state -= 1\n", + " done = False\n", + " reward = 0\n", + " if self.state == 8:\n", + " done = True\n", + " reward = 1\n", + " # state, reward, done, _\n", + " return self.state, reward, done, {}" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:33:26.797516Z", + "start_time": "2023-02-20T15:33:26.783342Z" + } + }, + "outputs": [], + "source": [ + "# 动作策略选择,基于当前环境的状态\n", + "class Agent:\n", + " def __init__(self):\n", + " self.action_space = list(range(4))\n", + " self.theta_0 = np.asarray([[np.nan, 1, 1, np.nan], # s0\n", + " [np.nan, 1, np.nan, 1], # s1\n", + " [np.nan, np.nan, 1, 1], # s2\n", + " [1, np.nan, np.nan, np.nan], # s3 \n", + " [np.nan, 1, 1, np.nan], # s4\n", + " [1, np.nan, np.nan, 1], # s5\n", + " [np.nan, 1, np.nan, np.nan], # s6 \n", + " [1, 1, np.nan, 1]] # s7\n", + " )\n", + " self.pi = self._cvt_theta_to_pi()\n", + "# self.pi = self._softmax_cvt_theta_to_pi()\n", + "# self.theta = self.theta_0\n", + "\n", + " self.Q = np.random.rand(*self.theta_0.shape) * self.theta_0\n", + " self.eta = 0.1\n", + " self.gamma = 0.9\n", + " self.eps = 0.5\n", + " \n", + " def _cvt_theta_to_pi(self):\n", + " m, n = self.theta_0.shape\n", + " pi = np.zeros((m, n))\n", + " for r in range(m):\n", + " pi[r, :] = self.theta_0[r, :] / np.nansum(self.theta_0[r, :])\n", + " return np.nan_to_num(pi)\n", + " \n", + "# def _softmax_cvt_theta_to_pi(self, beta=1.):\n", + "# m, n = self.theta.shape\n", + "# pi = np.zeros((m, n))\n", + "# exp_theta = np.exp(self.theta*beta)\n", + "# for r in range(m):\n", + "# pi[r, :] = exp_theta[r, :] / np.nansum(exp_theta[r, :])\n", + "# return np.nan_to_num(pi)\n", + " \n", + " def get_action(self, s):\n", + " # eps, explore\n", + " if np.random.rand() < self.eps:\n", + " action = np.random.choice(self.action_space, p=self.pi[s, :])\n", + " else:\n", + " # 1-eps, exploit\n", + " action = np.nanargmax(self.Q[s, :])\n", + " return action\n", + " \n", + " def sarsa(self, s, a, r, s_next, a_next):\n", + " if s_next == 8:\n", + " self.Q[s, a] = self.Q[s, a] + self.eta * (r - self.Q[s, a])\n", + " else:\n", + " self.Q[s, a] = self.Q[s, a] + self.eta * (r + self.gamma * self.Q[s_next, a_next] - self.Q[s, a])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 完成训练及更新" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:38:08.370694Z", + "start_time": "2023-02-20T15:38:08.296970Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1 0.48151692420235004 25\n", + "2 0.788238427054541 51\n", + "3 0.5193178707316495 73\n", + "4 0.09357958369502378 23\n", + "5 0.07459493697310315 7\n", + "6 0.06891781188632101 7\n", + "7 0.06585711141363843 7\n", + "8 0.06265613413272358 7\n", + "9 0.060100868881304614 7\n", + "10 0.059287337144400654 7\n", + "11 0.058488284106318955 7\n", + "12 0.05750697273897609 7\n", + "13 0.05634890213948618 7\n", + "14 0.055022881405307755 7\n", + "15 0.0535403609535367 7\n", + "16 0.051914820609968326 7\n", + "17 0.0501612179940748 7\n", + "18 0.048295498355812816 7\n", + "19 0.04633416520669903 7\n", + "20 0.04509458086047741 7\n", + "21 0.0443781660421414 7\n", + "22 0.043566741048947755 7\n", + "23 0.04266900608905749 7\n", + "24 0.04169365931899821 7\n", + "25 0.04064935278238613 7\n", + "26 0.039544641494567434 7\n", + "27 0.03877432598312053 7\n", + "28 0.03836919489516444 7\n", + "29 0.037932322993017575 7\n", + "30 0.03746384259607333 7\n", + "31 0.036964222794984625 7\n", + "32 0.03643423557551101 7\n", + "33 0.03587492225074168 7\n", + "34 0.03528756047376702 7\n", + "35 0.034673632089539 7\n", + "36 0.034034792062826946 7\n", + "37 0.03337283869172736 7\n", + "38 0.032689685285438086 7\n", + "39 0.031987333452767774 7\n", + "40 0.03126784811552158 7\n", + "41 0.030533334329512607 7\n", + "42 0.02978591596630603 7\n", + "43 0.02902771628138473 7\n", + "44 0.02826084036961629 7\n", + "45 0.0274873594868823 7\n", + "46 0.026709297197557824 7\n", + "47 0.02592861729120166 7\n", + "48 0.025147213398248935 7\n", + "49 0.024366900223515675 7\n", + "50 0.023589406307802152 7\n", + "51 0.022816368221570238 7\n", + "52 0.022049326090399224 7\n", + "53 0.021289720349440266 7\n", + "54 0.020538889623210932 7\n", + "55 0.019798069627555426 7\n", + "56 0.01906839299226787 7\n", + "57 0.018350889905525525 7\n", + "58 0.01764648948474906 7\n", + "59 0.01695602178262262 7\n", + "60 0.016280220341618468 7\n", + "61 0.01561972521537619 7\n", + "62 0.014975086380526359 7\n", + "63 0.01434676746795055 7\n", + "64 0.013735149747947817 7\n", + "65 0.013140536309220685 7\n", + "66 0.012563156376991624 7\n", + "67 0.012003169720803564 7\n", + "68 0.011460671107647646 7\n", + "69 0.01093569476093581 7\n", + "70 0.010428218790478228 7\n", + "71 0.009938169563017962 7\n", + "72 0.009465425986993514 7\n", + "73 0.009009823689056895 7\n", + "74 0.00857115906343553 7\n", + "75 0.00814919317852647 7\n", + "76 0.007743655528121618 7\n", + "77 0.0073542476174153215 7\n", + "78 0.006980646376435584 7\n", + "79 0.006622507395769972 7\n", + "80 0.0062794679814730525 7\n", + "81 0.005951150027797558 7\n", + "82 0.00563716270797221 7\n", + "83 0.005337104984597274 7\n", + "84 0.005050567942414208 7\n", + "85 0.004777136947210048 7\n", + "86 0.004516393635464522 7\n", + "87 0.004267917740051419 7\n", + "88 0.004031288757878282 7\n", + "89 0.0038060874658013555 7\n", + "90 0.0035918972914945613 7\n", + "91 0.0033883055462049505 7\n", + "92 0.003194904526485076 7\n", + "93 0.00301129249208143 7\n", + "94 0.002837074527182959 7\n", + "95 0.002671863292188603 7\n", + "96 0.0025152796730716354 7\n", + "97 0.002366953335290245 7\n", + "98 0.002226523189017504 7\n", + "99 0.002093637772283996 7\n", + "100 0.001967955558394241 7\n", + "101 0.0018491451937413528 7\n" + ] + } + ], + "source": [ + "maze = MazeEnv()\n", + "agent = Agent()\n", + "epoch = 0\n", + "while True:\n", + " old_Q = np.nanmax(agent.Q, axis=1)\n", + " s = maze.reset()\n", + " a = agent.get_action(s)\n", + " s_a_history = [[s, np.nan]]\n", + " while True:\n", + " # s, a \n", + " s_a_history[-1][1] = a\n", + " s_next, reward, done, _ = maze.step(a, )\n", + " # s_next, a_next\n", + " s_a_history.append([s_next, np.nan])\n", + " if done:\n", + " a_next = np.nan\n", + " else:\n", + " a_next = agent.get_action(s_next)\n", + "# print(s, a, reward, s_next, a_next)\n", + " agent.sarsa(s, a, reward, s_next, a_next)\n", + "# print(agent.pi)\n", + " if done:\n", + " break\n", + " else:\n", + " a = a_next\n", + " s = maze.state\n", + "\n", + " # s_s_history, agent.Q\n", + " update = np.sum(np.abs(np.nanmax(agent.Q, axis=1) - old_Q))\n", + " epoch +=1\n", + " agent.eps /= 2\n", + " print(epoch, update, len(s_a_history))\n", + " if epoch > 100 or update < 1e-5:\n", + " break\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:38:33.582285Z", + "start_time": "2023-02-20T15:38:33.568692Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ nan, 0.5741703 , 0.28654803, nan],\n", + " [ nan, 0.64891149, nan, 0.31090434],\n", + " [ nan, nan, 0.72647909, 0.04574432],\n", + " [0.24347619, nan, nan, nan],\n", + " [ nan, 0.34181286, 0.89988623, nan],\n", + " [0.21120543, nan, nan, 0.80934131],\n", + " [ nan, 0.61078998, nan, nan],\n", + " [0.19954603, 0.99999028, nan, 0.64883227]])" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.Q" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 可视化" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:38:57.716413Z", + "start_time": "2023-02-20T15:38:57.705590Z" + } + }, + "outputs": [], + "source": [ + "from matplotlib import animation\n", + "from IPython.display import HTML" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:38:58.576025Z", + "start_time": "2023-02-20T15:38:58.572588Z" + } + }, + "outputs": [], + "source": [ + "def init():\n", + " line.set_data([], [])\n", + " return (line, )\n", + "def animate(i):\n", + " state = s_a_history[i][0]\n", + " x = (state % 3)+0.5\n", + " y = 2.5 - int(state/3)\n", + " line.set_data(x, y)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:39:00.248954Z", + "start_time": "2023-02-20T15:39:00.244740Z" + } + }, + "outputs": [], + "source": [ + "anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(s_a_history), interval=200, repeat=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "ExecuteTime": { + "end_time": "2023-02-20T15:39:06.778021Z", + "start_time": "2023-02-20T15:39:06.064615Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + " \n", + "
\n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "HTML(anim.to_jshtml())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} -- cgit v1.2.3