-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathindex.py
1693 lines (1403 loc) · 60.6 KB
/
index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import asyncio
import json
import threading
import time
import warnings
import weakref
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
)
from redisvl.query.query import VectorQuery
from redisvl.redis.utils import convert_bytes, make_dict
from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper
if TYPE_CHECKING:
from redis.commands.search.aggregation import AggregateResult
from redis.commands.search.document import Document
from redis.commands.search.result import Result
from redisvl.query.query import BaseQuery
import redis
import redis.asyncio as aredis
from redis.client import NEVER_DECODE
from redis.commands.helpers import get_protocol_version # type: ignore
from redis.commands.search.indexDefinition import IndexDefinition
from redisvl.exceptions import (
QueryValidationError,
RedisModuleVersionError,
RedisSearchError,
RedisVLError,
SchemaValidationError,
)
from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage
from redisvl.query import (
AggregationQuery,
BaseQuery,
BaseVectorQuery,
CountQuery,
FilterQuery,
)
from redisvl.query.filter import FilterExpression
from redisvl.redis.connection import (
RedisConnectionFactory,
convert_index_info_to_schema,
)
from redisvl.schema import IndexSchema, StorageType
from redisvl.schema.fields import (
VECTOR_NORM_MAP,
VectorDistanceMetric,
VectorIndexAlgorithm,
)
from redisvl.utils.log import get_logger
logger = get_logger(__name__)
REQUIRED_MODULES_FOR_INTROSPECTION = [
{"name": "search", "ver": 20810},
{"name": "searchlight", "ver": 20810},
]
SearchParams = Union[
Tuple[
Union[str, BaseQuery],
Union[Dict[str, Union[str, int, float, bytes]], None],
],
Union[str, BaseQuery],
]
def process_results(
results: "Result", query: BaseQuery, schema: IndexSchema
) -> List[Dict[str, Any]]:
"""Convert a list of search Result objects into a list of document
dictionaries.
This function processes results from Redis, handling different storage
types and query types. For JSON storage with empty return fields, it
unpacks the JSON object while retaining the document ID. The 'payload'
field is also removed from all resulting documents for consistency.
Args:
results (Result): The search results from Redis.
query (BaseQuery): The query object used for the search.
storage_type (StorageType): The storage type of the search
index (json or hash).
Returns:
List[Dict[str, Any]]: A list of processed document dictionaries.
"""
# Handle count queries
if isinstance(query, CountQuery):
return results.total
# Determine if unpacking JSON is needed
unpack_json = (
(schema.index.storage_type == StorageType.JSON)
and isinstance(query, FilterQuery)
and not query._return_fields # type: ignore
)
if (isinstance(query, BaseVectorQuery)) and query._normalize_vector_distance:
dist_metric = VectorDistanceMetric(
schema.fields[query._vector_field_name].attrs.distance_metric.upper() # type: ignore
)
if dist_metric == VectorDistanceMetric.IP:
warnings.warn(
"Attempting to normalize inner product distance metric. Use cosine distance instead which is normalized inner product by definition."
)
norm_fn = VECTOR_NORM_MAP[dist_metric.value]
else:
norm_fn = None
# Process records
def _process(doc: "Document") -> Dict[str, Any]:
doc_dict = doc.__dict__
# Unpack and Project JSON fields properly
if unpack_json and "json" in doc_dict:
json_data = doc_dict.get("json", {})
if isinstance(json_data, str):
json_data = json.loads(json_data)
if isinstance(json_data, dict):
return {"id": doc_dict.get("id"), **json_data}
raise ValueError(f"Unable to parse json data from Redis {json_data}")
if norm_fn:
# convert float back to string to be consistent
doc_dict[query.DISTANCE_ID] = str( # type: ignore
norm_fn(float(doc_dict[query.DISTANCE_ID])) # type: ignore
)
# Remove 'payload' if present
doc_dict.pop("payload", None)
return doc_dict
return [_process(doc) for doc in results.docs]
def process_aggregate_results(
results: "AggregateResult", query: AggregationQuery, storage_type: StorageType
) -> List[Dict[str, Any]]:
"""Convert an aggregate reslt object into a list of document dictionaries.
This function processes results from Redis, handling different storage
types and query types. For JSON storage with empty return fields, it
unpacks the JSON object while retaining the document ID. The 'payload'
field is also removed from all resulting documents for consistency.
Args:
results (AggregateResult): The aggregate results from Redis.
query (AggregationQuery): The aggregation query object used for the aggregation.
storage_type (StorageType): The storage type of the search
index (json or hash).
Returns:
List[Dict[str, Any]]: A list of processed document dictionaries.
"""
def _process(row):
result = make_dict(convert_bytes(row))
result.pop("__score", None)
return result
return [_process(r) for r in results.rows]
class BaseSearchIndex:
"""Base search engine class"""
_STORAGE_MAP = {
StorageType.HASH: HashStorage,
StorageType.JSON: JsonStorage,
}
schema: IndexSchema
def __init__(*args, **kwargs):
pass
@property
def _storage(self) -> BaseStorage:
"""The storage type for the index schema."""
return self._STORAGE_MAP[self.schema.index.storage_type](
index_schema=self.schema
)
def _validate_query(self, query: BaseQuery) -> None:
"""Validate a query."""
if isinstance(query, VectorQuery):
field = self.schema.fields[query._vector_field_name]
if query.ef_runtime and field.attrs.algorithm != VectorIndexAlgorithm.HNSW: # type: ignore
raise QueryValidationError(
"Vector field using 'flat' algorithm does not support EF_RUNTIME query parameter."
)
@property
def name(self) -> str:
"""The name of the Redis search index."""
return self.schema.index.name
@property
def prefix(self) -> str:
"""The optional key prefix that comes before a unique key value in
forming a Redis key."""
return self.schema.index.prefix
@property
def key_separator(self) -> str:
"""The optional separator between a defined prefix and key value in
forming a Redis key."""
return self.schema.index.key_separator
@property
def storage_type(self) -> StorageType:
"""The underlying storage type for the search index; either
hash or json."""
return self.schema.index.storage_type
@classmethod
def from_yaml(cls, schema_path: str, **kwargs):
"""Create a SearchIndex from a YAML schema file.
Args:
schema_path (str): Path to the YAML schema file.
Returns:
SearchIndex: A RedisVL SearchIndex object.
.. code-block:: python
from redisvl.index import SearchIndex
index = SearchIndex.from_yaml("schemas/schema.yaml", redis_url="redis://localhost:6379")
"""
schema = IndexSchema.from_yaml(schema_path)
return cls(schema=schema, **kwargs)
@classmethod
def from_dict(cls, schema_dict: Dict[str, Any], **kwargs):
"""Create a SearchIndex from a dictionary.
Args:
schema_dict (Dict[str, Any]): A dictionary containing the schema.
Returns:
SearchIndex: A RedisVL SearchIndex object.
.. code-block:: python
from redisvl.index import SearchIndex
index = SearchIndex.from_dict({
"index": {
"name": "my-index",
"prefix": "rvl",
"storage_type": "hash",
},
"fields": [
{"name": "doc-id", "type": "tag"}
]
}, redis_url="redis://localhost:6379")
"""
schema = IndexSchema.from_dict(schema_dict)
return cls(schema=schema, **kwargs)
def disconnect(self):
"""Disconnect from the Redis database."""
raise NotImplementedError("This method should be implemented by subclasses.")
def key(self, id: str) -> str:
"""Construct a redis key as a combination of an index key prefix (optional)
and specified id.
The id is typically either a unique identifier, or
derived from some domain-specific metadata combination (like a document
id or chunk id).
Args:
id (str): The specified unique identifier for a particular
document indexed in Redis.
Returns:
str: The full Redis key including key prefix and value as a string.
"""
return self._storage._key(
id=id,
prefix=self.schema.index.prefix,
key_separator=self.schema.index.key_separator,
)
class SearchIndex(BaseSearchIndex):
"""A search index class for interacting with Redis as a vector database.
The SearchIndex is instantiated with a reference to a Redis database and an
IndexSchema (YAML path or dictionary object) that describes the various
settings and field configurations.
.. code-block:: python
from redisvl.index import SearchIndex
# initialize the index object with schema from file
index = SearchIndex.from_yaml(
"schemas/schema.yaml",
redis_url="redis://localhost:6379",
validate_on_load=True
)
# create the index
index.create(overwrite=True, drop=False)
# data is an iterable of dictionaries
index.load(data)
# delete index and data
index.delete(drop=True)
"""
@deprecated_argument("connection_args", "Use connection_kwargs instead.")
def __init__(
self,
schema: IndexSchema,
redis_client: Optional[redis.Redis] = None,
redis_url: Optional[str] = None,
connection_kwargs: Optional[Dict[str, Any]] = None,
validate_on_load: bool = False,
**kwargs,
):
"""Initialize the RedisVL search index with a schema, Redis client
(or URL string with other connection args), connection_args, and other
kwargs.
Args:
schema (IndexSchema): Index schema object.
redis_client(Optional[redis.Redis]): An
instantiated redis client.
redis_url (Optional[str]): The URL of the Redis server to
connect to.
connection_kwargs (Dict[str, Any], optional): Redis client connection
args.
validate_on_load (bool, optional): Whether to validate data against schema
when loading. Defaults to False.
"""
if "connection_args" in kwargs:
connection_kwargs = kwargs.pop("connection_args")
if not isinstance(schema, IndexSchema):
raise ValueError("Must provide a valid IndexSchema object")
self.schema = schema
self._validate_on_load = validate_on_load
self._lib_name: Optional[str] = kwargs.pop("lib_name", None)
# Store connection parameters
self.__redis_client = redis_client
self._redis_url = redis_url
self._connection_kwargs = connection_kwargs or {}
self._lock = threading.Lock()
self._validated_client = False
self._owns_redis_client = redis_client is None
if self._owns_redis_client:
weakref.finalize(self, self.disconnect)
def disconnect(self):
"""Disconnect from the Redis database."""
if self._owns_redis_client is False:
logger.info("Index does not own client, not disconnecting")
return
if self.__redis_client:
self.__redis_client.close()
self.__redis_client = None
@classmethod
def from_existing(
cls,
name: str,
redis_client: Optional[redis.Redis] = None,
redis_url: Optional[str] = None,
**kwargs,
):
"""
Initialize from an existing search index in Redis by index name.
Args:
name (str): Name of the search index in Redis.
redis_client(Optional[redis.Redis]): An
instantiated redis client.
redis_url (Optional[str]): The URL of the Redis server to
connect to.
Raises:
ValueError: If redis_url or redis_client is not provided.
RedisModuleVersionError: If required Redis modules are not installed.
"""
try:
if redis_url:
redis_client = RedisConnectionFactory.get_redis_connection(
redis_url=redis_url,
required_modules=REQUIRED_MODULES_FOR_INTROSPECTION,
**kwargs,
)
elif redis_client:
RedisConnectionFactory.validate_sync_redis(
redis_client, required_modules=REQUIRED_MODULES_FOR_INTROSPECTION
)
except RedisModuleVersionError as e:
raise RedisModuleVersionError(
f"Loading from existing index failed. {str(e)}"
)
if not redis_client:
raise ValueError("Must provide either a redis_url or redis_client")
# Fetch index info and convert to schema
index_info = cls._info(name, redis_client)
schema_dict = convert_index_info_to_schema(index_info)
schema = IndexSchema.from_dict(schema_dict)
return cls(schema, redis_client, **kwargs)
@property
def client(self) -> Optional[redis.Redis]:
"""The underlying redis-py client object."""
return self.__redis_client
@property
def _redis_client(self) -> redis.Redis:
"""
Get a Redis client instance.
Lazily creates a Redis client instance if it doesn't exist.
"""
if self.__redis_client is None:
with self._lock:
if self.__redis_client is None:
self.__redis_client = RedisConnectionFactory.get_redis_connection(
redis_url=self._redis_url,
**self._connection_kwargs,
)
if not self._validated_client:
RedisConnectionFactory.validate_sync_redis(
self.__redis_client,
self._lib_name,
)
self._validated_client = True
return self.__redis_client
@deprecated_function("connect", "Pass connection parameters in __init__.")
def connect(self, redis_url: Optional[str] = None, **kwargs):
"""Connect to a Redis instance using the provided `redis_url`, falling
back to the `REDIS_URL` environment variable (if available).
Note: Additional keyword arguments (`**kwargs`) can be used to provide
extra options specific to the Redis connection.
Args:
redis_url (Optional[str], optional): The URL of the Redis server to
connect to.
Raises:
redis.exceptions.ConnectionError: If the connection to the Redis
server fails.
ValueError: If the Redis URL is not provided nor accessible
through the `REDIS_URL` environment variable.
ModuleNotFoundError: If required Redis modules are not installed.
"""
self.__redis_client = RedisConnectionFactory.get_redis_connection(
redis_url=redis_url, **kwargs
)
@deprecated_function("set_client", "Pass connection parameters in __init__.")
def set_client(self, redis_client: redis.Redis, **kwargs):
"""Manually set the Redis client to use with the search index.
This method configures the search index to use a specific Redis or
Async Redis client. It is useful for cases where an external,
custom-configured client is preferred instead of creating a new one.
Args:
redis_client (redis.Redis): A Redis or Async Redis
client instance to be used for the connection.
Raises:
TypeError: If the provided client is not valid.
"""
RedisConnectionFactory.validate_sync_redis(redis_client)
self.__redis_client = redis_client
return self
def create(self, overwrite: bool = False, drop: bool = False) -> None:
"""Create an index in Redis with the current schema and properties.
Args:
overwrite (bool, optional): Whether to overwrite the index if it
already exists. Defaults to False.
drop (bool, optional): Whether to drop all keys associated with the
index in the case of overwriting. Defaults to False.
Raises:
RuntimeError: If the index already exists and 'overwrite' is False.
ValueError: If no fields are defined for the index.
.. code-block:: python
# create an index in Redis; only if one does not exist with given name
index.create()
# overwrite an index in Redis without dropping associated data
index.create(overwrite=True)
# overwrite an index in Redis; drop associated data (clean slate)
index.create(overwrite=True, drop=True)
"""
# Check that fields are defined.
redis_fields = self.schema.redis_fields
if not redis_fields:
raise ValueError("No fields defined for index")
if not isinstance(overwrite, bool):
raise TypeError("overwrite must be of type bool")
if self.exists():
if not overwrite:
logger.info("Index already exists, not overwriting.")
return None
logger.info("Index already exists, overwriting.")
self.delete(drop=drop)
try:
self._redis_client.ft(self.name).create_index( # type: ignore
fields=redis_fields,
definition=IndexDefinition(
prefix=[self.schema.index.prefix], index_type=self._storage.type
),
)
except:
logger.exception("Error while trying to create the index")
raise
def delete(self, drop: bool = True):
"""Delete the search index while optionally dropping all keys associated
with the index.
Args:
drop (bool, optional): Delete the key / documents pairs in the
index. Defaults to True.
raises:
redis.exceptions.ResponseError: If the index does not exist.
"""
try:
self._redis_client.ft(self.schema.index.name).dropindex( # type: ignore
delete_documents=drop
)
except Exception as e:
raise RedisSearchError(f"Error while deleting index: {str(e)}") from e
def clear(self) -> int:
"""Clear all keys in Redis associated with the index, leaving the index
available and in-place for future insertions or updates.
Returns:
int: Count of records deleted from Redis.
"""
# Track deleted records
total_records_deleted: int = 0
# Paginate using queries and delete in batches
for batch in self.paginate(
FilterQuery(FilterExpression("*"), return_fields=["id"]), page_size=500
):
batch_keys = [record["id"] for record in batch]
record_deleted = self._redis_client.delete(*batch_keys) # type: ignore
total_records_deleted += record_deleted # type: ignore
return total_records_deleted
def drop_keys(self, keys: Union[str, List[str]]) -> int:
"""Remove a specific entry or entries from the index by it's key ID.
Args:
keys (Union[str, List[str]]): The document ID or IDs to remove from the index.
Returns:
int: Count of records deleted from Redis.
"""
if isinstance(keys, List):
return self._redis_client.delete(*keys) # type: ignore
else:
return self._redis_client.delete(keys) # type: ignore
def drop_documents(self, ids: Union[str, List[str]]) -> int:
"""Remove documents from the index by their document IDs.
This method converts document IDs to Redis keys automatically by applying
the index's key prefix and separator configuration.
Args:
ids (Union[str, List[str]]): The document ID or IDs to remove from the index.
Returns:
int: Count of documents deleted from Redis.
"""
if isinstance(ids, list):
if not ids:
return 0
keys = [self.key(id) for id in ids]
return self._redis_client.delete(*keys) # type: ignore
else:
key = self.key(ids)
return self._redis_client.delete(key) # type: ignore
def expire_keys(
self, keys: Union[str, List[str]], ttl: int
) -> Union[int, List[int]]:
"""Set the expiration time for a specific entry or entries in Redis.
Args:
keys (Union[str, List[str]]): The entry ID or IDs to set the expiration for.
ttl (int): The time-to-live in seconds.
"""
if isinstance(keys, list):
pipe = self._redis_client.pipeline() # type: ignore
for key in keys:
pipe.expire(key, ttl)
return pipe.execute()
else:
return self._redis_client.expire(keys, ttl) # type: ignore
def load(
self,
data: Iterable[Any],
id_field: Optional[str] = None,
keys: Optional[Iterable[str]] = None,
ttl: Optional[int] = None,
preprocess: Optional[Callable] = None,
batch_size: Optional[int] = None,
) -> List[str]:
"""Load objects to the Redis database. Returns the list of keys loaded
to Redis.
RedisVL automatically handles constructing the object keys, batching,
optional preprocessing steps, and setting optional expiration
(TTL policies) on keys.
Args:
data (Iterable[Any]): An iterable of objects to store.
id_field (Optional[str], optional): Specified field used as the id
portion of the redis key (after the prefix) for each
object. Defaults to None.
keys (Optional[Iterable[str]], optional): Optional iterable of keys.
Must match the length of objects if provided. Defaults to None.
ttl (Optional[int], optional): Time-to-live in seconds for each key.
Defaults to None.
preprocess (Optional[Callable], optional): A function to preprocess
objects before storage. Defaults to None.
batch_size (Optional[int], optional): Number of objects to write in
a single Redis pipeline execution. Defaults to class's
default batch size.
Returns:
List[str]: List of keys loaded to Redis.
Raises:
SchemaValidationError: If validation fails when validate_on_load is enabled.
RedisVLError: If there's an error loading data to Redis.
"""
try:
return self._storage.write(
self._redis_client, # type: ignore
objects=data,
id_field=id_field,
keys=keys,
ttl=ttl,
preprocess=preprocess,
batch_size=batch_size,
validate=self._validate_on_load,
)
except SchemaValidationError:
# Pass through validation errors directly
logger.exception("Schema validation error while loading data")
raise
except Exception as e:
# Wrap other errors as general RedisVL errors
logger.exception("Error while loading data to Redis")
raise RedisVLError(f"Failed to load data: {str(e)}") from e
def fetch(self, id: str) -> Optional[Dict[str, Any]]:
"""Fetch an object from Redis by id.
The id is typically either a unique identifier,
or derived from some domain-specific metadata combination
(like a document id or chunk id).
Args:
id (str): The specified unique identifier for a particular
document indexed in Redis.
Returns:
Dict[str, Any]: The fetched object.
"""
obj = self._storage.get(self._redis_client, [self.key(id)]) # type: ignore
if obj:
return convert_bytes(obj[0])
return None
def _aggregate(self, aggregation_query: AggregationQuery) -> List[Dict[str, Any]]:
"""Execute an aggretation query and processes the results."""
results = self.aggregate(
aggregation_query, query_params=aggregation_query.params # type: ignore[attr-defined]
)
return process_aggregate_results(
results,
query=aggregation_query,
storage_type=self.schema.index.storage_type,
)
def aggregate(self, *args, **kwargs) -> "AggregateResult":
"""Perform an aggregation operation against the index.
Wrapper around the aggregation API that adds the index name
to the query and passes along the rest of the arguments
to the redis-py ft().aggregate() method.
Returns:
Result: Raw Redis aggregation results.
"""
try:
return self._redis_client.ft(self.schema.index.name).aggregate( # type: ignore
*args, **kwargs
)
except Exception as e:
raise RedisSearchError(f"Error while aggregating: {str(e)}") from e
def batch_search(
self,
queries: List[SearchParams],
batch_size: int = 10,
) -> List["Result"]:
"""Perform a search against the index for multiple queries.
This method takes a list of queries and optionally query params and
returns a list of Result objects for each query. Results are
returned in the same order as the queries.
Args:
queries (List[SearchParams]): The queries to search for. batch_size
(int, optional): The number of queries to search for at a time.
Defaults to 10.
Returns:
List[Result]: The search results for each query.
"""
all_parsed = []
search = self._redis_client.ft(self.schema.index.name)
options = {}
if get_protocol_version(self._redis_client) not in ["3", 3]:
options[NEVER_DECODE] = True
for i in range(0, len(queries), batch_size):
batch_queries = queries[i : i + batch_size]
# redis-py doesn't support calling `search` in a pipeline,
# so we need to manually execute each command in a pipeline
# and parse the results
with self._redis_client.pipeline(transaction=False) as pipe:
batch_built_queries = []
for query in batch_queries:
if isinstance(query, tuple):
query_args, q = search._mk_query_args( # type: ignore
query[0], query_params=query[1]
)
else:
query_args, q = search._mk_query_args( # type: ignore
query, query_params=None
)
batch_built_queries.append(q)
pipe.execute_command(
"FT.SEARCH",
*query_args,
**options,
)
st = time.time()
results = pipe.execute()
# We don't know how long each query took, so we'll use the total time
# for all queries in the batch as the duration for each query
duration = (time.time() - st) * 1000.0
for i, query_results in enumerate(results):
_built_query = batch_built_queries[i]
parsed_result = search._parse_search( # type: ignore
query_results,
query=_built_query,
duration=duration,
)
# Return a parsed Result object for each query
all_parsed.append(parsed_result)
return all_parsed
def search(self, *args, **kwargs) -> "Result":
"""Perform a search against the index.
Wrapper around the search API that adds the index name
to the query and passes along the rest of the arguments
to the redis-py ft().search() method.
Returns:
Result: Raw Redis search results.
"""
try:
return self._redis_client.ft(self.schema.index.name).search( # type: ignore
*args, **kwargs
)
except Exception as e:
raise RedisSearchError(f"Error while searching: {str(e)}") from e
def batch_query(
self, queries: Sequence[BaseQuery], batch_size: int = 10
) -> List[List[Dict[str, Any]]]:
"""Execute a batch of queries and process results."""
results = self.batch_search(
[(query.query, query.params) for query in queries], batch_size=batch_size
)
all_parsed = []
for query, batch_results in zip(queries, results):
parsed = process_results(batch_results, query=query, schema=self.schema)
# Create separate lists of parsed results for each query
# passed in to the batch_search method, so that callers can
# access the results for each query individually
all_parsed.append(parsed)
return all_parsed
def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
"""Execute a query and process results."""
try:
self._validate_query(query)
except QueryValidationError as e:
raise QueryValidationError(f"Invalid query: {str(e)}") from e
results = self.search(query.query, query_params=query.params)
return process_results(results, query=query, schema=self.schema)
def query(self, query: Union[BaseQuery, AggregationQuery]) -> List[Dict[str, Any]]:
"""Execute a query on the index.
This method takes a BaseQuery or AggregationQuery object directly, and
handles post-processing of the search.
Args:
query (Union[BaseQuery, AggregateQuery]): The query to run.
Returns:
List[Result]: A list of search results.
.. code-block:: python
from redisvl.query import VectorQuery
query = VectorQuery(
vector=[0.16, -0.34, 0.98, 0.23],
vector_field_name="embedding",
num_results=3
)
results = index.query(query)
"""
if isinstance(query, AggregationQuery):
return self._aggregate(query)
else:
return self._query(query)
def paginate(self, query: BaseQuery, page_size: int = 30) -> Generator:
"""Execute a given query against the index and return results in
paginated batches.
This method accepts a RedisVL query instance, enabling pagination of
results which allows for subsequent processing over each batch with a
generator.
Args:
query (BaseQuery): The search query to be executed.
page_size (int, optional): The number of results to return in each
batch. Defaults to 30.
Yields:
A generator yielding batches of search results.
Raises:
TypeError: If the page_size argument is not of type int.
ValueError: If the page_size argument is less than or equal to zero.
.. code-block:: python
# Iterate over paginated search results in batches of 10
for result_batch in index.paginate(query, page_size=10):
# Process each batch of results
pass
Note:
The page_size parameter controls the number of items each result
batch contains. Adjust this value based on performance
considerations and the expected volume of search results.
"""
if not isinstance(page_size, int):
raise TypeError("page_size must be an integer")
if page_size <= 0:
raise ValueError("page_size must be greater than 0")
offset = 0
while True:
query.paging(offset, page_size)
results = self._query(query)
if not results:
break
yield results
# Increment the offset for the next batch of pagination
offset += page_size
def listall(self) -> List[str]:
"""List all search indices in Redis database.
Returns:
List[str]: The list of indices in the database.
"""
return convert_bytes(self._redis_client.execute_command("FT._LIST")) # type: ignore
def exists(self) -> bool:
"""Check if the index exists in Redis.
Returns:
bool: True if the index exists, False otherwise.
"""
return self.schema.index.name in self.listall()
@staticmethod
def _info(name: str, redis_client: redis.Redis) -> Dict[str, Any]:
"""Run FT.INFO to fetch information about the index."""
try:
return convert_bytes(redis_client.ft(name).info()) # type: ignore
except Exception as e:
raise RedisSearchError(
f"Error while fetching {name} index info: {str(e)}"
) from e
def info(self, name: Optional[str] = None) -> Dict[str, Any]:
"""Get information about the index.
Args:
name (str, optional): Index name to fetch info about.
Defaults to None.
Returns:
dict: A dictionary containing the information about the index.
"""
index_name = name or self.schema.index.name
return self._info(index_name, self._redis_client) # type: ignore
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.disconnect()
class AsyncSearchIndex(BaseSearchIndex):
"""A search index class for interacting with Redis as a vector database in
async-mode.
The AsyncSearchIndex is instantiated with a reference to a Redis database
and an IndexSchema (YAML path or dictionary object) that describes the
various settings and field configurations.
.. code-block:: python
from redisvl.index import AsyncSearchIndex
# initialize the index object with schema from file
index = AsyncSearchIndex.from_yaml(
"schemas/schema.yaml",