From 678fab50280b647d95213a9695d07c49542696f2 Mon Sep 17 00:00:00 2001 From: zhang Date: Sat, 21 May 2022 14:23:49 +0800 Subject: 0521 --- rl/gym_demo/carl_pole.py | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 rl/gym_demo/carl_pole.py (limited to 'rl/gym_demo/carl_pole.py') diff --git a/rl/gym_demo/carl_pole.py b/rl/gym_demo/carl_pole.py new file mode 100644 index 0000000..146c337 --- /dev/null +++ b/rl/gym_demo/carl_pole.py @@ -0,0 +1,47 @@ + +import gym +import numpy as np + +class BespokeAgent: + def __init__(self, env): + pass + + def decide(self, observation): + position, velocity = observation + lb = min(-0.09*(position + 0.25) ** 2 + 0.03, 0.3*(position + 0.9)**4 - 0.008) + ub = -0.07*(position + 0.38) ** 2 + 0.07 + if lb < velocity < ub: + action = 2 + else: + action = 0 + # print('observation: {}, lb: {}, ub: {} => action: {}'.format(observation, lb, ub, action)) + return action + + def learn(self, *argg): + pass + + +def play(i, agent, env, render=True, train=False): + episode_reward = 0 + observation = env.reset() + while True: + if render: + env.render() + action = agent.decide(observation) + next_observation, reward, done, _ = env.step(action) + episode_reward += reward + if train: + agent.learn(observation, action, reward, done) + if done: + env.close() + break + observation = next_observation + print(i, episode_reward) + return i, episode_reward + + +if __name__ == '__main__': + env = gym.make('MountainCar-v0') + agent = BespokeAgent(env) + rewards = [play(i, agent, env) for i in range(100)] + print(rewards) -- cgit v1.2.3