summaryrefslogtreecommitdiff
path: root/rl/gym_demo/carl_pole.py
diff options
context:
space:
mode:
Diffstat (limited to 'rl/gym_demo/carl_pole.py')
-rw-r--r--rl/gym_demo/carl_pole.py60
1 files changed, 24 insertions, 36 deletions
diff --git a/rl/gym_demo/carl_pole.py b/rl/gym_demo/carl_pole.py
index 146c337..dd22d1d 100644
--- a/rl/gym_demo/carl_pole.py
+++ b/rl/gym_demo/carl_pole.py
@@ -2,46 +2,34 @@
import gym
import numpy as np
-class BespokeAgent:
- def __init__(self, env):
- pass
+env_name = 'CartPole-v1'
- def decide(self, observation):
- position, velocity = observation
- lb = min(-0.09*(position + 0.25) ** 2 + 0.03, 0.3*(position + 0.9)**4 - 0.008)
- ub = -0.07*(position + 0.38) ** 2 + 0.07
- if lb < velocity < ub:
- action = 2
- else:
- action = 0
- # print('observation: {}, lb: {}, ub: {} => action: {}'.format(observation, lb, ub, action))
- return action
+env = gym.make(env_name)
- def learn(self, *argg):
- pass
+class Agent:
+ def __init__(self, env):
+ self.action_size = env.action_space.n
-def play(i, agent, env, render=True, train=False):
- episode_reward = 0
- observation = env.reset()
- while True:
- if render:
- env.render()
- action = agent.decide(observation)
- next_observation, reward, done, _ = env.step(action)
- episode_reward += reward
- if train:
- agent.learn(observation, action, reward, done)
- if done:
- env.close()
- break
- observation = next_observation
- print(i, episode_reward)
- return i, episode_reward
+ def action_policy(self, observation):
+ pos, vel, angle, _ = observation
+ if angle < 0:
+ return 0
+ return 1
if __name__ == '__main__':
- env = gym.make('MountainCar-v0')
- agent = BespokeAgent(env)
- rewards = [play(i, agent, env) for i in range(100)]
- print(rewards)
+
+ observation = env.reset()
+ agent = Agent(env)
+ reward_history = []
+ for _ in range(100):
+ # env.render()
+ # action = agent.action_policy(observation)
+ action = env.action_space.sample()
+ observation, reward, done, info = env.step(action)
+ reward_history.append(reward)
+ if done:
+ # env.env.close()
+ env.reset()
+ print(reward_history, np.mean(reward_history))