Skip to content

Commit 7680a4b

Browse files
committed
Getting factors and numeric columns
1 parent 5eeb4c1 commit 7680a4b

File tree

4 files changed

+61
-21
lines changed

4 files changed

+61
-21
lines changed

data_preprocessor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,19 @@
1+
from typing import List
2+
3+
14
class DataPreprocessor(object):
25
def __init__(self, train_df=None, test_df=None):
36
self.train_df = train_df
47
self.test_df = test_df
8+
9+
10+
def get_factors(self):
11+
return self._get_cols_by_types(types=['string'])
12+
13+
14+
def get_numeric_columns(self):
15+
return self._get_cols_by_types(types=['double', 'int'])
16+
17+
18+
def _get_cols_by_types(self, types: List[str] = None):
19+
return [col for col, data_type in self.train_df.dtypes if data_type in types]

settings.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
1+
# Adult data set configs
2+
ADULT_COLUMN_NAMES = ["age",
3+
"workclass",
4+
"fnlwgt",
5+
"education",
6+
"education_num",
7+
"marital_status",
8+
"occupation",
9+
"relationship",
10+
"race",
11+
"sex",
12+
"capital_gain",
13+
"capital_loss",
14+
"hours_per_week",
15+
"native_country",
16+
"income"]
17+
18+
ADULT_TRAIN_DATA = "data/adult.data"
19+
ADULT_TEST_DATA = "data/adult.test"
20+
21+
# Test configurations
122
TEST_DATA_PATH = "data/data_example_for_tests.csv"
2-
COLUMN_NAMES = ["age",
3-
"workclass",
4-
"fnlwgt",
5-
"education",
6-
"education_num",
7-
"marital_status",
8-
"occupation",
9-
"relationship",
10-
"race",
11-
"sex",
12-
"capital_gain",
13-
"capital_loss",
14-
"hours_per_week",
15-
"native_country",
16-
"income"]

tests/test_data_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pyspark.sql import DataFrame
33

44
from data_loader import DataLoader
5-
from settings import TEST_DATA_PATH, COLUMN_NAMES
5+
from settings import TEST_DATA_PATH, ADULT_COLUMN_NAMES
66

77

88
@pytest.fixture
@@ -11,9 +11,9 @@ def data_loader():
1111

1212

1313
def test_data_loader_loads_data_frame(data_loader):
14-
df = data_loader.load_relative(path=TEST_DATA_PATH, columns=COLUMN_NAMES)
14+
df = data_loader.load_relative(path=TEST_DATA_PATH, columns=ADULT_COLUMN_NAMES)
1515
assert isinstance(df, DataFrame)
16-
assert df.columns == COLUMN_NAMES
16+
assert df.columns == ADULT_COLUMN_NAMES
1717
# Check some values from the first row
1818
first_row = df.first()
1919
assert first_row.income == ' <=50K'

tests/test_data_preprocessor.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,33 @@
55

66
from data_loader import DataLoader
77
from data_preprocessor import DataPreprocessor
8-
from settings import TEST_DATA_PATH, COLUMN_NAMES
8+
from settings import TEST_DATA_PATH, ADULT_COLUMN_NAMES
99

1010

1111
@pytest.fixture
1212
def preprocessor():
13-
df = DataLoader().load_relative(path=TEST_DATA_PATH, columns=COLUMN_NAMES)
13+
df = DataLoader().load_relative(path=TEST_DATA_PATH, columns=ADULT_COLUMN_NAMES)
1414
return DataPreprocessor(train_df=df, test_df=df)
1515

1616

17+
def test_preprocessor_get_factors(preprocessor):
18+
# Example 1
19+
factors_example = ['workclass', 'education', 'marital_status', 'occupation', 'relationship', 'race', 'sex',
20+
'native_country', 'income']
21+
assert preprocessor.get_factors() == factors_example
22+
# Example 2 (rename last column)
23+
preprocessor.train_df = preprocessor.train_df.withColumnRenamed('income', 'income2')
24+
factors_example[-1] = "income2"
25+
assert preprocessor.get_factors() == factors_example
26+
27+
28+
def test_preprocessor_get_numeric_columns(preprocessor):
29+
numeric_cols = ['age', 'fnlwgt', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week']
30+
assert preprocessor.get_numeric_columns() == numeric_cols
31+
preprocessor.train_df = preprocessor.train_df.withColumnRenamed('capital_loss', 'capital_super_loss')
32+
numeric_cols[-2] = 'capital_super_loss'
33+
assert preprocessor.get_numeric_columns() == numeric_cols
34+
35+
1736
def test_data_preprocessor_explore_factors(preprocessor):
18-
pass
37+
pass # preprocessor.explore_factors()

0 commit comments

Comments
 (0)