Skip to content

Commit 538d459

Browse files
authored
Support dtype= and copy= arguments in __array__(), as required by numpy 2. (#64)
1 parent 622de03 commit 538d459

File tree

4 files changed

+60
-10
lines changed

4 files changed

+60
-10
lines changed

src/delayedarray/DelayedArray.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,36 @@ def __repr__(self) -> str:
156156
return preamble + "\n" + converted
157157

158158
# For NumPy:
159-
def __array__(self) -> ndarray:
159+
def __array__(self, dtype: Optional[numpy.dtype] = None, copy: bool = True) -> ndarray:
160160
"""Convert a ``DelayedArray`` to a NumPy array, to be used by
161161
:py:meth:`~numpy.array`.
162162
163+
Args:
164+
dtype:
165+
The desired NumPy type of the output array. If None, the
166+
type of the seed is used.
167+
168+
copy:
169+
Currently ignored. The output is never a reference to the
170+
underlying seed, even if the seed is another NumPy array.
171+
163172
Returns:
164173
NumPy array of the same type as :py:attr:`~dtype` and shape as
165-
:py:attr:`~shape`.
174+
:py:attr:`~shape`.
166175
"""
167-
return to_dense_array(self._seed)
176+
if dtype is None or dtype == self.dtype:
177+
return to_dense_array(self._seed)
178+
else:
179+
# Filling it chunk by chunk rather than doing a big coercion,
180+
# to avoid creating an unnecessary intermediate full matrix.
181+
output = numpy.ndarray(self.shape, dtype=dtype)
182+
if is_masked(self._seed):
183+
output = numpy.ma.array(output, mask=False)
184+
def fill_output(job, part):
185+
subsets = (*(slice(s, e) for s, e in job),)
186+
output[subsets] = part
187+
apply_over_blocks(self._seed, fill_output)
188+
return output
168189

169190
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> "DelayedArray":
170191
"""Interface with NumPy array methods. This is used to implement

src/delayedarray/SparseNdarray.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,24 @@ def __repr__(self) -> str:
210210

211211

212212
# For NumPy:
213-
def __array__(self) -> numpy.ndarray:
213+
def __array__(self, dtype: Optional[numpy.dtype] = None, copy: bool = True) -> numpy.ndarray:
214214
"""Convert a ``SparseNdarray`` to a NumPy array.
215215
216+
Args:
217+
dtype:
218+
The desired NumPy type of the output array. If None, the
219+
type of the seed is used.
220+
221+
copy:
222+
Currently ignored. The output is never a reference to the
223+
underlying seed, even if the seed is another NumPy array.
224+
216225
Returns:
217226
Dense array of the same type as :py:attr:`~dtype` and shape as
218227
:py:attr:`~shape`.
219228
"""
220229
indices = _spawn_indices(self._shape)
221-
return _extract_dense_array_from_SparseNdarray(self, indices)
230+
return _extract_dense_array_from_SparseNdarray(self, indices, dtype=dtype)
222231

223232
# Assorted dunder methods.
224233
def __add__(self, other) -> Union["SparseNdarray", numpy.ndarray]:
@@ -1231,18 +1240,20 @@ def _recursive_extract_dense_array(contents: numpy.ndarray, subset: Tuple[Sequen
12311240
pos += 1
12321241

12331242

1234-
def _extract_dense_array_from_SparseNdarray(x: SparseNdarray, subset: Tuple[Sequence[int], ...]) -> numpy.ndarray:
1243+
def _extract_dense_array_from_SparseNdarray(x: SparseNdarray, subset: Tuple[Sequence[int], ...], dtype: Optional[numpy.dtype] = None) -> numpy.ndarray:
12351244
idims = [len(y) for y in subset]
12361245
subset_summary = _characterize_indices(subset[0], x._shape[0])
12371246

12381247
# We reverse the dimensions so that we use F-contiguous storage. This also
12391248
# makes it slightly easier to do the recursion as we can just index by
12401249
# the first dimension to obtain a subarray at each recursive step.
1241-
output = numpy.zeros((*reversed(idims),), dtype=x._dtype)
1250+
if dtype is None:
1251+
dtype = x._dtype
1252+
output = numpy.zeros((*reversed(idims),), dtype=dtype)
1253+
if x._is_masked:
1254+
output = numpy.ma.MaskedArray(output, mask=False)
12421255

12431256
if x._contents is not None:
1244-
if x._is_masked:
1245-
output = numpy.ma.MaskedArray(output, mask=False)
12461257
ndim = len(x._shape)
12471258
if ndim > 1:
12481259
_recursive_extract_dense_array(x._contents, subset, subset_summary=subset_summary, output=output, dim=ndim-1)

tests/test_DelayedArray.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ def test_DelayedArray_dense():
1818

1919
dump = numpy.array(x)
2020
assert isinstance(dump, numpy.ndarray)
21+
assert dump.dtype == x.dtype
22+
assert (dump == raw).all()
23+
24+
dump = numpy.array(x, dtype=numpy.float64)
25+
assert isinstance(dump, numpy.ndarray)
26+
assert dump.dtype == numpy.float64
2127
assert (dump == raw).all()
2228

2329

@@ -69,6 +75,16 @@ def test_DelayedArray_masked():
6975
x = delayedarray.wrap(y)
7076
assert delayedarray.is_masked(x)
7177

78+
dump = numpy.array(x)
79+
assert isinstance(dump, numpy.ndarray)
80+
assert dump.dtype == x.dtype
81+
assert (dump == numpy.array(y)).all()
82+
83+
dump = numpy.array(x, dtype=numpy.float32)
84+
assert isinstance(dump, numpy.ndarray)
85+
assert dump.dtype == numpy.float32
86+
assert (dump == numpy.array(y, dtype=numpy.float32)).all()
87+
7288

7389
#######################################################
7490
#######################################################
@@ -356,4 +372,4 @@ def test_SparseNdarray_all_sparse(mask_rate, buffer_size):
356372

357373
# Zero-length array is respected.
358374
y = delayedarray.wrap(delayedarray.SparseNdarray((0,), None, dtype=numpy.int32, index_dtype=numpy.int32)) * 50
359-
assert y.all()
375+
assert y.all()

tests/test_SparseNdarray.py

+2
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ def test_SparseNdarray_extract_dense_array_3d(mask_rate):
188188
# Full extraction.
189189
output = delayedarray.to_dense_array(y)
190190
assert_identical_ndarrays(output, convert_SparseNdarray_to_numpy(y))
191+
assert_identical_ndarrays(numpy.array(output), numpy.array(y))
192+
assert_identical_ndarrays(numpy.array(output, dtype=numpy.int32), numpy.array(y, dtype=numpy.int32))
191193

192194
# Sliced extraction.
193195
slices = (slice(2, 15, 3), slice(0, 20, 2), slice(4, 8))

0 commit comments

Comments
 (0)