Skip to content

Commit

Permalink
Finishes basic functionality & adds single test
Browse files Browse the repository at this point in the history
Signed-off-by: Timm638 <[email protected]>
  • Loading branch information
Timm638 committed Nov 5, 2024
1 parent e10b6a2 commit 5132208
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ class Denormalization(WranglerBaseInterface):

def __init__(self, df: PySparkDataFrame, normalization_to_revert: Normalization) -> None:
self.df = df
self.normalization_stage = normalization_to_revert

norm = Normalization(df, 0, )
self.normalization_to_revert = normalization_to_revert

@staticmethod
def system_type():
Expand All @@ -65,5 +63,5 @@ def libraries():
def settings() -> dict:
return {}

def filter(self) -> DataFrame:
def filter(self) -> PySparkDataFrame:
return self.normalization_to_revert.denormalize(self.df)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from enum import Enum

from eth_abi.grammar import normalize
from pyspark.sql import DataFrame as PySparkDataFrame
from pyspark.sql import functions as F
from typing import List
Expand Down Expand Up @@ -91,7 +92,7 @@ def settings() -> dict:
return {}

def filter(self):
pass
return self.normalize()

def normalize(self) -> PySparkDataFrame:
"""
Expand All @@ -112,14 +113,14 @@ def normalize(self) -> PySparkDataFrame:
normalized_df = self._mean_normalize(normalized_df, column)
return normalized_df

def denormalize(self, df) -> PySparkDataFrame:
def denormalize(self, input_df) -> PySparkDataFrame:
"""
Denormalizes the input DataFrame. Intended to be used by the denormalization component.
Parameters:
df (DataFrame): Dataframe containing the current data.
input_df (DataFrame): Dataframe containing the current data.
"""
denormalized_df = self.df
denormalized_df = input_df
if not self.in_place:
for column in self.column_names:
denormalized_df = denormalized_df.drop(self._get_norm_column_name(column))
Expand Down Expand Up @@ -195,7 +196,7 @@ def _min_max_denormalize(self, df: PySparkDataFrame, column: str) -> PySparkData

return df.withColumn(
store_column,
F.col(column) * (F.lit(max_val) - F.lit(min_val)) + F.lit(min_val)
(F.col(column) * (F.lit(max_val) - F.lit(min_val))) + F.lit(min_val)
)

def _mean_normalize(self, df: PySparkDataFrame, column: str) -> PySparkDataFrame:
Expand All @@ -220,12 +221,11 @@ def _mean_denormalize(self, df: PySparkDataFrame, column: str) -> PySparkDataFra
Private method to revert Mean normalization to the specified column.
Mean denormalization: normalized_value * (max - min) + mean = value
"""
mean_val = df.select(F.mean(F.col(column))).collect()[0][0]
min_val = df.select(F.min(F.col(column))).collect()[0][0]
max_val = df.select(F.max(F.col(column))).collect()[0][0]
mean_val = self.reversal_value[0]
min_val = self.reversal_value[1]
max_val = self.reversal_value[2]

store_column = self._get_norm_column_name(column)
self.reversal_value = [mean_val, min_val, max_val]

return df.withColumn(
store_column,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2022 RTDIP
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame

from rtdip_sdk.pipelines.data_wranglers import Normalization, NormalizationMethod, Denormalization


@pytest.fixture(scope="session")
def spark_session():
return SparkSession.builder.master("local[2]").appName("test").getOrCreate()


@pytest.mark.parametrize("method", NormalizationMethod)
def test_idempotence_of_normalization(spark_session: SparkSession, method: NormalizationMethod):
expected_df = spark_session.createDataFrame(
[
(1.0,),
(2.0,),
(3.0,),
(4.0,),
(5.0,),
],
["Value"],
)

df = expected_df.alias('df')

normalization_component = Normalization(df, method, column_names=["Value"], in_place=True)
actual_df = normalization_component.filter()

denormalization_component = Denormalization(actual_df, normalization_component)
actual_df = denormalization_component.filter()

assert isinstance(actual_df, DataFrame)

assert expected_df.columns == actual_df.columns
assert expected_df.schema == actual_df.schema
assert expected_df.collect() == actual_df.collect()

0 comments on commit 5132208

Please sign in to comment.