{ "cells": [ { "cell_type": "markdown", "id": "8d0665e6", "metadata": {}, "source": [ "## basics" ] }, { "cell_type": "markdown", "id": "7eeae6f6", "metadata": {}, "source": [ "- references\n", " - https://github.com/pytorch/examples/tree/main/reinforcement_learning\n", " - https://towardsdatascience.com/understanding-actor-critic-methods-931b97b6df3f\n", " - https://lilianweng.github.io/posts/2018-04-08-policy-gradient/" ] }, { "cell_type": "markdown", "id": "4f3f3454", "metadata": {}, "source": [ "### Policy gradient" ] }, { "cell_type": "markdown", "id": "23ecd402", "metadata": { "ExecuteTime": { "end_time": "2023-08-02T14:08:09.463801Z", "start_time": "2023-08-02T14:08:09.445731Z" } }, "source": [ "- REINFORCE: noisy gradients & high variance (of gradients)\n", " - update the policy parameter ($\\theta$) through Monte Carlo updates (i.e. taking random samples)\n", " - This introduces in inherent high variability in \n", " - log probabilities (log of the policy distribution): $\\log\\pi_\\theta(𝑎_𝑡|𝑠_𝑡)$\n", " - cumulative reward values: $G_t$\n", " - because each trajectories during training can deviate from each other at great degrees.\n", "- cumulative reward == 0\n", " - The essence of policy gradient is increasing the probabilities for “good” actions and decreasing those of “bad” actions in the policy distribution;\n", " - both “goods” and “bad” actions with will not be learned if the cumulative reward is 0.\n", "\n", "$$\n", "\\nabla_\\theta J(\\theta)=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)G_t]\n", "$$" ] }, { "cell_type": "markdown", "id": "a0cfc6b9", "metadata": {}, "source": [ "### introduce a baseline $b(s)$" ] }, { "cell_type": "markdown", "id": "5fd3c0a8", "metadata": {}, "source": [ "$$\n", "\\nabla_\\theta J(\\theta)=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)(G_t-b(s_t))]\n", "$$" ] }, { "cell_type": "markdown", "id": "bb95429f", "metadata": {}, "source": [ "### Actor Critic" ] }, { "cell_type": "markdown", "id": "49ba68fc", "metadata": {}, "source": [ "- AC\n", " - Actor: $\\pi(a|s)$\n", " - Critic: $Q(s, a)$\n", "- Critic\n", " - estimates the value function.\n", " - action-value: $Q$ value\n", " - state-value: $V$ value\n", " - average general action value at the given state\n", " - $Q_w(s_t,a_t)$ => Critic neural network,回归一个 value 值;\n", " - Q Actor Critic\n", "- Actor\n", " - The policy gradient method is also the “actor” part of Actor-Critic methods \n", "\n", "- both the Critic and Actor functions are parameterized with neural networks. " ] }, { "cell_type": "markdown", "id": "ddc3152b", "metadata": {}, "source": [ "\n", "$$\n", "\\begin{split}\n", "\\nabla_\\theta J(\\theta)&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)G_t]\\\\\n", "&=\\mathbb E_{s_0,a_0,\\cdots,s_t,a_t}[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)] \\mathbb E_{r_{t+1},s_{t+1},\\cdots,r_T,s_T}[G_t]\\\\\n", "&=\\mathbb E_{s_0,a_0,\\cdots,s_t,a_t}[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)] Q(s_t,a_t)\\\\\n", "&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) Q_w(s_t,a_t)]\n", "\\end{split}\n", "$$" ] }, { "cell_type": "markdown", "id": "f4aeb4c2", "metadata": {}, "source": [ "### subtract baseline" ] }, { "cell_type": "markdown", "id": "cc792fc3", "metadata": {}, "source": [ "$$\n", "A(s_t,a_t) = Q_w(s_t,a_t)-V_v(s_t)\n", "$$\n", "\n", "- using the V function as the baseline function, \n", "- we subtract the $Q$ value term with the $V$ value.\n", "- how much better it is to take a specific action compared to the average, general action at the given state. \n", " - **advantage value**" ] }, { "cell_type": "markdown", "id": "3a672c92", "metadata": {}, "source": [ "$$\n", "\\begin{split}\n", "&Q(s_t,a_t)=\\mathbb E[r_{t+1}+\\gamma V(s_{t+1})]\\\\\n", "&A(s_t,a_t)=r_{t+1}+\\gamma V_v(s_{t+1})-V_v(s_t)\n", "\\end{split}\n", "$$" ] }, { "cell_type": "markdown", "id": "f8261b90", "metadata": {}, "source": [ "### Advantage Actor Critic (A2C)" ] }, { "cell_type": "markdown", "id": "06839e87", "metadata": {}, "source": [ "$$\n", "\\begin{split}\n", "\\nabla_\\theta J(\\theta)&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) (Q_w(s_t,a_t)-V_v(s_t))]\\\\\n", "&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) A(s_t,a_t)]\\\\\n", "&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) \\left(r_{t+1}+\\gamma V_v(s_{t+1})-V_v(s_t)\\right)]\n", "\\end{split}\n", "$$" ] }, { "cell_type": "markdown", "id": "84e866dc", "metadata": {}, "source": [ "## implemention" ] }, { "cell_type": "code", "execution_count": 4, "id": "ca7c2b9b", "metadata": { "ExecuteTime": { "end_time": "2023-08-02T14:44:43.218346Z", "start_time": "2023-08-02T14:44:43.212181Z" } }, "outputs": [], "source": [ "#!pip install -U gym==0.15.3" ] }, { "cell_type": "code", "execution_count": 5, "id": "95577b00", "metadata": { "ExecuteTime": { "end_time": "2023-08-02T14:44:50.267749Z", "start_time": "2023-08-02T14:44:48.612105Z" } }, "outputs": [], "source": [ "import sys\n", "import torch \n", "import gym\n", "import numpy as np \n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "from torch.autograd import Variable\n", "import matplotlib.pyplot as plt\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 6, "id": "e5e4bf5c", "metadata": { "ExecuteTime": { "end_time": "2023-08-02T14:45:03.631947Z", "start_time": "2023-08-02T14:45:03.625967Z" } }, "outputs": [], "source": [ "# hyperparameters\n", "hidden_size = 256\n", "learning_rate = 3e-4\n", "\n", "# Constants\n", "GAMMA = 0.99\n", "num_steps = 300\n", "max_episodes = 3000" ] }, { "cell_type": "code", "execution_count": 33, "id": "7adb32c3", "metadata": { "ExecuteTime": { "end_time": "2023-08-02T15:02:01.173668Z", "start_time": "2023-08-02T15:02:01.157049Z" } }, "outputs": [ { "data": { "text/plain": [ "array([-0.01507673, 0.00588999, 0.00869466, 0.00153444])" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env = gym.make(\"CartPole-v0\")\n", "\n", "# 4-d 连续\n", "num_inputs = env.observation_space.shape[0]\n", "# 左右离散\n", "num_actions = env.action_space.n\n", "env.reset()" ] }, { "cell_type": "code", "execution_count": 37, "id": "920fe998", "metadata": { "ExecuteTime": { "end_time": "2023-08-02T15:09:09.477744Z", "start_time": "2023-08-02T15:09:09.464116Z" } }, "outputs": [], "source": [ "class ActorCritic(nn.Module):\n", " def __init__(self, num_inputs, num_actions, hidden_size):\n", " super(ActorCritic, self).__init__()\n", " self.num_actions = num_actions\n", " \n", " # num_inputs: shape of state\n", " # critic nn: state => value\n", " self.critic_ln1 = nn.Linear(num_inputs, hidden_size)\n", " self.critic_ln2 = nn.Linear(hidden_size, 1)\n", " \n", " # actor nn: state => action, policy\n", " self.actor_ln1 = nn.Linear(num_inputs, hidden_size)\n", " self.actor_ln2 = nn.Linear(hidden_size, num_actions)\n", " \n", " def forward(self, state):\n", " # (4, ) => (1, 4)\n", " # ndarray => Variable\n", " # state = Variable(torch.from_numpy(state).float().unsqueeze(0))\n", " state = torch.tensor(state, requires_grad=True, dtype=torch.float32).unsqueeze(0)\n", " \n", " # forward of critic network\n", " # (1, 4) => (1, 256)\n", " value = F.relu(self.critic_ln1(state))\n", " # (1, 256) => (1, 1)\n", " value = self.critic_ln2(value)\n", " \n", " # (1, 4) => (1, 256)\n", " policy_dist = F.relu(self.actor_ln1(state))\n", " # (1, 256) => (1, 2)\n", " policy_dist = F.softmax(self.actor_ln2(policy_dist), dim=1)\n", " return value, policy_dist" ] }, { "cell_type": "code", "execution_count": 38, "id": "b67cf056", "metadata": { "ExecuteTime": { "end_time": "2023-08-02T15:09:10.870382Z", "start_time": "2023-08-02T15:09:10.859615Z" } }, "outputs": [], "source": [ "ac = ActorCritic(num_inputs, num_actions, hidden_size, ) \n", "ac_opt = optim.Adam(ac.parameters(), lr=learning_rate)" ] }, { "cell_type": "code", "execution_count": 41, "id": "8821ac79", "metadata": { "ExecuteTime": { "end_time": "2023-08-02T15:18:44.585795Z", "start_time": "2023-08-02T15:18:44.541188Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 1]) tensor([[0.0335]], grad_fn=)\n", "torch.Size([1, 2]) tensor([[0.5095, 0.4905]], grad_fn=)\n", "() 0.03351228\n", "(1, 2) [[0.5095114 0.49048853]]\n" ] } ], "source": [ "for episode in range(max_episodes):\n", " \n", " # same length\n", " # index means timestamp: t\n", " log_probs = [] \n", " values = []\n", " rewards = []\n", " \n", " state = env.reset()\n", " \n", " for step in range(num_steps):\n", " value, policy_dist = ac(state)\n", " \n", " print(value.shape, value)\n", " print(policy_dist.shape, policy_dist)\n", " \n", " value = value.detach().numpy()[0, 0]\n", " dist = policy_dist.detach().numpy()\n", " \n", " print(value.shape, value)\n", " print(dist.shape, dist)\n", " \n", " action = np.random.choice(num_actions, p=np.squeeze(dist))\n", " log_prob = torch.log(policy_dist.squeeze(0)[action])\n", " \n", " new_state, reward, done, _ = env.step(action)\n", " \n", " rewards.append(reward)\n", " values.append(value)\n", " log_probs.append(log_prob)\n", " \n", " state = new_state\n", " \n", " if done or step == num_steps - 1:\n", " \n", " \n", " break\n", " break" ] }, { "cell_type": "code", "execution_count": null, "id": "ec50f7eb", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }