-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimages.py
42 lines (34 loc) · 1.37 KB
/
images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import matplotlib.pyplot as plt
import numpy as np
from minatar import Environment
import matplotlib
def rolling(rewards, dist = 500):
rolled = []
for i in range(len(rewards)):
if(i+1 <= dist):
rolled.append(sum(rewards[:i+1]) / len(rewards[:i+1]))
else:
rolled.append(sum(rewards[i - dist + 1 : i+1]) / dist)
return(rolled)
def show_rewards(title, train_rewards, test_rewards = None, quantity = 15000):
if(test_rewards == None): test_rewards = train_rewards
t = [i for i in range(len(train_rewards))]
rolled_train = rolling(train_rewards)
rolled_test = rolling(test_rewards)
if(len(train_rewards) > quantity):
train_rewards = train_rewards[-quantity:]
rolled_train = rolled_train[-quantity:]
test_rewards = test_rewards[-quantity:]
rolled_test = rolled_test[-quantity:]
t = t[-quantity:]
#matplotlib.use("module://matplotlib_inline.backend_inline")
#with plt.ioff():
plt.title(title)
plt.plot(t, train_rewards, color = "pink")
plt.plot(t, test_rewards, color = "aqua")
plt.plot(t, rolled_train, color = "red")
plt.plot(t, rolled_test, color = "blue")
#plt.savefig("/content/drive/MyDrive/aim_mini/regulated_plots/" + title + '.png')
plt.show()
plt.close()
#matplotlib.use('TkAgg')