Skip to content

Commit 12a666a

Browse files
committed
Model evaluator, produces right metrics for single classification model
1 parent 3987cf3 commit 12a666a

File tree

5 files changed

+44
-20
lines changed

5 files changed

+44
-20
lines changed

data_preprocessor.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ def __init__(self, train_df: DataFrame = None, test_df: DataFrame = None):
1515
self.train_encoded_df = None
1616
self.test_encoded_df = None
1717

18+
self.one_hot_suffix = '_vec'
19+
self.indexed_suffix = "_cat"
20+
1821

1922
def explore_factors(self):
2023
"""Generates a dictionary of one pandas dataframe per column
@@ -88,22 +91,33 @@ def assemble_features(self, *columns, out_name='features'):
8891
def prepare_to_model(self, target_col: str, to_strip=' '):
8992
"""Runs all cleaning and encoding steps to generate
9093
dataframes ready to use in modeling"""
94+
# if target_col in self.factors:
95+
# target_col += indexed_suffix
9196
self.strip_columns(*self.factors, to_strip=to_strip)
92-
self.string_index(*self.factors, suffix='_cat')
97+
self.string_index(*self.factors, suffix=self.indexed_suffix)
9398
# one-hot encode indexed factors, except target
94-
to_one_hot_encode = [fac + "_cat" for fac in self.factors if fac != target_col]
95-
self.one_hot_encode(*to_one_hot_encode, suffix='_vec')
99+
self.one_hot_encode(*self._one_hot_encode_columns(target_col), suffix=self.one_hot_suffix)
96100
# assemble all together with numeric columns into features (except target if it's numeric)
97-
to_assemble = [col for col in self.numeric_columns if col != target_col]
98-
to_assemble += [col for col, data_type in self.train_df.dtypes if "_cat_vec" in col]
99-
self.assemble_features(*to_assemble)
100-
if target_col in self.factors:
101-
target_col += "_cat"
101+
self.assemble_features(*self._columns_to_assemble(target_col))
102102

103+
if target_col in self.factors:
104+
target_col += self.indexed_suffix
103105
self.train_encoded_df = self._select_to_model(self.train_df, target_col)
104106
self.test_encoded_df = self._select_to_model(self.test_df, target_col)
105107

106108

109+
def _one_hot_encode_columns(self, target_col):
110+
return [fac + self.indexed_suffix for fac in self.factors if fac != target_col]
111+
112+
113+
def _columns_to_assemble(self, target_col):
114+
numeric = [col for col in self.numeric_columns
115+
if col != target_col and not col.endswith(self.indexed_suffix)]
116+
one_hot_encoded = [col for col, data_type in self.train_df.dtypes
117+
if self.indexed_suffix + self.one_hot_suffix in col]
118+
return numeric + one_hot_encoded
119+
120+
107121
@property
108122
def factors(self) -> List[str]:
109123
return self._get_cols_by_types(types=['string'])

model_evaluator.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,21 @@ def __init__(self, metrics_class=None):
1414
self.metrics_class = metrics_class if metrics_class else BinaryClassificationMetrics
1515

1616

17-
@staticmethod
18-
def compare(data_frames: Dict[str, DataFrame], models: list):
19-
output_value = 1.0
20-
df = data_frames[list(data_frames.keys())[0]]
21-
if len(df.take(20)) > 10:
22-
output_value = 0.768680
17+
def compare(self, data_frames: Dict[str, DataFrame], models: list):
18+
# per model and per data frame calculate all the metrics
19+
index = []
2320
metrics = ["areaUnderROC", "areaUnderPR"]
24-
return pandas.DataFrame({metric: [output_value for m in metrics] for metric in metrics},
25-
index=[key for key in data_frames.keys()])
21+
data = {metric: [] for metric in metrics}
22+
for model in models:
23+
for df_name, df in data_frames.items():
24+
index.append(self.index_key(df_name, model))
25+
evaluator = self.metrics_class(model.transform(df).select('prediction', 'label').rdd)
26+
for metric in metrics:
27+
data[metric].append(getattr(evaluator, metric))
28+
29+
return pandas.DataFrame(data, index=index)
30+
31+
32+
@staticmethod
33+
def index_key(df_name, model):
34+
return model.__class__.__name__ + "_" + df_name

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,6 @@ def _get_preprocessor(data_path: str):
3434

3535
def _fit_logistic(preprocessor):
3636
preprocessor.prepare_to_model(target_col='income', to_strip=' .')
37-
lr = LogisticRegression(maxIter=10, regParam=0.1, elasticNetParam=0.2)
37+
lr = LogisticRegression(maxIter=10, regParam=0, elasticNetParam=0)
3838
fit_model = lr.fit(preprocessor.train_encoded_df)
3939
return fit_model

tests/test_model_evaluator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def test_model_evaluator_with_linear_regression(logistic_model, preprocessor):
2222

2323
def test_model_evaluator_with_linear_regression_and_full_train_data(logistic_model_train_data, preprocessor_train_data):
2424
_check_evaluation(preprocessor=preprocessor_train_data,
25-
model=logistic_model_train_data, metrics={"areaUnderROC": 0.768680, "areaUnderPR": 0.640418})
25+
model=logistic_model_train_data,
26+
metrics={"areaUnderROC": 0.764655781, "areaUnderPR": 0.63384702449})
2627

2728

2829
def _check_evaluation(preprocessor, model, metrics: Dict[str, float]):
@@ -39,4 +40,4 @@ def _check_evaluation(preprocessor, model, metrics: Dict[str, float]):
3940
for metric in metrics:
4041
assert metric in comparison
4142
for dataframe in dataframes:
42-
assert comparison[metric][dataframe] == pytest.approx(metrics[metric])
43+
assert comparison[metric][evaluator.index_key(dataframe, model)] == pytest.approx(metrics[metric])

todo_list.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
- [x] fit a simple classification model
1414
- [x] prepare the data frame by applying all transformations
1515
(cleaning, encoding, etc)
16-
- [ ] **obtain evaluation metrics for the model**
16+
- [x] obtain evaluation metrics for a single model
1717
- [ ] fit several classification models
1818
- [ ] compare all classification models
1919
- [ ] prepare data for regression

0 commit comments

Comments
 (0)