2
2
3
3
from pyspark .ml .classification import LogisticRegression , RandomForestClassifier , GBTClassifier , \
4
4
LinearSVC , NaiveBayes
5
+ from pyspark .sql import DataFrame
5
6
6
7
from spark_launcher import SparkLauncher
7
8
@@ -16,6 +17,7 @@ def __init__(self, model, name='', kind=ModelKinds.CLASSIFICATION):
16
17
self .model = model
17
18
self .name = name
18
19
self .kind = kind
20
+ self .fitted_model = None
19
21
20
22
21
23
class ModelsContainer (object ):
@@ -36,12 +38,40 @@ def __init__(self):
36
38
37
39
@property
38
40
def classification (self ):
39
- return [obj for name , obj in self .__dict__ .items ()
40
- if getattr (obj , "kind" , None ) == ModelKinds .CLASSIFICATION ]
41
+ """Returns the classification models"""
42
+ return self ._get_models_of_kind (kind = ModelKinds .CLASSIFICATION )
43
+
44
+
45
+ def fit (self , data : DataFrame , kind = "*" ):
46
+ """Loops though all models of some kind and generates fitted models"""
47
+ if kind == "*" :
48
+ models = self ._all_models_dict .values ()
49
+ else :
50
+ models = self ._get_models_of_kind (kind )
51
+
52
+ for model in models :
53
+ model .fitted_model = model .model .fit (data )
54
+
55
+
56
+ @property
57
+ def fitted_models (self ):
58
+ return [model .fitted_model for model in self ._all_models_dict .values ()]
41
59
42
60
43
61
def _wrap_models (self ):
44
- for name , obj in self .__dict__ .items ():
45
- if self .model_path in str (obj .__class__ ):
46
- wrapped = Model (model = obj , name = name )
47
- setattr (self , name , wrapped )
62
+ """Wraps the pyspark model in our own Model class that
63
+ provides some metadata and perhaps extra functionality"""
64
+ for name , obj in self ._all_models_dict .items ():
65
+ wrapped = Model (model = obj , name = name )
66
+ setattr (self , name , wrapped )
67
+
68
+
69
+ @property
70
+ def _all_models_dict (self ):
71
+ return {name : obj for name , obj in self .__dict__ .items ()
72
+ if self .model_path in str (obj .__class__ )}
73
+
74
+
75
+ def _get_models_of_kind (self , kind ):
76
+ return [obj for name , obj in self .__dict__ .items ()
77
+ if getattr (obj , "kind" , None ) == kind ]
0 commit comments