Skip to content

Commit 7880460

Browse files
authored
Add query_params to FT.PROFILE (#2198)
* ft.profile query_params * fix pr comments * type hints
1 parent edf1004 commit 7880460

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

redis/commands/search/commands.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import itertools
22
import time
3-
from typing import Dict, Union
3+
from typing import Dict, Optional, Union
44

55
from redis.client import Pipeline
66

@@ -363,7 +363,11 @@ def info(self):
363363
it = map(to_string, res)
364364
return dict(zip(it, it))
365365

366-
def get_params_args(self, query_params: Dict[str, Union[str, int, float]]):
366+
def get_params_args(
367+
self, query_params: Union[Dict[str, Union[str, int, float]], None]
368+
):
369+
if query_params is None:
370+
return []
367371
args = []
368372
if len(query_params) > 0:
369373
args.append("params")
@@ -383,8 +387,7 @@ def _mk_query_args(self, query, query_params: Dict[str, Union[str, int, float]])
383387
raise ValueError(f"Bad query type {type(query)}")
384388

385389
args += query.get_args()
386-
if query_params is not None:
387-
args += self.get_params_args(query_params)
390+
args += self.get_params_args(query_params)
388391

389392
return args, query
390393

@@ -459,8 +462,7 @@ def aggregate(
459462
cmd = [CURSOR_CMD, "READ", self.index_name] + query.build_args()
460463
else:
461464
raise ValueError("Bad query", query)
462-
if query_params is not None:
463-
cmd += self.get_params_args(query_params)
465+
cmd += self.get_params_args(query_params)
464466

465467
raw = self.execute_command(*cmd)
466468
return self._get_aggregate_result(raw, query, has_cursor)
@@ -485,16 +487,22 @@ def _get_aggregate_result(self, raw, query, has_cursor):
485487

486488
return AggregateResult(rows, cursor, schema)
487489

488-
def profile(self, query, limited=False):
490+
def profile(
491+
self,
492+
query: Union[str, Query, AggregateRequest],
493+
limited: bool = False,
494+
query_params: Optional[Dict[str, Union[str, int, float]]] = None,
495+
):
489496
"""
490497
Performs a search or aggregate command and collects performance
491498
information.
492499
493500
### Parameters
494501
495-
**query**: This can be either an `AggregateRequest`, `Query` or
496-
string.
502+
**query**: This can be either an `AggregateRequest`, `Query` or string.
497503
**limited**: If set to True, removes details of reader iterator.
504+
**query_params**: Define one or more value parameters.
505+
Each parameter has a name and a value.
498506
499507
"""
500508
st = time.time()
@@ -509,6 +517,7 @@ def profile(self, query, limited=False):
509517
elif isinstance(query, Query):
510518
cmd[2] = "SEARCH"
511519
cmd += query.get_args()
520+
cmd += self.get_params_args(query_params)
512521
else:
513522
raise ValueError("Must provide AggregateRequest object or " "Query object.")
514523

@@ -907,8 +916,7 @@ async def aggregate(
907916
cmd = [CURSOR_CMD, "READ", self.index_name] + query.build_args()
908917
else:
909918
raise ValueError("Bad query", query)
910-
if query_params is not None:
911-
cmd += self.get_params_args(query_params)
919+
cmd += self.get_params_args(query_params)
912920

913921
raw = await self.execute_command(*cmd)
914922
return self._get_aggregate_result(raw, query, has_cursor)

tests/test_search.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,6 +1519,30 @@ def test_profile_limited(client):
15191519
assert len(res.docs) == 3 # check also the search result
15201520

15211521

1522+
@pytest.mark.redismod
1523+
@skip_ifmodversion_lt("2.4.3", "search")
1524+
def test_profile_query_params(modclient: redis.Redis):
1525+
modclient.flushdb()
1526+
modclient.ft().create_index(
1527+
(
1528+
VectorField(
1529+
"v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"}
1530+
),
1531+
)
1532+
)
1533+
modclient.hset("a", "v", "aaaaaaaa")
1534+
modclient.hset("b", "v", "aaaabaaa")
1535+
modclient.hset("c", "v", "aaaaabaa")
1536+
query = "*=>[KNN 2 @v $vec]"
1537+
q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2)
1538+
res, det = modclient.ft().profile(q, query_params={"vec": "aaaaaaaa"})
1539+
assert det["Iterators profile"]["Counter"] == 2.0
1540+
assert det["Iterators profile"]["Type"] == "VECTOR"
1541+
assert res.total == 2
1542+
assert "a" == res.docs[0].id
1543+
assert "0" == res.docs[0].__getattribute__("__v_score")
1544+
1545+
15221546
@pytest.mark.redismod
15231547
@skip_ifmodversion_lt("2.4.3", "search")
15241548
def test_vector_field(modclient):

0 commit comments

Comments
 (0)