Skip to content

Commit 876e940

Browse files
authored
Add dpnp.ndarray.__contains__ (#2534)
The PR implements `__contains__` method which needs to define the `in` operator for a `dpnp.ndarray`.
1 parent 307fc66 commit 876e940

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414
* Added a new backend routine `syrk` from oneMKL to perform symmetric rank-k update which is used for a specialized matrix multiplication where the result is a symmetric matrix [2509](https://github.com/IntelPython/dpnp/pull/2509)
1515
* Added `timeout-minutes` property to GitHub jobs [#2526](https://github.com/IntelPython/dpnp/pull/2526)
1616
* Added implementation of `dpnp.ndarray.data` and `dpnp.ndarray.data.ptr` attributes [#2521](https://github.com/IntelPython/dpnp/pull/2521)
17+
* Added `dpnp.ndarray.__contains__` method [#2534](https://github.com/IntelPython/dpnp/pull/2534)
1718

1819
### Changed
1920

dpnp/dpnp_array.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ def __bool__(self):
242242
def __complex__(self):
243243
return self._array_obj.__complex__()
244244

245-
# '__contains__',
245+
def __contains__(self, value, /):
246+
r"""Return :math:`\text{value in self}`."""
247+
return (self == value).any()
246248

247249
def __copy__(self):
248250
"""

dpnp/tests/test_ndarray.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,39 @@ def test_attributes(self):
7474
assert_equal(self.two.itemsize, self.two.dtype.itemsize)
7575

7676

77+
@testing.parameterize(*testing.product({"xp": [dpnp, numpy]}))
78+
class TestContains:
79+
def test_basic(self):
80+
a = self.xp.arange(10).reshape((2, 5))
81+
assert 4 in a
82+
assert 20 not in a
83+
84+
def test_broadcast(self):
85+
xp = self.xp
86+
a = xp.arange(6).reshape((2, 3))
87+
assert 4 in a
88+
assert xp.array([0, 1, 2]) in a
89+
assert xp.array([5, 3, 4]) not in a
90+
91+
def test_broadcast_error(self):
92+
a = self.xp.arange(10).reshape((2, 5))
93+
with pytest.raises(
94+
ValueError,
95+
match="operands could not be broadcast together with shapes",
96+
):
97+
self.xp.array([1, 2]) in a
98+
99+
def test_strides(self):
100+
xp = self.xp
101+
a = xp.arange(10).reshape((2, 5))
102+
a = a[:, ::2]
103+
assert 4 in a
104+
assert 8 not in a
105+
assert xp.full(a.shape[-1], fill_value=2) in a
106+
assert xp.full_like(a, fill_value=7) in a
107+
assert xp.full_like(a, fill_value=6) not in a
108+
109+
77110
class TestView:
78111
def test_none_dtype(self):
79112
a = numpy.ones((1, 2, 4), dtype=numpy.int32)

0 commit comments

Comments
 (0)