summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--basics/python/default_mutable_parameters.ipynb428
-rw-r--r--fun_math/r_w.py66
-rw-r--r--fun_math/random_walk.py14
-rw-r--r--rl/gym_demo/carl_pole.py60
-rw-r--r--rl/gym_demo/lunar/dqn.py112
-rw-r--r--rl/gym_demo/lunar/main.py35
-rw-r--r--rl/gym_demo/lunar/utils.py36
-rw-r--r--rl/gym_demo/taxi.py0
8 files changed, 715 insertions, 36 deletions
diff --git a/basics/python/default_mutable_parameters.ipynb b/basics/python/default_mutable_parameters.ipynb
new file mode 100644
index 0000000..c133f86
--- /dev/null
+++ b/basics/python/default_mutable_parameters.ipynb
@@ -0,0 +1,428 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. case 1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:27:26.778739Z",
+ "start_time": "2022-08-21T01:27:26.765873Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def add_employee(emp, emp_list = []):\n",
+ " emp_list.append(emp)\n",
+ " print(emp_list)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:27:36.118235Z",
+ "start_time": "2022-08-21T01:27:36.113245Z"
+ },
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['zhang']\n"
+ ]
+ }
+ ],
+ "source": [
+ "add_employee('zhang')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:27:39.612702Z",
+ "start_time": "2022-08-21T01:27:39.601724Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['zhang', 'li']\n"
+ ]
+ }
+ ],
+ "source": [
+ "add_employee('li')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:28:01.433747Z",
+ "start_time": "2022-08-21T01:28:01.430249Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['li']\n"
+ ]
+ }
+ ],
+ "source": [
+ "add_employee('li', [])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:28:15.988491Z",
+ "start_time": "2022-08-21T01:28:15.984261Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['zhang', 'li', 'li']\n"
+ ]
+ }
+ ],
+ "source": [
+ "add_employee('li')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:30:02.765148Z",
+ "start_time": "2022-08-21T01:30:02.744801Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(['zhang', 'li', 'li'],)"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "add_employee.__defaults__"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:30:44.627317Z",
+ "start_time": "2022-08-21T01:30:44.623666Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def add_employee(emp, emp_list = []):\n",
+ " print(add_employee.__defaults__)\n",
+ " emp_list.append(emp)\n",
+ " print(emp_list)\n",
+ " print(add_employee.__defaults__)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:30:46.009755Z",
+ "start_time": "2022-08-21T01:30:46.002688Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "([],)\n",
+ "['zhang']\n",
+ "(['zhang'],)\n"
+ ]
+ }
+ ],
+ "source": [
+ "add_employee('zhang')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:31:09.233929Z",
+ "start_time": "2022-08-21T01:31:09.230161Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(['zhang'],)\n",
+ "['zhang', 'li']\n",
+ "(['zhang', 'li'],)\n"
+ ]
+ }
+ ],
+ "source": [
+ "add_employee('li')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:31:45.175683Z",
+ "start_time": "2022-08-21T01:31:45.172977Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(['zhang', 'li'],)\n",
+ "['li']\n",
+ "(['zhang', 'li'],)\n"
+ ]
+ }
+ ],
+ "source": [
+ "add_employee('li', [])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:33:18.260279Z",
+ "start_time": "2022-08-21T01:33:18.254488Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def add_employee_fixed(emp, emp_list = None):\n",
+ " if emp_list is None:\n",
+ " emp_list = []\n",
+ " emp_list.append(emp)\n",
+ " print(emp_list)\n",
+ " print(add_employee_fixed.__defaults__)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:33:31.485064Z",
+ "start_time": "2022-08-21T01:33:31.481392Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(None,)\n",
+ "['zhang']\n",
+ "(None,)\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(add_employee_fixed.__defaults__)\n",
+ "add_employee_fixed('zhang')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. case 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:28:52.580680Z",
+ "start_time": "2022-08-21T01:28:52.576724Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from datetime import datetime\n",
+ "import time"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:34:32.671514Z",
+ "start_time": "2022-08-21T01:34:32.668551Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def print_time(time_to_print=datetime.now()):\n",
+ " print(time_to_print.strftime('%b %d, %Y %H:%M:%S'))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:34:48.345468Z",
+ "start_time": "2022-08-21T01:34:48.341096Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(datetime.datetime(2022, 8, 21, 9, 34, 32, 669555),)"
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "print_time.__defaults__"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:35:35.002020Z",
+ "start_time": "2022-08-21T01:35:34.998605Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def print_time_fixed(time_to_print=None):\n",
+ " if time_to_print is None:\n",
+ " time_to_print = datetime.now()\n",
+ " print(time_to_print.strftime('%b %d, %Y %H:%M:%S'))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:35:39.491909Z",
+ "start_time": "2022-08-21T01:35:39.488710Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Aug 21, 2022 09:35:39\n"
+ ]
+ }
+ ],
+ "source": [
+ "print_time_fixed()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-21T01:35:43.785204Z",
+ "start_time": "2022-08-21T01:35:43.782140Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Aug 21, 2022 09:35:43\n"
+ ]
+ }
+ ],
+ "source": [
+ "print_time_fixed()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/fun_math/r_w.py b/fun_math/r_w.py
new file mode 100644
index 0000000..c1eb3ed
--- /dev/null
+++ b/fun_math/r_w.py
@@ -0,0 +1,66 @@
+
+import matplotlib.pyplot as plt
+import matplotlib.animation as animation
+import numpy as np
+
+np.random.seed(1234)
+
+
+def random_walk(N):
+ """
+ Simulates a discrete random walk
+ :param int N : the number of steps to take
+ """
+ # event space: set of possible increments
+ increments = np.array([1, -1])
+ # the probability to generate 1
+ p = 0.5
+
+ # the epsilon values
+ random_increments = np.random.choice(increments, N, p)
+ # calculate the random walk
+ random_walk = np.cumsum(random_increments)
+
+ # return the entire walk and the increments
+ return random_walk, random_increments
+
+
+# generate a random walk
+N = 500
+X, epsilon = random_walk(N)
+
+# normalize the random walk using the Central Limit Theorem
+# X = X * np.sqrt(1. / N)
+
+fig = plt.figure(figsize=(21, 10))
+ax = plt.axes(xlim=(0, N), ylim=(np.min(X) - 0.5, np.max(X) + 0.5))
+line, = ax.plot([], [], lw=2, color='#0492C2')
+ax.set_xticks(np.arange(0, N+1, 50))
+ax.set_yticks(np.arange(np.min(X) - 0.5, np.max(X) + 0.5, 0.2))
+ax.set_title('2D Random Walk', fontsize=22)
+ax.set_xlabel('Steps', fontsize=18)
+ax.set_ylabel('Value', fontsize=18)
+ax.tick_params(labelsize=16)
+ax.grid(True, which='major', linestyle='--', color='black', alpha=0.4)
+
+# initialization function
+def init():
+ # creating an empty plot/frame
+ line.set_data([], [])
+ return line,
+
+# lists to store x and y axis points
+xdata, ydata = [], []
+
+# animation function
+def animate(i):
+ y = X[i]
+ # appending new points to x, y axes points list
+ xdata.append(i)
+ ydata.append(y)
+ line.set_data(xdata, ydata)
+ return line,
+
+# call the animator
+anim = animation.FuncAnimation(fig, animate, init_func=init, frames=N, interval=20, blit=True)
+anim.save('random_walk.gif',writer='imagemagick') \ No newline at end of file
diff --git a/fun_math/random_walk.py b/fun_math/random_walk.py
new file mode 100644
index 0000000..b425fcb
--- /dev/null
+++ b/fun_math/random_walk.py
@@ -0,0 +1,14 @@
+
+import random
+
+
+# N, S, W, E
+choices = [(0, 1), (0, -1), (-1, 0), (1, 0)]
+
+
+def rand_walk(n, p0=[]):
+ if p0 is None:
+ p0 = [0, 0]
+ pass
+
+
diff --git a/rl/gym_demo/carl_pole.py b/rl/gym_demo/carl_pole.py
index 146c337..dd22d1d 100644
--- a/rl/gym_demo/carl_pole.py
+++ b/rl/gym_demo/carl_pole.py
@@ -2,46 +2,34 @@
import gym
import numpy as np
-class BespokeAgent:
- def __init__(self, env):
- pass
+env_name = 'CartPole-v1'
- def decide(self, observation):
- position, velocity = observation
- lb = min(-0.09*(position + 0.25) ** 2 + 0.03, 0.3*(position + 0.9)**4 - 0.008)
- ub = -0.07*(position + 0.38) ** 2 + 0.07
- if lb < velocity < ub:
- action = 2
- else:
- action = 0
- # print('observation: {}, lb: {}, ub: {} => action: {}'.format(observation, lb, ub, action))
- return action
+env = gym.make(env_name)
- def learn(self, *argg):
- pass
+class Agent:
+ def __init__(self, env):
+ self.action_size = env.action_space.n
-def play(i, agent, env, render=True, train=False):
- episode_reward = 0
- observation = env.reset()
- while True:
- if render:
- env.render()
- action = agent.decide(observation)
- next_observation, reward, done, _ = env.step(action)
- episode_reward += reward
- if train:
- agent.learn(observation, action, reward, done)
- if done:
- env.close()
- break
- observation = next_observation
- print(i, episode_reward)
- return i, episode_reward
+ def action_policy(self, observation):
+ pos, vel, angle, _ = observation
+ if angle < 0:
+ return 0
+ return 1
if __name__ == '__main__':
- env = gym.make('MountainCar-v0')
- agent = BespokeAgent(env)
- rewards = [play(i, agent, env) for i in range(100)]
- print(rewards)
+
+ observation = env.reset()
+ agent = Agent(env)
+ reward_history = []
+ for _ in range(100):
+ # env.render()
+ # action = agent.action_policy(observation)
+ action = env.action_space.sample()
+ observation, reward, done, info = env.step(action)
+ reward_history.append(reward)
+ if done:
+ # env.env.close()
+ env.reset()
+ print(reward_history, np.mean(reward_history))
diff --git a/rl/gym_demo/lunar/dqn.py b/rl/gym_demo/lunar/dqn.py
new file mode 100644
index 0000000..a757a92
--- /dev/null
+++ b/rl/gym_demo/lunar/dqn.py
@@ -0,0 +1,112 @@
+import torch as T
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import numpy as np
+
+
+class DeepQNetwork(nn.Module):
+ def __init__(self, lr, input_dims, fc1_dims, fc2_dims,
+ n_actions):
+ super(DeepQNetwork, self).__init__()
+ self.input_dims = input_dims
+ self.fc1_dims = fc1_dims
+ self.fc2_dims = fc2_dims
+ self.n_actions = n_actions
+ self.fc1 = nn.Linear(*self.input_dims, self.fc1_dims)
+ self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims)
+ self.fc3 = nn.Linear(self.fc2_dims, self.n_actions)
+
+ self.optimizer = optim.Adam(self.parameters(), lr=lr)
+ self.loss = nn.MSELoss()
+ self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
+ self.to(self.device)
+
+ def forward(self, state):
+ x = F.relu(self.fc1(state))
+ x = F.relu(self.fc2(x))
+ actions = self.fc3(x)
+
+ return actions
+
+
+class Agent:
+ def __init__(self, gamma, epsilon, lr, input_dims, batch_size, n_actions,
+ max_mem_size=100000, eps_end=0.05, eps_dec=5e-4):
+ self.gamma = gamma
+ self.epsilon = epsilon
+ self.eps_min = eps_end
+ self.eps_dec = eps_dec
+ self.lr = lr
+ self.action_space = [i for i in range(n_actions)]
+ self.mem_size = max_mem_size
+ self.batch_size = batch_size
+ self.mem_cntr = 0
+ self.iter_cntr = 0
+ self.replace_target = 100
+
+ self.Q_eval = DeepQNetwork(lr, n_actions=n_actions,
+ input_dims=input_dims,
+ fc1_dims=256, fc2_dims=256)
+ self.state_memory = np.zeros((self.mem_size, *input_dims),
+ dtype=np.float32)
+ self.new_state_memory = np.zeros((self.mem_size, *input_dims),
+ dtype=np.float32)
+ self.action_memory = np.zeros(self.mem_size, dtype=np.int32)
+ self.reward_memory = np.zeros(self.mem_size, dtype=np.float32)
+ self.terminal_memory = np.zeros(self.mem_size, dtype=np.bool)
+
+ def store_transition(self, state, action, reward, state_, terminal):
+ index = self.mem_cntr % self.mem_size
+ self.state_memory[index] = state
+ self.new_state_memory[index] = state_
+ self.reward_memory[index] = reward
+ self.action_memory[index] = action
+ self.terminal_memory[index] = terminal
+
+ self.mem_cntr += 1
+
+ def choose_action(self, observation):
+ if np.random.random() > self.epsilon:
+ state = T.tensor([observation]).to(self.Q_eval.device)
+ actions = self.Q_eval.forward(state)
+ action = T.argmax(actions).item()
+ else:
+ action = np.random.choice(self.action_space)
+
+ return action
+
+ def learn(self):
+ if self.mem_cntr < self.batch_size:
+ return
+
+ self.Q_eval.optimizer.zero_grad()
+
+ max_mem = min(self.mem_cntr, self.mem_size)
+
+ batch = np.random.choice(max_mem, self.batch_size, replace=False)
+ batch_index = np.arange(self.batch_size, dtype=np.int32)
+
+ state_batch = T.tensor(self.state_memory[batch]).to(self.Q_eval.device)
+ new_state_batch = T.tensor(
+ self.new_state_memory[batch]).to(self.Q_eval.device)
+ action_batch = self.action_memory[batch]
+ reward_batch = T.tensor(
+ self.reward_memory[batch]).to(self.Q_eval.device)
+ terminal_batch = T.tensor(
+ self.terminal_memory[batch]).to(self.Q_eval.device)
+
+ q_eval = self.Q_eval.forward(state_batch)[batch_index, action_batch]
+ q_next = self.Q_eval.forward(new_state_batch)
+ q_next[terminal_batch] = 0.0
+
+ q_target = reward_batch + self.gamma*T.max(q_next, dim=1)[0]
+
+ loss = self.Q_eval.loss(q_target, q_eval).to(self.Q_eval.device)
+ loss.backward()
+ self.Q_eval.optimizer.step()
+
+ self.iter_cntr += 1
+ self.epsilon = self.epsilon - self.eps_dec \
+ if self.epsilon > self.eps_min else self.eps_min
+
diff --git a/rl/gym_demo/lunar/main.py b/rl/gym_demo/lunar/main.py
new file mode 100644
index 0000000..b718b71
--- /dev/null
+++ b/rl/gym_demo/lunar/main.py
@@ -0,0 +1,35 @@
+import gym
+from dqn import Agent
+from utils import plotLearning
+import numpy as np
+
+if __name__ == '__main__':
+ env = gym.make('LunarLander-v2')
+ agent = Agent(gamma=0.99, epsilon=1.0, batch_size=64, n_actions=4, eps_end=0.01,
+ input_dims=[8], lr=0.001)
+ scores, eps_history = [], []
+ n_games = 500
+
+ for i in range(n_games):
+ score = 0
+ done = False
+ observation = env.reset()
+ while not done:
+ action = agent.choose_action(observation)
+ observation_, reward, done, info = env.step(action)
+ score += reward
+ agent.store_transition(observation, action, reward,
+ observation_, done)
+ agent.learn()
+ observation = observation_
+ scores.append(score)
+ eps_history.append(agent.epsilon)
+
+ avg_score = np.mean(scores[-100:])
+
+ print('episode ', i, 'score %.2f' % score,
+ 'average score %.2f' % avg_score,
+ 'epsilon %.2f' % agent.epsilon)
+ x = [i + 1 for i in range(n_games)]
+ filename = 'lunar_lander.png'
+ plotLearning(x, scores, eps_history, filename)
diff --git a/rl/gym_demo/lunar/utils.py b/rl/gym_demo/lunar/utils.py
new file mode 100644
index 0000000..e881c76
--- /dev/null
+++ b/rl/gym_demo/lunar/utils.py
@@ -0,0 +1,36 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import gym
+
+def plotLearning(x, scores, epsilons, filename, lines=None):
+ fig=plt.figure()
+ ax=fig.add_subplot(111, label="1")
+ ax2=fig.add_subplot(111, label="2", frame_on=False)
+
+ ax.plot(x, epsilons, color="C0")
+ ax.set_xlabel("Game", color="C0")
+ ax.set_ylabel("Epsilon", color="C0")
+ ax.tick_params(axis='x', colors="C0")
+ ax.tick_params(axis='y', colors="C0")
+
+ N = len(scores)
+ running_avg = np.empty(N)
+ for t in range(N):
+ running_avg[t] = np.mean(scores[max(0, t-20):(t+1)])
+
+ ax2.scatter(x, running_avg, color="C1")
+ #ax2.xaxis.tick_top()
+ ax2.axes.get_xaxis().set_visible(False)
+ ax2.yaxis.tick_right()
+ #ax2.set_xlabel('x label 2', color="C1")
+ ax2.set_ylabel('Score', color="C1")
+ #ax2.xaxis.set_label_position('top')
+ ax2.yaxis.set_label_position('right')
+ #ax2.tick_params(axis='x', colors="C1")
+ ax2.tick_params(axis='y', colors="C1")
+
+ if lines is not None:
+ for line in lines:
+ plt.axvline(x=line)
+
+ plt.savefig(filename) \ No newline at end of file
diff --git a/rl/gym_demo/taxi.py b/rl/gym_demo/taxi.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/rl/gym_demo/taxi.py