Skip to content

Commit ae9ddb0

Browse files
committed
Models container, created first iteration container supporting classification models
1 parent 12a666a commit ae9ddb0

File tree

3 files changed

+69
-2
lines changed

3 files changed

+69
-2
lines changed

models_container.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from enum import Enum
2+
3+
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, GBTClassifier, \
4+
LinearSVC, NaiveBayes
5+
6+
from spark_launcher import SparkLauncher
7+
8+
9+
class ModelKinds(Enum):
10+
CLASSIFICATION = 'classification'
11+
REGRESSION = 'regression'
12+
13+
14+
class Model(object):
15+
def __init__(self, model, name='', kind=ModelKinds.CLASSIFICATION):
16+
self.model = model
17+
self.name = name
18+
self.kind = kind
19+
20+
21+
class ModelsContainer(object):
22+
model_path = "pyspark.ml"
23+
24+
25+
def __init__(self):
26+
self.spark = SparkLauncher()
27+
# Models
28+
self.logistic_class = LogisticRegression(maxIter=20)
29+
self.random_forest_class = RandomForestClassifier(cacheNodeIds=True)
30+
self.gbt_class = GBTClassifier(cacheNodeIds=True)
31+
self.svm_class = LinearSVC(maxIter=20)
32+
self.naive_bayes_class = NaiveBayes()
33+
34+
self._wrap_models()
35+
36+
37+
@property
38+
def classification(self):
39+
return [obj for name, obj in self.__dict__.items()
40+
if getattr(obj, "kind", None) == ModelKinds.CLASSIFICATION]
41+
42+
43+
def _wrap_models(self):
44+
for name, obj in self.__dict__.items():
45+
if self.model_path in str(obj.__class__):
46+
wrapped = Model(model=obj, name=name)
47+
setattr(self, name, wrapped)

tests/test_models_container.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from models_container import ModelsContainer
2+
3+
4+
def test_example_models_init():
5+
ModelsContainer()
6+
7+
8+
def test_get_classification_models():
9+
names = [model.name for model in ModelsContainer().classification]
10+
assert set(names) == {"logistic_class",
11+
"random_forest_class",
12+
"gbt_class",
13+
"svm_class",
14+
"naive_bayes_class"}

todo_list.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,14 @@
1414
- [x] prepare the data frame by applying all transformations
1515
(cleaning, encoding, etc)
1616
- [x] obtain evaluation metrics for a single model
17-
- [ ] fit several classification models
18-
- [ ] compare all classification models
17+
- [ ] **fit and compare several classification models without tuning**
18+
- [ ] create an object container for the models
19+
- [ ] initialize the models with default hyperparameters
20+
- [ ] fit and compare the results with the evaluator
21+
- [ ] fit and compare several classification models with tuning and crossvalidation
22+
- [ ] be able to pass a list of hyperparameters values for each hyperparameter
23+
- [ ] tune and obtain the best hyperparam set per model
24+
- [ ] compare the tuned models with the evaluator
1925
- [ ] prepare data for regression
2026
- [ ] fit regression model(s)
2127
- [ ] obtain the regression metrics and compare the models

0 commit comments

Comments
 (0)