Skip to content

Commit aeb13af

Browse files
committed
Merge pull request numpy#3107 from MrBago/improve_searchsorted
Improve searchsorted
2 parents 4e16d97 + 4674b9e commit aeb13af

File tree

2 files changed

+108
-41
lines changed

2 files changed

+108
-41
lines changed

numpy/core/src/multiarray/item_selection.c

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,10 +1866,9 @@ PyArray_LexSort(PyObject *sort_keys, int axis)
18661866
*
18671867
* For each key use bisection to find the first index i s.t. key <= arr[i].
18681868
* When there is no such index i, set i = len(arr). Return the results in ret.
1869-
* All arrays are assumed contiguous on entry and both arr and key must be of
1870-
* the same comparable type.
1869+
* Both arr and key must be of the same comparable type.
18711870
*
1872-
* @param arr contiguous sorted array to be searched.
1871+
* @param arr 1d, strided, sorted array to be searched.
18731872
* @param key contiguous array of keys.
18741873
* @param ret contiguous array of intp for returned indices.
18751874
* @return void
@@ -1883,15 +1882,16 @@ local_search_left(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret)
18831882
char *parr = PyArray_DATA(arr);
18841883
char *pkey = PyArray_DATA(key);
18851884
npy_intp *pret = (npy_intp *)PyArray_DATA(ret);
1886-
int elsize = PyArray_DESCR(arr)->elsize;
1885+
int elsize = PyArray_DESCR(key)->elsize;
1886+
npy_intp arrstride = *PyArray_STRIDES(arr);
18871887
npy_intp i;
18881888

