Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Minh Khue Tran <[email protected]>
  • Loading branch information
Minh Khue Tran committed Nov 26, 2024
1 parent aae37fd commit 2906037
Showing 1 changed file with 37 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

from pyspark.sql import DataFrame as PySparkDataFrame
from pyspark.sql import functions as F
from ...interfaces import WranglerBaseInterface
from ...interfaces import TransformerInterface
from ...._pipeline_utils.models import Libraries, SystemType


class OneHotEncoding(WranglerBaseInterface):
class OneHotEncoding(TransformerInterface):
"""
Performs One-Hot Encoding on a specified column of a PySpark DataFrame.
Expand Down Expand Up @@ -69,8 +69,42 @@ def libraries():
@staticmethod
def settings() -> dict:
return {}

def pre_transform_validation(self):
"""
Validate the input data before transformation.
- Check if the specified column exists in the DataFrame.
- If no values are provided, check if the distinct values can be computed.
- Ensure the DataFrame is not empty.
"""
if self.df is None or self.df.count() == 0:
raise ValueError("The DataFrame is empty.")

if self.column not in self.df.columns:
raise ValueError(f"Column '{self.column}' does not exist in the DataFrame.")

if not self.values:
distinct_values = [row[self.column] for row in self.df.select(self.column).distinct().collect()]
if not distinct_values:
raise ValueError(f"No distinct values found in column '{self.column}'.")
self.values = distinct_values

def post_transform_validation(self):
"""
Validate the result after transformation.
- Ensure that new columns have been added based on the distinct values.
- Verify the transformed DataFrame contains the expected number of columns.
"""
expected_columns = [f"{self.column}_{value if value is not None else 'None'}" for value in self.values]
missing_columns = [col for col in expected_columns if col not in self.df.columns]

if missing_columns:
raise ValueError(f"Missing columns in the transformed DataFrame: {missing_columns}")

if self.df.count() == 0:
raise ValueError("The transformed DataFrame is empty.")

def filter(self) -> PySparkDataFrame:
def transform(self) -> PySparkDataFrame:
if not self.values:
self.values = [
row[self.column]
Expand Down

0 comments on commit 2906037

Please sign in to comment.