Skip to content

Commit 4ddbf91

Browse files
committed
Merge remote-tracking branch 'origin/fix_normed_dtype' into develop
2 parents cf2abc6 + 9962e78 commit 4ddbf91

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

src/sed/binning/binning.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ def normalization_histogram_from_timed_dataframe(
465465
axis: str,
466466
bin_centers: np.ndarray,
467467
time_unit: float,
468+
**kwds,
468469
) -> xr.DataArray:
469470
"""Get a normalization histogram from a timed dataframe.
470471
@@ -475,17 +476,12 @@ def normalization_histogram_from_timed_dataframe(
475476
histogram.
476477
bin_centers (np.ndarray): Bin centers used for binning of the axis.
477478
time_unit (float): Time unit the data frame entries are based on.
479+
**kwds: Additional keyword arguments passed to the bin_dataframe function.
478480
479481
Returns:
480482
xr.DataArray: Calculated normalization histogram.
481483
"""
482-
bins = df[axis].map_partitions(
483-
pd.cut,
484-
bins=bin_centers_to_bin_edges(bin_centers),
485-
)
486-
487-
histogram = df[axis].groupby([bins]).count().compute().values * time_unit
488-
# histogram = bin_dataframe(df, axes=[axis], bins=[bin_centers]) * time_unit
484+
histogram = bin_dataframe(df, axes=[axis], bins=[bin_centers], **kwds) * time_unit
489485

490486
data_array = xr.DataArray(
491487
data=histogram,

src/sed/core/processor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2283,6 +2283,8 @@ def compute(
22832283
)
22842284
# if the axes are named correctly, xarray figures out the normalization correctly
22852285
self._normalized = self._binned / self._normalization_histogram
2286+
# Set datatype of binned data
2287+
self._normalized.data = self._normalized.data.astype(self._binned.data.dtype)
22862288
self._attributes.add(
22872289
self._normalization_histogram.values,
22882290
name="normalization_histogram",
@@ -2375,13 +2377,25 @@ def get_normalization_histogram(
23752377
axis,
23762378
self._binned.coords[axis].values,
23772379
self._config["dataframe"]["timed_dataframe_unit_time"],
2380+
hist_mode=self.config["binning"]["hist_mode"],
2381+
mode=self.config["binning"]["mode"],
2382+
pbar=self.config["binning"]["pbar"],
2383+
n_cores=self.config["core"]["num_cores"],
2384+
threads_per_worker=self.config["binning"]["threads_per_worker"],
2385+
threadpool_api=self.config["binning"]["threadpool_API"],
23782386
)
23792387
else:
23802388
self._normalization_histogram = normalization_histogram_from_timed_dataframe(
23812389
self._timed_dataframe,
23822390
axis,
23832391
self._binned.coords[axis].values,
23842392
self._config["dataframe"]["timed_dataframe_unit_time"],
2393+
hist_mode=self.config["binning"]["hist_mode"],
2394+
mode=self.config["binning"]["mode"],
2395+
pbar=self.config["binning"]["pbar"],
2396+
n_cores=self.config["core"]["num_cores"],
2397+
threads_per_worker=self.config["binning"]["threads_per_worker"],
2398+
threadpool_api=self.config["binning"]["threadpool_API"],
23852399
)
23862400

23872401
return self._normalization_histogram

tests/test_processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,8 @@ def test_compute_with_normalization() -> None:
10081008
processor.binned.data,
10091009
(processor.normalized * processor.normalization_histogram).data,
10101010
)
1011+
# check dtype
1012+
assert processor.normalized.dtype == processor.binned.dtype
10111013
# bin only second dataframe partition
10121014
result2 = processor.compute(
10131015
bins=bins,

0 commit comments

Comments
 (0)