45
45
{"name" : "ReJSON" , "ver" : 20404 }
46
46
]
47
47
REDIS_DEFAULT_ESCAPED_CHARS = re .compile (r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]" )
48
- REDIS_SEARCH_SCHEMA = {
49
- "document_id" : TagField ("$.document_id" , as_name = "document_id" ),
50
- "metadata" : {
51
- # "source_id": TagField("$.metadata.source_id", as_name="source_id"),
52
- "source" : TagField ("$.metadata.source" , as_name = "source" ),
53
- # "author": TextField("$.metadata.author", as_name="author"),
54
- # "created_at": NumericField("$.metadata.created_at", as_name="created_at"),
55
- },
56
- "embedding" : VectorField (
57
- "$.embedding" ,
58
- REDIS_INDEX_TYPE ,
59
- {
60
- "TYPE" : "FLOAT64" ,
61
- "DIM" : VECTOR_DIMENSION ,
62
- "DISTANCE_METRIC" : REDIS_DISTANCE_METRIC ,
63
- },
64
- as_name = "embedding" ,
65
- ),
66
- }
67
48
68
49
# Helper functions
69
50
def unpack_schema (d : dict ):
@@ -82,22 +63,23 @@ async def _check_redis_module_exist(client: redis.Redis, modules: List[dict]):
82
63
error_message = "You must add the RediSearch (>= 2.6) and ReJSON (>= 2.4) modules from Redis Stack. " \
83
64
"Please refer to Redis Stack docs: https://redis.io/docs/stack/"
84
65
logging .error (error_message )
85
- raise ValueError (error_message )
66
+ raise AttributeError (error_message )
86
67
87
68
88
69
89
70
class RedisDataStore (DataStore ):
90
- def __init__ (self , client : redis .Redis ):
71
+ def __init__ (self , client : redis .Redis , redisearch_schema ):
91
72
self .client = client
73
+ self ._schema = redisearch_schema
92
74
# Init default metadata with sentinel values in case the document written has no metadata
93
75
self ._default_metadata = {
94
- field : "_null_" for field in REDIS_SEARCH_SCHEMA ["metadata" ]
76
+ field : "_null_" for field in redisearch_schema ["metadata" ]
95
77
}
96
78
97
79
### Redis Helper Methods ###
98
80
99
81
@classmethod
100
- async def init (cls ):
82
+ async def init (cls , ** kwargs ):
101
83
"""
102
84
Setup the index if it does not exist.
103
85
"""
@@ -112,7 +94,27 @@ async def init(cls):
112
94
raise e
113
95
114
96
await _check_redis_module_exist (client , modules = REDIS_REQUIRED_MODULES )
115
-
97
+
98
+ dim = kwargs .get ("dim" , VECTOR_DIMENSION )
99
+ redisearch_schema = {
100
+ "document_id" : TagField ("$.document_id" , as_name = "document_id" ),
101
+ "metadata" : {
102
+ "source_id" : TagField ("$.metadata.source_id" , as_name = "source_id" ),
103
+ "source" : TagField ("$.metadata.source" , as_name = "source" ),
104
+ "author" : TextField ("$.metadata.author" , as_name = "author" ),
105
+ "created_at" : NumericField ("$.metadata.created_at" , as_name = "created_at" ),
106
+ },
107
+ "embedding" : VectorField (
108
+ "$.embedding" ,
109
+ REDIS_INDEX_TYPE ,
110
+ {
111
+ "TYPE" : "FLOAT64" ,
112
+ "DIM" : dim ,
113
+ "DISTANCE_METRIC" : REDIS_DISTANCE_METRIC ,
114
+ },
115
+ as_name = "embedding" ,
116
+ ),
117
+ }
116
118
try :
117
119
# Check for existence of RediSearch Index
118
120
await client .ft (REDIS_INDEX_NAME ).info ()
@@ -123,11 +125,12 @@ async def init(cls):
123
125
definition = IndexDefinition (
124
126
prefix = [REDIS_DOC_PREFIX ], index_type = IndexType .JSON
125
127
)
126
- fields = list (unpack_schema (REDIS_SEARCH_SCHEMA ))
128
+ fields = list (unpack_schema (redisearch_schema ))
129
+ logging .info (f"Creating index with fields: { fields } " )
127
130
await client .ft (REDIS_INDEX_NAME ).create_index (
128
131
fields = fields , definition = definition
129
132
)
130
- return cls (client )
133
+ return cls (client , redisearch_schema )
131
134
132
135
@staticmethod
133
136
def _redis_key (document_id : str , chunk_id : str ) -> str :
@@ -217,20 +220,21 @@ def _typ_to_str(typ, field, value) -> str: # type: ignore
217
220
218
221
# Build filter
219
222
if query .filter :
223
+ redisearch_schema = self ._schema
220
224
for field , value in query .filter .__dict__ .items ():
221
225
if not value :
222
226
continue
223
- if field in REDIS_SEARCH_SCHEMA :
224
- filter_str += _typ_to_str (REDIS_SEARCH_SCHEMA [field ], field , value )
225
- elif field in REDIS_SEARCH_SCHEMA ["metadata" ]:
227
+ if field in redisearch_schema :
228
+ filter_str += _typ_to_str (redisearch_schema [field ], field , value )
229
+ elif field in redisearch_schema ["metadata" ]:
226
230
if field == "source" : # handle the enum
227
231
value = value .value
228
232
filter_str += _typ_to_str (
229
- REDIS_SEARCH_SCHEMA ["metadata" ][field ], field , value
233
+ redisearch_schema ["metadata" ][field ], field , value
230
234
)
231
235
elif field in ["start_date" , "end_date" ]:
232
236
filter_str += _typ_to_str (
233
- REDIS_SEARCH_SCHEMA ["metadata" ]["created_at" ], field , value
237
+ redisearch_schema ["metadata" ]["created_at" ], field , value
234
238
)
235
239
236
240
# Postprocess filter string
0 commit comments