diff options
| author | zhang <zch921005@126.com> | 2022-08-21 09:41:53 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2022-08-21 09:41:53 +0800 |
| commit | 94b6d3246c72eb3cae58a2fd18771e3c2c3e7cb2 (patch) | |
| tree | e23a287289cdca8020fab062cdebffa60b021d7f /rl/gym_demo/lunar/utils.py | |
| parent | 756b736ca374dc6ef2adadce101f380e10f06c4e (diff) | |
copy
Diffstat (limited to 'rl/gym_demo/lunar/utils.py')
| -rw-r--r-- | rl/gym_demo/lunar/utils.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/rl/gym_demo/lunar/utils.py b/rl/gym_demo/lunar/utils.py new file mode 100644 index 0000000..e881c76 --- /dev/null +++ b/rl/gym_demo/lunar/utils.py @@ -0,0 +1,36 @@ +import matplotlib.pyplot as plt +import numpy as np +import gym + +def plotLearning(x, scores, epsilons, filename, lines=None): + fig=plt.figure() + ax=fig.add_subplot(111, label="1") + ax2=fig.add_subplot(111, label="2", frame_on=False) + + ax.plot(x, epsilons, color="C0") + ax.set_xlabel("Game", color="C0") + ax.set_ylabel("Epsilon", color="C0") + ax.tick_params(axis='x', colors="C0") + ax.tick_params(axis='y', colors="C0") + + N = len(scores) + running_avg = np.empty(N) + for t in range(N): + running_avg[t] = np.mean(scores[max(0, t-20):(t+1)]) + + ax2.scatter(x, running_avg, color="C1") + #ax2.xaxis.tick_top() + ax2.axes.get_xaxis().set_visible(False) + ax2.yaxis.tick_right() + #ax2.set_xlabel('x label 2', color="C1") + ax2.set_ylabel('Score', color="C1") + #ax2.xaxis.set_label_position('top') + ax2.yaxis.set_label_position('right') + #ax2.tick_params(axis='x', colors="C1") + ax2.tick_params(axis='y', colors="C1") + + if lines is not None: + for line in lines: + plt.axvline(x=line) + + plt.savefig(filename)
\ No newline at end of file |
