Skip to content

Commit dca2326

Browse files
committed
Merge branch 'main' into 0.6.0
2 parents 95ffe75 + 9f22a9a commit dca2326

File tree

6 files changed

+347
-7
lines changed

6 files changed

+347
-7
lines changed

redisvl/exceptions.py

+6
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,9 @@ def __init__(self, message, index=None):
3030
if index is not None:
3131
message = f"Validation failed for object at index {index}: {message}"
3232
super().__init__(message)
33+
34+
35+
class QueryValidationError(RedisVLError):
36+
"""Error when validating a query."""
37+
38+
pass

redisvl/index/index.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Union,
1919
)
2020

21+
from redisvl.query.query import VectorQuery
2122
from redisvl.redis.utils import convert_bytes, make_dict
2223
from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper
2324

@@ -34,6 +35,7 @@
3435
from redis.commands.search.indexDefinition import IndexDefinition
3536

3637
from redisvl.exceptions import (
38+
QueryValidationError,
3739
RedisModuleVersionError,
3840
RedisSearchError,
3941
RedisVLError,
@@ -46,16 +48,18 @@
4648
BaseVectorQuery,
4749
CountQuery,
4850
FilterQuery,
49-
HybridQuery,
5051
)
5152
from redisvl.query.filter import FilterExpression
5253
from redisvl.redis.connection import (
5354
RedisConnectionFactory,
5455
convert_index_info_to_schema,
5556
)
56-
from redisvl.redis.utils import convert_bytes
5757
from redisvl.schema import IndexSchema, StorageType
58-
from redisvl.schema.fields import VECTOR_NORM_MAP, VectorDistanceMetric
58+
from redisvl.schema.fields import (
59+
VECTOR_NORM_MAP,
60+
VectorDistanceMetric,
61+
VectorIndexAlgorithm,
62+
)
5963
from redisvl.utils.log import get_logger
6064

6165
logger = get_logger(__name__)
@@ -194,6 +198,15 @@ def _storage(self) -> BaseStorage:
194198
index_schema=self.schema
195199
)
196200

