Skip to content

Commit 5e52adf

Browse files
committed
Ignore mypy typehint bugs (it can't cope with additional kwargs being passed)
1 parent 1e7afa1 commit 5e52adf

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

movement/utils/broadcasting.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,16 @@ def make_broadcastable(
4141
4242
"""
4343

44-
def f_along_axis(
45-
data: ArrayLike, axis: int, *f_args, **f_kwargs
46-
) -> np.ndarray:
47-
return np.apply_along_axis(f, axis, data, *f_args, **f_kwargs)
48-
4944
def inner(
5045
data: DataArray,
5146
*args: KeywordArgs.args,
5247
broadcast_dimension: str = "space",
5348
**kwargs: KeywordArgs.kwargs,
5449
) -> DataArray:
55-
return data.reduce(f_along_axis, broadcast_dimension, *args, **kwargs)
50+
def f_along_axis(data: ArrayLike, axis: int) -> np.ndarray:
51+
return np.apply_along_axis(f, axis, data, *args, **kwargs)
52+
53+
return data.reduce(f_along_axis, broadcast_dimension)
5654

5755
return inner
5856

tests/test_unit/test_make_broadcastable.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -31,54 +31,61 @@ def mock_dataset() -> xr.DataArray:
3131

3232

3333
@pytest.mark.parametrize(
34-
["along_dimension", "expected_output", "mimic_fn", "fn_kwargs"],
34+
["along_dimension", "expected_output", "mimic_fn", "fn_args", "fn_kwargs"],
3535
[
3636
pytest.param(
3737
"space",
38-
data_in_shape(mock_shape()).sum(axis=1),
39-
sum,
38+
np.zeros(mock_shape()).sum(axis=1),
39+
lambda x: 0.0,
40+
tuple(),
4041
{},
41-
id="Mimic sum",
42+
id="Zero everything",
4243
),
4344
pytest.param(
4445
"space",
45-
np.zeros(mock_shape()).sum(axis=1),
46-
lambda x: 0.0,
46+
data_in_shape(mock_shape()).sum(axis=1),
47+
sum,
48+
tuple(),
4749
{},
48-
id="Zero everything",
50+
id="Mimic sum",
4951
),
5052
pytest.param(
5153
"time",
5254
data_in_shape(mock_shape()).prod(axis=0),
5355
np.prod,
56+
tuple(),
5457
{},
5558
id="Mimic prod, on non-space dimensions",
5659
),
5760
pytest.param(
5861
"space",
5962
5.0 * data_in_shape(mock_shape()).sum(axis=1),
6063
lambda x, **kwargs: kwargs.get("multiplier", 1.0) * sum(x),
64+
tuple(),
6165
{"multiplier": 5.0},
6266
id="Preserve kwargs",
6367
),
6468
pytest.param(
6569
"space",
6670
data_in_shape(mock_shape()).sum(axis=1),
6771
lambda x, **kwargs: kwargs.get("multiplier", 1.0) * sum(x),
72+
tuple(),
6873
{},
6974
id="Preserve kwargs [fall back on default]",
7075
),
7176
pytest.param(
7277
"space",
7378
5.0 * data_in_shape(mock_shape()).sum(axis=1),
7479
lambda x, multiplier=1.0: multiplier * sum(x),
75-
{"multiplier": 5.0},
80+
(5,),
81+
{},
7682
id="Preserve args",
7783
),
7884
pytest.param(
7985
"space",
8086
data_in_shape(mock_shape()).sum(axis=1),
8187
lambda x, multiplier=1.0: multiplier * sum(x),
88+
tuple(),
8289
{},
8390
id="Preserve args [fall back on default]",
8491
),
@@ -89,6 +96,7 @@ def test_make_broadcastable(
8996
along_dimension: str,
9097
expected_output: xr.DataArray,
9198
mimic_fn: Callable[Concatenate[Any, KeywordArgs], Scalar],
99+
fn_args: Any,
92100
fn_kwargs: Any,
93101
) -> None:
94102
if isinstance(expected_output, np.ndarray):
@@ -103,7 +111,10 @@ def test_make_broadcastable(
103111
decorated_fn = make_broadcastable(mimic_fn)
104112

105113
decorated_output = decorated_fn(
106-
mock_dataset, broadcast_dimension=along_dimension, **fn_kwargs
114+
mock_dataset,
115+
*fn_args,
116+
broadcast_dimension=along_dimension, # type: ignore
117+
**fn_kwargs,
107118
)
108119

109120
assert decorated_output.shape == expected_output.shape
@@ -113,7 +124,7 @@ def test_make_broadcastable(
113124
if along_dimension == "space":
114125
decorated_fn_space_only = make_broadcastable_over_space(mimic_fn)
115126
decorated_output_space = decorated_fn_space_only(
116-
mock_dataset, **fn_kwargs
127+
mock_dataset, *fn_args, **fn_kwargs
117128
)
118129

119130
assert decorated_output_space.shape == expected_output.shape

0 commit comments

Comments
 (0)