Skip to content

Commit

Permalink
Add basic unit tests to module (#3)
Browse files Browse the repository at this point in the history
* Added some basic unit tests
  • Loading branch information
sllynn authored Aug 19, 2020
1 parent 4efff42 commit d9828ce
Show file tree
Hide file tree
Showing 6 changed files with 32,675 additions and 0 deletions.
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
setuptools>=49.2.1
numpy>=1.18.5
Empty file added sparkxgb/testing/__init__.py
Empty file.
28 changes: 28 additions & 0 deletions sparkxgb/testing/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pyspark
from pyspark.sql.session import SparkSession

SPARK_SCALA_MAPPING = {
"2": "2.11",
"3": "2.12"
}

def default_session(conf=None):
spark_major_version = pyspark.__version__[0]
scala_version = SPARK_SCALA_MAPPING[spark_major_version]
mvn_group = "ml.dmlc"
xgb_version = "1.0.0"
xgboost4j_coords = f"{mvn_group}:xgboost4j_{scala_version}:{xgb_version}"
xgboost4j_spark_coords = f"{mvn_group}:xgboost4j-spark_{scala_version}:{xgb_version}"

if conf is None:
conf = {
"spark.jars.packages": ",".join([xgboost4j_coords, xgboost4j_spark_coords])
}

builder = SparkSession.builder.appName("spark-xgboost")
for key, value in conf.items():
builder = builder.config(key, value)

session = builder.getOrCreate()

return session
Empty file added sparkxgb/tests/__init__.py
Empty file.
83 changes: 83 additions & 0 deletions sparkxgb/tests/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import unittest

from pyspark.sql.types import StringType

from sparkxgb.xgboost import XGBoostClassifier, XGBoostClassificationModel
from sparkxgb.testing.utils import default_session
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.evaluation import BinaryClassificationEvaluator


class XGBClassifierTests(unittest.TestCase):
def setUp(self) -> None:
self.spark = default_session()

col_names = [
"age", "workclass", "fnlwgt",
"education", "education-num",
"marital-status", "occupation",
"relationship", "race", "sex",
"capital-gain", "capital-loss",
"hours-per-week", "native-country",
"label"
]

sdf = (
self.spark.read
.csv(path="./data/adult.data", inferSchema=True)
.toDF(*col_names)
.repartition(200)
)

string_columns = [fld.name for fld in sdf.schema.fields if isinstance(fld.dataType, StringType)]
string_col_replacements = [fld + "_ix" for fld in string_columns]
string_column_map = list(zip(string_columns, string_col_replacements))
target = string_col_replacements[-1]
predictors = [fld.name for fld in sdf.schema.fields if
not isinstance(fld.dataType, StringType)] + string_col_replacements[:-1]

si = [StringIndexer(inputCol=fld[0], outputCol=fld[1]) for fld in string_column_map]
va = VectorAssembler(inputCols=predictors, outputCol="features")
pipeline = Pipeline(stages=[*si, va])
fitted_pipeline = pipeline.fit(sdf)

sdf_prepared = fitted_pipeline.transform(sdf)

self.train_sdf, self.test_sdf = sdf_prepared.randomSplit([0.8, 0.2], seed=1337)

def test_binary_classifier(self):

self.spark.sparkContext.setLogLevel("INFO")

xgb_params = dict(
eta=0.1,
maxDepth=2,
missing=0.0,
objective="binary:logistic",
numRound=5,
numWorkers=2
)

xgb = (
XGBoostClassifier(**xgb_params)
.setFeaturesCol("features")
.setLabelCol("label_ix")
)

bce = BinaryClassificationEvaluator(
rawPredictionCol="rawPrediction",
labelCol="label_ix"
)

model = xgb.fit(self.train_sdf.limit(10))
roc = bce.evaluate(model.transform(self.test_sdf))

print(roc)

self.assertIsInstance(model, XGBoostClassificationModel)
self.assertGreater(roc, 0.3)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit d9828ce

Please sign in to comment.