@@ -31,54 +31,61 @@ def mock_dataset() -> xr.DataArray:
31
31
32
32
33
33
@pytest .mark .parametrize (
34
- ["along_dimension" , "expected_output" , "mimic_fn" , "fn_kwargs" ],
34
+ ["along_dimension" , "expected_output" , "mimic_fn" , "fn_args" , " fn_kwargs" ],
35
35
[
36
36
pytest .param (
37
37
"space" ,
38
- data_in_shape (mock_shape ()).sum (axis = 1 ),
39
- sum ,
38
+ np .zeros (mock_shape ()).sum (axis = 1 ),
39
+ lambda x : 0.0 ,
40
+ tuple (),
40
41
{},
41
- id = "Mimic sum " ,
42
+ id = "Zero everything " ,
42
43
),
43
44
pytest .param (
44
45
"space" ,
45
- np .zeros (mock_shape ()).sum (axis = 1 ),
46
- lambda x : 0.0 ,
46
+ data_in_shape (mock_shape ()).sum (axis = 1 ),
47
+ sum ,
48
+ tuple (),
47
49
{},
48
- id = "Zero everything " ,
50
+ id = "Mimic sum " ,
49
51
),
50
52
pytest .param (
51
53
"time" ,
52
54
data_in_shape (mock_shape ()).prod (axis = 0 ),
53
55
np .prod ,
56
+ tuple (),
54
57
{},
55
58
id = "Mimic prod, on non-space dimensions" ,
56
59
),
57
60
pytest .param (
58
61
"space" ,
59
62
5.0 * data_in_shape (mock_shape ()).sum (axis = 1 ),
60
63
lambda x , ** kwargs : kwargs .get ("multiplier" , 1.0 ) * sum (x ),
64
+ tuple (),
61
65
{"multiplier" : 5.0 },
62
66
id = "Preserve kwargs" ,
63
67
),
64
68
pytest .param (
65
69
"space" ,
66
70
data_in_shape (mock_shape ()).sum (axis = 1 ),
67
71
lambda x , ** kwargs : kwargs .get ("multiplier" , 1.0 ) * sum (x ),
72
+ tuple (),
68
73
{},
69
74
id = "Preserve kwargs [fall back on default]" ,
70
75
),
71
76
pytest .param (
72
77
"space" ,
73
78
5.0 * data_in_shape (mock_shape ()).sum (axis = 1 ),
74
79
lambda x , multiplier = 1.0 : multiplier * sum (x ),
75
- {"multiplier" : 5.0 },
80
+ (5 ,),
81
+ {},
76
82
id = "Preserve args" ,
77
83
),
78
84
pytest .param (
79
85
"space" ,
80
86
data_in_shape (mock_shape ()).sum (axis = 1 ),
81
87
lambda x , multiplier = 1.0 : multiplier * sum (x ),
88
+ tuple (),
82
89
{},
83
90
id = "Preserve args [fall back on default]" ,
84
91
),
@@ -89,6 +96,7 @@ def test_make_broadcastable(
89
96
along_dimension : str ,
90
97
expected_output : xr .DataArray ,
91
98
mimic_fn : Callable [Concatenate [Any , KeywordArgs ], Scalar ],
99
+ fn_args : Any ,
92
100
fn_kwargs : Any ,
93
101
) -> None :
94
102
if isinstance (expected_output , np .ndarray ):
@@ -103,7 +111,10 @@ def test_make_broadcastable(
103
111
decorated_fn = make_broadcastable (mimic_fn )
104
112
105
113
decorated_output = decorated_fn (
106
- mock_dataset , broadcast_dimension = along_dimension , ** fn_kwargs
114
+ mock_dataset ,
115
+ * fn_args ,
116
+ broadcast_dimension = along_dimension , # type: ignore
117
+ ** fn_kwargs ,
107
118
)
108
119
109
120
assert decorated_output .shape == expected_output .shape
@@ -113,7 +124,7 @@ def test_make_broadcastable(
113
124
if along_dimension == "space" :
114
125
decorated_fn_space_only = make_broadcastable_over_space (mimic_fn )
115
126
decorated_output_space = decorated_fn_space_only (
116
- mock_dataset , ** fn_kwargs
127
+ mock_dataset , * fn_args , * *fn_kwargs
117
128
)
118
129
119
130
assert decorated_output_space .shape == expected_output .shape
0 commit comments