diff options
Diffstat (limited to 'rl/gym_demo/lunar/utils.py')
| -rw-r--r-- | rl/gym_demo/lunar/utils.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/rl/gym_demo/lunar/utils.py b/rl/gym_demo/lunar/utils.py new file mode 100644 index 0000000..e881c76 --- /dev/null +++ b/rl/gym_demo/lunar/utils.py @@ -0,0 +1,36 @@ +import matplotlib.pyplot as plt +import numpy as np +import gym + +def plotLearning(x, scores, epsilons, filename, lines=None): + fig=plt.figure() + ax=fig.add_subplot(111, label="1") + ax2=fig.add_subplot(111, label="2", frame_on=False) + + ax.plot(x, epsilons, color="C0") + ax.set_xlabel("Game", color="C0") + ax.set_ylabel("Epsilon", color="C0") + ax.tick_params(axis='x', colors="C0") + ax.tick_params(axis='y', colors="C0") + + N = len(scores) + running_avg = np.empty(N) + for t in range(N): + running_avg[t] = np.mean(scores[max(0, t-20):(t+1)]) + + ax2.scatter(x, running_avg, color="C1") + #ax2.xaxis.tick_top() + ax2.axes.get_xaxis().set_visible(False) + ax2.yaxis.tick_right() + #ax2.set_xlabel('x label 2', color="C1") + ax2.set_ylabel('Score', color="C1") + #ax2.xaxis.set_label_position('top') + ax2.yaxis.set_label_position('right') + #ax2.tick_params(axis='x', colors="C1") + ax2.tick_params(axis='y', colors="C1") + + if lines is not None: + for line in lines: + plt.axvline(x=line) + + plt.savefig(filename)
\ No newline at end of file |
