Skip to content

Commit 01dcf27

Browse files
committed
Avoid cast to float when operating on uint64 indices in a SparseNdarray.
1 parent aa50f4c commit 01dcf27

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

src/delayedarray/SparseNdarray.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1207,7 +1207,7 @@ def _extract_sparse_vector_to_dense(indices: numpy.ndarray, values: numpy.ndarra
12071207
pass
12081208
elif subset_summary.consecutive:
12091209
start_pos = 0
1210-
first = subset_summary.first_index
1210+
first = indices.dtype.type(subset_summary.first_index) # avoid casting of uint64s to floats during subtraction.
12111211
if subset_summary.search_first:
12121212
start_pos = bisect_left(indices, first)
12131213

@@ -1276,7 +1276,7 @@ def _extract_sparse_vector_to_sparse(indices: numpy.ndarray, values: numpy.ndarr
12761276

12771277
elif subset_summary.consecutive:
12781278
start_pos = 0
1279-
first = subset_summary.first_index
1279+
first = indices.dtype.type(subset_summary.first_index) # avoid casting of uint64's to floats during subtraction.
12801280
if subset_summary.search_first:
12811281
start_pos = bisect_left(indices, first)
12821282

tests/test_SparseNdarray.py

+13
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,19 @@ def test_SparseNdarray_empty(mask_rate):
365365
assert spout.shape == (3, 4, 5)
366366

367367

368+
def test_SparseNdarray_u64_index():
369+
test_shape = (120, 50)
370+
y = simulate_SparseNdarray(test_shape, mask_rate=0, index_dtype=numpy.uint64)
371+
ref = convert_SparseNdarray_to_numpy(y)
372+
373+
slices = (slice(70, 120),slice(10, 40))
374+
sliced = y[slices]
375+
assert_identical_ndarrays(convert_SparseNdarray_to_numpy(sliced), ref[slices])
376+
377+
dout = delayedarray.extract_dense_array(y, slices2ranges(slices, test_shape))
378+
assert_identical_ndarrays(dout, ref[slices])
379+
380+
368381
#######################################################
369382
#######################################################
370383

tests/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def mock_SparseNdarray_contents(
6868
lower: float = -1,
6969
upper: float = 1,
7070
dtype: numpy.dtype = numpy.dtype("float64"),
71+
index_dtype: numpy.dtype = numpy.dtype("int32"),
7172
mask_rate: float = 0
7273
):
7374
if len(shape) == 1:
@@ -78,7 +79,7 @@ def mock_SparseNdarray_contents(
7879
new_indices.append(i)
7980
new_values.append(random.uniform(lower, upper))
8081

81-
new_indices = numpy.array(new_indices, dtype=numpy.int32)
82+
new_indices = numpy.array(new_indices, dtype=index_dtype)
8283
new_values = numpy.array(new_values, dtype=dtype)
8384
if mask_rate:
8485
new_mask = numpy.random.rand(len(new_values)) < mask_rate
@@ -101,6 +102,7 @@ def mock_SparseNdarray_contents(
101102
lower=lower,
102103
upper=upper,
103104
dtype=dtype,
105+
index_dtype=index_dtype,
104106
mask_rate=mask_rate,
105107
)
106108
)

0 commit comments

Comments
 (0)