Skip to content
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).

## [Unreleased]

### Changed
- `Data.norm_for_each`: allow for normlization for each combination of a set of variables

### Fixed
- fixed bug where numpy>2.4 broke `data.join`

Expand Down
21 changes: 12 additions & 9 deletions WrightTools/data/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def get_axis(self, hint: str | int | Axis) -> Axis:

def norm_for_each(
self,
var: str | Variable | int,
*vars: str | Variable | int,
channel: str | Channel | int = 0,
new_channel: dict = {},
):
Expand All @@ -755,31 +755,31 @@ def norm_for_each(

Parameters
----------
var : str, int, or WrightTools.data.Variable
variable to apply normalization at each unique point.
*vars : str, int, or WrightTools.data.Variable
variable to apply normalization at each unique point. if many variables are given, normalization occurs across the joint pairs of values
channel : str, int or WrightTools.data.Channel (default 0)
channel to apply normalization. Channel should have more non-trivial dimensions than variable
new_channel : dict
Default is empty, and channel is overwriten with norm values.
If not empty, a new channel will be created.
Fields (e.g. name) can be supplied by supplying a dictionary (consult `Data.create_channel`).


Examples
--------
import WrightTools.datasets as ds
import WrightTools as wt

d = wt.open(ds.wt5.v1p0p1_MoS2_TrEE_movie)
d.norm_for_each("d2", "ai0") # equivalent to d.ai0[:] /= d.ai0[:].max(axis=(0,1))[None, None, :]
d.norm_for_each("d2", channel="ai0") # equivalent to d.ai0[:] /= d.ai0[:].max(axis=(0,1))[None, None, :]

"""
variable = self.get_var(var)
variables = [self.get_var(var) for var in vars]
channel = self.get_channel(channel)
trivial = {i for i, si in enumerate(variable.shape) if si == 1}
joint_shape = [max([v.shape[i] for v in variables]) for i in range(self.ndim)]
trivial = {i for i, si in enumerate(joint_shape) if si == 1}
if not trivial:
raise wt_exceptions.WrightToolsWarning(
f"Variable {variable.natural_name} and Channel {channel.natural_name} have the same shape {variable.shape}. "
f"variable(s) {[var.natural_name for var in variables]} and Channel {channel.natural_name} have the same shape {channel.shape}. "
+ "Produces a ones array channel."
)
# nontrivial = tuple({i for i in range(self.ndim)} - trivial)
Expand All @@ -790,7 +790,10 @@ def norm_for_each(
if not isinstance(new_channel, dict):
new_channel = {}
self.create_channel(
new_channel.pop("name", f"{channel.natural_name}_{variable.natural_name}_norm"),
new_channel.pop(
"name",
f"{channel.natural_name}_{''.join([f'v{self.variable_names.index(v.natural_name)}' for v in variables])}_norm",
),
values=new,
**new_channel,
)
Expand Down
18 changes: 13 additions & 5 deletions tests/data/norm_for_each.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,26 @@
def test_3D():
data = wt.open(datasets.wt5.v1p0p1_MoS2_TrEE_movie)

data.norm_for_each("w1", 0)
data.norm_for_each("w1", channel=0)
assert np.all(data.channels[0][:].max(axis=(0, 2)) == 1)

data.norm_for_each("d2", 0, new_channel={"name": "ai0_d2_norm"})
data.norm_for_each("d2", new_channel={"name": "ai0_d2_norm"})
assert np.all(data.channels[-1][:].max(axis=(0, 1)) == 1)

data.norm_for_each("d1", 0, new_channel=True)
data.norm_for_each("d1", new_channel=True)
assert data.channels[-1].natural_name == "ai0_v6_norm"

data.norm_for_each("d1", 0, new_channel={"name": "ai0_d1_norm1"})
data.norm_for_each("d1", new_channel={"name": "ai0_d1_norm"})
data.channels[0].normalize()
assert np.all(np.isclose(data.ai0_d1_norm[:], data.channels[0][:]))
assert np.all(np.isclose(data.channels[-1][:], data.channels[0][:]))


def test_two_vars():
data = wt.open(datasets.wt5.v1p0p1_MoS2_TrEE_movie)
data.norm_for_each("w1", "d2")
assert np.all(data.channels[0][:].max(axis=0) == 1)


if __name__ == "__main__":
test_3D()
test_two_vars()