From 27b293fe54caf4849c7617f6a7ab5fdb6d6a536a Mon Sep 17 00:00:00 2001 From: Felipe Trost Date: Mon, 6 Jan 2025 06:34:37 +0100 Subject: [PATCH 1/4] refactor(normalization): use input validator Signed-off-by: Felipe Trost --- .../spark/normalization/normalization.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/normalization/normalization.py b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/normalization/normalization.py index e9e28f6c5..00033cc29 100644 --- a/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/normalization/normalization.py +++ b/src/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/normalization/normalization.py @@ -14,6 +14,7 @@ from abc import abstractmethod from pyspark.sql import DataFrame as PySparkDataFrame from typing import List +from pyspark.sql.types import DoubleType, StructField, StructType from ....input_validator import InputValidator from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.interfaces import ( DataManipulationBaseInterface, @@ -67,15 +68,15 @@ class NormalizationBaseClass(DataManipulationBaseInterface, InputValidator): def __init__( self, df: PySparkDataFrame, column_names: List[str], in_place: bool = False ) -> None: - - for column_name in column_names: - if not column_name in df.columns: - raise ValueError("{} not found in the DataFrame.".format(column_name)) - self.df = df self.column_names = column_names self.in_place = in_place + EXPECTED_SCHEMA = StructType( + [StructField(column_name, DoubleType()) for column_name in column_names] + ) + self.validate(EXPECTED_SCHEMA) + @staticmethod def system_type(): """ From 3321dc66e3f0b65926853b6f7873f42e2ae9f8f2 Mon Sep 17 00:00:00 2001 From: Felipe Trost Date: Mon, 6 Jan 2025 06:36:54 +0100 Subject: [PATCH 2/4] refactor(test/normalization): add tolerance for data frame comparison Signed-off-by: Felipe Trost --- .../data_manipulation/spark/test_normalization.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_normalization.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_normalization.py index c1ee19b22..33bf82e77 100644 --- a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_normalization.py +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_normalization.py @@ -145,6 +145,12 @@ def helper_assert_idempotence( assert expected_df.columns == actual_df.columns assert expected_df.schema == actual_df.schema - assert expected_df.collect() == actual_df.collect() + + for row1, row2 in zip(expected_df.collect(), actual_df.collect()): + for col1, col2 in zip(row1, row2): + if isinstance(col1, float) and isinstance(col2, float): + assert math.isclose(col1, col2, rel_tol=1e-9) + else: + assert col1 == col2 except ZeroDivisionError: pass From 94d335ca64cb81462185a999053001d4b46c15de Mon Sep 17 00:00:00 2001 From: Felipe Trost Date: Mon, 6 Jan 2025 06:37:49 +0100 Subject: [PATCH 3/4] test(normalization): test idempotence with large data set Signed-off-by: Felipe Trost --- .../spark/test_normalization.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_normalization.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_normalization.py index 33bf82e77..283f08982 100644 --- a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_normalization.py +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_normalization.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pandas.io.formats.format import math import pytest +import os from pyspark.sql import SparkSession from pyspark.sql.dataframe import DataFrame @@ -105,8 +107,6 @@ def test_idempotence_with_positive_values( expected_df = input_df.alias("input_df") helper_assert_idempotence(class_to_test, input_df, expected_df) - class_to_test(input_df, column_names=["Value"], in_place=True) - @pytest.mark.parametrize("class_to_test", NormalizationBaseClass.__subclasses__()) def test_idempotence_with_zero_values( @@ -127,6 +127,21 @@ def test_idempotence_with_zero_values( helper_assert_idempotence(class_to_test, input_df, expected_df) +@pytest.mark.parametrize("class_to_test", NormalizationBaseClass.__subclasses__()) +def test_idempotence_with_large_data_set( + spark_session: SparkSession, class_to_test: NormalizationBaseClass +): + base_path = os.path.dirname(__file__) + file_path = os.path.join(base_path, "../../test_data.csv") + input_df = spark_session.read.option("header", "true").csv(file_path) + input_df = input_df.withColumn("Value", input_df["Value"].cast("double")) + assert input_df.count() > 0, "Dataframe was not loaded correct" + input_df.show() + + expected_df = input_df.alias("input_df") + helper_assert_idempotence(class_to_test, input_df, expected_df) + + def helper_assert_idempotence( class_to_test: NormalizationBaseClass, input_df: DataFrame, From 869db2f0743262432eb45e976e16bb6ae82e45dc Mon Sep 17 00:00:00 2001 From: Felipe Trost Date: Mon, 6 Jan 2025 06:40:02 +0100 Subject: [PATCH 4/4] test(normalization): test wrong type Signed-off-by: Felipe Trost --- .../data_manipulation/spark/test_normalization.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_normalization.py b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_normalization.py index 283f08982..27087015e 100644 --- a/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_normalization.py +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_quality/data_manipulation/spark/test_normalization.py @@ -51,6 +51,19 @@ def test_nonexistent_column_normalization(spark_session: SparkSession): NormalizationMean(input_df, column_names=["NonexistingColumn"], in_place=True) +def test_wrong_column_type_normalization(spark_session: SparkSession): + input_df = spark_session.createDataFrame( + [ + ("a",), + ("b",), + ], + ["Value"], + ) + + with pytest.raises(ValueError): + NormalizationMean(input_df, column_names=["Value"]) + + def test_non_inplace_normalization(spark_session: SparkSession): input_df = spark_session.createDataFrame( [