1
1
import itertools
2
2
import time
3
- from typing import Dict , Union
3
+ from typing import Dict , Optional , Union
4
4
5
5
from redis .client import Pipeline
6
6
@@ -363,7 +363,11 @@ def info(self):
363
363
it = map (to_string , res )
364
364
return dict (zip (it , it ))
365
365
366
- def get_params_args (self , query_params : Dict [str , Union [str , int , float ]]):
366
+ def get_params_args (
367
+ self , query_params : Union [Dict [str , Union [str , int , float ]], None ]
368
+ ):
369
+ if query_params is None :
370
+ return []
367
371
args = []
368
372
if len (query_params ) > 0 :
369
373
args .append ("params" )
@@ -383,8 +387,7 @@ def _mk_query_args(self, query, query_params: Dict[str, Union[str, int, float]])
383
387
raise ValueError (f"Bad query type { type (query )} " )
384
388
385
389
args += query .get_args ()
386
- if query_params is not None :
387
- args += self .get_params_args (query_params )
390
+ args += self .get_params_args (query_params )
388
391
389
392
return args , query
390
393
@@ -459,8 +462,7 @@ def aggregate(
459
462
cmd = [CURSOR_CMD , "READ" , self .index_name ] + query .build_args ()
460
463
else :
461
464
raise ValueError ("Bad query" , query )
462
- if query_params is not None :
463
- cmd += self .get_params_args (query_params )
465
+ cmd += self .get_params_args (query_params )
464
466
465
467
raw = self .execute_command (* cmd )
466
468
return self ._get_aggregate_result (raw , query , has_cursor )
@@ -485,16 +487,22 @@ def _get_aggregate_result(self, raw, query, has_cursor):
485
487
486
488
return AggregateResult (rows , cursor , schema )
487
489
488
- def profile (self , query , limited = False ):
490
+ def profile (
491
+ self ,
492
+ query : Union [str , Query , AggregateRequest ],
493
+ limited : bool = False ,
494
+ query_params : Optional [Dict [str , Union [str , int , float ]]] = None ,
495
+ ):
489
496
"""
490
497
Performs a search or aggregate command and collects performance
491
498
information.
492
499
493
500
### Parameters
494
501
495
- **query**: This can be either an `AggregateRequest`, `Query` or
496
- string.
502
+ **query**: This can be either an `AggregateRequest`, `Query` or string.
497
503
**limited**: If set to True, removes details of reader iterator.
504
+ **query_params**: Define one or more value parameters.
505
+ Each parameter has a name and a value.
498
506
499
507
"""
500
508
st = time .time ()
@@ -509,6 +517,7 @@ def profile(self, query, limited=False):
509
517
elif isinstance (query , Query ):
510
518
cmd [2 ] = "SEARCH"
511
519
cmd += query .get_args ()
520
+ cmd += self .get_params_args (query_params )
512
521
else :
513
522
raise ValueError ("Must provide AggregateRequest object or " "Query object." )
514
523
@@ -907,8 +916,7 @@ async def aggregate(
907
916
cmd = [CURSOR_CMD , "READ" , self .index_name ] + query .build_args ()
908
917
else :
909
918
raise ValueError ("Bad query" , query )
910
- if query_params is not None :
911
- cmd += self .get_params_args (query_params )
919
+ cmd += self .get_params_args (query_params )
912
920
913
921
raw = await self .execute_command (* cmd )
914
922
return self ._get_aggregate_result (raw , query , has_cursor )
0 commit comments