Skip to content

Commit bc3fd03

Browse files
committed
Models container, midway to train and evaluate several models
1 parent ae9ddb0 commit bc3fd03

File tree

3 files changed

+50
-9
lines changed

3 files changed

+50
-9
lines changed

models_container.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, GBTClassifier, \
44
LinearSVC, NaiveBayes
5+
from pyspark.sql import DataFrame
56

67
from spark_launcher import SparkLauncher
78

@@ -16,6 +17,7 @@ def __init__(self, model, name='', kind=ModelKinds.CLASSIFICATION):
1617
self.model = model
1718
self.name = name
1819
self.kind = kind
20+
self.fitted_model = None
1921

2022

2123
class ModelsContainer(object):
@@ -36,12 +38,40 @@ def __init__(self):
3638

3739
@property
3840
def classification(self):
39-
return [obj for name, obj in self.__dict__.items()
40-
if getattr(obj, "kind", None) == ModelKinds.CLASSIFICATION]
41+
"""Returns the classification models"""
42+
return self._get_models_of_kind(kind=ModelKinds.CLASSIFICATION)
43+
44+
45+
def fit(self, data: DataFrame, kind="*"):
46+
"""Loops though all models of some kind and generates fitted models"""
47+
if kind == "*":
48+
models = self._all_models_dict.values()
49+
else:
50+
models = self._get_models_of_kind(kind)
51+
52+
for model in models:
53+
model.fitted_model = model.model.fit(data)
54+
55+
56+
@property
57+
def fitted_models(self):
58+
return [model.fitted_model for model in self._all_models_dict.values()]
4159

4260

4361
def _wrap_models(self):
44-
for name, obj in self.__dict__.items():
45-
if self.model_path in str(obj.__class__):
46-
wrapped = Model(model=obj, name=name)
47-
setattr(self, name, wrapped)
62+
"""Wraps the pyspark model in our own Model class that
63+
provides some metadata and perhaps extra functionality"""
64+
for name, obj in self._all_models_dict.items():
65+
wrapped = Model(model=obj, name=name)
66+
setattr(self, name, wrapped)
67+
68+
69+
@property
70+
def _all_models_dict(self):
71+
return {name: obj for name, obj in self.__dict__.items()
72+
if self.model_path in str(obj.__class__)}
73+
74+
75+
def _get_models_of_kind(self, kind):
76+
return [obj for name, obj in self.__dict__.items()
77+
if getattr(obj, "kind", None) == kind]

tests/test_model_evaluator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pyspark.mllib.evaluation import BinaryClassificationMetrics
66

77
from model_evaluator import ModelEvaluator
8+
from models_container import ModelsContainer, ModelKinds
89

910
"""Expected results with real data
1011
model dataset AUC_ROC AUC_PR
@@ -15,7 +16,7 @@
1516
CLASSIFICATION_METRICS = ["areaUnderROC", "areaUnderPR"]
1617

1718

18-
def test_model_evaluator_with_linear_regression(logistic_model, preprocessor):
19+
def test_model_evaluator_with_linear_regression_and_tiny_dataset(logistic_model, preprocessor):
1920
_check_evaluation(preprocessor=preprocessor, model=logistic_model,
2021
metrics={"areaUnderROC": 1., "areaUnderPR": 1.})
2122

@@ -26,9 +27,19 @@ def test_model_evaluator_with_linear_regression_and_full_train_data(logistic_mod
2627
metrics={"areaUnderROC": 0.764655781, "areaUnderPR": 0.63384702449})
2728

2829

30+
def test_several_classification_models_fitting(preprocessor_train_data):
31+
preprocessor_train_data.prepare_to_model(target_col='income', to_strip=' .')
32+
evaluator = ModelEvaluator(metrics_class=BinaryClassificationMetrics)
33+
models = ModelsContainer()
34+
models.fit(preprocessor_train_data.train_encoded_df, kind=ModelKinds.CLASSIFICATION)
35+
evaluator.compare({"train": preprocessor_train_data.train_encoded_df}, models=models.fitted_models)
36+
print('kk')
37+
38+
2939
def _check_evaluation(preprocessor, model, metrics: Dict[str, float]):
3040
metrics_class = BinaryClassificationMetrics
3141
evaluator = ModelEvaluator(metrics_class=metrics_class)
42+
# The purpose of this parameter is to prove names can be arbitrary in the compare method
3243
dataframes_sets = [['train', 'test'], ['train1', 'test1']]
3344
for dataframes in dataframes_sets:
3445
comparison = evaluator.compare(

todo_list.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
(cleaning, encoding, etc)
1616
- [x] obtain evaluation metrics for a single model
1717
- [ ] **fit and compare several classification models without tuning**
18-
- [ ] create an object container for the models
19-
- [ ] initialize the models with default hyperparameters
18+
- [x] create an object container for the models
19+
- [x] initialize the models with default hyperparameters
2020
- [ ] fit and compare the results with the evaluator
2121
- [ ] fit and compare several classification models with tuning and crossvalidation
2222
- [ ] be able to pass a list of hyperparameters values for each hyperparameter

0 commit comments

Comments
 (0)