10
10
11
11
12
12
class BaseQuery (RedisQuery ):
13
- """Base query class used to subclass many query types."""
13
+ """
14
+ Base query class used to subclass many query types.
15
+
16
+ NOTE: In the base class, the `_query_string` field is set once on
17
+ initialization, and afterward, redis-py expects to be able to read it. By
18
+ contrast, our query subclasses allow users to call methods that alter the
19
+ query string at runtime. To avoid having to rebuild `_query_string` every
20
+ time one of these methods is called, we lazily build the query string when a
21
+ user calls `query()` or accesses the property `_query_string`, when the
22
+ underlying `_built_query_string` field is None. Any method that alters the query
23
+ string should set `_built_query_string` to None so that the next time the query
24
+ string is accessed, it is rebuilt.
25
+ """
14
26
15
27
_params : Dict [str , Any ] = {}
16
28
_filter_expression : Union [str , FilterExpression ] = FilterExpression ("*" )
29
+ _built_query_string : Optional [str ] = None
17
30
18
31
def __init__ (self , query_string : str = "*" ):
19
32
"""
@@ -22,8 +35,15 @@ def __init__(self, query_string: str = "*"):
22
35
Args:
23
36
query_string (str, optional): The query string to use. Defaults to '*'.
24
37
"""
38
+ # The parent class expects a query string, so we pass it in, but we'll
39
+ # actually manage building it dynamically.
25
40
super ().__init__ (query_string )
26
41
42
+ # This is a private field that we use to track whether the query string
43
+ # has been built, and we set it to None here to indicate that the field
44
+ # has not been built yet.
45
+ self ._built_query_string = None
46
+
27
47
def __str__ (self ) -> str :
28
48
"""Return the string representation of the query."""
29
49
return " " .join ([str (x ) for x in self .get_args ()])
@@ -54,8 +74,8 @@ def set_filter(
54
74
"filter_expression must be of type FilterExpression or string or None"
55
75
)
56
76
57
- # Reset the query string
58
- self ._query_string = self . _build_query_string ()
77
+ # Invalidate the query string
78
+ self ._built_query_string = None
59
79
60
80
@property
61
81
def filter (self ) -> Union [str , FilterExpression ]:
@@ -72,6 +92,18 @@ def params(self) -> Dict[str, Any]:
72
92
"""Return the query parameters."""
73
93
return self ._params
74
94
95
+ @property
96
+ def _query_string (self ) -> str :
97
+ """Maintains compatibility with parent class while providing lazy loading."""
98
+ if self ._built_query_string is None :
99
+ self ._built_query_string = self ._build_query_string ()
100
+ return self ._built_query_string
101
+
102
+ @_query_string .setter
103
+ def _query_string (self , value : Optional [str ]):
104
+ """Setter for _query_string to maintain compatibility with parent class."""
105
+ self ._built_query_string = value
106
+
75
107
76
108
class FilterQuery (BaseQuery ):
77
109
def __init__ (
@@ -107,9 +139,9 @@ def __init__(
107
139
108
140
self ._num_results = num_results
109
141
110
- # Initialize the base query with the full query string constructed from the filter expression
111
- query_string = self . _build_query_string ( )
112
- super (). __init__ ( query_string )
142
+ # Initialize the base query with the query string from the property
143
+ super (). __init__ ( "*" )
144
+ self . _built_query_string = None # Ensure it's invalidated after initialization
113
145
114
146
# Handle query settings
115
147
if return_fields :
@@ -161,9 +193,9 @@ def __init__(
161
193
if params :
162
194
self ._params = params
163
195
164
- # Initialize the base query with the full query string constructed from the filter expression
165
- query_string = self . _build_query_string ( )
166
- super (). __init__ ( query_string )
196
+ # Initialize the base query with the query string from the property
197
+ super (). __init__ ( "*" )
198
+ self . _built_query_string = None
167
199
168
200
# Query specific modifications
169
201
self .no_content ().paging (0 , 0 ).dialect (dialect )
@@ -178,6 +210,8 @@ def _build_query_string(self) -> str:
178
210
class BaseVectorQuery :
179
211
DISTANCE_ID : str = "vector_distance"
180
212
VECTOR_PARAM : str = "vector"
213
+ EF_RUNTIME : str = "EF_RUNTIME"
214
+ EF_RUNTIME_PARAM : str = "EF"
181
215
182
216
_normalize_vector_distance : bool = False
183
217
@@ -204,6 +238,7 @@ def __init__(
204
238
in_order : bool = False ,
205
239
hybrid_policy : Optional [str ] = None ,
206
240
batch_size : Optional [int ] = None ,
241
+ ef_runtime : Optional [int ] = None ,
207
242
normalize_vector_distance : bool = False ,
208
243
):
209
244
"""A query for running a vector search along with an optional filter
@@ -241,6 +276,9 @@ def __init__(
241
276
of vectors to fetch in each batch. Larger values may improve performance
242
277
at the cost of memory usage. Only applies when hybrid_policy="BATCHES".
243
278
Defaults to None, which lets Redis auto-select an appropriate batch size.
279
+ ef_runtime (Optional[int]): Controls the size of the dynamic candidate list for HNSW
280
+ algorithm at query time. Higher values improve recall at the expense of
281
+ slower search performance. Defaults to None, which uses the index-defined value.
244
282
normalize_vector_distance (bool): Redis supports 3 distance metrics: L2 (euclidean),
245
283
IP (inner product), and COSINE. By default, L2 distance returns an unbounded value.
246
284
COSINE distance returns a value between 0 and 2. IP returns a value determined by
@@ -260,11 +298,13 @@ def __init__(
260
298
self ._num_results = num_results
261
299
self ._hybrid_policy : Optional [HybridPolicy ] = None
262
300
self ._batch_size : Optional [int ] = None
301
+ self ._ef_runtime : Optional [int ] = None
263
302
self ._normalize_vector_distance = normalize_vector_distance
264
303
self .set_filter (filter_expression )
265
- query_string = self ._build_query_string ()
266
304
267
- super ().__init__ (query_string )
305
+ # Initialize the base query
306
+ super ().__init__ ("*" )
307
+ self ._built_query_string = None
268
308
269
309
# Handle query modifiers
270
310
if return_fields :
@@ -289,6 +329,9 @@ def __init__(
289
329
if batch_size is not None :
290
330
self .set_batch_size (batch_size )
291
331
332
+ if ef_runtime is not None :
333
+ self .set_ef_runtime (ef_runtime )
334
+
292
335
def _build_query_string (self ) -> str :
293
336
"""Build the full query string for vector search with optional filtering."""
294
337
filter_expression = self ._filter_expression
@@ -308,6 +351,10 @@ def _build_query_string(self) -> str:
308
351
if self ._hybrid_policy == HybridPolicy .BATCHES and self ._batch_size :
309
352
knn_query += f" BATCH_SIZE { self ._batch_size } "
310
353
354
+ # Add EF_RUNTIME parameter if specified
355
+ if self ._ef_runtime :
356
+ knn_query += f" { self .EF_RUNTIME } ${ self .EF_RUNTIME_PARAM } "
357
+
311
358
# Add distance field alias
312
359
knn_query += f" AS { self .DISTANCE_ID } "
313
360
@@ -330,8 +377,8 @@ def set_hybrid_policy(self, hybrid_policy: str):
330
377
f"hybrid_policy must be one of { ', ' .join ([p .value for p in HybridPolicy ])} "
331
378
)
332
379
333
- # Reset the query string
334
- self ._query_string = self . _build_query_string ()
380
+ # Invalidate the query string
381
+ self ._built_query_string = None
335
382
336
383
def set_batch_size (self , batch_size : int ):
337
384
"""Set the batch size for the query.
@@ -349,8 +396,28 @@ def set_batch_size(self, batch_size: int):
349
396
raise ValueError ("batch_size must be positive" )
350
397
self ._batch_size = batch_size
351
398
352
- # Reset the query string
353
- self ._query_string = self ._build_query_string ()
399
+ # Invalidate the query string
400
+ self ._built_query_string = None
401
+
402
+ def set_ef_runtime (self , ef_runtime : int ):
403
+ """Set the EF_RUNTIME parameter for the query.
404
+
405
+ Args:
406
+ ef_runtime (int): The EF_RUNTIME value to use for HNSW algorithm.
407
+ Higher values improve recall at the expense of slower search.
408
+
409
+ Raises:
410
+ TypeError: If ef_runtime is not an integer
411
+ ValueError: If ef_runtime is not positive
412
+ """
413
+ if not isinstance (ef_runtime , int ):
414
+ raise TypeError ("ef_runtime must be an integer" )
415
+ if ef_runtime <= 0 :
416
+ raise ValueError ("ef_runtime must be positive" )
417
+ self ._ef_runtime = ef_runtime
418
+
419
+ # Invalidate the query string
420
+ self ._built_query_string = None
354
421
355
422
@property
356
423
def hybrid_policy (self ) -> Optional [str ]:
@@ -370,6 +437,15 @@ def batch_size(self) -> Optional[int]:
370
437
"""
371
438
return self ._batch_size
372
439
440
+ @property
441
+ def ef_runtime (self ) -> Optional [int ]:
442
+ """Return the EF_RUNTIME parameter for the query.
443
+
444
+ Returns:
445
+ Optional[int]: The EF_RUNTIME value for the query.
446
+ """
447
+ return self ._ef_runtime
448
+
373
449
@property
374
450
def params (self ) -> Dict [str , Any ]:
375
451
"""Return the parameters for the query.
@@ -382,7 +458,11 @@ def params(self) -> Dict[str, Any]:
382
458
else :
383
459
vector = array_to_buffer (self ._vector , dtype = self ._dtype )
384
460
385
- params = {self .VECTOR_PARAM : vector }
461
+ params : Dict [str , Any ] = {self .VECTOR_PARAM : vector }
462
+
463
+ # Add EF_RUNTIME parameter if specified
464
+ if self ._ef_runtime is not None :
465
+ params [self .EF_RUNTIME_PARAM ] = self ._ef_runtime
386
466
387
467
return params
388
468
@@ -475,6 +555,11 @@ def __init__(
475
555
self ._epsilon : Optional [float ] = None
476
556
self ._hybrid_policy : Optional [HybridPolicy ] = None
477
557
self ._batch_size : Optional [int ] = None
558
+ self ._normalize_vector_distance = normalize_vector_distance
559
+
560
+ # Initialize the base query
561
+ super ().__init__ ("*" )
562
+ self ._built_query_string = None
478
563
479
564
if epsilon is not None :
480
565
self .set_epsilon (epsilon )
@@ -485,12 +570,8 @@ def __init__(
485
570
if batch_size is not None :
486
571
self .set_batch_size (batch_size )
487
572
488
- self ._normalize_vector_distance = normalize_vector_distance
489
573
self .set_distance_threshold (distance_threshold )
490
574
self .set_filter (filter_expression )
491
- query_string = self ._build_query_string ()
492
-
493
- super ().__init__ (query_string )
494
575
495
576
# Handle query modifiers
496
577
if return_fields :
@@ -533,8 +614,8 @@ def set_distance_threshold(self, distance_threshold: float):
533
614
distance_threshold = denorm_cosine_distance (distance_threshold )
534
615
self ._distance_threshold = distance_threshold
535
616
536
- # Reset the query string
537
- self ._query_string = self . _build_query_string ()
617
+ # Invalidate the query string
618
+ self ._built_query_string = None
538
619
539
620
def set_epsilon (self , epsilon : float ):
540
621
"""Set the epsilon parameter for the range query.
@@ -553,8 +634,8 @@ def set_epsilon(self, epsilon: float):
553
634
raise ValueError ("epsilon must be non-negative" )
554
635
self ._epsilon = epsilon
555
636
556
- # Reset the query string
557
- self ._query_string = self . _build_query_string ()
637
+ # Invalidate the query string
638
+ self ._built_query_string = None
558
639
559
640
def set_hybrid_policy (self , hybrid_policy : str ):
560
641
"""Set the hybrid policy for the query.
@@ -573,8 +654,8 @@ def set_hybrid_policy(self, hybrid_policy: str):
573
654
f"hybrid_policy must be one of { ', ' .join ([p .value for p in HybridPolicy ])} "
574
655
)
575
656
576
- # Reset the query string
577
- self ._query_string = self . _build_query_string ()
657
+ # Invalidate the query string
658
+ self ._built_query_string = None
578
659
579
660
def set_batch_size (self , batch_size : int ):
580
661
"""Set the batch size for the query.
@@ -592,8 +673,8 @@ def set_batch_size(self, batch_size: int):
592
673
raise ValueError ("batch_size must be positive" )
593
674
self ._batch_size = batch_size
594
675
595
- # Reset the query string
596
- self ._query_string = self . _build_query_string ()
676
+ # Invalidate the query string
677
+ self ._built_query_string = None
597
678
598
679
def _build_query_string (self ) -> str :
599
680
"""Build the full query string for vector range queries with optional filtering"""
@@ -663,20 +744,22 @@ def params(self) -> Dict[str, Any]:
663
744
Dict[str, Any]: The parameters for the query.
664
745
"""
665
746
if isinstance (self ._vector , bytes ):
666
- vector_param = self ._vector
747
+ vector = self ._vector
667
748
else :
668
- vector_param = array_to_buffer (self ._vector , dtype = self ._dtype )
749
+ vector = array_to_buffer (self ._vector , dtype = self ._dtype )
669
750
670
751
params = {
671
- self .VECTOR_PARAM : vector_param ,
752
+ self .VECTOR_PARAM : vector ,
672
753
self .DISTANCE_THRESHOLD_PARAM : self ._distance_threshold ,
673
754
}
674
755
675
756
# Add hybrid policy and batch size as query parameters (not in query string)
676
- if self ._hybrid_policy :
757
+ if self ._hybrid_policy is not None :
677
758
params [self .HYBRID_POLICY_PARAM ] = self ._hybrid_policy .value
678
-
679
- if self ._hybrid_policy == HybridPolicy .BATCHES and self ._batch_size :
759
+ if (
760
+ self ._hybrid_policy == HybridPolicy .BATCHES
761
+ and self ._batch_size is not None
762
+ ):
680
763
params [self .BATCH_SIZE_PARAM ] = self ._batch_size
681
764
682
765
return params
@@ -763,7 +846,7 @@ def __init__(
763
846
TypeError: If stopwords is not a valid iterable set of strings.
764
847
"""
765
848
self ._text = text
766
- self ._text_field = text_field_name
849
+ self ._text_field_name = text_field_name
767
850
self ._num_results = num_results
768
851
769
852
self ._set_stopwords (stopwords )
@@ -772,9 +855,9 @@ def __init__(
772
855
if params :
773
856
self ._params = params
774
857
775
- # initialize the base query with the full query string and filter expression
776
- query_string = self . _build_query_string ( )
777
- super (). __init__ ( query_string )
858
+ # Initialize the base query
859
+ super (). __init__ ( "*" )
860
+ self . _built_query_string = None
778
861
779
862
# handle query settings
780
863
self .scorer (text_scorer )
@@ -860,7 +943,9 @@ def _build_query_string(self) -> str:
860
943
else :
861
944
filter_expression = ""
862
945
863
- text = f"@{ self ._text_field } :({ self ._tokenize_and_escape_query (self ._text )} )"
946
+ text = (
947
+ f"@{ self ._text_field_name } :({ self ._tokenize_and_escape_query (self ._text )} )"
948
+ )
864
949
if filter_expression and filter_expression != "*" :
865
950
text += f" AND { filter_expression } "
866
951
return text
0 commit comments