-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathE2P_analyze_real.py
60 lines (46 loc) · 1.41 KB
/
E2P_analyze_real.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
Plot E2 - Visualize classification results for real-world streams
"""
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(1233)
measures = ["clustering",
"complexity",
"concept",
"general",
"info-theory",
"itemset",
"landmarking",
"model-based",
"statistical"
]
real_streams = [
'covtypeNorm-1-2vsAll',
'electricity',
'poker-lsn-1-2vsAll',
'INSECTS-abrupt',
'INSECTS-gradual',
'INSECTS-incremental'
]
base_clfs = ['GNB', 'KNN', 'SVM', 'DT', 'MLP']
n_drift_types=3
stream_reps=5
res = np.load('results/real_clf.npy') # measures, datasets, reps, folds, clfs
res_mean = np.mean(res, axis=2)
print(res_mean.shape)
fig, ax = plt.subplots(2, 3, figsize=(13,11), sharex=True, sharey=True)
ax=ax.ravel()
plt.suptitle('Real-world', fontsize=18)
for dataset_id, dataset in enumerate(real_streams):
axx = ax[dataset_id]
r = res_mean[:,dataset_id]
axx.imshow(r, vmin=0.05, vmax=1., cmap='Blues')
for _a, __a in enumerate(measures):
for _b, __b in enumerate(base_clfs):
axx.text(_b, _a, "%.3f" % (r[_a, _b]) , va='center', ha='center', c='black' if r[_a, _b]<0.5 else 'white', fontsize=11)
axx.set_title(dataset)
axx.set_xticks(np.arange(len(base_clfs)),base_clfs)
axx.set_yticks(np.arange(len(measures)),measures)
plt.tight_layout()
plt.savefig('foo.png')
plt.savefig('figures/fig_clf/real.png')