22
22
is_jax_namespace ,
23
23
is_numpy_namespace ,
24
24
is_pydata_sparse_namespace ,
25
+ is_torch_array ,
25
26
is_torch_namespace ,
26
27
to_device ,
27
28
)
@@ -62,18 +63,28 @@ def _check_ns_shape_dtype(
62
63
msg = f"namespaces do not match: { actual_xp } != f{ desired_xp } "
63
64
assert actual_xp == desired_xp , msg
64
65
65
- if check_shape :
66
- actual_shape = actual .shape
67
- desired_shape = desired .shape
68
- if is_dask_namespace (desired_xp ):
69
- # Dask uses nan instead of None for unknown shapes
70
- if any (math .isnan (i ) for i in cast (tuple [float , ...], actual_shape )):
71
- actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
72
- if any (math .isnan (i ) for i in cast (tuple [float , ...], desired_shape )):
73
- desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
66
+ # Dask uses nan instead of None for unknown shapes
67
+ actual_shape = cast (tuple [float , ...], actual .shape )
68
+ desired_shape = cast (tuple [float , ...], desired .shape )
69
+ assert None not in actual_shape # Requires explicit support
70
+ assert None not in desired_shape
71
+ if is_dask_namespace (desired_xp ):
72
+ if any (math .isnan (i ) for i in actual_shape ):
73
+ actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
74
+ if any (math .isnan (i ) for i in desired_shape ):
75
+ desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
74
76
77
+ if check_shape :
75
78
msg = f"shapes do not match: { actual_shape } != f{ desired_shape } "
76
79
assert actual_shape == desired_shape , msg
80
+ else :
81
+ # Ignore shape, but check flattened size. This is normally done by
82
+ # np.testing.assert_array_equal etc even when strict=False, but not for
83
+ # non-materializable arrays.
84
+ actual_size = math .prod (actual_shape ) # pyright: ignore[reportUnknownArgumentType]
85
+ desired_size = math .prod (desired_shape ) # pyright: ignore[reportUnknownArgumentType]
86
+ msg = f"sizes do not match: { actual_size } != f{ desired_size } "
87
+ assert actual_size == desired_size , msg
77
88
78
89
if check_dtype :
79
90
msg = f"dtypes do not match: { actual .dtype } != { desired .dtype } "
@@ -90,6 +101,17 @@ def _check_ns_shape_dtype(
90
101
return desired_xp
91
102
92
103
104
+ def _is_materializable (x : Array ) -> bool :
105
+ """
106
+ Check if the array is materializable, e.g. `as_numpy_array` can be called on it
107
+ and one can assume that `__dlpack__` will succeed (if implemented, and given a
108
+ compatible device).
109
+ """
110
+ # Important: here we assume that we're not tracing -
111
+ # e.g. we're not inside `jax.jit`` nor `cupy.cuda.Stream.begin_capture`.
112
+ return not is_torch_array (x ) or x .device .type != "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
113
+
114
+
93
115
def as_numpy_array (array : Array , * , xp : ModuleType ) -> np .typing .NDArray [Any ]: # type: ignore[explicit-any]
94
116
"""
95
117
Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
@@ -100,11 +122,7 @@ def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]:
100
122
return array .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
101
123
102
124
if is_torch_namespace (xp ):
103
- if array .device .type == "meta" : # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
104
- # Can't materialize; generate dummy data instead
105
- array = xp .zeros_like (array , device = "cpu" )
106
- else :
107
- array = to_device (array , "cpu" )
125
+ array = to_device (array , "cpu" )
108
126
if is_array_api_strict_namespace (xp ):
109
127
cpu : Device = xp .Device ("CPU_DEVICE" )
110
128
array = to_device (array , cpu )
@@ -150,6 +168,8 @@ def xp_assert_equal(
150
168
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
151
169
"""
152
170
xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
171
+ if not _is_materializable (actual ):
172
+ return
153
173
actual_np = as_numpy_array (actual , xp = xp )
154
174
desired_np = as_numpy_array (desired , xp = xp )
155
175
np .testing .assert_array_equal (actual_np , desired_np , err_msg = err_msg )
@@ -185,6 +205,8 @@ def xp_assert_less(
185
205
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
186
206
"""
187
207
xp = _check_ns_shape_dtype (x , y , check_dtype , check_shape , check_scalar )
208
+ if not _is_materializable (x ):
209
+ return
188
210
x_np = as_numpy_array (x , xp = xp )
189
211
y_np = as_numpy_array (y , xp = xp )
190
212
np .testing .assert_array_less (x_np , y_np , err_msg = err_msg )
@@ -233,6 +255,8 @@ def xp_assert_close(
233
255
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
234
256
"""
235
257
xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
258
+ if not _is_materializable (actual ):
259
+ return
236
260
237
261
if rtol is None :
238
262
if xp .isdtype (actual .dtype , ("real floating" , "complex floating" )):
0 commit comments