Skip to content

Commit a1ed87f

Browse files
Add support for EF_RUNTIME param (#317)
Add support for EF_RUNTIME param at query time for HNSW vector queries. --------- Co-authored-by: Tyler Hutcherson <[email protected]>
1 parent ab3e711 commit a1ed87f

File tree

6 files changed

+321
-44
lines changed

6 files changed

+321
-44
lines changed

.cursor/rules/redisvl.mdc

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
---
2+
description:
3+
globs:
4+
alwaysApply: true
5+
---
6+
7+
# Rules for working on RedisVL
8+
- Do not change this line of code unless explicitly asked. It's already correct:
9+
```
10+
token.strip().strip(",").replace("“", "").replace("”", "").lower()
11+
```

redisvl/query/query.py

+124-39
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,23 @@
1010

1111

1212
class BaseQuery(RedisQuery):
13-
"""Base query class used to subclass many query types."""
13+
"""
14+
Base query class used to subclass many query types.
15+
16+
NOTE: In the base class, the `_query_string` field is set once on
17+
initialization, and afterward, redis-py expects to be able to read it. By
18+
contrast, our query subclasses allow users to call methods that alter the
19+
query string at runtime. To avoid having to rebuild `_query_string` every
20+
time one of these methods is called, we lazily build the query string when a
21+
user calls `query()` or accesses the property `_query_string`, when the
22+
underlying `_built_query_string` field is None. Any method that alters the query
23+
string should set `_built_query_string` to None so that the next time the query
24+
string is accessed, it is rebuilt.
25+
"""
1426

1527
_params: Dict[str, Any] = {}
1628
_filter_expression: Union[str, FilterExpression] = FilterExpression("*")
29+
_built_query_string: Optional[str] = None
1730

1831
def __init__(self, query_string: str = "*"):
1932
"""
@@ -22,8 +35,15 @@ def __init__(self, query_string: str = "*"):
2235
Args:
2336
query_string (str, optional): The query string to use. Defaults to '*'.
2437
"""
38+
# The parent class expects a query string, so we pass it in, but we'll
39+
# actually manage building it dynamically.
2540
super().__init__(query_string)
2641

42+
# This is a private field that we use to track whether the query string
43+
# has been built, and we set it to None here to indicate that the field
44+
# has not been built yet.
45+
self._built_query_string = None
46+
2747
def __str__(self) -> str:
2848
"""Return the string representation of the query."""
2949
return " ".join([str(x) for x in self.get_args()])
@@ -54,8 +74,8 @@ def set_filter(
5474
"filter_expression must be of type FilterExpression or string or None"
5575
)
5676

57-
# Reset the query string
58-
self._query_string = self._build_query_string()
77+
# Invalidate the query string
78+
self._built_query_string = None
5979

6080
@property
6181
def filter(self) -> Union[str, FilterExpression]:
@@ -72,6 +92,18 @@ def params(self) -> Dict[str, Any]:
7292
"""Return the query parameters."""
7393
return self._params
7494

95+
@property
96+
def _query_string(self) -> str:
97+
"""Maintains compatibility with parent class while providing lazy loading."""
98+
if self._built_query_string is None:
99+
self._built_query_string = self._build_query_string()
100+
return self._built_query_string
101+
102+
@_query_string.setter
103+
def _query_string(self, value: Optional[str]):
104+
"""Setter for _query_string to maintain compatibility with parent class."""
105+
self._built_query_string = value
106+
75107

76108
class FilterQuery(BaseQuery):
77109
def __init__(
@@ -107,9 +139,9 @@ def __init__(
107139

108140
self._num_results = num_results
109141

110-
# Initialize the base query with the full query string constructed from the filter expression
111-
query_string = self._build_query_string()
112-
super().__init__(query_string)
142+
# Initialize the base query with the query string from the property
143+
super().__init__("*")
144+
self._built_query_string = None # Ensure it's invalidated after initialization
113145

114146
# Handle query settings
115147
if return_fields:
@@ -161,9 +193,9 @@ def __init__(
161193
if params:
162194
self._params = params
163195

164-
# Initialize the base query with the full query string constructed from the filter expression
165-
query_string = self._build_query_string()
166-
super().__init__(query_string)
196+
# Initialize the base query with the query string from the property
197+
super().__init__("*")
198+
self._built_query_string = None
167199

168200
# Query specific modifications
169201
self.no_content().paging(0, 0).dialect(dialect)
@@ -178,6 +210,8 @@ def _build_query_string(self) -> str:
178210
class BaseVectorQuery:
179211
DISTANCE_ID: str = "vector_distance"
180212
VECTOR_PARAM: str = "vector"
213+
EF_RUNTIME: str = "EF_RUNTIME"
214+
EF_RUNTIME_PARAM: str = "EF"
181215

182216
_normalize_vector_distance: bool = False
183217

@@ -204,6 +238,7 @@ def __init__(
204238
in_order: bool = False,
205239
hybrid_policy: Optional[str] = None,
206240
batch_size: Optional[int] = None,
241+
ef_runtime: Optional[int] = None,
207242
normalize_vector_distance: bool = False,
208243
):
209244
"""A query for running a vector search along with an optional filter
@@ -241,6 +276,9 @@ def __init__(
241276
of vectors to fetch in each batch. Larger values may improve performance
242277
at the cost of memory usage. Only applies when hybrid_policy="BATCHES".
243278
Defaults to None, which lets Redis auto-select an appropriate batch size.
279+
ef_runtime (Optional[int]): Controls the size of the dynamic candidate list for HNSW
280+
algorithm at query time. Higher values improve recall at the expense of
281+
slower search performance. Defaults to None, which uses the index-defined value.
244282
normalize_vector_distance (bool): Redis supports 3 distance metrics: L2 (euclidean),
245283
IP (inner product), and COSINE. By default, L2 distance returns an unbounded value.
246284
COSINE distance returns a value between 0 and 2. IP returns a value determined by
@@ -260,11 +298,13 @@ def __init__(
260298
self._num_results = num_results
261299
self._hybrid_policy: Optional[HybridPolicy] = None
262300
self._batch_size: Optional[int] = None
301+
self._ef_runtime: Optional[int] = None
263302
self._normalize_vector_distance = normalize_vector_distance
264303
self.set_filter(filter_expression)
265-
query_string = self._build_query_string()
266304

267-
super().__init__(query_string)
305+
# Initialize the base query
306+
super().__init__("*")
307+
self._built_query_string = None
268308

269309
# Handle query modifiers
270310
if return_fields:
@@ -289,6 +329,9 @@ def __init__(
289329
if batch_size is not None:
290330
self.set_batch_size(batch_size)
291331

332+
if ef_runtime is not None:
333+
self.set_ef_runtime(ef_runtime)
334+
292335
def _build_query_string(self) -> str:
293336
"""Build the full query string for vector search with optional filtering."""
294337
filter_expression = self._filter_expression
@@ -308,6 +351,10 @@ def _build_query_string(self) -> str:
308351
if self._hybrid_policy == HybridPolicy.BATCHES and self._batch_size:
309352
knn_query += f" BATCH_SIZE {self._batch_size}"
310353

354+
# Add EF_RUNTIME parameter if specified
355+
if self._ef_runtime:
356+
knn_query += f" {self.EF_RUNTIME} ${self.EF_RUNTIME_PARAM}"
357+
311358
# Add distance field alias
312359
knn_query += f" AS {self.DISTANCE_ID}"
313360

@@ -330,8 +377,8 @@ def set_hybrid_policy(self, hybrid_policy: str):
330377
f"hybrid_policy must be one of {', '.join([p.value for p in HybridPolicy])}"
331378
)
332379

333-
# Reset the query string
334-
self._query_string = self._build_query_string()
380+
# Invalidate the query string
381+
self._built_query_string = None
335382

336383
def set_batch_size(self, batch_size: int):
337384
"""Set the batch size for the query.
@@ -349,8 +396,28 @@ def set_batch_size(self, batch_size: int):
349396
raise ValueError("batch_size must be positive")
350397
self._batch_size = batch_size
351398

352-
# Reset the query string
353-
self._query_string = self._build_query_string()
399+
# Invalidate the query string
400+
self._built_query_string = None
401+
402+
def set_ef_runtime(self, ef_runtime: int):
403+
"""Set the EF_RUNTIME parameter for the query.
404+
405+
Args:
406+
ef_runtime (int): The EF_RUNTIME value to use for HNSW algorithm.
407+
Higher values improve recall at the expense of slower search.
408+
409+
Raises:
410+
TypeError: If ef_runtime is not an integer
411+
ValueError: If ef_runtime is not positive
412+
"""
413+
if not isinstance(ef_runtime, int):
414+
raise TypeError("ef_runtime must be an integer")
415+
if ef_runtime <= 0:
416+
raise ValueError("ef_runtime must be positive")
417+
self._ef_runtime = ef_runtime
418+
419+
# Invalidate the query string
420+
self._built_query_string = None
354421

355422
@property
356423
def hybrid_policy(self) -> Optional[str]:
@@ -370,6 +437,15 @@ def batch_size(self) -> Optional[int]:
370437
"""
371438
return self._batch_size
372439

440+
@property
441+
def ef_runtime(self) -> Optional[int]:
442+
"""Return the EF_RUNTIME parameter for the query.
443+
444+
Returns:
445+
Optional[int]: The EF_RUNTIME value for the query.
446+
"""
447+
return self._ef_runtime
448+
373449
@property
374450
def params(self) -> Dict[str, Any]:
375451
"""Return the parameters for the query.
@@ -382,7 +458,11 @@ def params(self) -> Dict[str, Any]:
382458
else:
383459
vector = array_to_buffer(self._vector, dtype=self._dtype)
384460

385-
params = {self.VECTOR_PARAM: vector}
461+
params: Dict[str, Any] = {self.VECTOR_PARAM: vector}
462+
463+
# Add EF_RUNTIME parameter if specified
464+
if self._ef_runtime is not None:
465+
params[self.EF_RUNTIME_PARAM] = self._ef_runtime
386466

387467
return params
388468

@@ -475,6 +555,11 @@ def __init__(
475555
self._epsilon: Optional[float] = None
476556
self._hybrid_policy: Optional[HybridPolicy] = None
477557
self._batch_size: Optional[int] = None
558+
self._normalize_vector_distance = normalize_vector_distance
559+
560+
# Initialize the base query
561+
super().__init__("*")
562+
self._built_query_string = None
478563

479564
if epsilon is not None:
480565
self.set_epsilon(epsilon)
@@ -485,12 +570,8 @@ def __init__(
485570
if batch_size is not None:
486571
self.set_batch_size(batch_size)
487572

488-
self._normalize_vector_distance = normalize_vector_distance
489573
self.set_distance_threshold(distance_threshold)
490574
self.set_filter(filter_expression)
491-
query_string = self._build_query_string()
492-
493-
super().__init__(query_string)
494575

495576
# Handle query modifiers
496577
if return_fields:
@@ -533,8 +614,8 @@ def set_distance_threshold(self, distance_threshold: float):
533614
distance_threshold = denorm_cosine_distance(distance_threshold)
534615
self._distance_threshold = distance_threshold
535616

536-
# Reset the query string
537-
self._query_string = self._build_query_string()
617+
# Invalidate the query string
618+
self._built_query_string = None
538619

539620
def set_epsilon(self, epsilon: float):
540621
"""Set the epsilon parameter for the range query.
@@ -553,8 +634,8 @@ def set_epsilon(self, epsilon: float):
553634
raise ValueError("epsilon must be non-negative")
554635
self._epsilon = epsilon
555636

556-
# Reset the query string
557-
self._query_string = self._build_query_string()
637+
# Invalidate the query string
638+
self._built_query_string = None
558639

559640
def set_hybrid_policy(self, hybrid_policy: str):
560641
"""Set the hybrid policy for the query.
@@ -573,8 +654,8 @@ def set_hybrid_policy(self, hybrid_policy: str):
573654
f"hybrid_policy must be one of {', '.join([p.value for p in HybridPolicy])}"
574655
)
575656

576-
# Reset the query string
577-
self._query_string = self._build_query_string()
657+
# Invalidate the query string
658+
self._built_query_string = None
578659

579660
def set_batch_size(self, batch_size: int):
580661
"""Set the batch size for the query.
@@ -592,8 +673,8 @@ def set_batch_size(self, batch_size: int):
592673
raise ValueError("batch_size must be positive")
593674
self._batch_size = batch_size
594675

595-
# Reset the query string
596-
self._query_string = self._build_query_string()
676+
# Invalidate the query string
677+
self._built_query_string = None
597678

598679
def _build_query_string(self) -> str:
599680
"""Build the full query string for vector range queries with optional filtering"""
@@ -663,20 +744,22 @@ def params(self) -> Dict[str, Any]:
663744
Dict[str, Any]: The parameters for the query.
664745
"""
665746
if isinstance(self._vector, bytes):
666-
vector_param = self._vector
747+
vector = self._vector
667748
else:
668-
vector_param = array_to_buffer(self._vector, dtype=self._dtype)
749+
vector = array_to_buffer(self._vector, dtype=self._dtype)
669750

670751
params = {
671-
self.VECTOR_PARAM: vector_param,
752+
self.VECTOR_PARAM: vector,
672753
self.DISTANCE_THRESHOLD_PARAM: self._distance_threshold,
673754
}
674755

675756
# Add hybrid policy and batch size as query parameters (not in query string)
676-
if self._hybrid_policy:
757+
if self._hybrid_policy is not None:
677758
params[self.HYBRID_POLICY_PARAM] = self._hybrid_policy.value
678-
679-
if self._hybrid_policy == HybridPolicy.BATCHES and self._batch_size:
759+
if (
760+
self._hybrid_policy == HybridPolicy.BATCHES
761+
and self._batch_size is not None
762+
):
680763
params[self.BATCH_SIZE_PARAM] = self._batch_size
681764

682765
return params
@@ -763,7 +846,7 @@ def __init__(
763846
TypeError: If stopwords is not a valid iterable set of strings.
764847
"""
765848
self._text = text
766-
self._text_field = text_field_name
849+
self._text_field_name = text_field_name
767850
self._num_results = num_results
768851

769852
self._set_stopwords(stopwords)
@@ -772,9 +855,9 @@ def __init__(
772855
if params:
773856
self._params = params
774857

775-
# initialize the base query with the full query string and filter expression
776-
query_string = self._build_query_string()
777-
super().__init__(query_string)
858+
# Initialize the base query
859+
super().__init__("*")
860+
self._built_query_string = None
778861

779862
# handle query settings
780863
self.scorer(text_scorer)
@@ -860,7 +943,9 @@ def _build_query_string(self) -> str:
860943
else:
861944
filter_expression = ""
862945

863-
text = f"@{self._text_field}:({self._tokenize_and_escape_query(self._text)})"
946+
text = (
947+
f"@{self._text_field_name}:({self._tokenize_and_escape_query(self._text)})"
948+
)
864949
if filter_expression and filter_expression != "*":
865950
text += f" AND {filter_expression}"
866951
return text

0 commit comments

Comments
 (0)