Skip to content

Commit 5d2bccd

Browse files
Corrected typing hints for the FunctionScore query (#1960)
Fixes #1957
1 parent 5a2cc86 commit 5d2bccd

File tree

6 files changed

+37
-77
lines changed

6 files changed

+37
-77
lines changed

elasticsearch_dsl/query.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -612,9 +612,9 @@ class FunctionScore(Query):
612612

613613
name = "function_score"
614614
_param_defs = {
615+
"functions": {"type": "score_function", "multi": True},
615616
"query": {"type": "query"},
616617
"filter": {"type": "query"},
617-
"functions": {"type": "score_function", "multi": True},
618618
}
619619

620620
def __init__(
@@ -623,11 +623,7 @@ def __init__(
623623
boost_mode: Union[
624624
Literal["multiply", "replace", "sum", "avg", "max", "min"], "DefaultType"
625625
] = DEFAULT,
626-
functions: Union[
627-
Sequence["types.FunctionScoreContainer"],
628-
Sequence[Dict[str, Any]],
629-
"DefaultType",
630-
] = DEFAULT,
626+
functions: Union[Sequence[ScoreFunction], "DefaultType"] = DEFAULT,
631627
max_boost: Union[float, "DefaultType"] = DEFAULT,
632628
min_score: Union[float, "DefaultType"] = DEFAULT,
633629
query: Union[Query, "DefaultType"] = DEFAULT,

elasticsearch_dsl/types.py

+1-69
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from elastic_transport.client_utils import DEFAULT, DefaultType
2121

22-
from elasticsearch_dsl import Query, function
22+
from elasticsearch_dsl import Query
2323
from elasticsearch_dsl.document_base import InstrumentedField
2424
from elasticsearch_dsl.utils import AttrDict
2525

@@ -688,74 +688,6 @@ def __init__(
688688
super().__init__(kwargs)
689689

690690

691-
class FunctionScoreContainer(AttrDict[Any]):
692-
"""
693-
:arg exp: Function that scores a document with a exponential decay,
694-
depending on the distance of a numeric field value of the document
695-
from an origin.
696-
:arg gauss: Function that scores a document with a normal decay,
697-
depending on the distance of a numeric field value of the document
698-
from an origin.
699-
:arg linear: Function that scores a document with a linear decay,
700-
depending on the distance of a numeric field value of the document
701-
from an origin.
702-
:arg field_value_factor: Function allows you to use a field from a
703-
document to influence the score. It’s similar to using the
704-
script_score function, however, it avoids the overhead of
705-
scripting.
706-
:arg random_score: Generates scores that are uniformly distributed
707-
from 0 up to but not including 1. In case you want scores to be
708-
reproducible, it is possible to provide a `seed` and `field`.
709-
:arg script_score: Enables you to wrap another query and customize the
710-
scoring of it optionally with a computation derived from other
711-
numeric field values in the doc using a script expression.
712-
:arg filter:
713-
:arg weight:
714-
"""
715-
716-
exp: Union[function.DecayFunction, DefaultType]
717-
gauss: Union[function.DecayFunction, DefaultType]
718-
linear: Union[function.DecayFunction, DefaultType]
719-
field_value_factor: Union[function.FieldValueFactorScore, DefaultType]
720-
random_score: Union[function.RandomScore, DefaultType]
721-
script_score: Union[function.ScriptScore, DefaultType]
722-
filter: Union[Query, DefaultType]
723-
weight: Union[float, DefaultType]
724-
725-
def __init__(
726-
self,
727-
*,
728-
exp: Union[function.DecayFunction, DefaultType] = DEFAULT,
729-
gauss: Union[function.DecayFunction, DefaultType] = DEFAULT,
730-
linear: Union[function.DecayFunction, DefaultType] = DEFAULT,
731-
field_value_factor: Union[
732-
function.FieldValueFactorScore, DefaultType
733-
] = DEFAULT,
734-
random_score: Union[function.RandomScore, DefaultType] = DEFAULT,
735-
script_score: Union[function.ScriptScore, DefaultType] = DEFAULT,
736-
filter: Union[Query, DefaultType] = DEFAULT,
737-
weight: Union[float, DefaultType] = DEFAULT,
738-
**kwargs: Any,
739-
):
740-
if exp is not DEFAULT:
741-
kwargs["exp"] = exp
742-
if gauss is not DEFAULT:
743-
kwargs["gauss"] = gauss
744-
if linear is not DEFAULT:
745-
kwargs["linear"] = linear
746-
if field_value_factor is not DEFAULT:
747-
kwargs["field_value_factor"] = field_value_factor
748-
if random_score is not DEFAULT:
749-
kwargs["random_score"] = random_score
750-
if script_score is not DEFAULT:
751-
kwargs["script_score"] = script_score
752-
if filter is not DEFAULT:
753-
kwargs["filter"] = filter
754-
if weight is not DEFAULT:
755-
kwargs["weight"] = weight
756-
super().__init__(kwargs)
757-
758-
759691
class FuzzyQuery(AttrDict[Any]):
760692
"""
761693
:arg value: (required) Term you wish to find in the provided field.

tests/test_query.py

+27
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,33 @@ def test_function_score_to_dict() -> None:
562562
assert d == q.to_dict()
563563

564564

565+
def test_function_score_class_based_to_dict() -> None:
566+
q = query.FunctionScore(
567+
query=query.Match(title="python"),
568+
functions=[
569+
function.RandomScore(),
570+
function.FieldValueFactor(
571+
field="comment_count",
572+
filter=query.Term(tags="python"),
573+
),
574+
],
575+
)
576+
577+
d = {
578+
"function_score": {
579+
"query": {"match": {"title": "python"}},
580+
"functions": [
581+
{"random_score": {}},
582+
{
583+
"filter": {"term": {"tags": "python"}},
584+
"field_value_factor": {"field": "comment_count"},
585+
},
586+
],
587+
}
588+
}
589+
assert d == q.to_dict()
590+
591+
565592
def test_function_score_with_single_function() -> None:
566593
d = {
567594
"function_score": {

utils/generator.py

+6
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,12 @@ def get_python_type(self, schema_type, for_response=False):
200200
):
201201
# QueryContainer maps to the DSL's Query class
202202
return "Query", {"type": "query"}
203+
elif (
204+
type_name["namespace"] == "_types.query_dsl"
205+
and type_name["name"] == "FunctionScoreContainer"
206+
):
207+
# FunctionScoreContainer maps to the DSL's ScoreFunction class
208+
return "ScoreFunction", {"type": "score_function"}
203209
elif (
204210
type_name["namespace"] == "_types.aggregations"
205211
and type_name["name"] == "Buckets"

utils/templates/query.py.tpl

-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@ class {{ k.name }}({{ parent }}):
174174
shortcut property. Until the code generator can support shortcut
175175
properties directly that solution is added here #}
176176
"filter": {"type": "query"},
177-
"functions": {"type": "score_function", "multi": True},
178177
{% endif %}
179178
}
180179
{% endif %}

utils/templates/types.py.tpl

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ from typing import Any, Dict, Literal, Mapping, Sequence, Union
2020
from elastic_transport.client_utils import DEFAULT, DefaultType
2121

2222
from elasticsearch_dsl.document_base import InstrumentedField
23-
from elasticsearch_dsl import function, Query
23+
from elasticsearch_dsl import Query
2424
from elasticsearch_dsl.utils import AttrDict
2525

2626
PipeSeparatedFlags = str

0 commit comments

Comments
 (0)