Skip to content

Commit

Permalink
reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: Christian Munz <[email protected]>
  • Loading branch information
chris-1187 committed Nov 18, 2024
1 parent 0a9fe2f commit db7e33a
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,14 @@ def libraries():
def settings() -> dict:
return {}


@staticmethod
def _impute_missing_values_sp(df) -> PySparkDataFrame:
"""
Imputes missing values by Spline Interpolation
"""
data = np.array(df.select("Value").rdd.flatMap(lambda x: x).collect(), dtype=float)
data = np.array(
df.select("Value").rdd.flatMap(lambda x: x).collect(), dtype=float
)
mask = np.isnan(data)

x_data = np.arange(len(data))
Expand All @@ -142,17 +143,18 @@ def _impute_missing_values_sp(df) -> PySparkDataFrame:
data_imputed[mask] = spline(x_data[mask])
data_imputed_list = data_imputed.tolist()

imputed_rdd = df.rdd.zipWithIndex().map(lambda row: Row(
TagName=row[0][0],
EventTime=row[0][1],
Status=row[0][2],
Value=float(data_imputed_list[row[1]])
))
imputed_rdd = df.rdd.zipWithIndex().map(
lambda row: Row(
TagName=row[0][0],
EventTime=row[0][1],
Status=row[0][2],
Value=float(data_imputed_list[row[1]]),
)
)
imputed_df = imputed_rdd.toDF(df.schema)

return imputed_df


@staticmethod
def _flag_missing_values(df, tolerance_percentage) -> PySparkDataFrame:
"""
Expand All @@ -162,22 +164,38 @@ def _flag_missing_values(df, tolerance_percentage) -> PySparkDataFrame:
window_spec = Window.partitionBy("TagName").orderBy("EventTime")

df = df.withColumn("prev_event_time", F.lag("EventTime").over(window_spec))
df = df.withColumn("time_diff_seconds",
(F.unix_timestamp("EventTime") - F.unix_timestamp("prev_event_time")))
df = df.withColumn(
"time_diff_seconds",
(F.unix_timestamp("EventTime") - F.unix_timestamp("prev_event_time")),
)

df_diff = df.filter(F.col("time_diff_seconds").isNotNull())
interval_counts = df_diff.groupBy("time_diff_seconds").count()
most_frequent_interval = interval_counts.orderBy(F.desc("count")).first()
expected_interval = most_frequent_interval["time_diff_seconds"] if most_frequent_interval else None
expected_interval = (
most_frequent_interval["time_diff_seconds"]
if most_frequent_interval
else None
)

tolerance = (expected_interval * tolerance_percentage) / 100 if expected_interval else 0
tolerance = (
(expected_interval * tolerance_percentage) / 100 if expected_interval else 0
)

existing_timestamps = df.select("TagName", "EventTime").rdd \
.map(lambda row: (row["TagName"], row["EventTime"])).groupByKey().collectAsMap()
existing_timestamps = (
df.select("TagName", "EventTime")
.rdd.map(lambda row: (row["TagName"], row["EventTime"]))
.groupByKey()
.collectAsMap()
)

def generate_missing_timestamps(prev_event_time, event_time, tag_name):
# Check for first row
if prev_event_time is None or event_time is None or expected_interval is None:
if (
prev_event_time is None
or event_time is None
or expected_interval is None
):
return []

# Check against existing timestamps to avoid duplicates
Expand All @@ -196,26 +214,30 @@ def generate_missing_timestamps(prev_event_time, event_time, tag_name):

return missing_timestamps

generate_missing_timestamps_udf = udf(generate_missing_timestamps, ArrayType(TimestampType()))
generate_missing_timestamps_udf = udf(
generate_missing_timestamps, ArrayType(TimestampType())
)

df_with_missing = df.withColumn(
"missing_timestamps",
generate_missing_timestamps_udf("prev_event_time", "EventTime", "TagName")
generate_missing_timestamps_udf("prev_event_time", "EventTime", "TagName"),
)

df_missing_entries = df_with_missing.select(
"TagName",
F.explode("missing_timestamps").alias("EventTime"),
F.lit("Good").alias("Status"),
F.lit(float('nan')).cast(FloatType()).alias("Value")
F.lit(float("nan")).cast(FloatType()).alias("Value"),
)

df_combined = df.select("TagName", "EventTime", "Status", "Value").union(df_missing_entries).orderBy(
"EventTime")
df_combined = (
df.select("TagName", "EventTime", "Status", "Value")
.union(df_missing_entries)
.orderBy("EventTime")
)

return df_combined


@staticmethod
def _is_column_type(df, column_name, data_type):
"""
Expand All @@ -225,12 +247,13 @@ def _is_column_type(df, column_name, data_type):

