summaryrefslogtreecommitdiff
path: root/rl/gym_demo/lunar/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'rl/gym_demo/lunar/main.py')
-rw-r--r--rl/gym_demo/lunar/main.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/rl/gym_demo/lunar/main.py b/rl/gym_demo/lunar/main.py
new file mode 100644
index 0000000..b718b71
--- /dev/null
+++ b/rl/gym_demo/lunar/main.py
@@ -0,0 +1,35 @@
+import gym
+from dqn import Agent
+from utils import plotLearning
+import numpy as np
+
+if __name__ == '__main__':
+ env = gym.make('LunarLander-v2')
+ agent = Agent(gamma=0.99, epsilon=1.0, batch_size=64, n_actions=4, eps_end=0.01,
+ input_dims=[8], lr=0.001)
+ scores, eps_history = [], []
+ n_games = 500
+
+ for i in range(n_games):
+ score = 0
+ done = False
+ observation = env.reset()
+ while not done:
+ action = agent.choose_action(observation)
+ observation_, reward, done, info = env.step(action)
+ score += reward
+ agent.store_transition(observation, action, reward,
+ observation_, done)
+ agent.learn()
+ observation = observation_
+ scores.append(score)
+ eps_history.append(agent.epsilon)
+
+ avg_score = np.mean(scores[-100:])
+
+ print('episode ', i, 'score %.2f' % score,
+ 'average score %.2f' % avg_score,
+ 'epsilon %.2f' % agent.epsilon)
+ x = [i + 1 for i in range(n_games)]
+ filename = 'lunar_lander.png'
+ plotLearning(x, scores, eps_history, filename)