Skip to content

Commit c180a80

Browse files
authored
Refactor array-loading methods, add tests (#361)
1 parent a868dba commit c180a80

File tree

3 files changed

+146
-25
lines changed

3 files changed

+146
-25
lines changed

Diff for: rtree/index.py

+25-14
Original file line numberDiff line numberDiff line change
@@ -1061,12 +1061,16 @@ def intersection_v(self, mins, maxs):
10611061
"""
10621062
import numpy as np
10631063

1064-
assert mins.shape == maxs.shape
1065-
assert mins.strides == maxs.strides
1064+
# Ensure inputs are 2D float arrays
1065+
mins = np.atleast_2d(mins).astype(np.float64).copy()
1066+
maxs = np.atleast_2d(maxs).astype(np.float64).copy()
10661067

1067-
# Cast
1068-
mins = mins.astype(np.float64)
1069-
maxs = maxs.astype(np.float64)
1068+
if mins.ndim != 2 or maxs.ndim != 2:
1069+
raise ValueError("mins/maxs must have 2 dimensions: (n, d)")
1070+
if mins.shape != maxs.shape:
1071+
raise ValueError("mins and maxs shapes not equal")
1072+
if mins.strides != maxs.strides:
1073+
raise ValueError("mins and maxs strides not equal")
10701074

10711075
# Extract counts
10721076
n, d = mins.shape
@@ -1109,6 +1113,7 @@ def nearest_v(
11091113
self,
11101114
mins,
11111115
maxs,
1116+
*,
11121117
num_results=1,
11131118
max_dists=None,
11141119
strict=False,
@@ -1144,12 +1149,16 @@ def nearest_v(
11441149
"""
11451150
import numpy as np
11461151

1147-
assert mins.shape == maxs.shape
1148-
assert mins.strides == maxs.strides
1152+
# Ensure inputs are 2D float arrays
1153+
mins = np.atleast_2d(mins).astype(np.float64).copy()
1154+
maxs = np.atleast_2d(maxs).astype(np.float64).copy()
11491155

1150-
# Cast
1151-
mins = mins.astype(np.float64)
1152-
maxs = maxs.astype(np.float64)
1156+
if mins.ndim != 2 or maxs.ndim != 2:
1157+
raise ValueError("mins/maxs must have 2 dimensions: (n, d)")
1158+
if mins.shape != maxs.shape:
1159+
raise ValueError("mins and maxs shapes not equal")
1160+
if mins.strides != maxs.strides:
1161+
raise ValueError("mins and maxs strides not equal")
11531162

11541163
# Extract counts
11551164
n, d = mins.shape
@@ -1164,9 +1173,11 @@ def nearest_v(
11641173
offn, offi = 0, 0
11651174

11661175
if max_dists is not None:
1167-
assert len(max_dists) == n
1168-
1169-
dists = max_dists.astype(np.float64).copy()
1176+
dists = np.atleast_1d(max_dists).astype(np.float64).copy()
1177+
if dists.ndim != 1:
1178+
raise ValueError("max_dists must have 1 dimension")
1179+
if len(dists) != n:
1180+
raise ValueError(f"max_dists must have length {n}")
11701181
elif return_max_dists:
11711182
dists = np.zeros(n)
11721183
else:
@@ -1189,7 +1200,7 @@ def nearest_v(
11891200
ctypes.byref(nr),
11901201
)
11911202

1192-
# If we got the expected nuber of results then return
1203+
# If we got the expected number of results then return
11931204
if nr.value == n - offn:
11941205
if return_max_dists:
11951206
return ids[: counts.sum()], counts, dists

Diff for: tests/common.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Common test functions."""
22

3+
import pytest
4+
35
from rtree.core import rt
46

57
sidx_version_string = rt.SIDX_Version().decode()
68
sidx_version = tuple(map(int, sidx_version_string.split(".", maxsplit=3)[:3]))
9+
10+
skip_sidx_lt_210 = pytest.mark.skipif(sidx_version < (2, 1, 0), reason="SIDX < 2.1.0")

Diff for: tests/test_index.py

+117-11
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from rtree import core, index
1515
from rtree.exceptions import RTreeError
1616

17+
from .common import skip_sidx_lt_210
18+
1719

1820
class IndexTestCase(unittest.TestCase):
1921
def setUp(self) -> None:
@@ -268,6 +270,26 @@ def test_double_insertion(self) -> None:
268270

269271
self.assertEqual([1, 1], list(idx.intersection((0, 0, 5, 5))))
270272

273+
@skip_sidx_lt_210
274+
def test_intersection_v(self) -> None:
275+
mins = np.array([[0, 1]] * 2).T
276+
maxs = np.array([[60, 50]] * 2).T
277+
ret = self.idx.intersection_v(mins, maxs)
278+
assert type(ret) is tuple
279+
ids, counts = ret
280+
assert ids.dtype == np.int64
281+
ids0 = [0, 4, 16, 27, 35, 40, 47, 50, 76, 80]
282+
ids1 = [0, 16, 27, 35, 47, 76]
283+
assert ids.tolist() == ids0 + ids1
284+
assert counts.dtype == np.uint64
285+
assert counts.tolist() == [len(ids0), len(ids1)]
286+
287+
# errors
288+
with pytest.raises(ValueError, match="must have 2 dimensions"):
289+
self.idx.intersection_v(np.ones((2, 3, 4)), 4)
290+
with pytest.raises(ValueError, match="shapes not equal"):
291+
self.idx.intersection_v([0], [10, 12])
292+
271293

272294
class TestIndexIntersectionUnion:
273295
@pytest.fixture(scope="class")
@@ -314,6 +336,17 @@ def test_intersection_interleaved(
314336
else:
315337
assert False
316338

339+
@skip_sidx_lt_210
340+
def test_intersection_v_interleaved(
341+
self, index_a_interleaved: index.Index, index_b_interleaved: index.Index
342+
) -> None:
343+
index_c_interleaved = index_a_interleaved & index_b_interleaved
344+
mins = index_c_interleaved.bounds[0:2]
345+
maxs = index_c_interleaved.bounds[2:4]
346+
idxs, counts = index_c_interleaved.intersection_v(mins, maxs)
347+
assert idxs.tolist() == [0, 1]
348+
assert counts.tolist() == [2]
349+
317350
def test_intersection_uninterleaved(
318351
self, index_a_uninterleaved: index.Index, index_b_uninterleaved: index.Index
319352
) -> None:
@@ -330,6 +363,17 @@ def test_intersection_uninterleaved(
330363
else:
331364
assert False
332365

366+
@skip_sidx_lt_210
367+
def test_intersection_v_uninterleaved(
368+
self, index_a_uninterleaved: index.Index, index_b_uninterleaved: index.Index
369+
) -> None:
370+
index_c_uninterleaved = index_a_uninterleaved & index_b_uninterleaved
371+
mins = index_c_uninterleaved.bounds[0::2]
372+
maxs = index_c_uninterleaved.bounds[1::2]
373+
idxs, counts = index_c_uninterleaved.intersection_v(mins, maxs)
374+
assert idxs.tolist() == [0, 1]
375+
assert counts.tolist() == [2]
376+
333377
def test_intersection_mismatch(
334378
self, index_a_interleaved: index.Index, index_b_uninterleaved: index.Index
335379
) -> None:
@@ -617,6 +661,46 @@ def test_nearest_basic(self) -> None:
617661
hits = sorted(idx.nearest((13, 0, 20, 2), 3))
618662
self.assertEqual(hits, [3, 4, 5])
619663

664+
@skip_sidx_lt_210
665+
def test_nearest_v_basic(self) -> None:
666+
mins = np.array([[0, 5]] * 2).T
667+
maxs = np.array([[10, 15]] * 2).T
668+
ret = self.idx.nearest_v(mins, maxs, num_results=3)
669+
assert type(ret) is tuple
670+
ids, counts = ret
671+
assert ids.dtype == np.int64
672+
ids0 = [76, 48, 19]
673+
ids1 = [76, 47, 48]
674+
assert ids.tolist() == ids0 + ids1
675+
assert counts.dtype == np.uint64
676+
assert counts.tolist() == [3, 3]
677+
678+
ret = self.idx.nearest_v(mins, maxs, num_results=3, return_max_dists=True)
679+
assert type(ret) is tuple
680+
ids, counts, max_dists = ret
681+
assert ids.tolist() == ids0 + ids1
682+
assert counts.tolist() == [3, 3]
683+
assert max_dists.dtype == np.float64
684+
np.testing.assert_allclose(max_dists, [7.54938045, 11.05686397])
685+
686+
ret = self.idx.nearest_v(
687+
mins, maxs, num_results=3, max_dists=[10, 10], return_max_dists=True
688+
)
689+
ids, counts, max_dists = ret
690+
assert ids.tolist() == ids0 + ids1[:2]
691+
assert counts.tolist() == [3, 2]
692+
np.testing.assert_allclose(max_dists, [7.54938045, 3.92672575])
693+
694+
# errors
695+
with pytest.raises(ValueError, match="must have 2 dimensions"):
696+
self.idx.nearest_v(np.ones((2, 3, 4)), 4)
697+
with pytest.raises(ValueError, match="shapes not equal"):
698+
self.idx.nearest_v([0], [10, 12])
699+
with pytest.raises(ValueError, match="max_dists must have 1 dimension"):
700+
self.idx.nearest_v(maxs, mins, max_dists=[[10]])
701+
with pytest.raises(ValueError, match="max_dists must have length 2"):
702+
self.idx.nearest_v(maxs, mins, max_dists=[10])
703+
620704
def test_nearest_equidistant(self) -> None:
621705
"""Test that if records are equidistant, both are returned."""
622706
point = (0, 0)
@@ -677,25 +761,47 @@ def test_deletion(self) -> None:
677761
self.assertEqual(hits, [])
678762

679763

680-
class IndexMoreDimensions(IndexTestCase):
681-
def test_3d(self) -> None:
682-
"""Test we make and query a 3D index"""
764+
class Index3d(IndexTestCase):
765+
"""Test we make and query a 3D index"""
766+
767+
def setUp(self) -> None:
683768
p = index.Property()
684769
p.dimension = 3
685-
idx = index.Index(properties=p, interleaved=False)
686-
idx.insert(1, (0, 0, 60, 60, 22, 22.0))
687-
hits = idx.intersection((-1, 1, 58, 62, 22, 24))
770+
self.idx = index.Index(properties=p, interleaved=False)
771+
self.idx.insert(1, (0, 0, 60, 60, 22, 22.0))
772+
self.coords = (-1, 1, 58, 62, 22, 24)
773+
774+
def test_intersection(self) -> None:
775+
hits = self.idx.intersection(self.coords)
688776
self.assertEqual(list(hits), [1])
689777

690-
def test_4d(self) -> None:
691-
"""Test we make and query a 4D index"""
778+
@skip_sidx_lt_210
779+
def test_intersection_v(self) -> None:
780+
idxs, counts = self.idx.intersection_v(self.coords[0::2], self.coords[1::2])
781+
assert idxs.tolist() == [1]
782+
assert counts.tolist() == [1]
783+
784+
785+
class Index4d(IndexTestCase):
786+
"""Test we make and query a 4D index"""
787+
788+
def setUp(self) -> None:
692789
p = index.Property()
693790
p.dimension = 4
694-
idx = index.Index(properties=p, interleaved=False)
695-
idx.insert(1, (0, 0, 60, 60, 22, 22.0, 128, 142))
696-
hits = idx.intersection((-1, 1, 58, 62, 22, 24, 120, 150))
791+
self.idx = index.Index(properties=p, interleaved=False)
792+
self.idx.insert(1, (0, 0, 60, 60, 22, 22.0, 128, 142))
793+
self.coords = (-1, 1, 58, 62, 22, 24, 120, 150)
794+
795+
def test_intersection(self) -> None:
796+
hits = self.idx.intersection(self.coords)
697797
self.assertEqual(list(hits), [1])
698798

799+
@skip_sidx_lt_210
800+
def test_intersection_v(self) -> None:
801+
idxs, counts = self.idx.intersection_v(self.coords[0::2], self.coords[1::2])
802+
assert idxs.tolist() == [1]
803+
assert counts.tolist() == [1]
804+
699805

700806
class IndexStream(IndexTestCase):
701807
def test_stream_input(self) -> None:

0 commit comments

Comments
 (0)