return isinstance(type_.dataType, data_type)


def filter(self) -> PySparkDataFrame:
"""
Imputate missing values based on [Spline Interpolation, ]
"""
if not all(col_ in self.df.columns for col_ in ["TagName", "EventTime", "Value"]):
if not all(
col_ in self.df.columns for col_ in ["TagName", "EventTime", "Value"]
):
raise ValueError("Columns not as expected")

if not self._is_column_type(self.df, "EventTime", TimestampType):
Expand All @@ -240,8 +263,8 @@ def filter(self) -> PySparkDataFrame:
"EventTime",
F.coalesce(
F.to_timestamp("EventTime", "yyyy-MM-dd HH:mm:ss.SSS"),
F.to_timestamp("EventTime", "dd.MM.yyyy HH:mm:ss")
)
F.to_timestamp("EventTime", "dd.MM.yyyy HH:mm:ss"),
),
)
if not self._is_column_type(self.df, "Value", FloatType):
self.df = self.df.withColumn("Value", self.df["Value"].cast(FloatType()))
Expand All @@ -257,8 +280,9 @@ def filter(self) -> PySparkDataFrame:
# Impute the missing values of flagged entries
imputed_df_sp = self._impute_missing_values_sp(flagged_df)

imputed_df_sp = imputed_df_sp.withColumn("EventTime", col("EventTime").cast("string")) \
.withColumn("Value", col("Value").cast("string"))
imputed_df_sp = imputed_df_sp.withColumn(
"EventTime", col("EventTime").cast("string")
).withColumn("Value", col("Value").cast("string"))

imputed_dfs.append(imputed_df_sp)

Expand All @@ -268,13 +292,15 @@ def filter(self) -> PySparkDataFrame:

return result_df


def _split_by_source(self) -> dict:
"""
Helper method to separate individual time series based on their source
"""
tag_names = self.df.select("TagName").distinct().collect()
tag_names = [row["TagName"] for row in tag_names]
source_dict = {tag: self.df.filter(col("TagName") == tag).orderBy("EventTime") for tag in tag_names}
source_dict = {
tag: self.df.filter(col("TagName") == tag).orderBy("EventTime")
for tag in tag_names
}

return source_dict
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,37 @@ def spark_session():

def test_missing_value_imputation(spark_session: SparkSession):

schema = StructType([
StructField("TagName", StringType(), True),
StructField("EventTime", StringType(), True),
StructField("Status", StringType(), True),
StructField("Value", StringType(), True)
])
schema = StructType(
[
StructField("TagName", StringType(), True),
StructField("EventTime", StringType(), True),
StructField("Status", StringType(), True),
StructField("Value", StringType(), True),
]
)

