summaryrefslogtreecommitdiff
path: root/rl/tutorials/01_env.ipynb
diff options
context:
space:
mode:
authorchzhang <zch921005@126.com>2022-12-04 19:48:02 +0800
committerchzhang <zch921005@126.com>2022-12-04 19:48:02 +0800
commitc50696c00293adf872785ac048f00fd995eb54d2 (patch)
tree9d4bb59ba7f02db651b150216fd2d00e12fff30e /rl/tutorials/01_env.ipynb
parent8a4203f66b826fc82b481e2f999cc0816e366d76 (diff)
rl envs
Diffstat (limited to 'rl/tutorials/01_env.ipynb')
-rw-r--r--rl/tutorials/01_env.ipynb367
1 files changed, 367 insertions, 0 deletions
diff --git a/rl/tutorials/01_env.ipynb b/rl/tutorials/01_env.ipynb
new file mode 100644
index 0000000..2202a91
--- /dev/null
+++ b/rl/tutorials/01_env.ipynb
@@ -0,0 +1,367 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. 依赖"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-12-04T11:43:09.702221Z",
+ "start_time": "2022-12-04T11:43:09.500450Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import gym"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. 环境(以 cartpole 为例)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-12-04T11:43:10.882322Z",
+ "start_time": "2022-12-04T11:43:10.873917Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "env_name = 'CartPole-v0'\n",
+ "env = gym.make(env_name)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-12-04T10:39:54.826089Z",
+ "start_time": "2022-12-04T10:39:54.823921Z"
+ }
+ },
+ "source": [
+ "### 2.1 env 成员"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- 环境定义了动作空间及状态空间\n",
+ "- 此外还需要(step(action)):\n",
+ "$$\n",
+ "\\begin{split}\n",
+ "&R(s_t, a_t)=r_t\\\\\n",
+ "&P(s_t,a_t)=s_{t+1}\n",
+ "\\end{split}\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-12-04T11:36:42.119972Z",
+ "start_time": "2022-12-04T11:36:42.107426Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "<TimeLimit<CartPoleEnv<CartPole-v0>>>"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "env"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-12-04T11:36:49.966685Z",
+ "start_time": "2022-12-04T11:36:49.962150Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Discrete(2)"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "env.action_space"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-12-04T11:37:07.789565Z",
+ "start_time": "2022-12-04T11:37:07.785005Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "env.action_space.sample()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-12-04T11:37:20.687378Z",
+ "start_time": "2022-12-04T11:37:20.683155Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "env.observation_space"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-12-04T11:37:36.789515Z",
+ "start_time": "2022-12-04T11:37:36.785299Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],\n",
+ " dtype=float32)"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "env.observation_space.low"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-12-04T11:37:46.991799Z",
+ "start_time": "2022-12-04T11:37:46.987407Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],\n",
+ " dtype=float32)"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "env.observation_space.high"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-12-04T11:37:55.306094Z",
+ "start_time": "2022-12-04T11:37:55.300230Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([ 1.6534705e+00, -1.3081517e+38, -2.1108967e-01, 1.9319253e+38],\n",
+ " dtype=float32)"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "env.observation_space.sample()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.2 用 action 与 env 交互"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- 是一个 loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-12-04T11:40:48.246501Z",
+ "start_time": "2022-12-04T11:40:48.203554Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "env.step??"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-12-04T11:41:52.201203Z",
+ "start_time": "2022-12-04T11:41:49.284589Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "total reward: 21.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "done = False\n",
+ "score = 0\n",
+ "state = env.reset()\n",
+ "\n",
+ "while not done:\n",
+ " env.render()\n",
+ " action = env.action_space.sample()\n",
+ " observation, reward, done, info = env.step(action)\n",
+ " score += reward\n",
+ "print(f'total reward: {score}')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-12-04T11:43:16.749930Z",
+ "start_time": "2022-12-04T11:43:12.658103Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "epoch: 1, total reward: 29.0\n",
+ "epoch: 2, total reward: 22.0\n",
+ "epoch: 3, total reward: 9.0\n",
+ "epoch: 4, total reward: 34.0\n",
+ "epoch: 5, total reward: 23.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "for epoch in range(1, 5+1):\n",
+ " done = False\n",
+ " score = 0\n",
+ " state = env.reset()\n",
+ "\n",
+ " while not done:\n",
+ " env.render()\n",
+ " action = env.action_space.sample()\n",
+ " observation, reward, done, info = env.step(action)\n",
+ " score += reward\n",
+ " print(f'epoch: {epoch}, total reward: {score}')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}