18891889
for (i = 0; i < nkeys; ++i) {
18901890
npy_intp imin = 0;
18911891
npy_intp imax = nelts;
18921892
while (imin < imax) {
18931893
npy_intp imid = imin + ((imax - imin) >> 1);
1894-
if (compare(parr + elsize*imid, pkey, key) < 0) {
1894+
if (compare(parr + arrstride*imid, pkey, key) < 0) {
18951895
imin = imid + 1;
18961896
}
18971897
else {
@@ -1909,10 +1909,9 @@ local_search_left(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret)
19091909
*
19101910
* For each key use bisection to find the first index i s.t. key < arr[i].
19111911
* When there is no such index i, set i = len(arr). Return the results in ret.
1912-
* All arrays are assumed contiguous on entry and both arr and key must be of
1913-
* the same comparable type.
1912+
* Both arr and key must be of the same comparable type.
19141913
*
1915-
* @param arr contiguous sorted array to be searched.
1914+
* @param arr 1d, strided, sorted array to be searched.
19161915
* @param key contiguous array of keys.
19171916
* @param ret contiguous array of intp for returned indices.
19181917
* @return void
@@ -1926,15 +1925,16 @@ local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret)
19261925
char *parr = PyArray_DATA(arr);
19271926
char *pkey = PyArray_DATA(key);
19281927
npy_intp *pret = (npy_intp *)PyArray_DATA(ret);
1929-
int elsize = PyArray_DESCR(arr)->elsize;
1928+
int elsize = PyArray_DESCR(key)->elsize;
1929+
npy_intp arrstride = *PyArray_STRIDES(arr);
19301930
npy_intp i;
19311931

19321932
for(i = 0; i < nkeys; ++i) {
19331933
npy_intp imin = 0;
19341934
npy_intp imax = nelts;
19351935
while (imin < imax) {
19361936
npy_intp imid = imin + ((imax - imin) >> 1);
1937-
if (compare(parr + elsize*imid, pkey, key) <= 0) {
1937+
if (compare(parr + arrstride*imid, pkey, key) <= 0) {
19381938
imin = imid + 1;
19391939
}
19401940
else {
@@ -1951,11 +1951,11 @@ local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret)
19511951
*
19521952
* For each key use bisection to find the first index i s.t. key <= arr[i].
19531953
* When there is no such index i, set i = len(arr). Return the results in ret.
1954-
* All arrays are assumed contiguous on entry and both arr and key must be of
1955-
* the same comparable type.
1954+
* Both arr and key must be of the same comparable type.
19561955
*
1957-
* @param arr contiguous sorted array to be searched.
1956+
* @param arr 1d, strided array to be searched.
19581957
* @param key contiguous array of keys.
1958+
* @param sorter 1d, strided array of intp that sorts arr.
19591959
* @param ret contiguous array of intp for returned indices.
19601960
* @return int
19611961
*/
@@ -1968,22 +1968,24 @@ local_argsearch_left(PyArrayObject *arr, PyArrayObject *key,
19681968
npy_intp nkeys = PyArray_SIZE(key);
19691969
char *parr = PyArray_DATA(arr);
19701970
char *pkey = PyArray_DATA(key);
1971-
npy_intp *psorter = (npy_intp *)PyArray_DATA(sorter);
1971+
char *psorter = PyArray_DATA(sorter);
19721972
npy_intp *pret = (npy_intp *)PyArray_DATA(ret);
1973-
int elsize = PyArray_DESCR(arr)->elsize;
1973+
int elsize = PyArray_DESCR(key)->elsize;
1974+
npy_intp arrstride = *PyArray_STRIDES(arr);
1975+
npy_intp sorterstride = *PyArray_STRIDES(sorter);
19741976
npy_intp i;
19751977

19761978
for (i = 0; i < nkeys; ++i) {
19771979
npy_intp imin = 0;
19781980
npy_intp imax = nelts;
19791981
while (imin < imax) {
19801982
npy_intp imid = imin + ((imax - imin) >> 1);
1981-
npy_intp indx = psorter[imid];
1983+
npy_intp indx = *(npy_intp *)(psorter + sorterstride * imid);
19821984

19831985
if (indx < 0 || indx >= nelts) {
19841986
return -1;
19851987
}
1986-
if (compare(parr + elsize*indx, pkey, key) < 0) {
1988+
if (compare(parr + arrstride*indx, pkey, key) < 0) {
19871989
imin = imid + 1;
19881990
}
19891991
else {
@@ -2002,11 +2004,11 @@ local_argsearch_left(PyArrayObject *arr, PyArrayObject *key,
20022004
*
20032005
* For each key use bisection to find the first index i s.t. key < arr[i].
20042006
* When there is no such index i, set i = len(arr). Return the results in ret.
2005-
* All arrays are assumed contiguous on entry and both arr and key must be of
2006-
* the same comparable type.
2007+
* Both arr and key must be of the same comparable type.
20072008
*
2008-
* @param arr contiguous sorted array to be searched.
2009+
* @param arr 1d, strided array to be searched.
20092010
* @param key contiguous array of keys.
2011+
* @param sorter 1d, strided array of intp that sorts arr.
20102012
* @param ret contiguous array of intp for returned indices.
20112013
* @return int
20122014
*/
@@ -2019,22 +2021,24 @@ local_argsearch_right(PyArrayObject *arr, PyArrayObject *key,
20192021
npy_intp nkeys = PyArray_SIZE(key);
20202022
char *parr = PyArray_DATA(arr);
20212023
char *pkey = PyArray_DATA(key);
2022-
npy_intp *psorter = (npy_intp *)PyArray_DATA(sorter);
2024+
char *psorter = PyArray_DATA(sorter);
20232025
npy_intp *pret = (npy_intp *)PyArray_DATA(ret);
2024-
int elsize = PyArray_DESCR(arr)->elsize;
2026+
int elsize = PyArray_DESCR(key)->elsize;
2027+
npy_intp arrstride = *PyArray_STRIDES(arr);
2028+
npy_intp sorterstride = *PyArray_STRIDES(sorter);
20252029
npy_intp i;
20262030

20272031
for(i = 0; i < nkeys; ++i) {
20282032
npy_intp imin = 0;
20292033
npy_intp imax = nelts;
20302034
while (imin < imax) {
20312035
npy_intp imid = imin + ((imax - imin) >> 1);
2032-
npy_intp indx = psorter[imid];
2036+
npy_intp indx = *(npy_intp *)(psorter + sorterstride * imid);
20332037

20342038
if (indx < 0 || indx >= nelts) {
20352039
return -1;
20362040
}
2037-
if (compare(parr + elsize*indx, pkey, key) <= 0) {
2041+
if (compare(parr + arrstride*indx, pkey, key) <= 0) {
20382042
imin = imid + 1;
20392043
}
20402044
else {
@@ -2087,6 +2091,7 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
20872091
PyArrayObject *sorter = NULL;
20882092
PyArrayObject *ret = NULL;
20892093
PyArray_Descr *dtype;
2094+
int ap1_flags = NPY_ARRAY_NOTSWAPPED | NPY_ARRAY_ALIGNED;
20902095
NPY_BEGIN_THREADS_DEF;
20912096

20922097
/* Find common type */
@@ -2095,23 +2100,28 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
20952100
return NULL;
20962101
}
20972102

2098-
/* need ap1 as contiguous array and of right type */
2103+
/* need ap2 as contiguous array and of right type */
20992104
Py_INCREF(dtype);
2100-
ap1 = (PyArrayObject *)PyArray_CheckFromAny((PyObject *)op1, dtype,
2101-
1, 1,
2102-
NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED,
2105+
ap2 = (PyArrayObject *)PyArray_CheckFromAny(op2, dtype,
2106+
0, 0,
2107+
NPY_ARRAY_CARRAY_RO | NPY_ARRAY_NOTSWAPPED,
21032108
NULL);
2104-
if (ap1 == NULL) {
2109+
if (ap2 == NULL) {
21052110
Py_DECREF(dtype);
21062111
return NULL;
21072112
}
21082113

2109-
/* need ap2 as contiguous array and of right type */
2110-
ap2 = (PyArrayObject *)PyArray_CheckFromAny(op2, dtype,
2111-
0, 0,
2112-
NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED,
2113-
NULL);
2114-
if (ap2 == NULL) {
2114+
/*
2115+
* If the needle (ap2) is larger than the haystack (op1) we copy the
2116+
* haystack to a continuous array for improved cache utilization.
2117+
*/
2118+
if (PyArray_SIZE(ap2) > PyArray_SIZE(op1)) {
2119+
ap1_flags |= NPY_ARRAY_CARRAY_RO;
2120+
}
2121+
2122+
ap1 = (PyArrayObject *)PyArray_CheckFromAny((PyObject *)op1, dtype,
2123+
1, 1, ap1_flags, NULL);
2124+
if (ap1 == NULL) {
21152125
goto fail;
21162126
}
21172127
/* check that comparison function exists */
@@ -2125,7 +2135,7 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
21252135
/* need ap3 as contiguous array and of right type */
21262136
ap3 = (PyArrayObject *)PyArray_CheckFromAny(perm, NULL,
21272137
1, 1,
2128-
NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED,
2138+
NPY_ARRAY_ALIGNED | NPY_ARRAY_NOTSWAPPED,
21292139
NULL);
21302140
if (ap3 == NULL) {
21312141
PyErr_SetString(PyExc_TypeError,
@@ -2140,7 +2150,7 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
21402150
/* convert to known integer size */
21412151
sorter = (PyArrayObject *)PyArray_FromArray(ap3,
21422152
PyArray_DescrFromType(NPY_INTP),
2143-
NPY_ARRAY_DEFAULT | NPY_ARRAY_NOTSWAPPED);
2153+
NPY_ARRAY_ALIGNED | NPY_ARRAY_NOTSWAPPED);
21442154
if (sorter == NULL) {
21452155
PyErr_SetString(PyExc_ValueError,
21462156
"could not parse sorter argument");

numpy/core/tests/test_multiarray.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,6 @@ def test_searchsorted(self):
870870
# order.
871871

872872
# check double
873-
a = np.array([np.nan, 1, 0])
874873
a = np.array([0, 1, np.nan])
875874
msg = "Test real searchsorted with nans, side='l'"
876875
b = a.searchsorted(a, side='l')
@@ -897,6 +896,41 @@ def test_searchsorted(self):
897896
b = a.searchsorted(np.array(128,dtype='>i4'))
898897
assert_equal(b, 1, msg)
899898

899+
# Check 0 elements
900+
a = np.ones(0)
901+
b = a.searchsorted([0, 1, 2], 'l')
902+
assert_equal(b, [0, 0, 0])
903+
b = a.searchsorted([0, 1, 2], 'r')
904+
assert_equal(b, [0, 0, 0])
905+
a = np.ones(1)
906+
# Check 1 element
907+
b = a.searchsorted([0, 1, 2], 'l')
908+
assert_equal(b, [0, 0, 1])
909+
b = a.searchsorted([0, 1, 2], 'r')
910+
assert_equal(b, [0, 1, 1])
911+
# Check all elements equal
912+
a = np.ones(2)
913+
b = a.searchsorted([0, 1, 2], 'l')
914+
assert_equal(b, [0, 0, 2])
915+
b = a.searchsorted([0, 1, 2], 'r')
916+
assert_equal(b, [0, 2, 2])
917+
918+
# Test searching unaligned array
919+
a = np.arange(10)
920+
aligned = np.empty(a.itemsize * a.size + 1, 'uint8')
921+
unaligned = aligned[1:].view(a.dtype)
922+
unaligned[:] = a
923+
# Test searching unaligned array
924+
b = unaligned.searchsorted(a, 'l')
925+
assert_equal(b, a)
926+
b = unaligned.searchsorted(a, 'r')
927+
assert_equal(b, a + 1)
928+
# Test searching for unaligned keys
929+
b = a.searchsorted(unaligned, 'l')
930+
assert_equal(b, a)
931+
b = a.searchsorted(unaligned, 'r')
932+
assert_equal(b, a + 1)
933+
900934
def test_searchsorted_unicode(self):
901935
# Test searchsorted on unicode strings.
902936

@@ -935,8 +969,9 @@ def test_searchsorted_with_sorter(self):
935969
# bounds check
936970
assert_raises(ValueError, np.searchsorted, a, 4, sorter=[0,1,2,3,5])
937971
assert_raises(ValueError, np.searchsorted, a, 0, sorter=[-1,0,1,2,3])
972+
assert_raises(ValueError, np.searchsorted, a, 0, sorter=[4,0,-1,2,3])
938973

939-
a = np.random.rand(100)
974+
a = np.random.rand(300)
940975
s = a.argsort()
941976
b = np.sort(a)
942977
k = np.linspace(0, 1, 20)
@@ -945,8 +980,30 @@ def test_searchsorted_with_sorter(self):
945980
a = np.array([0, 1, 2, 3, 5]*20)
946981
s = a.argsort()
947982
k = [0, 1, 2, 3, 5]
948-
assert_equal(a.searchsorted(k, side='l', sorter=s), [0, 20, 40, 60, 80])
949-
assert_equal(a.searchsorted(k, side='r', sorter=s), [20, 40, 60, 80, 100])
983+
expected = [0, 20, 40, 60, 80]
984+
assert_equal(a.searchsorted(k, side='l', sorter=s), expected)
985+
expected = [20, 40, 60, 80, 100]
986+
assert_equal(a.searchsorted(k, side='r', sorter=s), expected)
987+
988+
# Test searching unaligned array
989+
keys = np.arange(10)
990+
a = keys.copy()
991+
np.random.shuffle(s)
992+
s = a.argsort()
993+
aligned = np.empty(a.itemsize * a.size + 1, 'uint8')
994+
unaligned = aligned[1:].view(a.dtype)
995+
# Test searching unaligned array
996+
unaligned[:] = a
997+
b = unaligned.searchsorted(keys, 'l', s)
998+
assert_equal(b, keys)
999+
b = unaligned.searchsorted(keys, 'r', s)
1000+
assert_equal(b, keys + 1)
1001+
# Test searching for unaligned keys
1002+
unaligned[:] = keys
1003+
b = a.searchsorted(unaligned, 'l', s)
1004+
assert_equal(b, keys)
1005+
b = a.searchsorted(unaligned, 'r', s)
1006+
assert_equal(b, keys + 1)
9501007

9511008

9521009
def test_partition(self):

0 commit comments

Comments
 (0)