4
4
import pytest
5
5
from pyspark .mllib .evaluation import BinaryClassificationMetrics
6
6
7
+ from data_preprocessor import DataPreprocessor
7
8
from model_evaluator import ModelEvaluator
8
9
from models_container import ModelsContainer , ModelKinds
9
10
@@ -28,17 +29,30 @@ def test_model_evaluator_with_linear_regression_and_full_train_data(logistic_mod
28
29
29
30
30
31
def test_several_classification_models_fitting (preprocessor_train_data ):
31
- preprocessor_train_data .prepare_to_model (target_col = 'income' , to_strip = ' .' )
32
- evaluator = ModelEvaluator (metrics_class = BinaryClassificationMetrics )
32
+ df = preprocessor_train_data .train_df .sample (0.1 )
33
+ preprocessor = DataPreprocessor (train_df = df , test_df = df )
34
+ preprocessor .prepare_to_model (target_col = 'income' , to_strip = ' .' )
35
+
33
36
models = ModelsContainer ()
34
- models .fit (preprocessor_train_data .train_encoded_df , kind = ModelKinds .CLASSIFICATION )
35
- evaluator .compare ({"train" : preprocessor_train_data .train_encoded_df }, models = models .fitted_models )
36
- print ('kk' )
37
+ models .fit (preprocessor .train_encoded_df , kind = ModelKinds .CLASSIFICATION )
38
+ expected_results = [
39
+ {"model" : models .logistic_class .fitted_model ,
40
+ "metrics" : {"areaUnderROC" : 0.770414 , "areaUnderPR" : 0.646093 }, },
41
+ {"model" : models .random_forest_class .fitted_model ,
42
+ "metrics" : {"areaUnderROC" : 0.674751 , "areaUnderPR" : 0.664931 }, },
43
+ {"model" : models .gbt_class .fitted_model ,
44
+ "metrics" : {"areaUnderROC" : 0.811643 , "areaUnderPR" : 0.746147 }, },
45
+ {"model" : models .svm_class .fitted_model ,
46
+ "metrics" : {"areaUnderROC" : 0.750627 , "areaUnderPR" : 0.645328 }, },
47
+ {"model" : models .naive_bayes_class .fitted_model ,
48
+ "metrics" : {"areaUnderROC" : 0.615000 , "areaUnderPR" : 0.504709 }, },
49
+ ]
50
+ for result in expected_results :
51
+ _check_evaluation (preprocessor = preprocessor , model = result ["model" ], metrics = result ["metrics" ])
37
52
38
53
39
54
def _check_evaluation (preprocessor , model , metrics : Dict [str , float ]):
40
- metrics_class = BinaryClassificationMetrics
41
- evaluator = ModelEvaluator (metrics_class = metrics_class )
55
+ evaluator = ModelEvaluator (metrics_class = BinaryClassificationMetrics )
42
56
# The purpose of this parameter is to prove names can be arbitrary in the compare method
43
57
dataframes_sets = [['train' , 'test' ], ['train1' , 'test1' ]]
44
58
for dataframes in dataframes_sets :
@@ -51,4 +65,5 @@ def _check_evaluation(preprocessor, model, metrics: Dict[str, float]):
51
65
for metric in metrics :
52
66
assert metric in comparison
53
67
for dataframe in dataframes :
54
- assert comparison [metric ][evaluator .index_key (dataframe , model )] == pytest .approx (metrics [metric ])
68
+ assert comparison [metric ][evaluator .index_key (dataframe , model )] == pytest .approx (metrics [metric ],
69
+ abs = 0.035 )
0 commit comments