summaryrefslogtreecommitdiff
path: root/rl/gym_demo/lunar/utils.py
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2022-08-21 09:41:53 +0800
committerzhang <zch921005@126.com>2022-08-21 09:41:53 +0800
commit94b6d3246c72eb3cae58a2fd18771e3c2c3e7cb2 (patch)
treee23a287289cdca8020fab062cdebffa60b021d7f /rl/gym_demo/lunar/utils.py
parent756b736ca374dc6ef2adadce101f380e10f06c4e (diff)
copy
Diffstat (limited to 'rl/gym_demo/lunar/utils.py')
-rw-r--r--rl/gym_demo/lunar/utils.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/rl/gym_demo/lunar/utils.py b/rl/gym_demo/lunar/utils.py
new file mode 100644
index 0000000..e881c76
--- /dev/null
+++ b/rl/gym_demo/lunar/utils.py
@@ -0,0 +1,36 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import gym
+
+def plotLearning(x, scores, epsilons, filename, lines=None):
+ fig=plt.figure()
+ ax=fig.add_subplot(111, label="1")
+ ax2=fig.add_subplot(111, label="2", frame_on=False)
+
+ ax.plot(x, epsilons, color="C0")
+ ax.set_xlabel("Game", color="C0")
+ ax.set_ylabel("Epsilon", color="C0")
+ ax.tick_params(axis='x', colors="C0")
+ ax.tick_params(axis='y', colors="C0")
+
+ N = len(scores)
+ running_avg = np.empty(N)
+ for t in range(N):
+ running_avg[t] = np.mean(scores[max(0, t-20):(t+1)])
+
+ ax2.scatter(x, running_avg, color="C1")
+ #ax2.xaxis.tick_top()
+ ax2.axes.get_xaxis().set_visible(False)
+ ax2.yaxis.tick_right()
+ #ax2.set_xlabel('x label 2', color="C1")
+ ax2.set_ylabel('Score', color="C1")
+ #ax2.xaxis.set_label_position('top')
+ ax2.yaxis.set_label_position('right')
+ #ax2.tick_params(axis='x', colors="C1")
+ ax2.tick_params(axis='y', colors="C1")
+
+ if lines is not None:
+ for line in lines:
+ plt.axvline(x=line)
+
+ plt.savefig(filename) \ No newline at end of file