summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--rl/tutorials/actor_critic.ipynb183
1 files changed, 147 insertions, 36 deletions
diff --git a/rl/tutorials/actor_critic.ipynb b/rl/tutorials/actor_critic.ipynb
index 8ea81dc..90e61bb 100644
--- a/rl/tutorials/actor_critic.ipynb
+++ b/rl/tutorials/actor_critic.ipynb
@@ -168,7 +168,8 @@
"\\begin{split}\n",
"\\nabla_\\theta J(\\theta)&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) (Q_w(s_t,a_t)-V_v(s_t))]\\\\\n",
"&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) A(s_t,a_t)]\\\\\n",
- "&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) \\left(r_{t+1}+\\gamma V_v(s_{t+1})-V_v(s_t)\\right)]\n",
+ "&=\\mathbb E_\\tau[\\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t) \\left(r_{t+1}+\\gamma V_v(s_{t+1})-V_v(s_t)\\right)]\\\\\n",
+ "&\\sim \\sum_{t=0}^{T-1}\\nabla_\\theta\\log\\pi_\\theta(a_t|s_t)A(s_t,a_t)\n",
"\\end{split}\n",
"$$"
]
@@ -198,12 +199,12 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 1,
"id": "95577b00",
"metadata": {
"ExecuteTime": {
- "end_time": "2023-08-02T14:44:50.267749Z",
- "start_time": "2023-08-02T14:44:48.612105Z"
+ "end_time": "2023-08-14T12:23:02.885907Z",
+ "start_time": "2023-08-14T12:23:01.085301Z"
}
},
"outputs": [],
@@ -222,12 +223,12 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 2,
"id": "e5e4bf5c",
"metadata": {
"ExecuteTime": {
- "end_time": "2023-08-02T14:45:03.631947Z",
- "start_time": "2023-08-02T14:45:03.625967Z"
+ "end_time": "2023-08-14T12:23:04.055208Z",
+ "start_time": "2023-08-14T12:23:04.051613Z"
}
},
"outputs": [],
@@ -244,22 +245,22 @@
},
{
"cell_type": "code",
- "execution_count": 33,
+ "execution_count": 3,
"id": "7adb32c3",
"metadata": {
"ExecuteTime": {
- "end_time": "2023-08-02T15:02:01.173668Z",
- "start_time": "2023-08-02T15:02:01.157049Z"
+ "end_time": "2023-08-14T12:23:05.500821Z",
+ "start_time": "2023-08-14T12:23:05.450333Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
- "array([-0.01507673, 0.00588999, 0.00869466, 0.00153444])"
+ "array([0.00262747, 0.04244922, 0.03201457, 0.0220412 ])"
]
},
- "execution_count": 33,
+ "execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@@ -276,12 +277,12 @@
},
{
"cell_type": "code",
- "execution_count": 37,
+ "execution_count": 4,
"id": "920fe998",
"metadata": {
"ExecuteTime": {
- "end_time": "2023-08-02T15:09:09.477744Z",
- "start_time": "2023-08-02T15:09:09.464116Z"
+ "end_time": "2023-08-14T12:23:08.880810Z",
+ "start_time": "2023-08-14T12:23:08.869140Z"
}
},
"outputs": [],
@@ -321,12 +322,12 @@
},
{
"cell_type": "code",
- "execution_count": 38,
+ "execution_count": 19,
"id": "b67cf056",
"metadata": {
"ExecuteTime": {
- "end_time": "2023-08-02T15:09:10.870382Z",
- "start_time": "2023-08-02T15:09:10.859615Z"
+ "end_time": "2023-08-14T12:49:34.881302Z",
+ "start_time": "2023-08-14T12:49:34.874061Z"
}
},
"outputs": [],
@@ -337,12 +338,12 @@
},
{
"cell_type": "code",
- "execution_count": 41,
+ "execution_count": 20,
"id": "8821ac79",
"metadata": {
"ExecuteTime": {
- "end_time": "2023-08-02T15:18:44.585795Z",
- "start_time": "2023-08-02T15:18:44.541188Z"
+ "end_time": "2023-08-14T12:51:16.845696Z",
+ "start_time": "2023-08-14T12:49:37.970632Z"
}
},
"outputs": [
@@ -350,14 +351,46 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "torch.Size([1, 1]) tensor([[0.0335]], grad_fn=<AddmmBackward0>)\n",
- "torch.Size([1, 2]) tensor([[0.5095, 0.4905]], grad_fn=<SoftmaxBackward0>)\n",
- "() 0.03351228\n",
- "(1, 2) [[0.5095114 0.49048853]]\n"
+ "episode: 0, reward: 15.0, steps: 14, average steps: 14.0\n",
+ "episode: 100, reward: 28.0, steps: 27, average steps: 29.9\n",
+ "episode: 200, reward: 38.0, steps: 37, average steps: 47.9\n",
+ "episode: 300, reward: 42.0, steps: 41, average steps: 34.3\n",
+ "episode: 400, reward: 53.0, steps: 52, average steps: 33.8\n",
+ "episode: 500, reward: 59.0, steps: 58, average steps: 70.3\n",
+ "episode: 600, reward: 74.0, steps: 73, average steps: 82.4\n",
+ "episode: 700, reward: 47.0, steps: 46, average steps: 114.2\n",
+ "episode: 800, reward: 181.0, steps: 180, average steps: 134.0\n",
+ "episode: 900, reward: 110.0, steps: 109, average steps: 119.2\n",
+ "episode: 1000, reward: 200.0, steps: 199, average steps: 143.2\n",
+ "episode: 1100, reward: 99.0, steps: 98, average steps: 108.0\n",
+ "episode: 1200, reward: 173.0, steps: 172, average steps: 152.3\n",
+ "episode: 1300, reward: 179.0, steps: 178, average steps: 128.4\n",
+ "episode: 1400, reward: 200.0, steps: 199, average steps: 130.3\n",
+ "episode: 1500, reward: 162.0, steps: 161, average steps: 132.8\n",
+ "episode: 1600, reward: 84.0, steps: 83, average steps: 140.9\n",
+ "episode: 1700, reward: 200.0, steps: 199, average steps: 193.7\n",
+ "episode: 1800, reward: 129.0, steps: 128, average steps: 135.4\n",
+ "episode: 1900, reward: 126.0, steps: 125, average steps: 144.8\n",
+ "episode: 2000, reward: 48.0, steps: 47, average steps: 148.7\n",
+ "episode: 2100, reward: 200.0, steps: 199, average steps: 179.0\n",
+ "episode: 2200, reward: 190.0, steps: 189, average steps: 196.2\n",
+ "episode: 2300, reward: 200.0, steps: 199, average steps: 160.0\n",
+ "episode: 2400, reward: 200.0, steps: 199, average steps: 199.0\n",
+ "episode: 2500, reward: 200.0, steps: 199, average steps: 195.8\n",
+ "episode: 2600, reward: 200.0, steps: 199, average steps: 193.0\n",
+ "episode: 2700, reward: 200.0, steps: 199, average steps: 186.4\n",
+ "episode: 2800, reward: 105.0, steps: 104, average steps: 161.9\n",
+ "episode: 2900, reward: 200.0, steps: 199, average steps: 187.7\n"
]
}
],
"source": [
+ "\n",
+ "# len == max_episodes\n",
+ "all_rewards = []\n",
+ "all_steps = []\n",
+ "ma_steps = []\n",
+ "\n",
"for episode in range(max_episodes):\n",
" \n",
" # same length\n",
@@ -365,21 +398,23 @@
" log_probs = [] \n",
" values = []\n",
" rewards = []\n",
+ " entropy_term = 0\n",
" \n",
" state = env.reset()\n",
" \n",
" for step in range(num_steps):\n",
" value, policy_dist = ac(state)\n",
" \n",
- " print(value.shape, value)\n",
- " print(policy_dist.shape, policy_dist)\n",
+ "# print(value.shape, value)\n",
+ "# print(policy_dist.shape, policy_dist)\n",
" \n",
" value = value.detach().numpy()[0, 0]\n",
" dist = policy_dist.detach().numpy()\n",
" \n",
- " print(value.shape, value)\n",
- " print(dist.shape, dist)\n",
+ "# print(value.shape, value)\n",
+ "# print(dist.shape, dist)\n",
" \n",
+ " # 概率化地选择,[0/1]\n",
" action = np.random.choice(num_actions, p=np.squeeze(dist))\n",
" log_prob = torch.log(policy_dist.squeeze(0)[action])\n",
" \n",
@@ -389,19 +424,90 @@
" values.append(value)\n",
" log_probs.append(log_prob)\n",
" \n",
+ " entropy_term += -np.sum(dist * np.log(dist))\n",
+ "# print(dist, np.log(dist), entropy_term)\n",
+ " \n",
" state = new_state\n",
" \n",
" if done or step == num_steps - 1:\n",
- " \n",
- " \n",
- " break\n",
- " break"
+ " q_value, _ = ac(new_state)\n",
+ " q_value = q_value.detach().numpy()[0, 0]\n",
+ " all_rewards.append(np.sum(rewards))\n",
+ " all_steps.append(step)\n",
+ " ma_steps.append(np.mean(all_steps[-10:]))\n",
+ " if episode % 100 == 0:\n",
+ " print(f'episode: {episode}, reward: {np.sum(rewards)}, steps: {step}, average steps: {ma_steps[-1]}')\n",
+ " break\n",
+ " \n",
+ " # 收集训练数据\n",
+ " q_values = np.zeros_like(values)\n",
+ " for t in range(len(values))[::-1]:\n",
+ " # discounted reward\n",
+ " # 只跟 final value 有关\n",
+ " q_value = rewards[-1] + GAMMA * q_value\n",
+ " q_values[t] = q_value\n",
+ " \n",
+ " values = torch.FloatTensor(values)\n",
+ " q_values = torch.FloatTensor(q_values)\n",
+ " log_probs = torch.stack(log_probs)\n",
+ "# print(values.shape, q_values.shape, log_probs.shape)\n",
+ " advantage = q_values - values\n",
+ " \n",
+ " # scalar objective to minimize\n",
+ " actor_loss = -(log_probs*advantage).mean()\n",
+ " critic_loss = 0.5*advantage.pow(2).mean()\n",
+ " loss = actor_loss + critic_loss + 0.001*entropy_term\n",
+ " \n",
+ " ac_opt.zero_grad()\n",
+ " loss.backward()\n",
+ " ac_opt.step()"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 27,
"id": "ec50f7eb",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-14T12:53:59.718552Z",
+ "start_time": "2023-08-14T12:53:59.492891Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[<matplotlib.lines.Line2D at 0x7f4cd00fbbb0>]"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<Figure size 640x480 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "ma_rewards = pd.Series.rolling(pd.Series(all_rewards), 10).mean()\n",
+ "plt.plot([r for r in ma_rewards])\n",
+ "# plt.plot(all_rewards)\n",
+ "# plt.plot()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0cc63d5a",
"metadata": {},
"outputs": [],
"source": []
@@ -434,9 +540,14 @@
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
- "toc_position": {},
+ "toc_position": {
+ "height": "calc(100% - 180px)",
+ "left": "10px",
+ "top": "150px",
+ "width": "336px"
+ },
"toc_section_display": true,
- "toc_window_display": false
+ "toc_window_display": true
}
},
"nbformat": 4,