Skip to content

Commit aa50f4c

Browse files
committed
Coerce numpy.generic instances to dtypes in SparseNdarray constructor.
This protects against callers using the former instead of the latter, as is done in various numpy functions that accept dtype=.
1 parent 538d459 commit aa50f4c

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

src/delayedarray/SparseNdarray.py

+7
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,13 @@ def __init__(
132132
if is_masked is None:
133133
is_masked = False
134134

135+
# Sometimes people put the numpy data class instead of the dtype.
136+
# It's a common enough mistake that we ought to catch it here.
137+
if not isinstance(dtype, numpy.dtype):
138+
dtype = numpy.dtype(dtype)
139+
if not isinstance(index_dtype, numpy.dtype):
140+
index_dtype = numpy.dtype(index_dtype)
141+
135142
self._dtype = dtype
136143
self._index_dtype = index_dtype
137144
self._is_masked = is_masked

tests/test_SparseNdarray.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_SparseNdarray_check(mask_rate):
117117
contents = mock_SparseNdarray_contents(test_shape, mask_rate=mask_rate)
118118
y = delayedarray.SparseNdarray(test_shape, contents)
119119
assert y.shape == test_shape
120-
assert y.dtype == numpy.float64
120+
assert y.dtype is numpy.dtype("float64")
121121
assert repr(y).find("SparseNdarray") > 0
122122
assert delayedarray.is_sparse(y)
123123
assert delayedarray.is_masked(y) == (mask_rate > 0)
@@ -172,9 +172,13 @@ def shorten(con, depth):
172172

173173
empty = delayedarray.SparseNdarray(test_shape, None, dtype=numpy.dtype("int32"), index_dtype=numpy.dtype("int32"))
174174
assert empty.shape == test_shape
175-
assert empty.dtype == numpy.int32
175+
assert empty.dtype is numpy.dtype("int32")
176176
assert not empty.is_masked
177177

178+
empty = delayedarray.SparseNdarray(test_shape, None, dtype=numpy.float32, index_dtype=numpy.int32) # generics converted to dtypes
179+
assert empty.dtype is numpy.dtype("float32")
180+
assert empty.index_dtype is numpy.dtype("int32")
181+
178182

179183
#######################################################
180184
#######################################################

0 commit comments

Comments
 (0)