diff options
Diffstat (limited to 'rl/gym_demo/lunar/main.py')
| -rw-r--r-- | rl/gym_demo/lunar/main.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/rl/gym_demo/lunar/main.py b/rl/gym_demo/lunar/main.py new file mode 100644 index 0000000..b718b71 --- /dev/null +++ b/rl/gym_demo/lunar/main.py @@ -0,0 +1,35 @@ +import gym +from dqn import Agent +from utils import plotLearning +import numpy as np + +if __name__ == '__main__': + env = gym.make('LunarLander-v2') + agent = Agent(gamma=0.99, epsilon=1.0, batch_size=64, n_actions=4, eps_end=0.01, + input_dims=[8], lr=0.001) + scores, eps_history = [], [] + n_games = 500 + + for i in range(n_games): + score = 0 + done = False + observation = env.reset() + while not done: + action = agent.choose_action(observation) + observation_, reward, done, info = env.step(action) + score += reward + agent.store_transition(observation, action, reward, + observation_, done) + agent.learn() + observation = observation_ + scores.append(score) + eps_history.append(agent.epsilon) + + avg_score = np.mean(scores[-100:]) + + print('episode ', i, 'score %.2f' % score, + 'average score %.2f' % avg_score, + 'epsilon %.2f' % agent.epsilon) + x = [i + 1 for i in range(n_games)] + filename = 'lunar_lander.png' + plotLearning(x, scores, eps_history, filename) |
