diff options
Diffstat (limited to 'rl/gym_demo/carl_pole.py')
| -rw-r--r-- | rl/gym_demo/carl_pole.py | 60 |
1 files changed, 24 insertions, 36 deletions
diff --git a/rl/gym_demo/carl_pole.py b/rl/gym_demo/carl_pole.py index 146c337..dd22d1d 100644 --- a/rl/gym_demo/carl_pole.py +++ b/rl/gym_demo/carl_pole.py @@ -2,46 +2,34 @@ import gym import numpy as np -class BespokeAgent: - def __init__(self, env): - pass +env_name = 'CartPole-v1' - def decide(self, observation): - position, velocity = observation - lb = min(-0.09*(position + 0.25) ** 2 + 0.03, 0.3*(position + 0.9)**4 - 0.008) - ub = -0.07*(position + 0.38) ** 2 + 0.07 - if lb < velocity < ub: - action = 2 - else: - action = 0 - # print('observation: {}, lb: {}, ub: {} => action: {}'.format(observation, lb, ub, action)) - return action +env = gym.make(env_name) - def learn(self, *argg): - pass +class Agent: + def __init__(self, env): + self.action_size = env.action_space.n -def play(i, agent, env, render=True, train=False): - episode_reward = 0 - observation = env.reset() - while True: - if render: - env.render() - action = agent.decide(observation) - next_observation, reward, done, _ = env.step(action) - episode_reward += reward - if train: - agent.learn(observation, action, reward, done) - if done: - env.close() - break - observation = next_observation - print(i, episode_reward) - return i, episode_reward + def action_policy(self, observation): + pos, vel, angle, _ = observation + if angle < 0: + return 0 + return 1 if __name__ == '__main__': - env = gym.make('MountainCar-v0') - agent = BespokeAgent(env) - rewards = [play(i, agent, env) for i in range(100)] - print(rewards) + + observation = env.reset() + agent = Agent(env) + reward_history = [] + for _ in range(100): + # env.render() + # action = agent.action_policy(observation) + action = env.action_space.sample() + observation, reward, done, info = env.step(action) + reward_history.append(reward) + if done: + # env.env.close() + env.reset() + print(reward_history, np.mean(reward_history)) |
