Skip to content

Commit 92ba767

Browse files
committed
Allow make_broadcastable to preserve old function call behaviour
1 parent a939e29 commit 92ba767

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

movement/utils/broadcasting.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ def make_broadcastable( # noqa: C901
174174
where the ``self`` argument is present only if ``f`` was a class method.
175175
``fr`` applies ``f`` along the ``broadcast_dimension`` of ``data``.
176176
The ``\*args`` and ``\*\*kwargs`` match those passed to ``f``, and retain
177-
the same interpretations and effects on the result.
177+
the same interpretations and effects on the result. If ``data`` provided to
178+
``fr`` is not an ``xarray.DataArray``, it will fall back on the behaviour
179+
of ``f`` (and ignore the ``broadcast_dimension`` argument).
178180
179181
See the docstring of ``make_broadcastable_inner`` in the source code for a
180182
more explicit explanation of the returned decorator.
@@ -268,6 +270,9 @@ def inner_clsmethod( # type: ignore[valid-type]
268270
broadcast_dimension: str = "space",
269271
**kwargs: KeywordArgs.kwargs,
270272
) -> xr.DataArray:
273+
# Preserve original functionality
274+
if not isinstance(data, xr.DataArray):
275+
return f(self, data, *args, **kwargs)
271276
return apply_along_da_axis(
272277
lambda input_1D: f(self, input_1D, *args, **kwargs),
273278
data,
@@ -297,6 +302,9 @@ def inner( # type: ignore[valid-type]
297302
broadcast_dimension: str = "space",
298303
**kwargs: KeywordArgs.kwargs,
299304
) -> xr.DataArray:
305+
# Preserve original functionality
306+
if not isinstance(data, xr.DataArray):
307+
return f(data, *args, **kwargs)
300308
return apply_along_da_axis(
301309
lambda input_1D: f(input_1D, *args, **kwargs),
302310
data,

tests/test_unit/test_make_broadcastable.py

+22
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,25 @@ def two_to_some(xy_pair) -> np.ndarray:
241241
else:
242242
assert d in mock_dataset.dims
243243
assert len(output[d]) == len(mock_dataset[d])
244+
245+
246+
def test_retain_underlying_function() -> None:
247+
value_for_arg = 5.0
248+
value_for_kwarg = 7.0
249+
value_for_simple_input = [0.0, 1.0, 2.0]
250+
251+
def simple_function(input_1D, arg, kwarg=3.0):
252+
return arg * sum(input_1D) + kwarg
253+
254+
@make_broadcastable()
255+
def simple_function_broadcastable(input_1D, arg, kwarg=3.0):
256+
return simple_function(input_1D, arg, kwarg=kwarg)
257+
258+
result_from_broadcastable = simple_function_broadcastable(
259+
value_for_simple_input, value_for_arg, kwarg=value_for_kwarg
260+
)
261+
result_from_original = simple_function(
262+
value_for_simple_input, value_for_arg, kwarg=value_for_kwarg
263+
)
264+
assert isinstance(result_from_broadcastable, float)
265+
assert np.isclose(result_from_broadcastable, result_from_original)

0 commit comments

Comments
 (0)