Skip to content

Commit a47ee75

Browse files
committed
add explicit shape comparison
1 parent bbe197a commit a47ee75

19 files changed

+169
-88
lines changed

dpnp/dpnp_iface_linearalgebra.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,31 @@
6969

7070

7171
# TODO: implement a specific scalar-array kernel
72-
def _call_multiply(a, b, out=None):
73-
"""Call multiply function for special cases of scalar-array dots."""
72+
def _call_multiply(a, b, out=None, outer_calc=False):
73+
"""
74+
Call multiply function for special cases of scalar-array dots.
75+
76+
if `sc` is an scalar and `a` is an array of type float32, we have
77+
dpnp.multiply(a, sc).dtype == dpnp.float32 and
78+
numpy.multiply(a, sc).dtype == dpnp.float32.
79+
80+
However, for scalar-array dots such as dot function we have
81+
dpnp.dot(a, sc).dtype == dpnp.float32 while
82+
numpy.dot(a, sc).dtype == dpnp.float64.
83+
84+
We need to adjust the behavior of the multiply function when it is
85+
being used for special cases of scalar-array dots.
86+
87+
"""
7488

7589
sc, arr = (a, b) if dpnp.isscalar(a) else (b, a)
7690
sc_dtype = map_dtype_to_device(type(sc), arr.sycl_device)
7791
res_dtype = dpnp.result_type(sc_dtype, arr)
92+
multiply_func = dpnp.multiply.outer if outer_calc else dpnp.multiply
7893
if out is not None and out.dtype == arr.dtype:
79-
res = dpnp.multiply(a, b, out=out)
94+
res = multiply_func(a, b, out=out)
8095
else:
81-
res = dpnp.multiply(a, b, dtype=res_dtype)
96+
res = multiply_func(a, b, dtype=res_dtype)
8297
return dpnp.get_result_array(res, out, casting="no")
8398

8499

@@ -1109,16 +1124,15 @@ def outer(a, b, out=None):
11091124

11101125
dpnp.check_supported_arrays_type(a, b, scalar_type=True, all_scalars=False)
11111126
if dpnp.isscalar(a):
1112-
x1 = a
11131127
x2 = dpnp.ravel(b)[None, :]
1128+
result = _call_multiply(a, x2, out=out, outer_calc=True)
11141129
elif dpnp.isscalar(b):
11151130
x1 = dpnp.ravel(a)[:, None]
1116-
x2 = b
1131+
result = _call_multiply(x1, b, out=out, outer_calc=True)
11171132
else:
1118-
x1 = dpnp.ravel(a)
1119-
x2 = dpnp.ravel(b)
1133+
result = dpnp.multiply.outer(dpnp.ravel(a), dpnp.ravel(b), out=out)
11201134

1121-
return dpnp.multiply.outer(x1, x2, out=out)
1135+
return result
11221136

11231137

11241138
def tensordot(a, b, axes=2):
@@ -1288,13 +1302,13 @@ def vdot(a, b):
12881302
if b.size != 1:
12891303
raise ValueError("The second array should be of size one.")
12901304
a_conj = numpy.conj(a)
1291-
return _call_multiply(a_conj, b)
1305+
return dpnp.squeeze(_call_multiply(a_conj, b))
12921306

12931307
if dpnp.isscalar(b):
12941308
if a.size != 1:
12951309
raise ValueError("The first array should be of size one.")
12961310
a_conj = dpnp.conj(a)
1297-
return _call_multiply(a_conj, b)
1311+
return dpnp.squeeze(_call_multiply(a_conj, b))
12981312