201+
def _validate_query(self, query: BaseQuery) -> None:
202+
"""Validate a query."""
203+
if isinstance(query, VectorQuery):
204+
field = self.schema.fields[query._vector_field_name]
205+
if query.ef_runtime and field.attrs.algorithm != VectorIndexAlgorithm.HNSW: # type: ignore
206+
raise QueryValidationError(
207+
"Vector field using 'flat' algorithm does not support EF_RUNTIME query parameter."
208+
)
209+
197210
@property
198211
def name(self) -> str:
199212
"""The name of the Redis search index."""
@@ -837,6 +850,10 @@ def batch_query(
837850

838851
def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
839852
"""Execute a query and process results."""
853+
try:
854+
self._validate_query(query)
855+
except QueryValidationError as e:
856+
raise QueryValidationError(f"Invalid query: {str(e)}") from e
840857
results = self.search(query.query, query_params=query.params)
841858
return process_results(results, query=query, schema=self.schema)
842859

@@ -1401,7 +1418,8 @@ async def _aggregate(
14011418
) -> List[Dict[str, Any]]:
14021419
"""Execute an aggregation query and processes the results."""
14031420
results = await self.aggregate(
1404-
aggregation_query, query_params=aggregation_query.params # type: ignore[attr-defined]
1421+
aggregation_query,
1422+
query_params=aggregation_query.params, # type: ignore[attr-defined]
14051423
)
14061424
return process_aggregate_results(
14071425
results,
@@ -1529,6 +1547,10 @@ async def batch_query(
15291547

15301548
async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
15311549
"""Asynchronously execute a query and process results."""
1550+
try:
1551+
self._validate_query(query)
1552+
except QueryValidationError as e:
1553+
raise QueryValidationError(f"Invalid query: {str(e)}") from e
15321554
results = await self.search(query.query, query_params=query.params)
15331555
return process_results(results, query=query, schema=self.schema)
15341556

tests/conftest.py

+210
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import pytest
55
from testcontainers.compose import DockerCompose
66

7+
from redisvl.index.index import AsyncSearchIndex, SearchIndex
78
from redisvl.redis.connection import RedisConnectionFactory
9+
from redisvl.redis.utils import array_to_buffer
810
from redisvl.utils.vectorize import HFTextVectorizer
911

1012

@@ -191,3 +193,211 @@ def pytest_collection_modifyitems(
191193
for item in items:
192194
if item.get_closest_marker("requires_api_keys"):
193195
item.add_marker(skip_api)
196+
197+
198+
@pytest.fixture
199+
def flat_index(sample_data, redis_url):
200+
"""
201+
A fixture that uses the "flag" algorithm for its vector field.
202+
"""
203+
# construct a search index from the schema
204+
index = SearchIndex.from_dict(
205+
{
206+
"index": {
207+
"name": "user_index",
208+
"prefix": "v1",
209+
"storage_type": "hash",
210+
},
211+
"fields": [
212+
{"name": "description", "type": "text"},
213+
{"name": "credit_score", "type": "tag"},
214+
{"name": "job", "type": "text"},
215+
{"name": "age", "type": "numeric"},
216+
{"name": "last_updated", "type": "numeric"},
217+
{"name": "location", "type": "geo"},
218+
{
219+
"name": "user_embedding",
220+
"type": "vector",
221+
"attrs": {
222+
"dims": 3,
223+
"distance_metric": "cosine",
224+
"algorithm": "flat",
225+
"datatype": "float32",
226+
},
227+
},
228+
],
229+
},
230+
redis_url=redis_url,
231+
)
232+
233+
# create the index (no data yet)
234+
index.create(overwrite=True)
235+
236+
# Prepare and load the data
237+
def hash_preprocess(item: dict) -> dict:
238+
return {
239+
**item,
240+
"user_embedding": array_to_buffer(item["user_embedding"], "float32"),
241+
}
242+
243+
index.load(sample_data, preprocess=hash_preprocess)
244+
245+
# run the test
246+
yield index
247+
248+
# clean up
249+
index.delete(drop=True)
250+
251+
252+
@pytest.fixture
253+
async def async_flat_index(sample_data, redis_url):
254+
"""
255+
A fixture that uses the "flag" algorithm for its vector field.
256+
"""
257+
# construct a search index from the schema
258+
index = AsyncSearchIndex.from_dict(
259+
{
260+
"index": {
261+
"name": "user_index",
262+
"prefix": "v1",
263+
"storage_type": "hash",
264+
},
265+
"fields": [
266+
{"name": "description", "type": "text"},
267+
{"name": "credit_score", "type": "tag"},
268+
{"name": "job", "type": "text"},
269+
{"name": "age", "type": "numeric"},
270+
{"name": "last_updated", "type": "numeric"},
271+
{"name": "location", "type": "geo"},
272+
{
273+
"name": "user_embedding",
274+
"type": "vector",
275+
"attrs": {
276+
"dims": 3,
277+
"distance_metric": "cosine",
278+
"algorithm": "flat",
279+
"datatype": "float32",
280+
},
281+
},
282+
],
283+
},
284+
redis_url=redis_url,
285+
)
286+
287+
# create the index (no data yet)
288+
await index.create(overwrite=True)
289+
290+
# Prepare and load the data
291+
def hash_preprocess(item: dict) -> dict:
292+
return {
293+
**item,
294+
"user_embedding": array_to_buffer(item["user_embedding"], "float32"),
295+
}
296+
297+
await index.load(sample_data, preprocess=hash_preprocess)
298+
299+
# run the test
300+
yield index
301+
302+
# clean up
303+
await index.delete(drop=True)
304+
305+
306+
@pytest.fixture
307+
async def async_hnsw_index(sample_data, redis_url):
308+
"""
309+
A fixture that uses the "hnsw" algorithm for its vector field.
310+
"""
311+
index = AsyncSearchIndex.from_dict(
312+
{
313+
"index": {
314+
"name": "user_index",
315+
"prefix": "v1",
316+
"storage_type": "hash",
317+
},
318+
"fields": [
319+
{"name": "description", "type": "text"},
320+
{"name": "credit_score", "type": "tag"},
321+
{"name": "job", "type": "text"},
322+
{"name": "age", "type": "numeric"},
323+
{"name": "last_updated", "type": "numeric"},
324+
{"name": "location", "type": "geo"},
325+
{
326+
"name": "user_embedding",
327+
"type": "vector",
328+
"attrs": {
329+
"dims": 3,
330+
"distance_metric": "cosine",
331+
"algorithm": "hnsw",
332+
"datatype": "float32",
333+
},
334+
},
335+
],
336+
},
337+
redis_url=redis_url,
338+
)
339+
340+
# create the index (no data yet)
341+
await index.create(overwrite=True)
342+
343+
# Prepare and load the data
344+
def hash_preprocess(item: dict) -> dict:
345+
return {
346+
**item,
347+
"user_embedding": array_to_buffer(item["user_embedding"], "float32"),
348+
}
349+
350+
await index.load(sample_data, preprocess=hash_preprocess)
351+
352+
# run the test
353+
yield index
354+
355+
356+
@pytest.fixture
357+
def hnsw_index(sample_data, redis_url):
358+
"""
359+
A fixture that uses the "hnsw" algorithm for its vector field.
360+
"""
361+
index = SearchIndex.from_dict(
362+
{
363+
"index": {
364+
"name": "user_index",
365+
"prefix": "v1",
366+
"storage_type": "hash",
367+
},
368+
"fields": [
369+
{"name": "description", "type": "text"},
370+
{"name": "credit_score", "type": "tag"},
371+
{"name": "job", "type": "text"},
372+
{"name": "age", "type": "numeric"},
373+
{"name": "last_updated", "type": "numeric"},
374+
{"name": "location", "type": "geo"},
375+
{
376+
"name": "user_embedding",
377+
"type": "vector",
378+
"attrs": {
379+
"dims": 3,
380+
"distance_metric": "cosine",
381+
"algorithm": "hnsw",
382+
"datatype": "float32",
383+
},
384+
},
385+
],
386+
},
387+
redis_url=redis_url,
388+
)
389+
390+
# create the index (no data yet)
391+
index.create(overwrite=True)
392+
393+
# Prepare and load the data
394+
def hash_preprocess(item: dict) -> dict:
395+
return {
396+
**item,
397+
"user_embedding": array_to_buffer(item["user_embedding"], "float32"),
398+
}
399+
400+
index.load(sample_data, preprocess=hash_preprocess)
401+
402+
# run the test
403+
yield index

tests/integration/test_async_search_index.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@
55
from redis import Redis as SyncRedis
66
from redis.asyncio import Redis as AsyncRedis
77

8-
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError, RedisVLError
8+
from redisvl.exceptions import (
9+
QueryValidationError,
10+
RedisModuleVersionError,
11+
RedisSearchError,
12+
RedisVLError,
13+
)
914
from redisvl.index import AsyncSearchIndex
1015
from redisvl.query import VectorQuery
1116
from redisvl.query.query import FilterQuery
1217
from redisvl.redis.utils import convert_bytes
1318
from redisvl.schema import IndexSchema, StorageType
19+
from redisvl.schema.fields import VectorIndexAlgorithm
1420

1521
fields = [{"name": "test", "type": "tag"}]
1622

@@ -614,3 +620,41 @@ async def test_async_search_index_expire_keys(async_index):
614620
ttl = await client.ttl(key)
615621
assert ttl > 0
616622
assert ttl <= 30
623+
624+
625+
@pytest.mark.asyncio
626+
async def test_search_index_validates_query_with_flat_algorithm(
627+
async_flat_index, sample_data
628+
):
629+
assert (
630+
async_flat_index.schema.fields["user_embedding"].attrs.algorithm
631+
== VectorIndexAlgorithm.FLAT
632+
)
633+
query = VectorQuery(
634+
[0.1, 0.1, 0.5],
635+
"user_embedding",
636+
return_fields=["user", "credit_score", "age", "job", "location"],
637+
num_results=7,
638+
ef_runtime=100,
639+
)
640+
with pytest.raises(QueryValidationError):
641+
await async_flat_index.query(query)
642+
643+
644+
@pytest.mark.asyncio
645+
async def test_search_index_validates_query_with_hnsw_algorithm(
646+
async_hnsw_index, sample_data
647+
):
648+
assert (
649+
async_hnsw_index.schema.fields["user_embedding"].attrs.algorithm
650+
== VectorIndexAlgorithm.HNSW
651+
)
652+
query = VectorQuery(
653+
[0.1, 0.1, 0.5],
654+
"user_embedding",
655+
return_fields=["user", "credit_score", "age", "job", "location"],
656+
num_results=7,
657+
ef_runtime=100,
658+
)
659+
# Should not raise
660+
await async_hnsw_index.query(query)

0 commit comments

Comments
 (0)