summaryrefslogtreecommitdiff
path: root/rl/tutorials
diff options
context:
space:
mode:
authorlanchunhui <zch921005@126.com>2023-08-02 23:20:47 +0800
committerlanchunhui <zch921005@126.com>2023-08-02 23:20:47 +0800
commit186a22521af4c1f56abfaf227d2e85ca03d24c5d (patch)
treeff3fe32ec55cd260ac5da599a164869a9e286dd7 /rl/tutorials
parent5bff7e2bcc303bfa6caee9d0b95bc21540d4c279 (diff)
update: notes
Diffstat (limited to 'rl/tutorials')
-rw-r--r--rl/tutorials/actor_critic.ipynb390
1 files changed, 385 insertions, 5 deletions
diff --git a/rl/tutorials/actor_critic.ipynb b/rl/tutorials/actor_critic.ipynb
index 0b357d7..1e19940 100644
--- a/rl/tutorials/actor_critic.ipynb
+++ b/rl/tutorials/actor_critic.ipynb
@@ -16,12 +16,392 @@
"- references\n",
" - https://github.com/pytorch/examples/tree/main/reinforcement_learning\n",
" - https://towardsdatascience.com/understanding-actor-critic-methods-931b97b6df3f\n",
- " - https://lilianweng.github.io/posts/2018-04-08-policy-gradient/\n",
- "- Actor - Critic\n",
- " - Actor\n",
- " - The policy gradient method is also the “actor” part of Actor-Critic methods \n",
- " - Critic"
+ " - https://lilianweng.github.io/posts/2018-04-08-policy-gradient/"
]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4f3f3454",
+ "metadata": {},
+ "source": [
+ "### Policy gradient"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "23ecd402",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-02T14:08:09.463801Z",
+ "start_time": "2023-08-02T14:08:09.445731Z"
+ }
+ },
+ "source": [
+ "- REINFORCE: noisy gradients & high variance (of gradients)\n",
+ " - update the policy parameter ($\\theta$) through Monte Carlo updates (i.e. taking random samples)\n",
+ " - This introduces in inherent high variability in \n",
+ " - log probabilities (log of the policy distribution): $\\log\\pi_\\theta(𝑎_𝑡|𝑠_𝑡)$\n",
+ " - cumulative reward values: $G_t$\n",
+ " - because each trajectories during training can deviate from each other at great degrees.\n",
+ "- cumulative reward == 0\n",
+ " - The essence of policy gradient is increasing the probabilities for “good” actions and decreasing those of “bad” actions in the policy distribution;\n",
+ " - both “goods” and “bad” actions with will not be learned if the cumulative reward is 0.\n",
+ "\n",
+ "$$\n",
+ "\\nabla_\\theta J(\\theta)=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)G_t]\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a0cfc6b9",
+ "metadata": {},
+ "source": [
+ "### introduce a baseline $b(s)$"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5fd3c0a8",
+ "metadata": {},
+ "source": [
+ "$$\n",
+ "\\nabla_\\theta J(\\theta)=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)(G_t-b(s_t))]\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bb95429f",
+ "metadata": {},
+ "source": [
+ "### Actor Critic"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "49ba68fc",
+ "metadata": {},
+ "source": [
+ "- Critic\n",
+ " - estimates the value function.\n",
+ " - action-value: $Q$ value\n",
+ " - state-value: $V$ value\n",
+ " - average general action value at the given state\n",
+ " - $Q_w(s_t,a_t)$ => Critic neural network,回归一个 value 值;\n",
+ " - Q Actor Critic\n",
+ "- Actor\n",
+ " - The policy gradient method is also the “actor” part of Actor-Critic methods \n",
+ "\n",
+ "- both the Critic and Actor functions are parameterized with neural networks. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ddc3152b",
+ "metadata": {},
+ "source": [
+ "\n",
+ "$$\n",
+ "\\begin{split}\n",
+ "\\nabla_\\theta J(\\theta)&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)G_t]\\\\\n",
+ "&=\\mathbb E_{s_0,a_0,\\cdots,s_t,a_t}[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)] \\mathbb E_{r_{t+1},s_{t+1},\\cdots,r_T,s_T}[G_t]\\\\\n",
+ "&=\\mathbb E_{s_0,a_0,\\cdots,s_t,a_t}[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)] Q(s_t,a_t)\\\\\n",
+ "&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) Q_w(s_t,a_t)]\n",
+ "\\end{split}\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f4aeb4c2",
+ "metadata": {},
+ "source": [
+ "### subtract baseline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cc792fc3",
+ "metadata": {},
+ "source": [
+ "$$\n",
+ "A(s_t,a_t) = Q_w(s_t,a_t)-V_v(s_t)\n",
+ "$$\n",
+ "\n",
+ "- using the V function as the baseline function, \n",
+ "- we subtract the $Q$ value term with the $V$ value.\n",
+ "- how much better it is to take a specific action compared to the average, general action at the given state. \n",
+ " - **advantage value**"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3a672c92",
+ "metadata": {},
+ "source": [
+ "$$\n",
+ "\\begin{split}\n",
+ "&Q(s_t,a_t)=\\mathbb E[r_{t+1}+\\gamma V(s_{t+1})]\\\\\n",
+ "&A(s_t,a_t)=r_{t+1}+\\gamma V_v(s_{t+1})-V_v(s_t)\n",
+ "\\end{split}\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f8261b90",
+ "metadata": {},
+ "source": [
+ "### Advantage Actor Critic (A2C)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "06839e87",
+ "metadata": {},
+ "source": [
+ "$$\n",
+ "\\begin{split}\n",
+ "\\nabla_\\theta J(\\theta)&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) (Q_w(s_t,a_t)-V_v(s_t))]\\\\\n",
+ "&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) A(s_t,a_t)]\\\\\n",
+ "&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) \\left(r_{t+1}+\\gamma V_v(s_{t+1})-V_v(s_t)\\right)]\n",
+ "\\end{split}\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "84e866dc",
+ "metadata": {},
+ "source": [
+ "## implemention"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "ca7c2b9b",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-02T14:44:43.218346Z",
+ "start_time": "2023-08-02T14:44:43.212181Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "#!pip install -U gym==0.15.3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "95577b00",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-02T14:44:50.267749Z",
+ "start_time": "2023-08-02T14:44:48.612105Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "import torch \n",
+ "import gym\n",
+ "import numpy as np \n",
+ "import torch.nn as nn\n",
+ "import torch.optim as optim\n",
+ "import torch.nn.functional as F\n",
+ "from torch.autograd import Variable\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "e5e4bf5c",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-02T14:45:03.631947Z",
+ "start_time": "2023-08-02T14:45:03.625967Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# hyperparameters\n",
+ "hidden_size = 256\n",
+ "learning_rate = 3e-4\n",
+ "\n",
+ "# Constants\n",
+ "GAMMA = 0.99\n",
+ "num_steps = 300\n",
+ "max_episodes = 3000"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "7adb32c3",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-02T15:02:01.173668Z",
+ "start_time": "2023-08-02T15:02:01.157049Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([-0.01507673, 0.00588999, 0.00869466, 0.00153444])"
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "env = gym.make(\"CartPole-v0\")\n",
+ "\n",
+ "# 4-d 连续\n",
+ "num_inputs = env.observation_space.shape[0]\n",
+ "# 左右离散\n",
+ "num_actions = env.action_space.n\n",
+ "env.reset()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "id": "920fe998",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-02T15:09:09.477744Z",
+ "start_time": "2023-08-02T15:09:09.464116Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "class ActorCritic(nn.Module):\n",
+ " def __init__(self, num_inputs, num_actions, hidden_size):\n",
+ " super(ActorCritic, self).__init__()\n",
+ " self.num_actions = num_actions\n",
+ " \n",
+ " # num_inputs: shape of state\n",
+ " # critic nn: state => value\n",
+ " self.critic_ln1 = nn.Linear(num_inputs, hidden_size)\n",
+ " self.critic_ln2 = nn.Linear(hidden_size, 1)\n",
+ " \n",
+ " # actor nn: state => action, policy\n",
+ " self.actor_ln1 = nn.Linear(num_inputs, hidden_size)\n",
+ " self.actor_ln2 = nn.Linear(hidden_size, num_actions)\n",
+ " \n",
+ " def forward(self, state):\n",
+ " # (4, ) => (1, 4)\n",
+ " # ndarray => Variable\n",
+ " # state = Variable(torch.from_numpy(state).float().unsqueeze(0))\n",
+ " state = torch.tensor(state, requires_grad=True, dtype=torch.float32).unsqueeze(0)\n",
+ " \n",
+ " # forward of critic network\n",
+ " # (1, 4) => (1, 256)\n",
+ " value = F.relu(self.critic_ln1(state))\n",
+ " # (1, 256) => (1, 1)\n",
+ " value = self.critic_ln2(value)\n",
+ " \n",
+ " # (1, 4) => (1, 256)\n",
+ " policy_dist = F.relu(self.actor_ln1(state))\n",
+ " # (1, 256) => (1, 2)\n",
+ " policy_dist = F.softmax(self.actor_ln2(policy_dist), dim=1)\n",
+ " return value, policy_dist"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "id": "b67cf056",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-02T15:09:10.870382Z",
+ "start_time": "2023-08-02T15:09:10.859615Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "ac = ActorCritic(num_inputs, num_actions, hidden_size, ) \n",
+ "ac_opt = optim.Adam(ac.parameters(), lr=learning_rate)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "id": "8821ac79",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-02T15:18:44.585795Z",
+ "start_time": "2023-08-02T15:18:44.541188Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([1, 1]) tensor([[0.0335]], grad_fn=<AddmmBackward0>)\n",
+ "torch.Size([1, 2]) tensor([[0.5095, 0.4905]], grad_fn=<SoftmaxBackward0>)\n",
+ "() 0.03351228\n",
+ "(1, 2) [[0.5095114 0.49048853]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "for episode in range(max_episodes):\n",
+ " \n",
+ " # same length\n",
+ " # index means timestamp: t\n",
+ " log_probs = [] \n",
+ " values = []\n",
+ " rewards = []\n",
+ " \n",
+ " state = env.reset()\n",
+ " \n",
+ " for step in range(num_steps):\n",
+ " value, policy_dist = ac(state)\n",
+ " \n",
+ " print(value.shape, value)\n",
+ " print(policy_dist.shape, policy_dist)\n",
+ " \n",
+ " value = value.detach().numpy()[0, 0]\n",
+ " dist = policy_dist.detach().numpy()\n",
+ " \n",
+ " print(value.shape, value)\n",
+ " print(dist.shape, dist)\n",
+ " \n",
+ " action = np.random.choice(num_actions, p=np.squeeze(dist))\n",
+ " log_prob = torch.log(policy_dist.squeeze(0)[action])\n",
+ " \n",
+ " new_state, reward, done, _ = env.step(action)\n",
+ " \n",
+ " rewards.append(reward)\n",
+ " values.append(value)\n",
+ " log_probs.append(log_prob)\n",
+ " \n",
+ " state = new_state\n",
+ " \n",
+ " if done or step == num_steps - 1:\n",
+ " \n",
+ " \n",
+ " break\n",
+ " break"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ec50f7eb",
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {