Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/fix_normed_dtype' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
rettigl committed Feb 2, 2025
2 parents cf2abc6 + 9962e78 commit 4ddbf91
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
10 changes: 3 additions & 7 deletions src/sed/binning/binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ def normalization_histogram_from_timed_dataframe(
axis: str,
bin_centers: np.ndarray,
time_unit: float,
**kwds,
) -> xr.DataArray:
"""Get a normalization histogram from a timed dataframe.
Expand All @@ -475,17 +476,12 @@ def normalization_histogram_from_timed_dataframe(
histogram.
bin_centers (np.ndarray): Bin centers used for binning of the axis.
time_unit (float): Time unit the data frame entries are based on.
**kwds: Additional keyword arguments passed to the bin_dataframe function.
Returns:
xr.DataArray: Calculated normalization histogram.
"""
bins = df[axis].map_partitions(
pd.cut,
bins=bin_centers_to_bin_edges(bin_centers),
)

histogram = df[axis].groupby([bins]).count().compute().values * time_unit
# histogram = bin_dataframe(df, axes=[axis], bins=[bin_centers]) * time_unit
histogram = bin_dataframe(df, axes=[axis], bins=[bin_centers], **kwds) * time_unit

data_array = xr.DataArray(
data=histogram,
Expand Down
14 changes: 14 additions & 0 deletions src/sed/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2283,6 +2283,8 @@ def compute(
)
# if the axes are named correctly, xarray figures out the normalization correctly
self._normalized = self._binned / self._normalization_histogram
# Set datatype of binned data
self._normalized.data = self._normalized.data.astype(self._binned.data.dtype)
self._attributes.add(
self._normalization_histogram.values,
name="normalization_histogram",
Expand Down Expand Up @@ -2375,13 +2377,25 @@ def get_normalization_histogram(
axis,
self._binned.coords[axis].values,
self._config["dataframe"]["timed_dataframe_unit_time"],
hist_mode=self.config["binning"]["hist_mode"],
mode=self.config["binning"]["mode"],
pbar=self.config["binning"]["pbar"],
n_cores=self.config["core"]["num_cores"],
threads_per_worker=self.config["binning"]["threads_per_worker"],
threadpool_api=self.config["binning"]["threadpool_API"],
)
else:
self._normalization_histogram = normalization_histogram_from_timed_dataframe(
self._timed_dataframe,
axis,
self._binned.coords[axis].values,
self._config["dataframe"]["timed_dataframe_unit_time"],
hist_mode=self.config["binning"]["hist_mode"],
mode=self.config["binning"]["mode"],
pbar=self.config["binning"]["pbar"],
n_cores=self.config["core"]["num_cores"],
threads_per_worker=self.config["binning"]["threads_per_worker"],
threadpool_api=self.config["binning"]["threadpool_API"],
)

return self._normalization_histogram
Expand Down
2 changes: 2 additions & 0 deletions tests/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,8 @@ def test_compute_with_normalization() -> None:
processor.binned.data,
(processor.normalized * processor.normalization_histogram).data,
)
# check dtype
assert processor.normalized.dtype == processor.binned.dtype
# bin only second dataframe partition
result2 = processor.compute(
bins=bins,
Expand Down

0 comments on commit 4ddbf91

Please sign in to comment.