{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:09:34.179156Z", "start_time": "2023-02-20T15:09:33.559589Z" } }, "outputs": [], "source": [ "import numpy as np\n", "import gym\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:09:36.809994Z", "start_time": "2023-02-20T15:09:36.373146Z" }, "scrolled": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAASUAAAEeCAYAAADM2gMZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAanklEQVR4nO3df1RUZf4H8PedX/waCL5JyI8EXVYJvrEWyAHtZAqtVIe1JNwDW4kUq6uVJ+nYuuvWVtseM7GjK99Wzipt2lpqqdBWKyZh/soFNV1/lK6ZirYgScqPgRnm+f4xwgoBAzgz95nh/TpnTgv3mXs/80Tvfe4z9z5XEUKAiEgWGrULICK6HkOJiKTCUCIiqTCUiEgqDCUikgpDiYikoutr47Bhw0RUVJSLSiGioaK6uvqSECK4p219hlJUVBSqqqqcUxURDVmKonzT2zaevhGRVBhKRCQVhhIRSYWhRERSYSgRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVPpcJUBmQgjUXK1B9YVq7K/Zj8pvKnGs7hhaLC2wWC1ot7ZDq9FCp9HBR+eD2OBYTIyciKTwJCSEJSDcPxyKoqj9MYioG7cKJauw4pPTn2DZvmXYfXY3LFYL9Fo9GtsaYRXWH7S3WC2wWC0wWUzYfW439p7fC6PBiLb2Nug1ekwYMQHzk+cjdVQqNAoHjUQycItQutxyGWsOrkHh3kJcbbuKxrbGzm0tlpZ+78cqrLjSegUAYIIJH5/6GLvO7oK/wR8FKQXIuyMPQT5BDq+fiPpP6ethlImJiULNRd7OXzmPBeULsPnEZmgUDZrNzU47lq/eF1ZhxbSYaXj13lcRERDhtGMRDXWKolQLIRJ72iblOYsQAqsPrkbMyhhsPLoRJovJqYEEAM3mZpgsJmw4ugExK2Ow+uBq8OnBRK4nXSjVXKnBpL9OwryP5qHJ3ASLsLj0+BZhQZO5CfM+modJf52Emis1Lj0+0VAnVSiVHCpBzMoY7D63G03mJlVraTI3Yfe53YgpikHJoRJVayEaSqQIJSEEnvn4GTz54ZNoNDfCYnXt6Kg3FqsFjW2NePLDJzH/H/N5OkfkAqqHUru1HblbclF8oNjp80aD1WxuxqrqVZi5dSbare1ql0Pk0VS9JEAIgbytedh0fJO0gdSh2dyMjcc2AgBKppbwwksiJ1F1pDT/H/Px3vH3pA+kDh3BVLCtQO1SiDyWaqFUcqgExQeKVZ/QHqiOUzlOfhM5hyqhVHOlBk9/+LTbjJC6azY34+mPnublAkRO4PJQEkIg5/0cmNpNrj60Q7VaWvGL93/Bb+SIHMzlobTm0BpUX6iW5mv/wTJbzai6UMXTOCIHc2konb9yvvNKbU/QZG7CvI/n8TSOyIFcGkoLyheg1dLqykM6ncliwoLyBWqXQeQxXBZKl1suY/OJzS6/l83ZLFYL3j/xPi63XFa7FCKP4LJQWnNwjccupKZRNJxbInIQl6SEVVhRuLfQbS8BsKfZ3IzCPYU9rn5JRAPjklD65PQnuNp21fE7bgLwAYDXAbwM4DUAfwXw72vbBYAKAEsB/AFACYBax5cBAFfarmDH1zucs3OJ1NXVYc6cOYiKioKXlxdCQkKQmpqK8vJyAMD777+PKVOmIDg4GIqi4NNPP1W3YA/QV5+bzWY899xziI+Ph5+fH0JDQ5GTk4OzZ8+qXfagueTet2X7lnVZwtZh3gVgBjAVwP/AFlJnAHQMyHYD2AvgQQA3A6gE8BaApwB4ObaUxrZGFO4tRNqoNMfuWDKZmZlobm7G6tWrER0djdraWlRWVqK+vh4A0NTUhPHjx+ORRx7BY489pnK1nqGvPm9ubsaBAwfw29/+FmPHjsX333+PgoICpKen4/Dhw9Dp3GLF6y6cvhyuEAI3Lb7J8SOlFgCvAngUwI96OjCAQgBJAO6+9jszbKOpnwLocSHOGxPgFYCG5xo89mbdhoYGBAUFoby8HGlpfYfvpUuXEBwcjIqKCtxzzz2uKdADDaTPOxw7dgxxcXE4fPgwbr/9didXODiqLodbc7UGZqvZ8Ts2XHt9CVvYdHcZQCO6BpYeQCSAc44vBwDa2ttw4eoF5+xcAkajEUajEaWlpTCZ3PuKfHcxmD6/csX2cIygIPd8CIbTQ6n6QjUMWoPjd6yF7bTsMIDFAP4C4B8Azl/b3nG26NftfX7XbXMwg9aA6ovVztm5BHQ6Hd58802sW7cOgYGBSElJwbPPPovPP/9c7dI81kD7vK2tDQUFBcjIyEBEhHs+/MLpobS/Zr9z5pMAIBZAAYAcANGwjYD+AmCncw5nT1NbE/bX7Ffn4C6SmZmJCxcuoKysDPfddx/27NmD5ORk/PGPf1S7NI/V3z63WCx45JFH0NDQgJIS971ExelzSnetuQu7z+2+oX0MyFYAXwCYA2AlgHwA4ddtfxuAL4CHnHP4u0bchc9mfuacnUvqiSeewFtvvYXGxkYYDLZRMeeUnKt7n1ssFmRnZ+PIkSP49NNPMXz4cLVL7JOqc0rH6o45+xBdBQOwAjBee/37um1mAN8AuNV5h3f555VAbGwsLBYL55lc6Po+N5vN+PnPf47Dhw+joqJC+kCyx+nfFw7kCbYD0gxgA4A7AITA9hX/BdguAxgFwBtAMoDPAAyD7ZKAnbBNjjvxC4kWs5M+rwTq6+uRlZWFvLw8xMfHw9/fH1VVVViyZAlSU1MREBCA7777DmfPnkVDQwMA4NSpUwgMDMTw4cPd/j8WNdjrc19fXzz88MP45z//ibKyMiiKgm+//RYAcNNNN8HHx0flTzBwTg8lpy1RYgAQAeBzAN8BsAAIgC1wOi4BmADb6OhD2C4hiIDtEgIHX6N0Pad80ygJo9GI5ORkLF++HKdOnUJrayvCw8ORk5ODRYsWAQBKS0sxc+bMzvfk5+cDAF544QX8/ve/V6Nst2avz8+fP4+tW7cCABISErq8t6SkBLm5uSpUfWOcPqekeVEDgaGzEJoCBdYXeLsJUV9UnVPSarTOPoRUhtrnJXI0p4eSTuN+l7nfCL1Gr3YJRG7N6aHko3O/ibYb4aMfWp+XyNGcHkqxwbHOPoRUhtrnJXI0p4fSxMiJHru4W3daRYuJkRPVLoPIrTk9LZLCk2A0GJ19GCn4GfyQFJ6kdhlEbs3poZQQloC29jZnH0YKbe1tSAhNsN+QiHrl9FAK9w8fMt9IGbQGhPmHqV0GkVtzeigpioIJIyY4+zBSGH/reI9d4I3IVVwyAz0/eb7HzysZDUYUpBSoXQaR23PJlY2po1Lhb/Af3LpKOwEcAaBce/nAdh9bG2w35QZea/cAgBEA3oDtBtys6/axGbbVATrueZsC22JwR6/9XAvglmv/+w7YbuQdoACvAEweOXngbySiLlwSShpFg4KUAjz/6fMDe8zSOQBfAZgFW6VNANphu/H2awB7APziuvZ1sK3NfRa20Lp+wct7AcRde18ZgKfx3xt3XwHwq4F+qv/y1fuiIKVgyFz6QORMLvuvKO+OvIE/F+0qbAuydUSnH2yB1JsjAOJhW5f7RC9tIgBcGVgZ9liFFTPHzrTfkIjsclkoBfkE4aGYh6BTBjA4+xGA7wGsgO35bmfstD8K4H+vvf7VS5tTAGL6X4I9Oo0O02KmIcjHPRdpJ5KNS883lty7BF66ASxm5AXbqVsGbKOkjQAO9tK2BrZRVSBsi7xdxH+f/wYA5bCF23sA7hpg4X3w1nljyb1LHLdDoiHOpaEUERCB5fcth5+++yNG+qABMBLAJAD3AzjeS7t/AbgE29NylwNo7db2Xtjmke6FbR1vB/DT+2F5+nKEB4Tbb0xE/eLymdm8sXlIDEvs35ImlwDUX/fztwBu6qGdFbZTt18BeObaKxu2OabukmCbDD81oLJ/QK/RY1z4OM4lETmYy0NJURS8Pe1teGu97Tdug+3r/JUA/g+2b9fu6aHdWQD+6DoJHnmtffcH8yqwfet2gw9Y8dJ5Yd1D63ixJJGDOX053N6UHCrBkx8+ObBLBCThq/fFyvtXcpRENEiqLofbm5ljZ+KXd/4SvnpftUoYFD+9H2YlzGIgETmJqlf7LZuyDA/f9rDbBJOv3hcPxz6Mwp8Wql0KkcdSNZQURcGaqWuQFZslfTD56n2RFZuF1T9bzXkkIidS/b4IrUaLkqklmJUwS9pg8tX7YnbCbJRMLeHTSoicTPVQAmwjpmVTlmHl/SthNBileQKKXqOH0WDEyvtXonBKIUdIRC4gRSh1mDl2Jk7MPYEJt04Y2AWWTuCn98P4W8fjxNwTnNQmciGpQgkAwgPCUTGjAivuW2EbNQ3kXjkH0Gl0MBqMWHHfClTMqODV2kQuJl0oAbbTubw78nB87nFMj5sOb503fHXOnW/y1fnCW+eN6bHTcWLuCeTdkcfTNSIVyDF504uIgAi8nfk2LrdcRsmhEizdsxRX264ObrG4XhgNRgQYAlAwvgAzx87k3f5EKlPtiu7BsAordny9A4V7C7Hn3B60tbfBoDWgsa2xX2s1aRQNjAZj5/vG3zoeBSkFmDxyMhdoI3Khvq7olnqk1J1G0SBtVBrSRqVBCIELVy+g+mI19tfsR+U3lThWdwwt5haYrWa0W9uh1Wih1+jho/dBbHAsJkZORFJ4EhJCExDmH8bTMyIJuVUoXU9RFIQHhCM8IBw/G/MztcshIgfhOQsRSYWhRERSYSgRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVBhKRCQVhhIRScVtVwnwSFxKRT19rCtGrsWREhFJhSMlGfH/tWkI40iJiKTCUCIiqTCUiEgqDCUikgpDiYikwlAiIqkwlIhIKgwlIpIKQ4mIpMJQIiKpMJSISCoMJSKSCkOJiKTCUCIiqTCUiEgqDCUikgpDiYikwlAiIqkwlIhIKgwlIpIKQ4mIpMJQIiKpMJSISCoMJSKSCkOJiKTCUCIiqTCUiEgqDCUikgpDiYikwlAiIqkwlIhIKgwlIpIKQ4mIpOLWoVRXV4c5c+YgKioKXl5eCAkJQWpqKsrLywEAv/vd7xATEwM/Pz8EBQUhNTUVe/bsUblq92avz683a9YsKIqCpUuXqlCp57DX57m5uVAUpcsrOTlZ5aoHT6d2ATciMzMTzc3NWL16NaKjo1FbW4vKykrU19cDAMaMGYOioiKMHDkSLS0teP3115Geno6TJ08iJCRE5erdk70+77Bp0ybs378fYWFhKlXqOfrT52lpaVi7dm3nzwaDQY1SHUMI0esrISFByOry5csCgCgvL+/3e77//nsBQHz88cdOrOwGALaXpPrb52fOnBFhYWHi2LFjIjIyUrz22msuqtDz9KfPZ8yYIR544AEXVnXjAFSJXnLHbU/fjEYjjEYjSktLYTKZ7LZva2tDcXExAgICMHbsWBdU6Hn60+cWiwXZ2dlYtGgRbrvtNhdX6Hn6+3e+a9cu3HLLLRg9ejTy8/NRW1vrwiodrLe0EpKPlIQQYtOmTSIoKEh4eXmJ5ORkUVBQIPbt29elTVlZmfDz8xOKooiwsDDx+eefq1RtP0g+UhLCfp//5je/ERkZGZ0/c6R04+z1+fr168XWrVvF4cOHRWlpqYiPjxdxcXHCZDKpWHXf0MdIya1DSQghWlpaxLZt28SLL74oUlJSBADxyiuvdG5vbGwUJ0+eFHv37hV5eXkiMjJSXLhwQcWK++AGoSRE731eUVEhwsLCRG1tbWdbhpJj2Ps7v15NTY3Q6XTivffec3GV/efRodTd448/LvR6vWhtbe1xe3R0tHjppZdcXFU/uUkoddfR5wsXLhSKogitVtv5AiA0Go0IDw9Xu0yPYu/vPCoqSixevNjFVfVfX6Hk1t++9SQ2NhYWiwUmk6nHbyCsVitaW1tVqMxzdfT57NmzkZOT02XblClTkJ2djfz8fJWq80x9/Z1funQJNTU1CA0NVam6G+O2oVRfX4+srCzk5eUhPj4e/v7+qKqqwpIlS5CamgoAWLRoETIyMhAaGoq6ujoUFRXh/PnzmD59usrVuyd7fT5ixIgfvEev12P48OEYM2aMChW7P3t9rtFo8OyzzyIzMxOhoaE4c+YMFi5ciFtuuQUPPfSQ2uUPituGktFoRHJyMpYvX45Tp06htbUV4eHhyMnJwaJFi6DT6XD06FGsWbMG9fX1uPnmmzFu3Djs3LkT8fHxapfvluz1OTmevT7XarU4cuQI3nrrLTQ0NCA0NBSTJk3Chg0b4O/vr3b5g6LYTu96lpiYKKqqqlxYzhCnKLZ/9vHvhMgTKIpSLYRI7Gmb216nRESeiaFERFJhKBGRVBhKRCQVhhIRSYWhRERSYSgRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVBhKRCQVhhIRSYWhRERSYSgRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVBhKRCQVhhIRSYWhRERSYSgRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVBhKRCQVhhIRSYWhRERSYSgRkVR0fW6trgYUxUWlEKmAf9/S4UiJiKTS90gpIQGoqnJRKUQqEELtCoamPkaoHCkRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVBhKRCQVhhIRSYWhRERSYSgRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVBhKRCQVhhIRSYWhRERSYSgRkVQYSkQkFYYSEUmFoUREUmEoEZFUGEpEJBWGEhFJhaFERFJhKBGRVBhKRCQVtw6luro6zJkzB1FRUfDy8kJISAhSU1NRXl7e2earr77CtGnTEBgYCF9fX9x55504fvy4ilW7N3t9rihKj6+5c+eqXLn7stfnjY2NeOqppxAREQEfHx+MGTMGr7/+uspVD55O7QJuRGZmJpqbm7F69WpER0ejtrYWlZWVqK+vBwB8/fXXmDBhAh577DHs2LEDgYGBOHHiBIxGo8qVuy97fX7x4sUu7auqqpCRkYHp06erUa5HsNfn8+fPx/bt27F27VqMHDkSO3fuRH5+PoYNG4ZHH31U5eoHQQjR6yshIUHI6vLlywKAKC8v77VNdna2yMnJcWFVnq0/fd7dE088IUaPHu3Eqjxbf/o8Li5OPP/8811+d/fdd4u5c+c6u7xBA1Aleskdtz19MxqNMBqNKC0thclk+sF2q9WKsrIyxMbGIj09HcHBwRg3bhzeffddFar1DPb6vLvGxka88847yM/Pd0F1nqk/fX7XXXehrKwM586dAwDs2bMHhw4dQnp6uitLdZze0kpIPlISQohNmzaJoKAg4eXlJZKTk0VBQYHYt2+fEEKIixcvCgDC19dXFBYWioMHD4rCwkKh1WrFBx98oHLl7quvPu9u1apVwmAwiNraWhdX6Vns9Xlra6vIzc0VAIROpxM6nU688cYbKlZsH/oYKbl1KAkhREtLi9i2bZt48cUXRUpKigAgXnnlFVFTUyMAiOzs7C7ts7OzRXp6ukrVeobe+ry7xMREkZWVpUKFnqevPl+6dKkYPXq0KC0tFV988YX405/+JPz8/MRHH32kctW98+hQ6u7xxx8Xer1etLa2Cp1OJ15++eUu21966SURGxurUnWe6fo+73Dw4EEBQGzbtk3FyjxXR583NDQIvV4vtmzZ8oPtqampKlVnX1+h5LZzSr2JjY2FxWKByWTCuHHj8OWXX3bZ/tVXXyEyMlKl6jzT9X3eobi4GCNHjkRaWpqKlXmujj5XFAVmsxlarbbLdq1WC6vVqlJ1N6i3tBKSj5QuXbokJk2aJNauXSu++OILcfr0abFhwwYREhIi0tLShBBCbN68Wej1erFq1Spx8uRJUVxcLHQ6HeeUBqk/fS6EEE1NTSIgIED84Q9/ULFaz9CfPp84caKIi4sTFRUV4vTp06KkpER4e3uLFStWqFx97+CJp28mk0ksXLhQJCYmisDAQOHj4yOio6PFM888I+rr6zvblZSUiB//+MfC29tb3H777eJvf/ubilW7t/72+Zo1a4RWqxU1NTUqVusZ+tPnFy9eFLm5uSIsLEx4e3uLMWPGiNdee01YrVaVq+9dX6Gk2Lb3LDExUVRVVbls1EZEQ4OiKNVCiMSetnncnBIRuTeGEhFJhaFERFJhKBGRVBhKRCQVhhIRSYWhRERSYSgRkVQYSkQkFYYSkST+85//ICcnB6NGjUJCQgJSUlKwefNmAMCuXbuQlJSEmJgYxMTEoLi4uMt7LRYLgoOD8etf/7rL7++55x64210ZDCUiCQgh8OCDD+Luu+/G6dOnUV1djXfeeQfnz5/Ht99+i5ycHPz5z3/GiRMnsGvXLqxatQp///vfO99fXl6O0aNHY+PGjejr1jF3wFAiksCOHTtgMBgwe/bszt9FRkbiqaeeQlFREXJzc3HnnXcCAIYNG4YlS5Zg8eLFnW3Xr1+PefPmYcSIEdi7d6/L63ckhhKRBI4ePdoZOj1tS0hI6PK7xMREHD16FABgMpmwfft2ZGRkIDs7G+vXr3d6vc7EUCKS0Ny5c/GTn/wE48aNs9v2gw8+wKRJk+Dj44PMzExs2bIF7e3tLqjSORhKRBKIi4vDgQMHOn8uKirCJ598grq6OsTGxqK6urpL++rqasTFxQGwnbpt374dUVFRSEhIQH19PXbs2OHS+h2JoUQkgcmTJ8NkMuGNN97o/F1zczMA26jpzTffxKFDhwAA9fX1eO6557BgwQJcuXIFn332Gc6ePYszZ87gzJkzKCoqcutTOIYSkQQURcGWLVtQWVmJkSNHIikpCTNmzMCrr76K0NBQrFu3Dvn5+YiJicH48eORl5eHjIwMbN68GZMnT4aXl1fnvqZOnYqysjK0trYCAB544AFEREQgIiICWVlZan3EfuPKk0Tkclx5kojcBkOJiKTCUCIiqTCUiEgqDCUikgpDiYikwlAiIqkwlIhIKgwlIpIKQ4mIpMJQIiKpMJSISCoMJSKSCkOJiKTCUCIiqTCUiEgqDCUikgpDiYikwlAiIqkwlIhIKgwlIpIKQ4mIpMJQIiKpMJSISCoMJSKSCkOJiKTS52O7FUWpA/CN68ohoiEiUggR3NOGPkOJiMjVePpGRFJhKBGRVBhKRCQVhhIRSYWhRERS+X9kfBV3sq0wKgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = plt.figure(figsize=(5, 5))\n", "ax = plt.gca()\n", "ax.set_xlim(0, 3)\n", "ax.set_ylim(0, 3)\n", "\n", "# plt.plot([1, 1], [0, 1], color='red', linewidth=2)\n", "# plt.plot([1, 2], [2, 2], color='red', linewidth=2)\n", "# plt.plot([2, 2], [2, 1], color='red', linewidth=2)\n", "# plt.plot([2, 3], [1, 1], color='red', linewidth=2)\n", "\n", "plt.plot([2, 3], [1, 1], color='red', linewidth=2)\n", "plt.plot([0, 1], [1, 1], color='red', linewidth=2)\n", "plt.plot([1, 1], [1, 2], color='red', linewidth=2)\n", "plt.plot([1, 2], [2, 2], color='red', linewidth=2)\n", "\n", "plt.text(0.5, 2.5, 'S0', size=14, ha='center')\n", "plt.text(1.5, 2.5, 'S1', size=14, ha='center')\n", "plt.text(2.5, 2.5, 'S2', size=14, ha='center')\n", "plt.text(0.5, 1.5, 'S3', size=14, ha='center')\n", "plt.text(1.5, 1.5, 'S4', size=14, ha='center')\n", "plt.text(2.5, 1.5, 'S5', size=14, ha='center')\n", "plt.text(0.5, 0.5, 'S6', size=14, ha='center')\n", "plt.text(1.5, 0.5, 'S7', size=14, ha='center')\n", "plt.text(2.5, 0.5, 'S8', size=14, ha='center')\n", "plt.text(0.5, 2.3, 'START', ha='center')\n", "plt.text(2.5, 0.3, 'GOAL', ha='center')\n", "# plt.axis('off')\n", "plt.tick_params(axis='both', which='both', \n", " bottom=False, top=False, \n", " right=False, left=False,\n", " labelbottom=False, labelleft=False\n", " )\n", "line, = ax.plot([0.5], [2.5], marker='o', color='g', markersize=60)" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2023-02-19T01:21:55.938956Z", "start_time": "2023-02-19T01:21:55.936601Z" } }, "source": [ "## basics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 基本概念及术语\n", " - algorithm\n", " - policy iteration\n", " - policy gradient\n", " - value iteration\n", " - sarsa: state, action, reward, state, action\n", " - Q-table:待学习(learning/iteration update)\n", " - row index: state; column index: action;\n", " - 不是概率分布,是 value\n", " - q-learning\n", " - reward\n", " - 特定时间 $t$ 给到的奖励 $R_t$ 称为即时奖励(immediate reward)\n", " - 未来的总奖励 $G_t$\n", " - $G_t=R_{t+1}+R_{t+2}+R_{t+3}+\\cdots$\n", " - $G_t=R_{t+1}+\\gamma R_{t+2}+\\gamma^2R_{t+3} + \\cdots + \\gamma^kR_{t+k+1}\\cdots$\n", " - 举例\n", " - $Q_\\pi(s=7,a=1)=R_{t+1}=1$\n", " - $Q_\\pi(s=7,a=0)=\\gamma^2$\n", " - action value,state value\n", " - bellman equation\n", " - 适用于状态价值函数(state value function),也适用于动作价值函数(action value function)\n", " - mdp:markov decision process\n", " - 马尔可夫性\n", " - $p(s_{t+1}|s_t)=p(s_{t+1}|{s_1,s_2,s_3,\\cdots,s_t})$\n", " - bellman equation 成立的前提条件" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2023-02-19T01:22:33.918570Z", "start_time": "2023-02-19T01:22:33.812489Z" } }, "source": [ "- $R_t$\n", "- $Q_{\\pi}(s,a)$:state action value function\n", " - Q table\n", " - 通过 Sarsa 算法迭代更新 $Q_{\\pi}(s,a)$\n", "- 对于强化学习而言\n", " - state, action, reward 都是需要精心设计的" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sarsa(state action reward state action)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:20:32.328550Z", "start_time": "2023-02-20T15:20:32.322841Z" } }, "outputs": [], "source": [ "# border & barrier\n", "# ↑, →, ↓, ←(顺时针)\n", "# row index: given state\n", "# col index: posible action\n", "# (state, action) matrix\n", "# 跟环境对齐\n", "theta_0 = np.asarray([[np.nan, 1, 1, np.nan], # s0\n", " [np.nan, 1, np.nan, 1], # s1\n", " [np.nan, np.nan, 1, 1], # s2\n", " [1, np.nan, np.nan, np.nan], # s3 \n", " [np.nan, 1, 1, np.nan], # s4\n", " [1, np.nan, np.nan, 1], # s5\n", " [np.nan, 1, np.nan, np.nan], # s6 \n", " [1, 1, np.nan, 1]] # s7\n", " )" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:20:33.445898Z", "start_time": "2023-02-20T15:20:33.443279Z" } }, "outputs": [], "source": [ "n_states, n_actions = theta_0.shape" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:22:28.342788Z", "start_time": "2023-02-20T15:22:28.340225Z" } }, "outputs": [], "source": [ "# Q table, 状态是离散的(s0-s7),动作也是离散的(上右下左)\n", "Q = np.random.rand(n_states, n_actions) * theta_0" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:22:29.712906Z", "start_time": "2023-02-20T15:22:29.702001Z" } }, "outputs": [ { "data": { "text/plain": [ "array([[ nan, 0.05711349, 0.7487308 , nan],\n", " [ nan, 0.21055329, nan, 0.34236513],\n", " [ nan, nan, 0.58900823, 0.12580755],\n", " [0.19742577, nan, nan, nan],\n", " [ nan, 0.46973656, 0.03069312, nan],\n", " [0.90690458, nan, nan, 0.67280829],\n", " [ nan, 0.18346137, nan, nan],\n", " [0.40679226, 0.88075109, nan, 0.38732487]])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Q" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### $\\epsilon$-greedy (explore, exploit)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:23:51.690714Z", "start_time": "2023-02-20T15:23:51.684199Z" } }, "outputs": [], "source": [ "# 基于占比, 最 naive 的概率化方式\n", "def cvt_theta_0_to_pi(theta):\n", " m, n = theta.shape\n", " pi = np.zeros((m, n))\n", " for r in range(m):\n", " pi[r, :] = theta[r, :] / np.nansum(theta[r, :])\n", " return np.nan_to_num(pi)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:23:53.973666Z", "start_time": "2023-02-20T15:23:53.970002Z" } }, "outputs": [], "source": [ "pi_0 = cvt_theta_0_to_pi(theta_0)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:23:59.386828Z", "start_time": "2023-02-20T15:23:59.382946Z" }, "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "array([[0. , 0.5 , 0.5 , 0. ],\n", " [0. , 0.5 , 0. , 0.5 ],\n", " [0. , 0. , 0.5 , 0.5 ],\n", " [1. , 0. , 0. , 0. ],\n", " [0. , 0.5 , 0.5 , 0. ],\n", " [0.5 , 0. , 0. , 0.5 ],\n", " [0. , 1. , 0. , 0. ],\n", " [0.33333333, 0.33333333, 0. , 0.33333333]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pi_0" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:24:58.768881Z", "start_time": "2023-02-20T15:24:58.763948Z" } }, "outputs": [ { "data": { "text/plain": [ "array([[ nan, 0.05711349, 0.7487308 , nan],\n", " [ nan, 0.21055329, nan, 0.34236513],\n", " [ nan, nan, 0.58900823, 0.12580755],\n", " [0.19742577, nan, nan, nan],\n", " [ nan, 0.46973656, 0.03069312, nan],\n", " [0.90690458, nan, nan, 0.67280829],\n", " [ nan, 0.18346137, nan, nan],\n", " [0.40679226, 0.88075109, nan, 0.38732487]])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Q" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2023-02-19T01:33:27.896249Z", "start_time": "2023-02-19T01:33:27.893096Z" } }, "outputs": [], "source": [ "# epsilon-greedy\n", "def get_action(s, Q, eps, pi_0):\n", " action_space = list(range(4))\n", " # eps, explore\n", " if np.random.rand() < eps:\n", " action = np.random.choice(action_space, p=pi_0[s, :])\n", " else:\n", " # 1-eps, exploit\n", " action = np.nanargmax(Q[s, :])\n", " return action" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sarsa (update $Q_\\pi(s,a)$)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "理想情况下:\n", "$$\n", "Q(s_t,a_t) = R_{t+1}+\\gamma Q(s_{t+1}, a_{t+1})\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- td (temporal difference error)\n", " - $R_{t+1}+\\gamma Q(s_{t+1},a_{t+1})-Q(s_t,a_t)$\n", "- final update equation\n", " - $Q(s_t,a_t)=Q(s_t,a_t)+\\eta\\cdot(R_{t+1}+\\gamma Q(s_{t+1},a_{t+1})-Q(s_t,a_t))$\n", " - $s_t,a_t,r_{t+1},s_{t+1},a_{t+1}$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 折扣(discount factor,$\\gamma$)\n", " - 有助于缩短步数(更快地结束任务);" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2023-02-19T02:15:30.250510Z", "start_time": "2023-02-19T02:15:30.247289Z" } }, "outputs": [], "source": [ "def sarsa(s, a, r, s_next, a_next, Q, eta, gamma):\n", " if s_next == 8:\n", " Q[s, a] = Q[s, a] + eta * (r - Q[s, a])\n", " else:\n", " Q[s, a] = Q[s, a] + eta * (r + gamma * Q[s_next, a_next] - Q[s, a])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 解决 maze 问题" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:28:38.124516Z", "start_time": "2023-02-20T15:28:38.116963Z" } }, "outputs": [], "source": [ "# 维护着状态,以及 step 函数的返回\n", "class MazeEnv(gym.Env):\n", " def __init__(self):\n", " self.state = 0\n", " pass\n", " \n", " def reset(self):\n", " self.state = 0\n", " return self.state\n", " \n", " def step(self, action):\n", " if action == 0:\n", " self.state -= 3\n", " elif action == 1:\n", " self.state += 1\n", " elif action == 2:\n", " self.state += 3\n", " elif action == 3:\n", " self.state -= 1\n", " done = False\n", " reward = 0\n", " if self.state == 8:\n", " done = True\n", " reward = 1\n", " # state, reward, done, _\n", " return self.state, reward, done, {}" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:33:26.797516Z", "start_time": "2023-02-20T15:33:26.783342Z" } }, "outputs": [], "source": [ "# 动作策略选择,基于当前环境的状态\n", "class Agent:\n", " def __init__(self):\n", " self.action_space = list(range(4))\n", " self.theta_0 = np.asarray([[np.nan, 1, 1, np.nan], # s0\n", " [np.nan, 1, np.nan, 1], # s1\n", " [np.nan, np.nan, 1, 1], # s2\n", " [1, np.nan, np.nan, np.nan], # s3 \n", " [np.nan, 1, 1, np.nan], # s4\n", " [1, np.nan, np.nan, 1], # s5\n", " [np.nan, 1, np.nan, np.nan], # s6 \n", " [1, 1, np.nan, 1]] # s7\n", " )\n", " self.pi = self._cvt_theta_to_pi()\n", "# self.pi = self._softmax_cvt_theta_to_pi()\n", "# self.theta = self.theta_0\n", "\n", " self.Q = np.random.rand(*self.theta_0.shape) * self.theta_0\n", " self.eta = 0.1\n", " self.gamma = 0.9\n", " self.eps = 0.5\n", " \n", " def _cvt_theta_to_pi(self):\n", " m, n = self.theta_0.shape\n", " pi = np.zeros((m, n))\n", " for r in range(m):\n", " pi[r, :] = self.theta_0[r, :] / np.nansum(self.theta_0[r, :])\n", " return np.nan_to_num(pi)\n", " \n", "# def _softmax_cvt_theta_to_pi(self, beta=1.):\n", "# m, n = self.theta.shape\n", "# pi = np.zeros((m, n))\n", "# exp_theta = np.exp(self.theta*beta)\n", "# for r in range(m):\n", "# pi[r, :] = exp_theta[r, :] / np.nansum(exp_theta[r, :])\n", "# return np.nan_to_num(pi)\n", " \n", " def get_action(self, s):\n", " # eps, explore\n", " if np.random.rand() < self.eps:\n", " action = np.random.choice(self.action_space, p=self.pi[s, :])\n", " else:\n", " # 1-eps, exploit\n", " action = np.nanargmax(self.Q[s, :])\n", " return action\n", " \n", " def sarsa(self, s, a, r, s_next, a_next):\n", " if s_next == 8:\n", " self.Q[s, a] = self.Q[s, a] + self.eta * (r - self.Q[s, a])\n", " else:\n", " self.Q[s, a] = self.Q[s, a] + self.eta * (r + self.gamma * self.Q[s_next, a_next] - self.Q[s, a])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 完成训练及更新" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:38:08.370694Z", "start_time": "2023-02-20T15:38:08.296970Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1 0.48151692420235004 25\n", "2 0.788238427054541 51\n", "3 0.5193178707316495 73\n", "4 0.09357958369502378 23\n", "5 0.07459493697310315 7\n", "6 0.06891781188632101 7\n", "7 0.06585711141363843 7\n", "8 0.06265613413272358 7\n", "9 0.060100868881304614 7\n", "10 0.059287337144400654 7\n", "11 0.058488284106318955 7\n", "12 0.05750697273897609 7\n", "13 0.05634890213948618 7\n", "14 0.055022881405307755 7\n", "15 0.0535403609535367 7\n", "16 0.051914820609968326 7\n", "17 0.0501612179940748 7\n", "18 0.048295498355812816 7\n", "19 0.04633416520669903 7\n", "20 0.04509458086047741 7\n", "21 0.0443781660421414 7\n", "22 0.043566741048947755 7\n", "23 0.04266900608905749 7\n", "24 0.04169365931899821 7\n", "25 0.04064935278238613 7\n", "26 0.039544641494567434 7\n", "27 0.03877432598312053 7\n", "28 0.03836919489516444 7\n", "29 0.037932322993017575 7\n", "30 0.03746384259607333 7\n", "31 0.036964222794984625 7\n", "32 0.03643423557551101 7\n", "33 0.03587492225074168 7\n", "34 0.03528756047376702 7\n", "35 0.034673632089539 7\n", "36 0.034034792062826946 7\n", "37 0.03337283869172736 7\n", "38 0.032689685285438086 7\n", "39 0.031987333452767774 7\n", "40 0.03126784811552158 7\n", "41 0.030533334329512607 7\n", "42 0.02978591596630603 7\n", "43 0.02902771628138473 7\n", "44 0.02826084036961629 7\n", "45 0.0274873594868823 7\n", "46 0.026709297197557824 7\n", "47 0.02592861729120166 7\n", "48 0.025147213398248935 7\n", "49 0.024366900223515675 7\n", "50 0.023589406307802152 7\n", "51 0.022816368221570238 7\n", "52 0.022049326090399224 7\n", "53 0.021289720349440266 7\n", "54 0.020538889623210932 7\n", "55 0.019798069627555426 7\n", "56 0.01906839299226787 7\n", "57 0.018350889905525525 7\n", "58 0.01764648948474906 7\n", "59 0.01695602178262262 7\n", "60 0.016280220341618468 7\n", "61 0.01561972521537619 7\n", "62 0.014975086380526359 7\n", "63 0.01434676746795055 7\n", "64 0.013735149747947817 7\n", "65 0.013140536309220685 7\n", "66 0.012563156376991624 7\n", "67 0.012003169720803564 7\n", "68 0.011460671107647646 7\n", "69 0.01093569476093581 7\n", "70 0.010428218790478228 7\n", "71 0.009938169563017962 7\n", "72 0.009465425986993514 7\n", "73 0.009009823689056895 7\n", "74 0.00857115906343553 7\n", "75 0.00814919317852647 7\n", "76 0.007743655528121618 7\n", "77 0.0073542476174153215 7\n", "78 0.006980646376435584 7\n", "79 0.006622507395769972 7\n", "80 0.0062794679814730525 7\n", "81 0.005951150027797558 7\n", "82 0.00563716270797221 7\n", "83 0.005337104984597274 7\n", "84 0.005050567942414208 7\n", "85 0.004777136947210048 7\n", "86 0.004516393635464522 7\n", "87 0.004267917740051419 7\n", "88 0.004031288757878282 7\n", "89 0.0038060874658013555 7\n", "90 0.0035918972914945613 7\n", "91 0.0033883055462049505 7\n", "92 0.003194904526485076 7\n", "93 0.00301129249208143 7\n", "94 0.002837074527182959 7\n", "95 0.002671863292188603 7\n", "96 0.0025152796730716354 7\n", "97 0.002366953335290245 7\n", "98 0.002226523189017504 7\n", "99 0.002093637772283996 7\n", "100 0.001967955558394241 7\n", "101 0.0018491451937413528 7\n" ] } ], "source": [ "maze = MazeEnv()\n", "agent = Agent()\n", "epoch = 0\n", "while True:\n", " old_Q = np.nanmax(agent.Q, axis=1)\n", " s = maze.reset()\n", " a = agent.get_action(s)\n", " s_a_history = [[s, np.nan]]\n", " while True:\n", " # s, a \n", " s_a_history[-1][1] = a\n", " s_next, reward, done, _ = maze.step(a, )\n", " # s_next, a_next\n", " s_a_history.append([s_next, np.nan])\n", " if done:\n", " a_next = np.nan\n", " else:\n", " a_next = agent.get_action(s_next)\n", "# print(s, a, reward, s_next, a_next)\n", " agent.sarsa(s, a, reward, s_next, a_next)\n", "# print(agent.pi)\n", " if done:\n", " break\n", " else:\n", " a = a_next\n", " s = maze.state\n", "\n", " # s_s_history, agent.Q\n", " update = np.sum(np.abs(np.nanmax(agent.Q, axis=1) - old_Q))\n", " epoch +=1\n", " agent.eps /= 2\n", " print(epoch, update, len(s_a_history))\n", " if epoch > 100 or update < 1e-5:\n", " break\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:38:33.582285Z", "start_time": "2023-02-20T15:38:33.568692Z" } }, "outputs": [ { "data": { "text/plain": [ "array([[ nan, 0.5741703 , 0.28654803, nan],\n", " [ nan, 0.64891149, nan, 0.31090434],\n", " [ nan, nan, 0.72647909, 0.04574432],\n", " [0.24347619, nan, nan, nan],\n", " [ nan, 0.34181286, 0.89988623, nan],\n", " [0.21120543, nan, nan, 0.80934131],\n", " [ nan, 0.61078998, nan, nan],\n", " [0.19954603, 0.99999028, nan, 0.64883227]])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "agent.Q" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 可视化" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:38:57.716413Z", "start_time": "2023-02-20T15:38:57.705590Z" } }, "outputs": [], "source": [ "from matplotlib import animation\n", "from IPython.display import HTML" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:38:58.576025Z", "start_time": "2023-02-20T15:38:58.572588Z" } }, "outputs": [], "source": [ "def init():\n", " line.set_data([], [])\n", " return (line, )\n", "def animate(i):\n", " state = s_a_history[i][0]\n", " x = (state % 3)+0.5\n", " y = 2.5 - int(state/3)\n", " line.set_data(x, y)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:39:00.248954Z", "start_time": "2023-02-20T15:39:00.244740Z" } }, "outputs": [], "source": [ "anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(s_a_history), interval=200, repeat=False)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "ExecuteTime": { "end_time": "2023-02-20T15:39:06.778021Z", "start_time": "2023-02-20T15:39:06.064615Z" } }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "
\n", " \n", "
\n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "
\n", "
\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "HTML(anim.to_jshtml())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" }, "toc": { "base_numbering": 1.0, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 2 }