Skip to content

Commit baed085

Browse files
Add a knn method to elasticsearch_dsl.search.Search (elastic#1691)
* Add a `knn` method to `elasticsearch_dsl.search.Search` * add knn's boost option
1 parent f0c5045 commit baed085

File tree

3 files changed

+152
-1
lines changed

3 files changed

+152
-1
lines changed

docs/search_dsl.rst

+27
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ The ``Search`` object represents the entire search request:
1414

1515
* aggregations
1616

17+
* k-nearest neighbor searches
18+
1719
* sort
1820

1921
* pagination
@@ -352,6 +354,31 @@ As opposed to other methods on the ``Search`` objects, defining aggregations is
352354
done in-place (does not return a copy).
353355

354356

357+
K-Nearest Neighbor Searches
358+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
359+
360+
To issue a kNN search, use the ``.knn()`` method:
361+
362+
.. code:: python
363+
364+
s = Search()
365+
vector = get_embedding("search text")
366+
367+
s = s.knn(
368+
field="embedding",
369+
k=5,
370+
num_candidates=10,
371+
query_vector=vector
372+
)
373+
374+
The ``field``, ``k`` and ``num_candidates`` arguments can be given as
375+
positional or keyword arguments and are required. In addition to these,
376+
``query_vector`` or ``query_vector_builder`` must be given as well.
377+
378+
The ``.knn()`` method can be invoked multiple times to include multiple kNN
379+
searches in the request.
380+
381+
355382
Sorting
356383
~~~~~~~
357384

elasticsearch_dsl/search.py

+71-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .aggs import A, AggBase
2525
from .connections import get_connection
2626
from .exceptions import IllegalOperation
27-
from .query import Bool, Q
27+
from .query import Bool, Q, Query
2828
from .response import Hit, Response
2929
from .utils import AttrDict, DslBase, recursive_to_dict
3030

@@ -319,6 +319,7 @@ def __init__(self, **kwargs):
319319
self.aggs = AggsProxy(self)
320320
self._sort = []
321321
self._collapse = {}
322+
self._knn = []
322323
self._source = None
323324
self._highlight = {}
324325
self._highlight_opts = {}
@@ -406,6 +407,7 @@ def _clone(self):
406407
s = super()._clone()
407408

408409
s._response_class = self._response_class
410+
s._knn = [knn.copy() for knn in self._knn]
409411
s._collapse = self._collapse.copy()
410412
s._sort = self._sort[:]
411413
s._source = copy.copy(self._source) if self._source is not None else None
@@ -445,6 +447,10 @@ def update_from_dict(self, d):
445447
self.aggs._params = {
446448
"aggs": {name: A(value) for (name, value) in aggs.items()}
447449
}
450+
if "knn" in d:
451+
self._knn = d.pop("knn")
452+
if isinstance(self._knn, dict):
453+
self._knn = [self._knn]
448454
if "collapse" in d:
449455
self._collapse = d.pop("collapse")
450456
if "sort" in d:
@@ -494,6 +500,64 @@ def script_fields(self, **kwargs):
494500
s._script_fields.update(kwargs)
495501
return s
496502

503+
def knn(
504+
self,
505+
field,
506+
k,
507+
num_candidates,
508+
query_vector=None,
509+
query_vector_builder=None,
510+
boost=None,
511+
filter=None,
512+
similarity=None,
513+
):
514+
"""
515+
Add a k-nearest neighbor (kNN) search.
516+
517+
:arg field: the name of the vector field to search against
518+
:arg k: number of nearest neighbors to return as top hits
519+
:arg num_candidates: number of nearest neighbor candidates to consider per shard
520+
:arg query_vector: the vector to search for
521+
:arg query_vector_builder: A dictionary indicating how to build a query vector
522+
:arg boost: A floating-point boost factor for kNN scores
523+
:arg filter: query to filter the documents that can match
524+
:arg similarity: the minimum similarity required for a document to be considered a match, as a float value
525+
526+
Example::
527+
528+
s = Search()
529+
s = s.knn(field='embedding', k=5, num_candidates=10, query_vector=vector,
530+
filter=Q('term', category='blog')))
531+
"""
532+
s = self._clone()
533+
s._knn.append(
534+
{
535+
"field": field,
536+
"k": k,
537+
"num_candidates": num_candidates,
538+
}
539+
)
540+
if query_vector is None and query_vector_builder is None:
541+
raise ValueError("one of query_vector and query_vector_builder is required")
542+
if query_vector is not None and query_vector_builder is not None:
543+
raise ValueError(
544+
"only one of query_vector and query_vector_builder must be given"
545+
)
546+
if query_vector is not None:
547+
s._knn[-1]["query_vector"] = query_vector
548+
if query_vector_builder is not None:
549+
s._knn[-1]["query_vector_builder"] = query_vector_builder
550+
if boost is not None:
551+
s._knn[-1]["boost"] = boost
552+
if filter is not None:
553+
if isinstance(filter, Query):
554+
s._knn[-1]["filter"] = filter.to_dict()
555+
else:
556+
s._knn[-1]["filter"] = filter
557+
if similarity is not None:
558+
s._knn[-1]["similarity"] = similarity
559+
return s
560+
497561
def source(self, fields=None, **kwargs):
498562
"""
499563
Selectively control how the _source field is returned.
@@ -677,6 +741,12 @@ def to_dict(self, count=False, **kwargs):
677741
if self.query:
678742
d["query"] = self.query.to_dict()
679743

744+
if self._knn:
745+
if len(self._knn) == 1:
746+
d["knn"] = self._knn[0]
747+
else:
748+
d["knn"] = self._knn
749+
680750
# count request doesn't care for sorting and other things
681751
if not count:
682752
if self.post_filter:

tests/test_search.py

+54
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,60 @@ class MyDocument(Document):
234234
assert s._doc_type_map == {}
235235

236236

237+
def test_knn():
238+
s = search.Search()
239+
240+
with raises(TypeError):
241+
s.knn()
242+
with raises(TypeError):
243+
s.knn("field")
244+
with raises(TypeError):
245+
s.knn("field", 5)
246+
with raises(ValueError):
247+
s.knn("field", 5, 100)
248+
with raises(ValueError):
249+
s.knn("field", 5, 100, query_vector=[1, 2, 3], query_vector_builder={})
250+
251+
s = s.knn("field", 5, 100, query_vector=[1, 2, 3])
252+
assert {
253+
"knn": {
254+
"field": "field",
255+
"k": 5,
256+
"num_candidates": 100,
257+
"query_vector": [1, 2, 3],
258+
}
259+
} == s.to_dict()
260+
261+
s = s.knn(
262+
k=4,
263+
num_candidates=40,
264+
boost=0.8,
265+
field="name",
266+
query_vector_builder={
267+
"text_embedding": {"model_id": "foo", "model_text": "search text"}
268+
},
269+
)
270+
assert {
271+
"knn": [
272+
{
273+
"field": "field",
274+
"k": 5,
275+
"num_candidates": 100,
276+
"query_vector": [1, 2, 3],
277+
},
278+
{
279+
"field": "name",
280+
"k": 4,
281+
"num_candidates": 40,
282+
"query_vector_builder": {
283+
"text_embedding": {"model_id": "foo", "model_text": "search text"}
284+
},
285+
"boost": 0.8,
286+
},
287+
]
288+
} == s.to_dict()
289+
290+
237291
def test_sort():
238292
s = search.Search()
239293
s = s.sort("fielda", "-fieldb")

0 commit comments

Comments
 (0)