test_data = [
("A2PS64V0J.:ZUX09R", "2024-01-01 03:29:21.000", "Good", "1.0"),
("A2PS64V0J.:ZUX09R", "2024-01-01 07:32:55.000", "Good", "2.0"),
("A2PS64V0J.:ZUX09R", "2024-01-01 11:36:29.000", "Good", "3.0"),
("A2PS64V0J.:ZUX09R", "2024-01-01 15:39:03.000", "Good", "4.0"),
("A2PS64V0J.:ZUX09R", "2024-01-01 19:42:37.000", "Good", "5.0"),
#("A2PS64V0J.:ZUX09R", "2024-01-01 23:46:11.000", "Good", "6.0"), # Test values
# ("A2PS64V0J.:ZUX09R", "2024-01-01 23:46:11.000", "Good", "6.0"), # Test values
("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:45.000", "Good", "7.0"),
("A2PS64V0J.:ZUX09R", "2024-01-02 07:53:11.000", "Good", "8.0"),
("A2PS64V0J.:ZUX09R", "2024-01-02 11:56:42.000", "Good", "9.0"),
("A2PS64V0J.:ZUX09R", "2024-01-02 16:00:12.000", "Good", "10.0"),
("A2PS64V0J.:ZUX09R", "2024-01-02 20:13:46.000", "Good", "11.0"), # Tolerance Test
(
"A2PS64V0J.:ZUX09R",
"2024-01-02 20:13:46.000",
"Good",
"11.0",
), # Tolerance Test
("A2PS64V0J.:ZUX09R", "2024-01-03 00:07:20.000", "Good", "12.0"),
#("A2PS64V0J.:ZUX09R", "2024-01-03 04:10:54.000", "Good", "13.0"),
#("A2PS64V0J.:ZUX09R", "2024-01-03 08:14:28.000", "Good", "14.0"),
# ("A2PS64V0J.:ZUX09R", "2024-01-03 04:10:54.000", "Good", "13.0"),
# ("A2PS64V0J.:ZUX09R", "2024-01-03 08:14:28.000", "Good", "14.0"),
("A2PS64V0J.:ZUX09R", "2024-01-03 12:18:02.000", "Good", "15.0"),
#("A2PS64V0J.:ZUX09R", "2024-01-03 16:21:36.000", "Good", "16.0"),
# ("A2PS64V0J.:ZUX09R", "2024-01-03 16:21:36.000", "Good", "16.0"),
("A2PS64V0J.:ZUX09R", "2024-01-03 20:25:10.000", "Good", "17.0"),
("A2PS64V0J.:ZUX09R", "2024-01-04 00:28:44.000", "Good", "18.0"),
("A2PS64V0J.:ZUX09R", "2024-01-04 04:32:18.000", "Good", "19.0"),
Expand Down Expand Up @@ -262,7 +269,9 @@ def test_missing_value_imputation(spark_session: SparkSession):
assert expected_df.columns == actual_df.columns
assert expected_df.schema == actual_df.schema

def assert_dataframe_similar(expected_df, actual_df, tolerance=1e-4, time_tolerance_seconds=5):
def assert_dataframe_similar(
expected_df, actual_df, tolerance=1e-4, time_tolerance_seconds=5
):

expected_df = expected_df.orderBy(["TagName", "EventTime"])
actual_df = actual_df.orderBy(["TagName", "EventTime"])
Expand All @@ -271,21 +280,37 @@ def assert_dataframe_similar(expected_df, actual_df, tolerance=1e-4, time_tolera
actual_df = actual_df.withColumn("Value", col("Value").cast("float"))

for expected_row, actual_row in zip(expected_df.collect(), actual_df.collect()):
for expected_val, actual_val, column_name in zip(expected_row, actual_row, expected_df.columns):
for expected_val, actual_val, column_name in zip(
expected_row, actual_row, expected_df.columns
):
if column_name == "Value":
assert abs(expected_val - actual_val) < tolerance, f"Value mismatch: {expected_val} != {actual_val}"
assert (
abs(expected_val - actual_val) < tolerance
), f"Value mismatch: {expected_val} != {actual_val}"
elif column_name == "EventTime":
expected_event_time = unix_timestamp(col("EventTime")).cast("timestamp")
actual_event_time = unix_timestamp(col("EventTime")).cast("timestamp")
expected_event_time = unix_timestamp(col("EventTime")).cast(
"timestamp"
)
actual_event_time = unix_timestamp(col("EventTime")).cast(
"timestamp"
)

time_diff = A(expected_event_time.cast("long") - actual_event_time.cast("long"))
time_diff = A(
expected_event_time.cast("long")
- actual_event_time.cast("long")
)
condition = time_diff <= time_tolerance_seconds

mismatched_rows = expected_df.join(actual_df, on=["TagName", "EventTime"], how="inner") \
.filter(~condition)
mismatched_rows = expected_df.join(
actual_df, on=["TagName", "EventTime"], how="inner"
).filter(~condition)

assert mismatched_rows.count() == 0, f"EventTime mismatch: {expected_val} != {actual_val} (tolerance: {time_tolerance_seconds}s)"
assert (
mismatched_rows.count() == 0
), f"EventTime mismatch: {expected_val} != {actual_val} (tolerance: {time_tolerance_seconds}s)"
else:
assert expected_val == actual_val, f"Mismatch in column '{column_name}': {expected_val} != {actual_val}"
assert (
expected_val == actual_val
), f"Mismatch in column '{column_name}': {expected_val} != {actual_val}"

assert_dataframe_similar(expected_df, actual_df, tolerance=1e-4)

0 comments on commit db7e33a

Please sign in to comment.