12991313
if a.ndim == 1 and b.ndim == 1:
13001314
return dpnp_dot(a, b, out=None, conjugate=True)

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1108,7 +1108,7 @@ def dpnp_multiplication(
11081108
result = dpnp.moveaxis(result, (-2, -1), axes_res)
11091109
elif len(axes_res) == 1:
11101110
result = dpnp.moveaxis(result, (-1,), axes_res)
1111-
return dpnp.ascontiguousarray(result)
1111+
return result
11121112

11131113
return dpnp.asarray(result, order=order)
11141114

dpnp/tests/helper.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,28 @@
1010
from . import config
1111

1212

13+
def _assert_dtype(a_dt, b_dt, check_only_type_kind=False):
14+
if check_only_type_kind:
15+
assert a_dt.kind == b_dt.kind, f"{a_dt.kind} != {b_dt.kind}"
16+
else:
17+
assert a_dt == b_dt, f"{a_dt} != {b_dt}"
18+
19+
20+
def _assert_shape(a, b):
21+
if hasattr(b, "shape"):
22+
assert a.shape == b.shape, f"{a.shape} != {b.shape}"
23+
else:
24+
# numpy output is scalar, then dpnp is 0-D array
25+
assert a.shape == (), f"{a.shape} != ()"
26+
27+
1328
def assert_dtype_allclose(
1429
dpnp_arr,
1530
numpy_arr,
1631
check_type=True,
1732
check_only_type_kind=False,
1833
factor=8,
19-
relative_factor=None,
34+
check_shape=True,
2035
):
2136
"""
2237
Assert DPNP and NumPy array based on maximum dtype resolution of input arrays
@@ -37,10 +52,13 @@ def assert_dtype_allclose(
3752
for all data types supported by DPNP when set to True.
3853
It is effective only when 'check_type' is also set to True.
3954
The parameter `factor` scales the resolution used for comparing the arrays.
55+
The parameter `check_shape`, when True (default), asserts the shape of input arrays is the same.
4056
4157
"""
4258

43-
list_64bit_types = [numpy.float64, numpy.complex128]
59+
if check_shape:
60+
_assert_shape(dpnp_arr, numpy_arr)
61+
4462
is_inexact = lambda x: hasattr(x, "dtype") and dpnp.issubdtype(
4563
x.dtype, dpnp.inexact
4664
)
@@ -57,34 +75,32 @@ def assert_dtype_allclose(
5775
else -dpnp.inf
5876
)
5977
tol = factor * max(tol_dpnp, tol_numpy)
60-
assert_allclose(dpnp_arr.asnumpy(), numpy_arr, atol=tol, rtol=tol)
78+
assert_allclose(dpnp_arr, numpy_arr, atol=tol, rtol=tol, strict=False)
6179
if check_type:
80+
list_64bit_types = [numpy.float64, numpy.complex128]
6281
numpy_arr_dtype = numpy_arr.dtype
6382
dpnp_arr_dtype = dpnp_arr.dtype
6483
dpnp_arr_dev = dpnp_arr.sycl_device
6584

6685
if check_only_type_kind:
67-
assert dpnp_arr_dtype.kind == numpy_arr_dtype.kind
86+
_assert_dtype(dpnp_arr_dtype, numpy_arr_dtype, True)
6887
else:
6988
is_np_arr_f2 = numpy_arr_dtype == numpy.float16
7089

7190
if is_np_arr_f2:
7291
if has_support_aspect16(dpnp_arr_dev):
73-
assert dpnp_arr_dtype == numpy_arr_dtype
92+
_assert_dtype(dpnp_arr_dtype, numpy_arr_dtype)
7493
elif (
7594
numpy_arr_dtype not in list_64bit_types
7695
or has_support_aspect64(dpnp_arr_dev)
7796
):
78-
assert dpnp_arr_dtype == numpy_arr_dtype
97+
_assert_dtype(dpnp_arr_dtype, numpy_arr_dtype)
7998
else:
80-
assert dpnp_arr_dtype.kind == numpy_arr_dtype.kind
99+
_assert_dtype(dpnp_arr_dtype, numpy_arr_dtype, True)
81100
else:
82-
assert_array_equal(dpnp_arr.asnumpy(), numpy_arr)
101+
assert_array_equal(dpnp_arr, numpy_arr, strict=False)
83102
if check_type and hasattr(numpy_arr, "dtype"):
84-
if check_only_type_kind:
85-
assert dpnp_arr.dtype.kind == numpy_arr.dtype.kind
86-
else:
87-
assert dpnp_arr.dtype == numpy_arr.dtype
103+
_assert_dtype(dpnp_arr.dtype, numpy_arr.dtype, check_only_type_kind)
88104

89105

90106
def generate_random_numpy_array(

dpnp/tests/test_arraycreation.py

-4
Original file line numberDiff line numberDiff line change
@@ -952,15 +952,13 @@ def test_ascontiguousarray1(data):
952952
result = dpnp.ascontiguousarray(data)
953953
expected = numpy.ascontiguousarray(data)
954954
assert_dtype_allclose(result, expected)
955-
assert result.shape == expected.shape
956955

957956

958957
@pytest.mark.parametrize("data", [(), 1, (2, 3), [4]])
959958
def test_ascontiguousarray2(data):
960959
result = dpnp.ascontiguousarray(dpnp.array(data))
961960
expected = numpy.ascontiguousarray(numpy.array(data))
962961
assert_dtype_allclose(result, expected)
963-
assert result.shape == expected.shape
964962

965963

966964
@pytest.mark.parametrize(
@@ -970,15 +968,13 @@ def test_asfortranarray1(data):
970968
result = dpnp.asfortranarray(data)
971969
expected = numpy.asfortranarray(data)
972970
assert_dtype_allclose(result, expected)
973-
assert result.shape == expected.shape
974971

975972

976973
@pytest.mark.parametrize("data", [(), 1, (2, 3), [4]])
977974
def test_asfortranarray2(data):
978975
result = dpnp.asfortranarray(dpnp.array(data))
979976
expected = numpy.asfortranarray(numpy.array(data))
980977
assert_dtype_allclose(result, expected)
981-
assert result.shape == expected.shape
982978

983979

984980
def test_meshgrid_raise_error():

dpnp/tests/test_arraypad.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def test_basic(self, mode):
4242
result = dpnp.pad(a_dp, (25, 20), mode=mode)
4343
if mode == "empty":
4444
# omit uninitialized "empty" boundary from the comparison
45-
assert result.shape == expected.shape
4645
assert_equal(result[25:-20], expected[25:-20])
4746
else:
4847
assert_array_equal(result, expected)
@@ -70,7 +69,6 @@ def test_non_contiguous_array(self, mode):
7069
result = dpnp.pad(a_dp, (2, 3), mode=mode)
7170
if mode == "empty":
7271
# omit uninitialized "empty" boundary from the comparison
73-
assert result.shape == expected.shape
7472
assert_equal(result[2:-3, 2:-3], expected[2:-3, 2:-3])
7573
else:
7674
assert_array_equal(result, expected)
@@ -287,10 +285,10 @@ def test_linear_ramp_end_values(self):
287285
"""Ensure that end values are exact."""
288286
a_dp = dpnp.ones(10).reshape(2, 5)
289287
a = dpnp.pad(a_dp, (223, 123), mode="linear_ramp")
290-
assert_equal(a[:, 0], 0.0)
291-
assert_equal(a[:, -1], 0.0)
292-
assert_equal(a[0, :], 0.0)
293-
assert_equal(a[-1, :], 0.0)
288+
assert_equal(a[:, 0], 0.0, strict=False)
289+
assert_equal(a[:, -1], 0.0, strict=False)
290+
assert_equal(a[0, :], 0.0, strict=False)
291+
assert_equal(a[-1, :], 0.0, strict=False)
294292

295293
@pytest.mark.parametrize(
296294
"dtype", [numpy.uint32, numpy.uint64] + get_all_dtypes(no_none=True)
@@ -426,7 +424,6 @@ def test_empty(self):
426424
expected = numpy.pad(a_np, [(2, 3), (3, 1)], "empty")
427425
result = dpnp.pad(a_dp, [(2, 3), (3, 1)], "empty")
428426
# omit uninitialized "empty" boundary from the comparison
429-
assert result.shape == expected.shape
430427
assert_equal(result[2:-3, 3:-1], expected[2:-3, 3:-1])
431428

432429
# Check how padding behaves on arrays with an empty dimension.

dpnp/tests/test_fill.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_fill_strided_array():
3838
expected = dpnp.tile(dpnp.asarray([0, 1], dtype=a.dtype), 50)
3939

4040
b.fill(1)
41-
assert_array_equal(b, 1)
41+
assert_array_equal(b, 1, strict=False)
4242
assert_array_equal(a, expected)
4343

4444

@@ -51,7 +51,7 @@ def test_fill_strided_2d_array(order):
5151
expected[::-2, ::2] = 1
5252

5353
b.fill(1)
54-
assert_array_equal(b, 1)
54+
assert_array_equal(b, 1, strict=False)
5555
assert_array_equal(a, expected)
5656

5757

@@ -60,27 +60,27 @@ def test_fill_memset(order):
6060
a = dpnp.ones((10, 10), dtype="i4", order=order)
6161
a.fill(0)
6262

63-
assert_array_equal(a, 0)
63+
assert_array_equal(a, 0, strict=False)
6464

6565

6666
def test_fill_float_complex_to_int():
6767
a = dpnp.ones((10, 10), dtype="i4")
6868

6969
a.fill(complex(2, 0))
70-
assert_array_equal(a, 2)
70+
assert_array_equal(a, 2, strict=False)
7171

7272
a.fill(float(3))
73-
assert_array_equal(a, 3)
73+
assert_array_equal(a, 3, strict=False)
7474

7575

7676
def test_fill_complex_to_float():
7777
a = dpnp.ones((10, 10), dtype="f4")
7878

7979
a.fill(complex(2, 0))
80-
assert_array_equal(a, 2)
80+
assert_array_equal(a, 2, strict=False)
8181

8282

8383
def test_fill_bool():
8484
a = dpnp.full(5, fill_value=7, dtype="i4")
8585
a.fill(True)
86-
assert_array_equal(a, 1)
86+
assert_array_equal(a, 1, strict=False)

dpnp/tests/test_histogram.py

+32-11
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,8 @@ def test_bins(self, bins):
733733
expected_hist, expected_edges = numpy.histogramdd(v, bins)
734734
result_hist, result_edges = dpnp.histogramdd(iv, bins_dpnp)
735735
assert_allclose(result_hist, expected_hist)
736-
assert_allclose(result_edges, expected_edges)
736+
for x, y in zip(result_edges, expected_edges):
737+
assert_allclose(x, y)
737738

738739
def test_no_side_effects(self):
739740
v = dpnp.array([[1.3, 2.5, 2.3]])
@@ -752,7 +753,8 @@ def test_01d(self, data):
752753
result_hist, result_edges = dpnp.histogramdd(ia)
753754

754755
assert_allclose(result_hist, expected_hist)
755-
assert_allclose(result_edges, expected_edges)
756+
for x, y in zip(result_edges, expected_edges):
757+
assert_allclose(x, y)
756758

757759
def test_3d(self):
758760
a = dpnp.ones((10, 10, 10))
@@ -822,15 +824,21 @@ def test_nan_values(self):
822824
)
823825
result_hist, result_edges = dpnp.histogramdd(ione_nan, bins=[[0, 1]])
824826
assert_allclose(result_hist, expected_hist)
825-
assert_allclose(result_edges, expected_edges)
827+
# dpnp returns both result_hist and result_edges as float64 while
828+
# numpy returns result_hist as float64 but result_edges as int64
829+
for x, y in zip(result_edges, expected_edges):
830+
assert_allclose(x, y, strict=False)
826831

827832
# NaN is not counted
828833
expected_hist, expected_edges = numpy.histogramdd(
829834
all_nan, bins=[[0, 1]]
830835
)
831836
result_hist, result_edges = dpnp.histogramdd(iall_nan, bins=[[0, 1]])
832837
assert_allclose(result_hist, expected_hist)
833-
assert_allclose(result_edges, expected_edges)
838+
# dpnp returns both result_hist and result_edges as float64 while
839+
# numpy returns result_hist as float64 but result_edges as int64
840+
for x, y in zip(result_edges, expected_edges):
841+
assert_allclose(x, y, strict=False)
834842

835843
def test_bins_another_sycl_queue(self):
836844
v = dpnp.arange(7, 12, sycl_queue=dpctl.SyclQueue())
@@ -866,7 +874,10 @@ def test_different_bins_amount(self, bins_count):
866874
expected_hist, expected_edges = numpy.histogramdd(v, bins=[bins_count])
867875
result_hist, result_edges = dpnp.histogramdd(iv, bins=[bins_count])
868876
assert_array_equal(result_hist, expected_hist)
869-
assert_allclose(result_edges, expected_edges)
877+
# for float32 input, dpnp returns both result_hist and result_edges
878+
# as float64 while numpy returns result_hist as float64 but
879+
for x, y in zip(result_edges, expected_edges):
880+
assert_allclose(x, y, strict=False)
870881

871882

872883
class TestHistogram2d:
@@ -1045,8 +1056,10 @@ def test_nan_values(self):
10451056
ione_nan, ione_nan, bins=[[0, 1]] * 2
10461057
)
10471058
assert_allclose(result_hist, expected_hist)
1048-
assert_allclose(result_edges_x, expected_edges_x)
1049-
assert_allclose(result_edges_y, expected_edges_y)
1059+
# dpnp returns both result_hist and result_edges as float64 while
1060+
# numpy returns result_hist as float64 but result_edges as int64
1061+
assert_allclose(result_edges_x, expected_edges_x, strict=False)
1062+
assert_allclose(result_edges_y, expected_edges_y, strict=False)
10501063

10511064
# NaN is not counted
10521065
expected_hist, expected_edges_x, expected_edges_y = numpy.histogram2d(
@@ -1056,8 +1069,10 @@ def test_nan_values(self):
10561069
iall_nan, iall_nan, bins=[[0, 1]] * 2
10571070
)
10581071
assert_allclose(result_hist, expected_hist)
1059-
assert_allclose(result_edges_x, expected_edges_x)
1060-
assert_allclose(result_edges_y, expected_edges_y)
1072+
# dpnp returns both result_hist and result_edges as float64 while
1073+
# numpy returns result_hist as float64 but result_edges as int64
1074+
assert_allclose(result_edges_x, expected_edges_x, strict=False)
1075+
assert_allclose(result_edges_y, expected_edges_y, strict=False)
10611076

10621077
def test_bins_another_sycl_queue(self):
10631078
x = y = dpnp.arange(7, 12, sycl_queue=dpctl.SyclQueue())
@@ -1107,5 +1122,11 @@ def test_different_bins_amount(self, bins_count):
11071122
ix, iy, bins=bins_count
11081123
)
11091124
assert_array_equal(result_hist, expected_hist)
1110-
assert_allclose(result_edges_x, expected_edges_x, rtol=1e-6)
1111-
assert_allclose(result_edges_y, expected_edges_y, rtol=1e-6)
1125+
# dpnp returns both result_hist and result_edges as float64 while
1126+
# numpy returns result_hist as float64 but result_edges as float32
1127+
assert_allclose(
1128+
result_edges_x, expected_edges_x, rtol=1e-6, strict=False
1129+
)
1130+
assert_allclose(
1131+
result_edges_y, expected_edges_y, rtol=1e-6, strict=False
1132+
)

dpnp/tests/test_indexing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def test_indexing_array_negative_strides(self):
292292

293293
slices = (slice(None), dpnp.array([0, 1, 2, 3]))
294294
arr[slices] = 10
295-
assert_array_equal(arr, 10.0)
295+
assert_equal(arr, 10.0, strict=False)
296296

297297

298298
class TestIx:

0 commit comments

Comments
 (0)