Skip to content

Commit 7a30e03

Browse files
committed
Renamed the ModelKinds enum to ModelTypes
1 parent 1234e42 commit 7a30e03

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

models_container.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
from spark_launcher import SparkLauncher
88

99

10-
class ModelKinds(Enum):
10+
class ModelTypes(Enum):
1111
CLASSIFICATION = 'classification'
1212
REGRESSION = 'regression'
1313

1414

1515
class Model(object):
16-
def __init__(self, model, name='', kind=ModelKinds.CLASSIFICATION):
16+
def __init__(self, model, name='', kind=ModelTypes.CLASSIFICATION):
1717
self.model = model
1818
self.name = name
1919
self.kind = kind
@@ -39,7 +39,7 @@ def __init__(self):
3939
@property
4040
def classification(self):
4141
"""Returns the classification models"""
42-
return self._get_models_of_kind(kind=ModelKinds.CLASSIFICATION)
42+
return self._get_models_of_kind(kind=ModelTypes.CLASSIFICATION)
4343

4444

4545
def fit(self, data: DataFrame, kind="*"):

tests/test_model_evaluator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from data_preprocessor import DataPreprocessor
88
from model_evaluator import ModelEvaluator
9-
from models_container import ModelsContainer, ModelKinds
9+
from models_container import ModelsContainer, ModelTypes
1010

1111
"""Expected results with real data
1212
model dataset AUC_ROC AUC_PR
@@ -34,7 +34,7 @@ def test_several_classification_models_fitting(preprocessor_train_data):
3434
preprocessor.prepare_to_model(target_col='income', to_strip=' .')
3535

3636
models = ModelsContainer()
37-
models.fit(preprocessor.train_encoded_df, kind=ModelKinds.CLASSIFICATION)
37+
models.fit(preprocessor.train_encoded_df, kind=ModelTypes.CLASSIFICATION)
3838
expected_results = [
3939
{"model": models.logistic_class.fitted_model,
4040
"metrics": {"areaUnderROC": 0.770414, "areaUnderPR": 0.646093}, },

0 commit comments

Comments
 (0)