Skip to content

Commit

Permalink
Merge pull request #114 from amosproj/refactor/64-de-normalization-tests
Browse files Browse the repository at this point in the history
refactor/64 de normalization tests
  • Loading branch information
dh1542 authored Jan 8, 2025
2 parents 0277a5c + 869db2f commit e7eac02
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from abc import abstractmethod
from pyspark.sql import DataFrame as PySparkDataFrame
from typing import List
from pyspark.sql.types import DoubleType, StructField, StructType
from ....input_validator import InputValidator
from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.interfaces import (
DataManipulationBaseInterface,
Expand Down Expand Up @@ -67,15 +68,15 @@ class NormalizationBaseClass(DataManipulationBaseInterface, InputValidator):
def __init__(
self, df: PySparkDataFrame, column_names: List[str], in_place: bool = False
) -> None:

for column_name in column_names:
if not column_name in df.columns:
raise ValueError("{} not found in the DataFrame.".format(column_name))

self.df = df
self.column_names = column_names
self.in_place = in_place

EXPECTED_SCHEMA = StructType(
[StructField(column_name, DoubleType()) for column_name in column_names]
)
self.validate(EXPECTED_SCHEMA)

@staticmethod
def system_type():
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pandas.io.formats.format import math
import pytest
import os

from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame
Expand Down Expand Up @@ -49,6 +51,19 @@ def test_nonexistent_column_normalization(spark_session: SparkSession):
NormalizationMean(input_df, column_names=["NonexistingColumn"], in_place=True)


def test_wrong_column_type_normalization(spark_session: SparkSession):
input_df = spark_session.createDataFrame(
[
("a",),
("b",),
],
["Value"],
)

with pytest.raises(ValueError):
NormalizationMean(input_df, column_names=["Value"])


def test_non_inplace_normalization(spark_session: SparkSession):
input_df = spark_session.createDataFrame(
[
Expand Down Expand Up @@ -105,8 +120,6 @@ def test_idempotence_with_positive_values(
expected_df = input_df.alias("input_df")
helper_assert_idempotence(class_to_test, input_df, expected_df)

class_to_test(input_df, column_names=["Value"], in_place=True)


@pytest.mark.parametrize("class_to_test", NormalizationBaseClass.__subclasses__())
def test_idempotence_with_zero_values(
Expand All @@ -127,6 +140,21 @@ def test_idempotence_with_zero_values(
helper_assert_idempotence(class_to_test, input_df, expected_df)


@pytest.mark.parametrize("class_to_test", NormalizationBaseClass.__subclasses__())
def test_idempotence_with_large_data_set(
spark_session: SparkSession, class_to_test: NormalizationBaseClass
):
base_path = os.path.dirname(__file__)
file_path = os.path.join(base_path, "../../test_data.csv")
input_df = spark_session.read.option("header", "true").csv(file_path)
input_df = input_df.withColumn("Value", input_df["Value"].cast("double"))
assert input_df.count() > 0, "Dataframe was not loaded correct"
input_df.show()

expected_df = input_df.alias("input_df")
helper_assert_idempotence(class_to_test, input_df, expected_df)


def helper_assert_idempotence(
class_to_test: NormalizationBaseClass,
input_df: DataFrame,
Expand All @@ -145,6 +173,12 @@ def helper_assert_idempotence(

assert expected_df.columns == actual_df.columns
assert expected_df.schema == actual_df.schema
assert expected_df.collect() == actual_df.collect()

for row1, row2 in zip(expected_df.collect(), actual_df.collect()):
for col1, col2 in zip(row1, row2):
if isinstance(col1, float) and isinstance(col2, float):
assert math.isclose(col1, col2, rel_tol=1e-9)
else:
assert col1 == col2
except ZeroDivisionError:
pass

0 comments on commit e7eac02

Please sign in to comment.