@@ -46,7 +46,7 @@ class ModelBase:
46
46
clf_timing = Timing ()
47
47
48
48
def __init__ (self , ** kwargs ):
49
- self ._plot_label_dic = {}
49
+ self ._plot_label_dict = {}
50
50
self ._title = self ._name = None
51
51
self ._metrics , self ._available_metrics = [], {
52
52
"acc" : ClassifierBase .acc
@@ -135,9 +135,9 @@ def scatter2d(self, x, y, padding=0.5, title=None):
135
135
y_max += y_padding
136
136
137
137
if labels .ndim == 1 :
138
- if not self ._plot_label_dic :
139
- self ._plot_label_dic = {c : i for i , c in enumerate (set (labels ))}
140
- dic = self ._plot_label_dic
138
+ if not self ._plot_label_dict :
139
+ self ._plot_label_dict = {c : i for i , c in enumerate (set (labels ))}
140
+ dic = self ._plot_label_dict
141
141
n_label = len (dic )
142
142
labels = np .array ([dic [label ] for label in labels ])
143
143
else :
@@ -181,9 +181,9 @@ def scatter3d(self, x, y, padding=0.1, title=None):
181
181
182
182
def transform_arr (arr ):
183
183
if arr .ndim == 1 :
184
- _dic = {c : i for i , c in enumerate (set (arr ))}
185
- n_dim = len (_dic )
186
- arr = np .array ([_dic [label ] for label in arr ])
184
+ dic = {c : i for i , c in enumerate (set (arr ))}
185
+ n_dim = len (dic )
186
+ arr = np .array ([dic [label ] for label in arr ])
187
187
else :
188
188
n_dim = arr .shape [1 ]
189
189
arr = np .argmax (arr , axis = 1 )
@@ -197,15 +197,15 @@ def transform_arr(arr):
197
197
198
198
labels , n_label = transform_arr (labels )
199
199
colors = plt .cm .rainbow ([i / n_label for i in range (n_label )])[labels ]
200
- _indices = [labels == i for i in range (n_label )]
201
- _scatters = []
200
+ indices = [labels == i for i in range (n_label )]
201
+ scatters = []
202
202
fig = plt .figure ()
203
203
plt .title (title )
204
204
ax = fig .add_subplot (111 , projection = '3d' )
205
- for _index in _indices :
206
- _scatters .append (ax .scatter (axis [0 ][_index ], axis [1 ][_index ], axis [2 ][_index ], c = colors [_index ]))
207
- ax .legend (_scatters , ["$c_{}$" .format ("{" + str (i ) + "}" ) for i in range (len (_scatters ))],
208
- ncol = math .ceil (math .sqrt (len (_scatters ))), fontsize = 8 )
205
+ for _index in indices :
206
+ scatters .append (ax .scatter (axis [0 ][_index ], axis [1 ][_index ], axis [2 ][_index ], c = colors [_index ]))
207
+ ax .legend (scatters , ["$c_{}$" .format ("{" + str (i ) + "}" ) for i in range (len (scatters ))],
208
+ ncol = math .ceil (math .sqrt (len (scatters ))), fontsize = 8 )
209
209
plt .show ()
210
210
211
211
# Util
@@ -344,9 +344,9 @@ def get_base(_nx, _ny):
344
344
z = self .predict (base_matrix ).reshape ((nx , ny ))
345
345
346
346
if labels .ndim == 1 :
347
- if not self ._plot_label_dic :
348
- self ._plot_label_dic = {c : i for i , c in enumerate (set (labels ))}
349
- dic = self ._plot_label_dic
347
+ if not self ._plot_label_dict :
348
+ self ._plot_label_dict = {c : i for i , c in enumerate (set (labels ))}
349
+ dic = self ._plot_label_dict
350
350
n_label = len (dic )
351
351
labels = np .array ([dic [label ] for label in labels ])
352
352
else :
@@ -366,9 +366,9 @@ def get_base(_nx, _ny):
366
366
plt .contour (xf , yf , z , c = 'k-' , levels = [0 ])
367
367
plt .scatter (axis [0 ], axis [1 ], c = colors )
368
368
if emphasize is not None :
369
- _indices = np .array ([False ] * len (axis [0 ]))
370
- _indices [np .asarray (emphasize )] = True
371
- plt .scatter (axis [0 ][_indices ], axis [1 ][_indices ], s = 80 ,
369
+ indices = np .array ([False ] * len (axis [0 ]))
370
+ indices [np .asarray (emphasize )] = True
371
+ plt .scatter (axis [0 ][indices ], axis [1 ][indices ], s = 80 ,
372
372
facecolors = "None" , zorder = 10 )
373
373
if extra is not None :
374
374
plt .scatter (* np .asarray (extra ).T , s = 80 , zorder = 25 , facecolors = "red" )
@@ -411,9 +411,9 @@ def get_base(_nx, _ny):
411
411
print ("Drawing figures..." )
412
412
xy_xf , xy_yf = np .meshgrid (xf , yf , sparse = True )
413
413
if labels .ndim == 1 :
414
- if not self ._plot_label_dic :
415
- self ._plot_label_dic = {c : i for i , c in enumerate (set (labels ))}
416
- dic = self ._plot_label_dic
414
+ if not self ._plot_label_dict :
415
+ self ._plot_label_dict = {c : i for i , c in enumerate (set (labels ))}
416
+ dic = self ._plot_label_dict
417
417
n_label = len (dic )
418
418
labels = np .array ([dic [label ] for label in labels ])
419
419
else :
@@ -439,9 +439,9 @@ def get_base(_nx, _ny):
439
439
plt .contour (xf , yf , z , c = 'k-' , levels = [0 ])
440
440
plt .scatter (axis [0 ], axis [1 ], c = colors )
441
441
if emphasize is not None :
442
- _indices = np .array ([False ] * len (axis [0 ]))
443
- _indices [np .asarray (emphasize )] = True
444
- plt .scatter (axis [0 ][_indices ], axis [1 ][_indices ], s = 80 ,
442
+ indices = np .array ([False ] * len (axis [0 ]))
443
+ indices [np .asarray (emphasize )] = True
444
+ plt .scatter (axis [0 ][indices ], axis [1 ][indices ], s = 80 ,
445
445
facecolors = "None" , zorder = 10 )
446
446
if extra is not None :
447
447
plt .scatter (* np .asarray (extra ).T , s = 80 , zorder = 25 , facecolors = "red" )
@@ -503,9 +503,9 @@ def get_base(_nx, _ny, _nz):
503
503
504
504
def transform_arr (arr ):
505
505
if arr .ndim == 1 :
506
- _dic = {c : i for i , c in enumerate (set (arr ))}
507
- n_dim = len (_dic )
508
- arr = np .array ([_dic [label ] for label in arr ])
506
+ dic = {c : i for i , c in enumerate (set (arr ))}
507
+ n_dim = len (dic )
508
+ arr = np .array ([dic [label ] for label in arr ])
509
509
else :
510
510
n_dim = arr .shape [1 ]
511
511
arr = np .argmax (arr , axis = 1 )
@@ -566,9 +566,9 @@ def _draw(_ax, _x, _xf, _y, _yf, _z):
566
566
def _emphasize (_ax , axis0 , axis1 , _c ):
567
567
_ax .scatter (axis0 , axis1 , c = _c )
568
568
if emphasize is not None :
569
- _indices = np .array ([False ] * len (axis [0 ]))
570
- _indices [np .asarray (emphasize )] = True
571
- _ax .scatter (axis0 [_indices ], axis1 [_indices ], s = 80 ,
569
+ indices = np .array ([False ] * len (axis [0 ]))
570
+ indices [np .asarray (emphasize )] = True
571
+ _ax .scatter (axis0 [indices ], axis1 [indices ], s = 80 ,
572
572
facecolors = "None" , zorder = 10 )
573
573
574
574
def _extra (_ax , axis0 , axis1 , _c , _ex0 , _ex1 ):
0 commit comments