Skip to content

Commit 4dfa9e3

Browse files
committed
Add OOB score and feature importance chart to displayed model metrics
1 parent 96fb017 commit 4dfa9e3

File tree

5 files changed

+61
-9
lines changed

5 files changed

+61
-9
lines changed

cesium_app/handlers/model.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,24 @@ def _build_model_compute_statistics(fset_path, model_type, model_params,
6363
if params_to_optimize:
6464
model = GridSearchCV(model, params_to_optimize)
6565
model.fit(fset, data['labels'])
66-
score = model.score(fset, data['labels'])
66+
67+
metrics = {}
68+
metrics['train_score'] = model.score(fset, data['labels'])
69+
6770
best_params = model.best_params_ if params_to_optimize else {}
6871
joblib.dump(model, model_path)
6972

70-
return score, best_params
73+
if model_type == 'RandomForestClassifier':
74+
if params_to_optimize:
75+
model = model.best_estimator_
76+
if hasattr(model, 'oob_score_'):
77+
metrics['oob_score'] = model.oob_score_
78+
if hasattr(model, 'feature_importances_'):
79+
metrics['feature_importances'] = dict(zip(
80+
fset.columns.get_level_values(0).tolist(),
81+
model.feature_importances_.tolist()))
82+
83+
return metrics, best_params
7184

7285

7386
class ModelHandler(BaseHandler):
@@ -84,12 +97,12 @@ def get(self, model_id=None):
8497
@auth_or_token
8598
async def _await_model_statistics(self, model_stats_future, model):
8699
try:
87-
score, best_params = await model_stats_future
100+
model_metrics, best_params = await model_stats_future
88101

89102
model = DBSession().merge(model)
90103
model.task_id = None
91104
model.finished = datetime.datetime.now()
92-
model.train_score = score
105+
model.metrics = model_metrics
93106
model.params.update(best_params)
94107
DBSession().commit()
95108

cesium_app/models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class Model(Base):
8989
file_uri = sa.Column(sa.String(), nullable=True, index=True)
9090
task_id = sa.Column(sa.String())
9191
finished = sa.Column(sa.DateTime)
92-
train_score = sa.Column(sa.Float)
92+
metrics = sa.Column(sa.JSON, nullable=True)
9393

9494
featureset = relationship('Featureset')
9595
project = relationship('Project')

package.json

+2
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
"bokehjs": "^0.12.5",
1010
"bootstrap": "^3.3.7",
1111
"bootstrap-css": "^3.0.0",
12+
"chart.js": "^2.7.1",
1213
"css-loader": "^0.26.2",
1314
"exports-loader": "^0.6.4",
1415
"imports-loader": "^0.7.1",
1516
"jquery": "^3.1.1",
1617
"prop-types": "^15.5.10",
1718
"react": "^15.1.0",
19+
"react-chartjs-2": "^2.7.0",
1820
"react-dom": "^15.1.0",
1921
"react-redux": "^5.0.3",
2022
"react-tabs": "^0.8.2",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import React from 'react';
2+
import { HorizontalBar } from 'react-chartjs-2';
3+
4+
5+
const FeatureImportancesBarchart = props => {
6+
const sorted_features = Object.keys(props.data).sort(
7+
(a, b) => props.data[b] - props.data[a]).slice(0, 15);
8+
const values = sorted_features.map(
9+
feature => props.data[feature].toFixed(3));
10+
const data = {
11+
labels: sorted_features,
12+
datasets: [
13+
{
14+
label: 'Feature Importance',
15+
backgroundColor: '#2222ff',
16+
hoverBackgroundColor: '#5555ff',
17+
data: values
18+
}
19+
]
20+
};
21+
22+
return (
23+
<div>
24+
<HorizontalBar data={data} />
25+
</div>
26+
);
27+
};
28+
29+
export default FeatureImportancesBarchart;

static/js/components/Models.jsx

+12-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import Expand from './Expand';
1111
import Delete from './Delete';
1212
import { $try, reformatDatetime } from '../utils';
1313
import FoldableRow from './FoldableRow';
14+
import FeatureImportances from './FeatureImportances';
1415

1516

1617
const ModelsTab = props => (
@@ -169,7 +170,7 @@ let ModelInfo = props => (
169170
<tr>
170171
<th>Model Type</th>
171172
<th>Hyperparameters</th>
172-
<th>Training Data Score</th>
173+
{Object.keys(props.model.metrics).map(metric => <th>{metric}</th>)}
173174
</tr>
174175
</thead>
175176
<tbody>
@@ -191,9 +192,16 @@ let ModelInfo = props => (
191192
</tbody>
192193
</table>
193194
</td>
194-
<td>
195-
{props.model.train_score}
196-
</td>
195+
{
196+
Object.keys(props.model.metrics).map(metric => (
197+
<td>
198+
{
199+
metric == 'feature_importances' ?
200+
<FeatureImportances data={props.model.metrics[metric]} /> :
201+
props.model.metrics[metric]
202+
}
203+
</td>))
204+
}
197205
</tr>
198206
</tbody>
199207
</table>

0 commit comments

Comments
 (0)