Skip to content

Commit

Permalink
Merge pull request #546 from OpenCOMPES/flash_normalization_fixes-zain
Browse files Browse the repository at this point in the history
allow both timed dataframe formats
  • Loading branch information
zain-sohail authored Jan 10, 2025
2 parents 8d9bcd5 + 87fa1fa commit b3e79c3
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
41 changes: 34 additions & 7 deletions src/sed/loader/flash/buffer_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(
extend_aux=True,
)
self.metadata: dict = {}
self.filter_timed_by_electron: bool = None

def _schema_check(self, files: list[Path], expected_schema_set: set) -> None:
"""
Expand Down Expand Up @@ -159,6 +160,30 @@ def _schema_check(self, files: list[Path], expected_schema_set: set) -> None:
"Please check the configuration file or set force_recreate to True.",
)

def _create_timed_dataframe(self, df: dd.DataFrame) -> dd.DataFrame:
"""Creates the timed dataframe, optionally filtering by electron events.
Args:
df (dd.DataFrame): The input dataframe containing all data
Returns:
dd.DataFrame: The timed dataframe
"""
# Get channels that should be in timed dataframe
timed_channels = self.fill_channels

if self.filter_timed_by_electron:
# Get electron channels to use for filtering
electron_channels = get_channels(self._config, "per_electron")
# Filter rows where electron data exists
df_timed = df.dropna(subset=electron_channels)[timed_channels]
else:
# Take all timed data rows without filtering
df_timed = df[timed_channels]

# Take only first electron per event
return df_timed.loc[:, :, 0]

def _save_buffer_file(self, paths: dict[str, Path]) -> None:
"""
Creates the electron and timed buffer files from the raw H5 file.
Expand All @@ -170,26 +195,24 @@ def _save_buffer_file(self, paths: dict[str, Path]) -> None:
Args:
paths (dict[str, Path]): Dictionary containing the paths to the H5 and buffer files.
"""

# Create a DataFrameCreator instance and the h5 file
# Create a DataFrameCreator instance and get the h5 file
df = DataFrameCreator(config_dataframe=self._config, h5_path=paths["raw"]).df

# forward fill all the non-electron channels
df[self.fill_channels] = df[self.fill_channels].ffill()

# Reset the index of the DataFrame and save both the electron and timed dataframes
# electron resolved dataframe
# Save electron resolved dataframe
electron_channels = get_channels(self._config, "per_electron")
dtypes = get_dtypes(self._config, df.columns.values)
df.dropna(subset=electron_channels).astype(dtypes).reset_index().to_parquet(
paths["electron"],
)

# timed dataframe
# drop the electron channels and only take rows with the first electronId
df_timed = df.dropna(subset=electron_channels)[self.fill_channels].loc[:, :, 0]
# Create and save timed dataframe
df_timed = self._create_timed_dataframe(df)
dtypes = get_dtypes(self._config, df_timed.columns.values)
df_timed.astype(dtypes).reset_index().to_parquet(paths["timed"])

logger.debug(f"Processed {paths['raw'].stem}")

def _save_buffer_files(self, force_recreate: bool, debug: bool) -> None:
Expand Down Expand Up @@ -272,6 +295,7 @@ def process_and_load_dataframe(
suffix: str = "",
debug: bool = False,
remove_invalid_files: bool = False,
filter_timed_by_electron: bool = True,
) -> tuple[dd.DataFrame, dd.DataFrame]:
"""
Runs the buffer file creation process.
Expand All @@ -284,11 +308,14 @@ def process_and_load_dataframe(
force_recreate (bool): Flag to force recreation of buffer files.
suffix (str): Suffix for buffer file names.
debug (bool): Flag to enable debug mode.):
remove_invalid_files (bool): Flag to remove invalid files.
filter_timed_by_electron (bool): Flag to filter timed data by valid electron events.
Returns:
Tuple[dd.DataFrame, dd.DataFrame]: The electron and timed dataframes.
"""
self.fp = BufferFilePaths(self._config, h5_paths, folder, suffix, remove_invalid_files)
self.filter_timed_by_electron = filter_timed_by_electron

if not force_recreate:
schema_set = set(
Expand Down
5 changes: 5 additions & 0 deletions src/sed/loader/flash/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def read_dataframe(
ftype: str = "h5",
metadata: dict = {},
collect_metadata: bool = False,
filter_timed_by_electron: bool = True,
**kwds,
) -> tuple[dd.DataFrame, dd.DataFrame, dict]:
"""
Expand All @@ -319,6 +320,9 @@ def read_dataframe(
ftype (str, optional): The file extension type. Defaults to "h5".
metadata (dict, optional): Additional metadata. Defaults to None.
collect_metadata (bool, optional): Whether to collect metadata. Defaults to False.
filter_timed_by_electron (bool, optional): When True, the timed dataframe will only
contain data points where valid electron events were detected. When False, all
timed data points are included regardless of electron detection. Defaults to True.
Keyword Args:
detector (str, optional): The detector to use. Defaults to "".
Expand Down Expand Up @@ -391,6 +395,7 @@ def read_dataframe(
suffix=detector,
debug=debug,
remove_invalid_files=remove_invalid_files,
filter_timed_by_electron=filter_timed_by_electron,
)

if self.instrument == "wespe":
Expand Down

0 comments on commit b3e79c3

Please sign in to comment.