|
24 | 24 | from .aggs import A, AggBase
|
25 | 25 | from .connections import get_connection
|
26 | 26 | from .exceptions import IllegalOperation
|
27 |
| -from .query import Bool, Q |
| 27 | +from .query import Bool, Q, Query |
28 | 28 | from .response import Hit, Response
|
29 | 29 | from .utils import AttrDict, DslBase, recursive_to_dict
|
30 | 30 |
|
@@ -319,6 +319,7 @@ def __init__(self, **kwargs):
|
319 | 319 | self.aggs = AggsProxy(self)
|
320 | 320 | self._sort = []
|
321 | 321 | self._collapse = {}
|
| 322 | + self._knn = [] |
322 | 323 | self._source = None
|
323 | 324 | self._highlight = {}
|
324 | 325 | self._highlight_opts = {}
|
@@ -406,6 +407,7 @@ def _clone(self):
|
406 | 407 | s = super()._clone()
|
407 | 408 |
|
408 | 409 | s._response_class = self._response_class
|
| 410 | + s._knn = [knn.copy() for knn in self._knn] |
409 | 411 | s._collapse = self._collapse.copy()
|
410 | 412 | s._sort = self._sort[:]
|
411 | 413 | s._source = copy.copy(self._source) if self._source is not None else None
|
@@ -445,6 +447,10 @@ def update_from_dict(self, d):
|
445 | 447 | self.aggs._params = {
|
446 | 448 | "aggs": {name: A(value) for (name, value) in aggs.items()}
|
447 | 449 | }
|
| 450 | + if "knn" in d: |
| 451 | + self._knn = d.pop("knn") |
| 452 | + if isinstance(self._knn, dict): |
| 453 | + self._knn = [self._knn] |
448 | 454 | if "collapse" in d:
|
449 | 455 | self._collapse = d.pop("collapse")
|
450 | 456 | if "sort" in d:
|
@@ -494,6 +500,64 @@ def script_fields(self, **kwargs):
|
494 | 500 | s._script_fields.update(kwargs)
|
495 | 501 | return s
|
496 | 502 |
|
| 503 | + def knn( |
| 504 | + self, |
| 505 | + field, |
| 506 | + k, |
| 507 | + num_candidates, |
| 508 | + query_vector=None, |
| 509 | + query_vector_builder=None, |
| 510 | + boost=None, |
| 511 | + filter=None, |
| 512 | + similarity=None, |
| 513 | + ): |
| 514 | + """ |
| 515 | + Add a k-nearest neighbor (kNN) search. |
| 516 | +
|
| 517 | + :arg field: the name of the vector field to search against |
| 518 | + :arg k: number of nearest neighbors to return as top hits |
| 519 | + :arg num_candidates: number of nearest neighbor candidates to consider per shard |
| 520 | + :arg query_vector: the vector to search for |
| 521 | + :arg query_vector_builder: A dictionary indicating how to build a query vector |
| 522 | + :arg boost: A floating-point boost factor for kNN scores |
| 523 | + :arg filter: query to filter the documents that can match |
| 524 | + :arg similarity: the minimum similarity required for a document to be considered a match, as a float value |
| 525 | +
|
| 526 | + Example:: |
| 527 | +
|
| 528 | + s = Search() |
| 529 | + s = s.knn(field='embedding', k=5, num_candidates=10, query_vector=vector, |
| 530 | + filter=Q('term', category='blog'))) |
| 531 | + """ |
| 532 | + s = self._clone() |
| 533 | + s._knn.append( |
| 534 | + { |
| 535 | + "field": field, |
| 536 | + "k": k, |
| 537 | + "num_candidates": num_candidates, |
| 538 | + } |
| 539 | + ) |
| 540 | + if query_vector is None and query_vector_builder is None: |
| 541 | + raise ValueError("one of query_vector and query_vector_builder is required") |
| 542 | + if query_vector is not None and query_vector_builder is not None: |
| 543 | + raise ValueError( |
| 544 | + "only one of query_vector and query_vector_builder must be given" |
| 545 | + ) |
| 546 | + if query_vector is not None: |
| 547 | + s._knn[-1]["query_vector"] = query_vector |
| 548 | + if query_vector_builder is not None: |
| 549 | + s._knn[-1]["query_vector_builder"] = query_vector_builder |
| 550 | + if boost is not None: |
| 551 | + s._knn[-1]["boost"] = boost |
| 552 | + if filter is not None: |
| 553 | + if isinstance(filter, Query): |
| 554 | + s._knn[-1]["filter"] = filter.to_dict() |
| 555 | + else: |
| 556 | + s._knn[-1]["filter"] = filter |
| 557 | + if similarity is not None: |
| 558 | + s._knn[-1]["similarity"] = similarity |
| 559 | + return s |
| 560 | + |
497 | 561 | def source(self, fields=None, **kwargs):
|
498 | 562 | """
|
499 | 563 | Selectively control how the _source field is returned.
|
@@ -677,6 +741,12 @@ def to_dict(self, count=False, **kwargs):
|
677 | 741 | if self.query:
|
678 | 742 | d["query"] = self.query.to_dict()
|
679 | 743 |
|
| 744 | + if self._knn: |
| 745 | + if len(self._knn) == 1: |
| 746 | + d["knn"] = self._knn[0] |
| 747 | + else: |
| 748 | + d["knn"] = self._knn |
| 749 | + |
680 | 750 | # count request doesn't care for sorting and other things
|
681 | 751 | if not count:
|
682 | 752 | if self.post_filter:
|
|
0 commit comments