14
14
from rtree import core , index
15
15
from rtree .exceptions import RTreeError
16
16
17
+ from .common import skip_sidx_lt_210
18
+
17
19
18
20
class IndexTestCase (unittest .TestCase ):
19
21
def setUp (self ) -> None :
@@ -268,6 +270,26 @@ def test_double_insertion(self) -> None:
268
270
269
271
self .assertEqual ([1 , 1 ], list (idx .intersection ((0 , 0 , 5 , 5 ))))
270
272
273
+ @skip_sidx_lt_210
274
+ def test_intersection_v (self ) -> None :
275
+ mins = np .array ([[0 , 1 ]] * 2 ).T
276
+ maxs = np .array ([[60 , 50 ]] * 2 ).T
277
+ ret = self .idx .intersection_v (mins , maxs )
278
+ assert type (ret ) is tuple
279
+ ids , counts = ret
280
+ assert ids .dtype == np .int64
281
+ ids0 = [0 , 4 , 16 , 27 , 35 , 40 , 47 , 50 , 76 , 80 ]
282
+ ids1 = [0 , 16 , 27 , 35 , 47 , 76 ]
283
+ assert ids .tolist () == ids0 + ids1
284
+ assert counts .dtype == np .uint64
285
+ assert counts .tolist () == [len (ids0 ), len (ids1 )]
286
+
287
+ # errors
288
+ with pytest .raises (ValueError , match = "must have 2 dimensions" ):
289
+ self .idx .intersection_v (np .ones ((2 , 3 , 4 )), 4 )
290
+ with pytest .raises (ValueError , match = "shapes not equal" ):
291
+ self .idx .intersection_v ([0 ], [10 , 12 ])
292
+
271
293
272
294
class TestIndexIntersectionUnion :
273
295
@pytest .fixture (scope = "class" )
@@ -314,6 +336,17 @@ def test_intersection_interleaved(
314
336
else :
315
337
assert False
316
338
339
+ @skip_sidx_lt_210
340
+ def test_intersection_v_interleaved (
341
+ self , index_a_interleaved : index .Index , index_b_interleaved : index .Index
342
+ ) -> None :
343
+ index_c_interleaved = index_a_interleaved & index_b_interleaved
344
+ mins = index_c_interleaved .bounds [0 :2 ]
345
+ maxs = index_c_interleaved .bounds [2 :4 ]
346
+ idxs , counts = index_c_interleaved .intersection_v (mins , maxs )
347
+ assert idxs .tolist () == [0 , 1 ]
348
+ assert counts .tolist () == [2 ]
349
+
317
350
def test_intersection_uninterleaved (
318
351
self , index_a_uninterleaved : index .Index , index_b_uninterleaved : index .Index
319
352
) -> None :
@@ -330,6 +363,17 @@ def test_intersection_uninterleaved(
330
363
else :
331
364
assert False
332
365
366
+ @skip_sidx_lt_210
367
+ def test_intersection_v_uninterleaved (
368
+ self , index_a_uninterleaved : index .Index , index_b_uninterleaved : index .Index
369
+ ) -> None :
370
+ index_c_uninterleaved = index_a_uninterleaved & index_b_uninterleaved
371
+ mins = index_c_uninterleaved .bounds [0 ::2 ]
372
+ maxs = index_c_uninterleaved .bounds [1 ::2 ]
373
+ idxs , counts = index_c_uninterleaved .intersection_v (mins , maxs )
374
+ assert idxs .tolist () == [0 , 1 ]
375
+ assert counts .tolist () == [2 ]
376
+
333
377
def test_intersection_mismatch (
334
378
self , index_a_interleaved : index .Index , index_b_uninterleaved : index .Index
335
379
) -> None :
@@ -617,6 +661,46 @@ def test_nearest_basic(self) -> None:
617
661
hits = sorted (idx .nearest ((13 , 0 , 20 , 2 ), 3 ))
618
662
self .assertEqual (hits , [3 , 4 , 5 ])
619
663
664
+ @skip_sidx_lt_210
665
+ def test_nearest_v_basic (self ) -> None :
666
+ mins = np .array ([[0 , 5 ]] * 2 ).T
667
+ maxs = np .array ([[10 , 15 ]] * 2 ).T
668
+ ret = self .idx .nearest_v (mins , maxs , num_results = 3 )
669
+ assert type (ret ) is tuple
670
+ ids , counts = ret
671
+ assert ids .dtype == np .int64
672
+ ids0 = [76 , 48 , 19 ]
673
+ ids1 = [76 , 47 , 48 ]
674
+ assert ids .tolist () == ids0 + ids1
675
+ assert counts .dtype == np .uint64
676
+ assert counts .tolist () == [3 , 3 ]
677
+
678
+ ret = self .idx .nearest_v (mins , maxs , num_results = 3 , return_max_dists = True )
679
+ assert type (ret ) is tuple
680
+ ids , counts , max_dists = ret
681
+ assert ids .tolist () == ids0 + ids1
682
+ assert counts .tolist () == [3 , 3 ]
683
+ assert max_dists .dtype == np .float64
684
+ np .testing .assert_allclose (max_dists , [7.54938045 , 11.05686397 ])
685
+
686
+ ret = self .idx .nearest_v (
687
+ mins , maxs , num_results = 3 , max_dists = [10 , 10 ], return_max_dists = True
688
+ )
689
+ ids , counts , max_dists = ret
690
+ assert ids .tolist () == ids0 + ids1 [:2 ]
691
+ assert counts .tolist () == [3 , 2 ]
692
+ np .testing .assert_allclose (max_dists , [7.54938045 , 3.92672575 ])
693
+
694
+ # errors
695
+ with pytest .raises (ValueError , match = "must have 2 dimensions" ):
696
+ self .idx .nearest_v (np .ones ((2 , 3 , 4 )), 4 )
697
+ with pytest .raises (ValueError , match = "shapes not equal" ):
698
+ self .idx .nearest_v ([0 ], [10 , 12 ])
699
+ with pytest .raises (ValueError , match = "max_dists must have 1 dimension" ):
700
+ self .idx .nearest_v (maxs , mins , max_dists = [[10 ]])
701
+ with pytest .raises (ValueError , match = "max_dists must have length 2" ):
702
+ self .idx .nearest_v (maxs , mins , max_dists = [10 ])
703
+
620
704
def test_nearest_equidistant (self ) -> None :
621
705
"""Test that if records are equidistant, both are returned."""
622
706
point = (0 , 0 )
@@ -677,25 +761,47 @@ def test_deletion(self) -> None:
677
761
self .assertEqual (hits , [])
678
762
679
763
680
- class IndexMoreDimensions (IndexTestCase ):
681
- def test_3d (self ) -> None :
682
- """Test we make and query a 3D index"""
764
+ class Index3d (IndexTestCase ):
765
+ """Test we make and query a 3D index"""
766
+
767
+ def setUp (self ) -> None :
683
768
p = index .Property ()
684
769
p .dimension = 3
685
- idx = index .Index (properties = p , interleaved = False )
686
- idx .insert (1 , (0 , 0 , 60 , 60 , 22 , 22.0 ))
687
- hits = idx .intersection ((- 1 , 1 , 58 , 62 , 22 , 24 ))
770
+ self .idx = index .Index (properties = p , interleaved = False )
771
+ self .idx .insert (1 , (0 , 0 , 60 , 60 , 22 , 22.0 ))
772
+ self .coords = (- 1 , 1 , 58 , 62 , 22 , 24 )
773
+
774
+ def test_intersection (self ) -> None :
775
+ hits = self .idx .intersection (self .coords )
688
776
self .assertEqual (list (hits ), [1 ])
689
777
690
- def test_4d (self ) -> None :
691
- """Test we make and query a 4D index"""
778
+ @skip_sidx_lt_210
779
+ def test_intersection_v (self ) -> None :
780
+ idxs , counts = self .idx .intersection_v (self .coords [0 ::2 ], self .coords [1 ::2 ])
781
+ assert idxs .tolist () == [1 ]
782
+ assert counts .tolist () == [1 ]
783
+
784
+
785
+ class Index4d (IndexTestCase ):
786
+ """Test we make and query a 4D index"""
787
+
788
+ def setUp (self ) -> None :
692
789
p = index .Property ()
693
790
p .dimension = 4
694
- idx = index .Index (properties = p , interleaved = False )
695
- idx .insert (1 , (0 , 0 , 60 , 60 , 22 , 22.0 , 128 , 142 ))
696
- hits = idx .intersection ((- 1 , 1 , 58 , 62 , 22 , 24 , 120 , 150 ))
791
+ self .idx = index .Index (properties = p , interleaved = False )
792
+ self .idx .insert (1 , (0 , 0 , 60 , 60 , 22 , 22.0 , 128 , 142 ))
793
+ self .coords = (- 1 , 1 , 58 , 62 , 22 , 24 , 120 , 150 )
794
+
795
+ def test_intersection (self ) -> None :
796
+ hits = self .idx .intersection (self .coords )
697
797
self .assertEqual (list (hits ), [1 ])
698
798
799
+ @skip_sidx_lt_210
800
+ def test_intersection_v (self ) -> None :
801
+ idxs , counts = self .idx .intersection_v (self .coords [0 ::2 ], self .coords [1 ::2 ])
802
+ assert idxs .tolist () == [1 ]
803
+ assert counts .tolist () == [1 ]
804
+
699
805
700
806
class IndexStream (IndexTestCase ):
701
807
def test_stream_input (self ) -> None :
0 commit comments