From dbdbb9cc0105b84b556d00ace63776a932fbab64 Mon Sep 17 00:00:00 2001 From: chzhang Date: Sun, 11 Dec 2022 18:51:27 +0800 Subject: =?UTF-8?q?MazeEnv=EF=BC=8CAgent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rl/tutorials/03_MazeEnv_Agent.ipynb | 11751 ++++++++++++++++++++++++++++++++++ rl/tutorials/save/maze.mp4 | Bin 0 -> 9399 bytes rl/tutorials/save/maze_0.mp4 | Bin 0 -> 8645 bytes rl/tutorials/save/maze_2.mp4 | Bin 0 -> 38351 bytes 4 files changed, 11751 insertions(+) create mode 100644 rl/tutorials/03_MazeEnv_Agent.ipynb create mode 100644 rl/tutorials/save/maze.mp4 create mode 100644 rl/tutorials/save/maze_0.mp4 create mode 100644 rl/tutorials/save/maze_2.mp4 (limited to 'rl') diff --git a/rl/tutorials/03_MazeEnv_Agent.ipynb b/rl/tutorials/03_MazeEnv_Agent.ipynb new file mode 100644 index 0000000..a794b0a --- /dev/null +++ b/rl/tutorials/03_MazeEnv_Agent.ipynb @@ -0,0 +1,11751 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2022-12-11T10:29:37.308054Z", + "start_time": "2022-12-11T10:29:36.493980Z" + } + }, + "outputs": [], + "source": [ + "import gym\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2022-12-11T10:34:32.634101Z", + "start_time": "2022-12-11T10:34:32.628223Z" + } + }, + "outputs": [], + "source": [ + "# 维护着状态,以及 step 函数的返回\n", + "class MazeEnv(gym.Env):\n", + " def __init__(self):\n", + " self.state = 0\n", + " pass\n", + " \n", + " def reset(self):\n", + " self.state = 0\n", + " return self.state\n", + " \n", + " def step(self, action):\n", + " if action == 0:\n", + " self.state -= 3\n", + " elif action == 1:\n", + " self.state += 1\n", + " elif action == 2:\n", + " self.state += 3\n", + " elif action == 3:\n", + " self.state -= 1\n", + " done = False\n", + " if self.state == 8:\n", + " done = True\n", + " return self.state, 1, done, {}" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2022-12-11T10:36:52.035349Z", + "start_time": "2022-12-11T10:36:52.024615Z" + } + }, + "outputs": [], + "source": [ + "# 动作策略选择,基于当前环境的状态\n", + "class Agent:\n", + " def __init__(self):\n", + " self.actions = list(range(4))\n", + " self.theta_0 = np.asarray([[np.nan, 1, 1, np.nan], # s0\n", + " [np.nan, 1, np.nan, 1], # s1\n", + " [np.nan, np.nan, 1, 1], # s2\n", + " [1, np.nan, np.nan, np.nan], # s3 \n", + " [np.nan, 1, 1, np.nan], # s4\n", + " [1, np.nan, np.nan, 1], # s5\n", + " [np.nan, 1, np.nan, np.nan], # s6 \n", + " [1, 1, np.nan, 1]] # s7\n", + " )\n", + " self.pi = self._cvt_theta_0_to_pi(self.theta_0)\n", + " \n", + " def _cvt_theta_0_to_pi(self, theta):\n", + " m, n = theta.shape\n", + " pi = np.zeros((m, n))\n", + " for r in range(m):\n", + " pi[r, :] = theta[r, :] / np.nansum(theta[r, :])\n", + " return np.nan_to_num(pi)\n", + " \n", + " def choose_action(self, state):\n", + " action = np.random.choice(self.actions, p=self.pi[state, :])\n", + " return action" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "ExecuteTime": { + "end_time": "2022-12-11T10:39:18.313013Z", + "start_time": "2022-12-11T10:39:18.308837Z" + } + }, + "outputs": [], + "source": [ + "env = MazeEnv()\n", + "state = env.reset()\n", + "agent = Agent()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "ExecuteTime": { + "end_time": "2022-12-11T10:39:19.093662Z", + "start_time": "2022-12-11T10:39:19.086025Z" + } + }, + "outputs": [], + "source": [ + "done = False\n", + "action_history = []\n", + "state_history = [state]\n", + "while not done:\n", + " action = agent.choose_action(state)\n", + " state, reward, done, _ = env.step(action)\n", + " action_history.append(action)\n", + " state_history.append(state)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "ExecuteTime": { + "end_time": "2022-12-11T10:39:44.186760Z", + "start_time": "2022-12-11T10:39:43.891921Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAASUAAAEeCAYAAADM2gMZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAanklEQVR4nO3df1RUZf4H8PedX/waCL5JyI8EXVYJvrEWyAHtZAqtVIe1JNwDW4kUq6uVJ+nYuuvWVtseM7GjK99Wzipt2lpqqdBWKyZh/soFNV1/lK6ZirYgScqPgRnm+f4xwgoBAzgz95nh/TpnTgv3mXs/80Tvfe4z9z5XEUKAiEgWGrULICK6HkOJiKTCUCIiqTCUiEgqDCUikgpDiYikoutr47Bhw0RUVJSLSiGioaK6uvqSECK4p219hlJUVBSqqqqcUxURDVmKonzT2zaevhGRVBhKRCQVhhIRSYWhRERSYSgRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVPpcJUBmQgjUXK1B9YVq7K/Zj8pvKnGs7hhaLC2wWC1ot7ZDq9FCp9HBR+eD2OBYTIyciKTwJCSEJSDcPxyKoqj9MYioG7cKJauw4pPTn2DZvmXYfXY3LFYL9Fo9GtsaYRXWH7S3WC2wWC0wWUzYfW439p7fC6PBiLb2Nug1ekwYMQHzk+cjdVQqNAoHjUQycItQutxyGWsOrkHh3kJcbbuKxrbGzm0tlpZ+78cqrLjSegUAYIIJH5/6GLvO7oK/wR8FKQXIuyMPQT5BDq+fiPpP6ethlImJiULNRd7OXzmPBeULsPnEZmgUDZrNzU47lq/eF1ZhxbSYaXj13lcRERDhtGMRDXWKolQLIRJ72iblOYsQAqsPrkbMyhhsPLoRJovJqYEEAM3mZpgsJmw4ugExK2Ow+uBq8OnBRK4nXSjVXKnBpL9OwryP5qHJ3ASLsLj0+BZhQZO5CfM+modJf52Emis1Lj0+0VAnVSiVHCpBzMoY7D63G03mJlVraTI3Yfe53YgpikHJoRJVayEaSqQIJSEEnvn4GTz54ZNoNDfCYnXt6Kg3FqsFjW2NePLDJzH/H/N5OkfkAqqHUru1HblbclF8oNjp80aD1WxuxqrqVZi5dSbare1ql0Pk0VS9JEAIgbytedh0fJO0gdSh2dyMjcc2AgBKppbwwksiJ1F1pDT/H/Px3vH3pA+kDh3BVLCtQO1SiDyWaqFUcqgExQeKVZ/QHqiOUzlOfhM5hyqhVHOlBk9/+LTbjJC6azY34+mPnublAkRO4PJQEkIg5/0cmNpNrj60Q7VaWvGL93/Bb+SIHMzlobTm0BpUX6iW5mv/wTJbzai6UMXTOCIHc2konb9yvvNKbU/QZG7CvI/n8TSOyIFcGkoLyheg1dLqykM6ncliwoLyBWqXQeQxXBZKl1suY/OJzS6/l83ZLFYL3j/xPi63XFa7FCKP4LJQWnNwjccupKZRNJxbInIQl6SEVVhRuLfQbS8BsKfZ3IzCPYU9rn5JRAPjklD65PQnuNp21fE7bgLwAYDXAbwM4DUAfwXw72vbBYAKAEsB/AFACYBax5cBAFfarmDH1zucs3OJ1NXVYc6cOYiKioKXlxdCQkKQmpqK8vJyAMD777+PKVOmIDg4GIqi4NNPP1W3YA/QV5+bzWY899xziI+Ph5+fH0JDQ5GTk4OzZ8+qXfagueTet2X7lnVZwtZh3gVgBjAVwP/AFlJnAHQMyHYD2AvgQQA3A6gE8BaApwB4ObaUxrZGFO4tRNqoNMfuWDKZmZlobm7G6tWrER0djdraWlRWVqK+vh4A0NTUhPHjx+ORRx7BY489pnK1nqGvPm9ubsaBAwfw29/+FmPHjsX333+PgoICpKen4/Dhw9Dp3GLF6y6cvhyuEAI3Lb7J8SOlFgCvAngUwI96OjCAQgBJAO6+9jszbKOpnwLocSHOGxPgFYCG5xo89mbdhoYGBAUFoby8HGlpfYfvpUuXEBwcjIqKCtxzzz2uKdADDaTPOxw7dgxxcXE4fPgwbr/9didXODiqLodbc7UGZqvZ8Ts2XHt9CVvYdHcZQCO6BpYeQCSAc44vBwDa2ttw4eoF5+xcAkajEUajEaWlpTCZ3PuKfHcxmD6/csX2cIygIPd8CIbTQ6n6QjUMWoPjd6yF7bTsMIDFAP4C4B8Azl/b3nG26NftfX7XbXMwg9aA6ovVztm5BHQ6Hd58802sW7cOgYGBSElJwbPPPovPP/9c7dI81kD7vK2tDQUFBcjIyEBEhHs+/MLpobS/Zr9z5pMAIBZAAYAcANGwjYD+AmCncw5nT1NbE/bX7Ffn4C6SmZmJCxcuoKysDPfddx/27NmD5ORk/PGPf1S7NI/V3z63WCx45JFH0NDQgJIS971ExelzSnetuQu7z+2+oX0MyFYAXwCYA2AlgHwA4ddtfxuAL4CHnHP4u0bchc9mfuacnUvqiSeewFtvvYXGxkYYDLZRMeeUnKt7n1ssFmRnZ+PIkSP49NNPMXz4cLVL7JOqc0rH6o45+xBdBQOwAjBee/37um1mAN8AuNV5h3f555VAbGwsLBYL55lc6Po+N5vN+PnPf47Dhw+joqJC+kCyx+nfFw7kCbYD0gxgA4A7AITA9hX/BdguAxgFwBtAMoDPAAyD7ZKAnbBNjjvxC4kWs5M+rwTq6+uRlZWFvLw8xMfHw9/fH1VVVViyZAlSU1MREBCA7777DmfPnkVDQwMA4NSpUwgMDMTw4cPd/j8WNdjrc19fXzz88MP45z//ibKyMiiKgm+//RYAcNNNN8HHx0flTzBwTg8lpy1RYgAQAeBzAN8BsAAIgC1wOi4BmADb6OhD2C4hiIDtEgIHX6N0Pad80ygJo9GI5ORkLF++HKdOnUJrayvCw8ORk5ODRYsWAQBKS0sxc+bMzvfk5+cDAF544QX8/ve/V6Nst2avz8+fP4+tW7cCABISErq8t6SkBLm5uSpUfWOcPqekeVEDgaGzEJoCBdYXeLsJUV9UnVPSarTOPoRUhtrnJXI0p4eSTuN+l7nfCL1Gr3YJRG7N6aHko3O/ibYb4aMfWp+XyNGcHkqxwbHOPoRUhtrnJXI0p4fSxMiJHru4W3daRYuJkRPVLoPIrTk9LZLCk2A0GJ19GCn4GfyQFJ6kdhlEbs3poZQQloC29jZnH0YKbe1tSAhNsN+QiHrl9FAK9w8fMt9IGbQGhPmHqV0GkVtzeigpioIJIyY4+zBSGH/reI9d4I3IVVwyAz0/eb7HzysZDUYUpBSoXQaR23PJlY2po1Lhb/Af3LpKOwEcAaBce/nAdh9bG2w35QZea/cAgBEA3oDtBtys6/axGbbVATrueZsC22JwR6/9XAvglmv/+w7YbuQdoACvAEweOXngbySiLlwSShpFg4KUAjz/6fMDe8zSOQBfAZgFW6VNANphu/H2awB7APziuvZ1sK3NfRa20Lp+wct7AcRde18ZgKfx3xt3XwHwq4F+qv/y1fuiIKVgyFz6QORMLvuvKO+OvIE/F+0qbAuydUSnH2yB1JsjAOJhW5f7RC9tIgBcGVgZ9liFFTPHzrTfkIjsclkoBfkE4aGYh6BTBjA4+xGA7wGsgO35bmfstD8K4H+vvf7VS5tTAGL6X4I9Oo0O02KmIcjHPRdpJ5KNS883lty7BF66ASxm5AXbqVsGbKOkjQAO9tK2BrZRVSBsi7xdxH+f/wYA5bCF23sA7hpg4X3w1nljyb1LHLdDoiHOpaEUERCB5fcth5+++yNG+qABMBLAJAD3AzjeS7t/AbgE29NylwNo7db2Xtjmke6FbR1vB/DT+2F5+nKEB4Tbb0xE/eLymdm8sXlIDEvs35ImlwDUX/fztwBu6qGdFbZTt18BeObaKxu2OabukmCbDD81oLJ/QK/RY1z4OM4lETmYy0NJURS8Pe1teGu97Tdug+3r/JUA/g+2b9fu6aHdWQD+6DoJHnmtffcH8yqwfet2gw9Y8dJ5Yd1D63ixJJGDOX053N6UHCrBkx8+ObBLBCThq/fFyvtXcpRENEiqLofbm5ljZ+KXd/4SvnpftUoYFD+9H2YlzGIgETmJqlf7LZuyDA/f9rDbBJOv3hcPxz6Mwp8Wql0KkcdSNZQURcGaqWuQFZslfTD56n2RFZuF1T9bzXkkIidS/b4IrUaLkqklmJUwS9pg8tX7YnbCbJRMLeHTSoicTPVQAmwjpmVTlmHl/SthNBileQKKXqOH0WDEyvtXonBKIUdIRC4gRSh1mDl2Jk7MPYEJt04Y2AWWTuCn98P4W8fjxNwTnNQmciGpQgkAwgPCUTGjAivuW2EbNQ3kXjkH0Gl0MBqMWHHfClTMqODV2kQuJl0oAbbTubw78nB87nFMj5sOb503fHXOnW/y1fnCW+eN6bHTcWLuCeTdkcfTNSIVyDF504uIgAi8nfk2LrdcRsmhEizdsxRX264ObrG4XhgNRgQYAlAwvgAzx87k3f5EKlPtiu7BsAordny9A4V7C7Hn3B60tbfBoDWgsa2xX2s1aRQNjAZj5/vG3zoeBSkFmDxyMhdoI3Khvq7olnqk1J1G0SBtVBrSRqVBCIELVy+g+mI19tfsR+U3lThWdwwt5haYrWa0W9uh1Wih1+jho/dBbHAsJkZORFJ4EhJCExDmH8bTMyIJuVUoXU9RFIQHhCM8IBw/G/MztcshIgfhOQsRSYWhRERSYSgRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVBhKRCQVhhIRScVtVwnwSFxKRT19rCtGrsWREhFJhSMlGfH/tWkI40iJiKTCUCIiqTCUiEgqDCUikgpDiYikwlAiIqkwlIhIKgwlIpIKQ4mIpMJQIiKpMJSISCoMJSKSCkOJiKTCUCIiqTCUiEgqDCUikgpDiYikwlAiIqkwlIhIKgwlIpIKQ4mIpMJQIiKpMJSISCoMJSKSCkOJiKTCUCIiqTCUiEgqDCUikgpDiYikwlAiIqkwlIhIKgwlIpIKQ4mIpOLWoVRXV4c5c+YgKioKXl5eCAkJQWpqKsrLywEAv/vd7xATEwM/Pz8EBQUhNTUVe/bsUblq92avz683a9YsKIqCpUuXqlCp57DX57m5uVAUpcsrOTlZ5aoHT6d2ATciMzMTzc3NWL16NaKjo1FbW4vKykrU19cDAMaMGYOioiKMHDkSLS0teP3115Geno6TJ08iJCRE5erdk70+77Bp0ybs378fYWFhKlXqOfrT52lpaVi7dm3nzwaDQY1SHUMI0esrISFByOry5csCgCgvL+/3e77//nsBQHz88cdOrOwGALaXpPrb52fOnBFhYWHi2LFjIjIyUrz22msuqtDz9KfPZ8yYIR544AEXVnXjAFSJXnLHbU/fjEYjjEYjSktLYTKZ7LZva2tDcXExAgICMHbsWBdU6Hn60+cWiwXZ2dlYtGgRbrvtNhdX6Hn6+3e+a9cu3HLLLRg9ejTy8/NRW1vrwiodrLe0EpKPlIQQYtOmTSIoKEh4eXmJ5ORkUVBQIPbt29elTVlZmfDz8xOKooiwsDDx+eefq1RtP0g+UhLCfp//5je/ERkZGZ0/c6R04+z1+fr168XWrVvF4cOHRWlpqYiPjxdxcXHCZDKpWHXf0MdIya1DSQghWlpaxLZt28SLL74oUlJSBADxyiuvdG5vbGwUJ0+eFHv37hV5eXkiMjJSXLhwQcWK++AGoSRE731eUVEhwsLCRG1tbWdbhpJj2Ps7v15NTY3Q6XTivffec3GV/efRodTd448/LvR6vWhtbe1xe3R0tHjppZdcXFU/uUkoddfR5wsXLhSKogitVtv5AiA0Go0IDw9Xu0yPYu/vPCoqSixevNjFVfVfX6Hk1t++9SQ2NhYWiwUmk6nHbyCsVitaW1tVqMxzdfT57NmzkZOT02XblClTkJ2djfz8fJWq80x9/Z1funQJNTU1CA0NVam6G+O2oVRfX4+srCzk5eUhPj4e/v7+qKqqwpIlS5CamgoAWLRoETIyMhAaGoq6ujoUFRXh/PnzmD59usrVuyd7fT5ixIgfvEev12P48OEYM2aMChW7P3t9rtFo8OyzzyIzMxOhoaE4c+YMFi5ciFtuuQUPPfSQ2uUPituGktFoRHJyMpYvX45Tp06htbUV4eHhyMnJwaJFi6DT6XD06FGsWbMG9fX1uPnmmzFu3Djs3LkT8fHxapfvluz1OTmevT7XarU4cuQI3nrrLTQ0NCA0NBSTJk3Chg0b4O/vr3b5g6LYTu96lpiYKKqqqlxYzhCnKLZ/9vHvhMgTKIpSLYRI7Gmb216nRESeiaFERFJhKBGRVBhKRCQVhhIRSYWhRERSYSgRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVBhKRCQVhhIRSYWhRERSYSgRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVBhKRCQVhhIRSYWhRERSYSgRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVBhKRCQVhhIRSYWhRERSYSgRkVR0fW6trgYUxUWlEKmAf9/S4UiJiKTS90gpIQGoqnJRKUQqEELtCoamPkaoHCkRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVBhKRCQVhhIRSYWhRERSYSgRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVBhKRCQVhhIRSYWhRERSYSgRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVBhKRCQVtw6luro6zJkzB1FRUfDy8kJISAhSU1NRXl7e2earr77CtGnTEBgYCF9fX9x55504fvy4ilW7N3t9rihKj6+5c+eqXLn7stfnjY2NeOqppxAREQEfHx+MGTMGr7/+uspVD55O7QJuRGZmJpqbm7F69WpER0ejtrYWlZWVqK+vBwB8/fXXmDBhAh577DHs2LEDgYGBOHHiBIxGo8qVuy97fX7x4sUu7auqqpCRkYHp06erUa5HsNfn8+fPx/bt27F27VqMHDkSO3fuRH5+PoYNG4ZHH31U5eoHQQjR6yshIUHI6vLlywKAKC8v77VNdna2yMnJcWFVnq0/fd7dE088IUaPHu3Eqjxbf/o8Li5OPP/8811+d/fdd4u5c+c6u7xBA1Aleskdtz19MxqNMBqNKC0thclk+sF2q9WKsrIyxMbGIj09HcHBwRg3bhzeffddFar1DPb6vLvGxka88847yM/Pd0F1nqk/fX7XXXehrKwM586dAwDs2bMHhw4dQnp6uitLdZze0kpIPlISQohNmzaJoKAg4eXlJZKTk0VBQYHYt2+fEEKIixcvCgDC19dXFBYWioMHD4rCwkKh1WrFBx98oHLl7quvPu9u1apVwmAwiNraWhdX6Vns9Xlra6vIzc0VAIROpxM6nU688cYbKlZsH/oYKbl1KAkhREtLi9i2bZt48cUXRUpKigAgXnnlFVFTUyMAiOzs7C7ts7OzRXp6ukrVeobe+ry7xMREkZWVpUKFnqevPl+6dKkYPXq0KC0tFV988YX405/+JPz8/MRHH32kctW98+hQ6u7xxx8Xer1etLa2Cp1OJ15++eUu21966SURGxurUnWe6fo+73Dw4EEBQGzbtk3FyjxXR583NDQIvV4vtmzZ8oPtqampKlVnX1+h5LZzSr2JjY2FxWKByWTCuHHj8OWXX3bZ/tVXXyEyMlKl6jzT9X3eobi4GCNHjkRaWpqKlXmujj5XFAVmsxlarbbLdq1WC6vVqlJ1N6i3tBKSj5QuXbokJk2aJNauXSu++OILcfr0abFhwwYREhIi0tLShBBCbN68Wej1erFq1Spx8uRJUVxcLHQ6HeeUBqk/fS6EEE1NTSIgIED84Q9/ULFaz9CfPp84caKIi4sTFRUV4vTp06KkpER4e3uLFStWqFx97+CJp28mk0ksXLhQJCYmisDAQOHj4yOio6PFM888I+rr6zvblZSUiB//+MfC29tb3H777eJvf/ubilW7t/72+Zo1a4RWqxU1NTUqVusZ+tPnFy9eFLm5uSIsLEx4e3uLMWPGiNdee01YrVaVq+9dX6Gk2Lb3LDExUVRVVbls1EZEQ4OiKNVCiMSetnncnBIRuTeGEhFJhaFERFJhKBGRVBhKRCQVhhIRSYWhRERSYSgRkVQYSkQkFYYSkST+85//ICcnB6NGjUJCQgJSUlKwefNmAMCuXbuQlJSEmJgYxMTEoLi4uMt7LRYLgoOD8etf/7rL7++55x64210ZDCUiCQgh8OCDD+Luu+/G6dOnUV1djXfeeQfnz5/Ht99+i5ycHPz5z3/GiRMnsGvXLqxatQp///vfO99fXl6O0aNHY+PGjejr1jF3wFAiksCOHTtgMBgwe/bszt9FRkbiqaeeQlFREXJzc3HnnXcCAIYNG4YlS5Zg8eLFnW3Xr1+PefPmYcSIEdi7d6/L63ckhhKRBI4ePdoZOj1tS0hI6PK7xMREHD16FABgMpmwfft2ZGRkIDs7G+vXr3d6vc7EUCKS0Ny5c/GTn/wE48aNs9v2gw8+wKRJk+Dj44PMzExs2bIF7e3tLqjSORhKRBKIi4vDgQMHOn8uKirCJ598grq6OsTGxqK6urpL++rqasTFxQGwnbpt374dUVFRSEhIQH19PXbs2OHS+h2JoUQkgcmTJ8NkMuGNN97o/F1zczMA26jpzTffxKFDhwAA9fX1eO6557BgwQJcuXIFn332Gc6ePYszZ87gzJkzKCoqcutTOIYSkQQURcGWLVtQWVmJkSNHIikpCTNmzMCrr76K0NBQrFu3Dvn5+YiJicH48eORl5eHjIwMbN68GZMnT4aXl1fnvqZOnYqysjK0trYCAB544AFEREQgIiICWVlZan3EfuPKk0Tkclx5kojcBkOJiKTCUCIiqTCUiEgqDCUikgpDiYikwlAiIqkwlIhIKgwlIpIKQ4mIpMJQIiKpMJSISCoMJSKSCkOJiKTCUCIiqTCUiEgqDCUikgpDiYikwlAiIqkwlIhIKgwlIpIKQ4mIpMJQIiKpMJSISCoMJSKSCkOJiKTS52O7FUWpA/CN68ohoiEiUggR3NOGPkOJiMjVePpGRFJhKBGRVBhKRCQVhhIRSYWhRERS+X9kfBV3sq0wKgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig = plt.figure(figsize=(5, 5))\n", + "ax = plt.gca()\n", + "ax.set_xlim(0, 3)\n", + "ax.set_ylim(0, 3)\n", + "\n", + "# plt.plot([1, 1], [0, 1], color='red', linewidth=2)\n", + "# plt.plot([1, 2], [2, 2], color='red', linewidth=2)\n", + "# plt.plot([2, 2], [2, 1], color='red', linewidth=2)\n", + "# plt.plot([2, 3], [1, 1], color='red', linewidth=2)\n", + "\n", + "plt.plot([2, 3], [1, 1], color='red', linewidth=2)\n", + "plt.plot([0, 1], [1, 1], color='red', linewidth=2)\n", + "plt.plot([1, 1], [1, 2], color='red', linewidth=2)\n", + "plt.plot([1, 2], [2, 2], color='red', linewidth=2)\n", + "\n", + "plt.text(0.5, 2.5, 'S0', size=14, ha='center')\n", + "plt.text(1.5, 2.5, 'S1', size=14, ha='center')\n", + "plt.text(2.5, 2.5, 'S2', size=14, ha='center')\n", + "plt.text(0.5, 1.5, 'S3', size=14, ha='center')\n", + "plt.text(1.5, 1.5, 'S4', size=14, ha='center')\n", + "plt.text(2.5, 1.5, 'S5', size=14, ha='center')\n", + "plt.text(0.5, 0.5, 'S6', size=14, ha='center')\n", + "plt.text(1.5, 0.5, 'S7', size=14, ha='center')\n", + "plt.text(2.5, 0.5, 'S8', size=14, ha='center')\n", + "plt.text(0.5, 2.3, 'START', ha='center')\n", + "plt.text(2.5, 0.3, 'GOAL', ha='center')\n", + "# plt.axis('off')\n", + "plt.tick_params(axis='both', which='both', \n", + " bottom=False, top=False, \n", + " right=False, left=False,\n", + " labelbottom=False, labelleft=False\n", + " )\n", + "line, = ax.plot([0.5], [2.5], marker='o', color='g', markersize=60)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "ExecuteTime": { + "end_time": "2022-12-11T10:39:57.536092Z", + "start_time": "2022-12-11T10:39:57.527088Z" + } + }, + "outputs": [], + "source": [ + "from matplotlib import animation\n", + "from IPython.display import HTML" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "ExecuteTime": { + "end_time": "2022-12-11T10:40:06.654608Z", + "start_time": "2022-12-11T10:40:06.650476Z" + } + }, + "outputs": [], + "source": [ + "def init():\n", + " line.set_data([], [])\n", + " return (line, )\n", + "def animate(i):\n", + " state = state_history[i]\n", + " x = (state % 3)+0.5\n", + " y = 2.5 - int(state/3)\n", + " line.set_data(x, y)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "ExecuteTime": { + "end_time": "2022-12-11T10:40:13.195375Z", + "start_time": "2022-12-11T10:40:13.192239Z" + } + }, + "outputs": [], + "source": [ + "anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(state_history), interval=200, repeat=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "ExecuteTime": { + "end_time": "2022-12-11T10:40:30.568860Z", + "start_time": "2022-12-11T10:40:23.429183Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + " \n", + "
\n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "HTML(anim.to_jshtml())" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "ExecuteTime": { + "end_time": "2022-12-11T10:41:18.663707Z", + "start_time": "2022-12-11T10:41:13.036491Z" + } + }, + "outputs": [], + "source": [ + "anim.save('save/maze_2.mp4')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/rl/tutorials/save/maze.mp4 b/rl/tutorials/save/maze.mp4 new file mode 100644 index 0000000..81cb67e Binary files /dev/null and b/rl/tutorials/save/maze.mp4 differ diff --git a/rl/tutorials/save/maze_0.mp4 b/rl/tutorials/save/maze_0.mp4 new file mode 100644 index 0000000..0e791b2 Binary files /dev/null and b/rl/tutorials/save/maze_0.mp4 differ diff --git a/rl/tutorials/save/maze_2.mp4 b/rl/tutorials/save/maze_2.mp4 new file mode 100644 index 0000000..bd35279 Binary files /dev/null and b/rl/tutorials/save/maze_2.mp4 differ -- cgit v1.2.3