diff options
| author | lanchunhui <zch921005@126.com> | 2023-08-02 23:20:47 +0800 |
|---|---|---|
| committer | lanchunhui <zch921005@126.com> | 2023-08-02 23:20:47 +0800 |
| commit | 186a22521af4c1f56abfaf227d2e85ca03d24c5d (patch) | |
| tree | ff3fe32ec55cd260ac5da599a164869a9e286dd7 /rl/tutorials/actor_critic.ipynb | |
| parent | 5bff7e2bcc303bfa6caee9d0b95bc21540d4c279 (diff) | |
update: notes
Diffstat (limited to 'rl/tutorials/actor_critic.ipynb')
| -rw-r--r-- | rl/tutorials/actor_critic.ipynb | 390 |
1 files changed, 385 insertions, 5 deletions
diff --git a/rl/tutorials/actor_critic.ipynb b/rl/tutorials/actor_critic.ipynb index 0b357d7..1e19940 100644 --- a/rl/tutorials/actor_critic.ipynb +++ b/rl/tutorials/actor_critic.ipynb @@ -16,12 +16,392 @@ "- references\n", " - https://github.com/pytorch/examples/tree/main/reinforcement_learning\n", " - https://towardsdatascience.com/understanding-actor-critic-methods-931b97b6df3f\n", - " - https://lilianweng.github.io/posts/2018-04-08-policy-gradient/\n", - "- Actor - Critic\n", - " - Actor\n", - " - The policy gradient method is also the “actor” part of Actor-Critic methods \n", - " - Critic" + " - https://lilianweng.github.io/posts/2018-04-08-policy-gradient/" ] + }, + { + "cell_type": "markdown", + "id": "4f3f3454", + "metadata": {}, + "source": [ + "### Policy gradient" + ] + }, + { + "cell_type": "markdown", + "id": "23ecd402", + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-02T14:08:09.463801Z", + "start_time": "2023-08-02T14:08:09.445731Z" + } + }, + "source": [ + "- REINFORCE: noisy gradients & high variance (of gradients)\n", + " - update the policy parameter ($\\theta$) through Monte Carlo updates (i.e. taking random samples)\n", + " - This introduces in inherent high variability in \n", + " - log probabilities (log of the policy distribution): $\\log\\pi_\\theta(𝑎_𝑡|𝑠_𝑡)$\n", + " - cumulative reward values: $G_t$\n", + " - because each trajectories during training can deviate from each other at great degrees.\n", + "- cumulative reward == 0\n", + " - The essence of policy gradient is increasing the probabilities for “good” actions and decreasing those of “bad” actions in the policy distribution;\n", + " - both “goods” and “bad” actions with will not be learned if the cumulative reward is 0.\n", + "\n", + "$$\n", + "\\nabla_\\theta J(\\theta)=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)G_t]\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "a0cfc6b9", + "metadata": {}, + "source": [ + "### introduce a baseline $b(s)$" + ] + }, + { + "cell_type": "markdown", + "id": "5fd3c0a8", + "metadata": {}, + "source": [ + "$$\n", + "\\nabla_\\theta J(\\theta)=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)(G_t-b(s_t))]\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "bb95429f", + "metadata": {}, + "source": [ + "### Actor Critic" + ] + }, + { + "cell_type": "markdown", + "id": "49ba68fc", + "metadata": {}, + "source": [ + "- Critic\n", + " - estimates the value function.\n", + " - action-value: $Q$ value\n", + " - state-value: $V$ value\n", + " - average general action value at the given state\n", + " - $Q_w(s_t,a_t)$ => Critic neural network,回归一个 value 值;\n", + " - Q Actor Critic\n", + "- Actor\n", + " - The policy gradient method is also the “actor” part of Actor-Critic methods \n", + "\n", + "- both the Critic and Actor functions are parameterized with neural networks. " + ] + }, + { + "cell_type": "markdown", + "id": "ddc3152b", + "metadata": {}, + "source": [ + "\n", + "$$\n", + "\\begin{split}\n", + "\\nabla_\\theta J(\\theta)&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)G_t]\\\\\n", + "&=\\mathbb E_{s_0,a_0,\\cdots,s_t,a_t}[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)] \\mathbb E_{r_{t+1},s_{t+1},\\cdots,r_T,s_T}[G_t]\\\\\n", + "&=\\mathbb E_{s_0,a_0,\\cdots,s_t,a_t}[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)] Q(s_t,a_t)\\\\\n", + "&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) Q_w(s_t,a_t)]\n", + "\\end{split}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "f4aeb4c2", + "metadata": {}, + "source": [ + "### subtract baseline" + ] + }, + { + "cell_type": "markdown", + "id": "cc792fc3", + "metadata": {}, + "source": [ + "$$\n", + "A(s_t,a_t) = Q_w(s_t,a_t)-V_v(s_t)\n", + "$$\n", + "\n", + "- using the V function as the baseline function, \n", + "- we subtract the $Q$ value term with the $V$ value.\n", + "- how much better it is to take a specific action compared to the average, general action at the given state. \n", + " - **advantage value**" + ] + }, + { + "cell_type": "markdown", + "id": "3a672c92", + "metadata": {}, + "source": [ + "$$\n", + "\\begin{split}\n", + "&Q(s_t,a_t)=\\mathbb E[r_{t+1}+\\gamma V(s_{t+1})]\\\\\n", + "&A(s_t,a_t)=r_{t+1}+\\gamma V_v(s_{t+1})-V_v(s_t)\n", + "\\end{split}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "f8261b90", + "metadata": {}, + "source": [ + "### Advantage Actor Critic (A2C)" + ] + }, + { + "cell_type": "markdown", + "id": "06839e87", + "metadata": {}, + "source": [ + "$$\n", + "\\begin{split}\n", + "\\nabla_\\theta J(\\theta)&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) (Q_w(s_t,a_t)-V_v(s_t))]\\\\\n", + "&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) A(s_t,a_t)]\\\\\n", + "&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) \\left(r_{t+1}+\\gamma V_v(s_{t+1})-V_v(s_t)\\right)]\n", + "\\end{split}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "84e866dc", + "metadata": {}, + "source": [ + "## implemention" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ca7c2b9b", + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-02T14:44:43.218346Z", + "start_time": "2023-08-02T14:44:43.212181Z" + } + }, + "outputs": [], + "source": [ + "#!pip install -U gym==0.15.3" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "95577b00", + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-02T14:44:50.267749Z", + "start_time": "2023-08-02T14:44:48.612105Z" + } + }, + "outputs": [], + "source": [ + "import sys\n", + "import torch \n", + "import gym\n", + "import numpy as np \n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.functional as F\n", + "from torch.autograd import Variable\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "e5e4bf5c", + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-02T14:45:03.631947Z", + "start_time": "2023-08-02T14:45:03.625967Z" + } + }, + "outputs": [], + "source": [ + "# hyperparameters\n", + "hidden_size = 256\n", + "learning_rate = 3e-4\n", + "\n", + "# Constants\n", + "GAMMA = 0.99\n", + "num_steps = 300\n", + "max_episodes = 3000" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "7adb32c3", + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-02T15:02:01.173668Z", + "start_time": "2023-08-02T15:02:01.157049Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([-0.01507673, 0.00588999, 0.00869466, 0.00153444])" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "env = gym.make(\"CartPole-v0\")\n", + "\n", + "# 4-d 连续\n", + "num_inputs = env.observation_space.shape[0]\n", + "# 左右离散\n", + "num_actions = env.action_space.n\n", + "env.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "920fe998", + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-02T15:09:09.477744Z", + "start_time": "2023-08-02T15:09:09.464116Z" + } + }, + "outputs": [], + "source": [ + "class ActorCritic(nn.Module):\n", + " def __init__(self, num_inputs, num_actions, hidden_size):\n", + " super(ActorCritic, self).__init__()\n", + " self.num_actions = num_actions\n", + " \n", + " # num_inputs: shape of state\n", + " # critic nn: state => value\n", + " self.critic_ln1 = nn.Linear(num_inputs, hidden_size)\n", + " self.critic_ln2 = nn.Linear(hidden_size, 1)\n", + " \n", + " # actor nn: state => action, policy\n", + " self.actor_ln1 = nn.Linear(num_inputs, hidden_size)\n", + " self.actor_ln2 = nn.Linear(hidden_size, num_actions)\n", + " \n", + " def forward(self, state):\n", + " # (4, ) => (1, 4)\n", + " # ndarray => Variable\n", + " # state = Variable(torch.from_numpy(state).float().unsqueeze(0))\n", + " state = torch.tensor(state, requires_grad=True, dtype=torch.float32).unsqueeze(0)\n", + " \n", + " # forward of critic network\n", + " # (1, 4) => (1, 256)\n", + " value = F.relu(self.critic_ln1(state))\n", + " # (1, 256) => (1, 1)\n", + " value = self.critic_ln2(value)\n", + " \n", + " # (1, 4) => (1, 256)\n", + " policy_dist = F.relu(self.actor_ln1(state))\n", + " # (1, 256) => (1, 2)\n", + " policy_dist = F.softmax(self.actor_ln2(policy_dist), dim=1)\n", + " return value, policy_dist" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "b67cf056", + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-02T15:09:10.870382Z", + "start_time": "2023-08-02T15:09:10.859615Z" + } + }, + "outputs": [], + "source": [ + "ac = ActorCritic(num_inputs, num_actions, hidden_size, ) \n", + "ac_opt = optim.Adam(ac.parameters(), lr=learning_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "8821ac79", + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-02T15:18:44.585795Z", + "start_time": "2023-08-02T15:18:44.541188Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 1]) tensor([[0.0335]], grad_fn=<AddmmBackward0>)\n", + "torch.Size([1, 2]) tensor([[0.5095, 0.4905]], grad_fn=<SoftmaxBackward0>)\n", + "() 0.03351228\n", + "(1, 2) [[0.5095114 0.49048853]]\n" + ] + } + ], + "source": [ + "for episode in range(max_episodes):\n", + " \n", + " # same length\n", + " # index means timestamp: t\n", + " log_probs = [] \n", + " values = []\n", + " rewards = []\n", + " \n", + " state = env.reset()\n", + " \n", + " for step in range(num_steps):\n", + " value, policy_dist = ac(state)\n", + " \n", + " print(value.shape, value)\n", + " print(policy_dist.shape, policy_dist)\n", + " \n", + " value = value.detach().numpy()[0, 0]\n", + " dist = policy_dist.detach().numpy()\n", + " \n", + " print(value.shape, value)\n", + " print(dist.shape, dist)\n", + " \n", + " action = np.random.choice(num_actions, p=np.squeeze(dist))\n", + " log_prob = torch.log(policy_dist.squeeze(0)[action])\n", + " \n", + " new_state, reward, done, _ = env.step(action)\n", + " \n", + " rewards.append(reward)\n", + " values.append(value)\n", + " log_probs.append(log_prob)\n", + " \n", + " state = new_state\n", + " \n", + " if done or step == num_steps - 1:\n", + " \n", + " \n", + " break\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec50f7eb", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { |
