@@ -63,11 +63,24 @@ def _build_model_compute_statistics(fset_path, model_type, model_params,
63
63
if params_to_optimize :
64
64
model = GridSearchCV (model , params_to_optimize )
65
65
model .fit (fset , data ['labels' ])
66
- score = model .score (fset , data ['labels' ])
66
+
67
+ metrics = {}
68
+ metrics ['train_score' ] = model .score (fset , data ['labels' ])
69
+
67
70
best_params = model .best_params_ if params_to_optimize else {}
68
71
joblib .dump (model , model_path )
69
72
70
- return score , best_params
73
+ if model_type == 'RandomForestClassifier' :
74
+ if params_to_optimize :
75
+ model = model .best_estimator_
76
+ if hasattr (model , 'oob_score_' ):
77
+ metrics ['oob_score' ] = model .oob_score_
78
+ if hasattr (model , 'feature_importances_' ):
79
+ metrics ['feature_importances' ] = dict (zip (
80
+ fset .columns .get_level_values (0 ).tolist (),
81
+ model .feature_importances_ .tolist ()))
82
+
83
+ return metrics , best_params
71
84
72
85
73
86
class ModelHandler (BaseHandler ):
@@ -84,12 +97,12 @@ def get(self, model_id=None):
84
97
@auth_or_token
85
98
async def _await_model_statistics (self , model_stats_future , model ):
86
99
try :
87
- score , best_params = await model_stats_future
100
+ model_metrics , best_params = await model_stats_future
88
101
89
102
model = DBSession ().merge (model )
90
103
model .task_id = None
91
104
model .finished = datetime .datetime .now ()
92
- model .train_score = score
105
+ model .metrics = model_metrics
93
106
model .params .update (best_params )
94
107
DBSession ().commit ()
95
108
0 commit comments