66"""
77
88from enum import Enum
9- from typing import Any , Dict , Optional , Tuple , Type , Union
9+ from typing import Any , Dict , Literal , Optional , Tuple , Type , Union
1010
11- from pydantic . v1 import BaseModel , Field , validator
11+ from pydantic import BaseModel , Field , field_validator
1212from redis .commands .search .field import Field as RedisField
1313from redis .commands .search .field import GeoField as RedisGeoField
1414from redis .commands .search .field import NumericField as RedisNumericField
1515from redis .commands .search .field import TagField as RedisTagField
1616from redis .commands .search .field import TextField as RedisTextField
1717from redis .commands .search .field import VectorField as RedisVectorField
1818
19- ### Attribute Enums ###
20-
2119
2220class VectorDistanceMetric (str , Enum ):
2321 COSINE = "COSINE"
@@ -97,7 +95,7 @@ class BaseVectorFieldAttributes(BaseModel):
9795 initial_cap : Optional [int ] = None
9896 """Initial vector capacity in the index affecting memory allocation size of the index"""
9997
100- @validator ("algorithm" , "datatype" , "distance_metric" , pre = True )
98+ @field_validator ("algorithm" , "datatype" , "distance_metric" , mode = "before" )
10199 @classmethod
102100 def uppercase_strings (cls , v ):
103101 """Validate that provided values are cast to uppercase"""
@@ -119,9 +117,7 @@ def field_data(self) -> Dict[str, Any]:
119117class FlatVectorFieldAttributes (BaseVectorFieldAttributes ):
120118 """FLAT vector field attributes"""
121119
122- algorithm : VectorIndexAlgorithm = Field (
123- default = VectorIndexAlgorithm .FLAT , const = True
124- )
120+ algorithm : Literal [VectorIndexAlgorithm .FLAT ] = VectorIndexAlgorithm .FLAT
125121 """The indexing algorithm for the vector field"""
126122 block_size : Optional [int ] = None
127123 """Block size to hold amount of vectors in a contiguous array. This is useful when the index is dynamic with respect to addition and deletion"""
@@ -130,9 +126,7 @@ class FlatVectorFieldAttributes(BaseVectorFieldAttributes):
130126class HNSWVectorFieldAttributes (BaseVectorFieldAttributes ):
131127 """HNSW vector field attributes"""
132128
133- algorithm : VectorIndexAlgorithm = Field (
134- default = VectorIndexAlgorithm .HNSW , const = True
135- )
129+ algorithm : Literal [VectorIndexAlgorithm .HNSW ] = VectorIndexAlgorithm .HNSW
136130 """The indexing algorithm for the vector field"""
137131 m : int = Field (default = 16 )
138132 """Number of max outgoing edges for each graph node in each layer"""
@@ -171,7 +165,7 @@ def as_redis_field(self) -> RedisField:
171165class TextField (BaseField ):
172166 """Text field supporting a full text search index"""
173167
174- type : str = Field ( default = "text" , const = True )
168+ type : Literal [ "text" ] = "text"
175169 attrs : TextFieldAttributes = Field (default_factory = TextFieldAttributes )
176170
177171 def as_redis_field (self ) -> RedisField :
@@ -189,7 +183,7 @@ def as_redis_field(self) -> RedisField:
189183class TagField (BaseField ):
190184 """Tag field for simple boolean-style filtering"""
191185
192- type : str = Field ( default = "tag" , const = True )
186+ type : Literal [ "tag" ] = "tag"
193187 attrs : TagFieldAttributes = Field (default_factory = TagFieldAttributes )
194188
195189 def as_redis_field (self ) -> RedisField :
@@ -206,7 +200,7 @@ def as_redis_field(self) -> RedisField:
206200class NumericField (BaseField ):
207201 """Numeric field for numeric range filtering"""
208202
209- type : str = Field ( default = "numeric" , const = True )
203+ type : Literal [ "numeric" ] = "numeric"
210204 attrs : NumericFieldAttributes = Field (default_factory = NumericFieldAttributes )
211205
212206 def as_redis_field (self ) -> RedisField :
@@ -221,7 +215,7 @@ def as_redis_field(self) -> RedisField:
221215class GeoField (BaseField ):
222216 """Geo field with a geo-spatial index for location based search"""
223217
224- type : str = Field ( default = "geo" , const = True )
218+ type : Literal [ "geo" ] = "geo"
225219 attrs : GeoFieldAttributes = Field (default_factory = GeoFieldAttributes )
226220
227221 def as_redis_field (self ) -> RedisField :
@@ -236,7 +230,7 @@ def as_redis_field(self) -> RedisField:
236230class FlatVectorField (BaseField ):
237231 "Vector field with a FLAT index (brute force nearest neighbors search)"
238232
239- type : str = Field ( default = "vector" , const = True )
233+ type : Literal [ "vector" ] = "vector"
240234 attrs : FlatVectorFieldAttributes
241235
242236 def as_redis_field (self ) -> RedisField :
@@ -251,7 +245,7 @@ def as_redis_field(self) -> RedisField:
251245class HNSWVectorField (BaseField ):
252246 """Vector field with an HNSW index (approximate nearest neighbors search)"""
253247
254- type : str = Field ( default = "vector" , const = True )
248+ type : Literal [ "vector" ] = "vector"
255249 attrs : HNSWVectorFieldAttributes
256250
257251 def as_redis_field (self ) -> RedisField :
@@ -269,20 +263,21 @@ def as_redis_field(self) -> RedisField:
269263 return RedisVectorField (name , self .attrs .algorithm , field_data , as_name = as_name )
270264
271265
272- class FieldFactory :
273- """Factory class to create fields from client data and kwargs."""
266+ FIELD_TYPE_MAP = {
267+ "tag" : TagField ,
268+ "text" : TextField ,
269+ "numeric" : NumericField ,
270+ "geo" : GeoField ,
271+ }
274272
275- FIELD_TYPE_MAP = {
276- "tag" : TagField ,
277- "text" : TextField ,
278- "numeric" : NumericField ,
279- "geo" : GeoField ,
280- }
273+ VECTOR_FIELD_TYPE_MAP = {
274+ "flat" : FlatVectorField ,
275+ "hnsw" : HNSWVectorField ,
276+ }
281277
282- VECTOR_FIELD_TYPE_MAP = {
283- "flat" : FlatVectorField ,
284- "hnsw" : HNSWVectorField ,
285- }
278+
279+ class FieldFactory :
280+ """Factory class to create fields from client data and kwargs."""
286281
287282 @classmethod
288283 def pick_vector_field_type (cls , attrs : Dict [str , Any ]) -> Type [BaseField ]:
@@ -294,10 +289,10 @@ def pick_vector_field_type(cls, attrs: Dict[str, Any]) -> Type[BaseField]:
294289 raise ValueError ("Must provide dims param for the vector field." )
295290
296291 algorithm = attrs ["algorithm" ].lower ()
297- if algorithm not in cls . VECTOR_FIELD_TYPE_MAP :
292+ if algorithm not in VECTOR_FIELD_TYPE_MAP :
298293 raise ValueError (f"Unknown vector field algorithm: { algorithm } " )
299294
300- return cls . VECTOR_FIELD_TYPE_MAP [algorithm ] # type: ignore
295+ return VECTOR_FIELD_TYPE_MAP [algorithm ] # type: ignore
301296
302297 @classmethod
303298 def create_field (
@@ -312,8 +307,14 @@ def create_field(
312307 if type == "vector" :
313308 field_class = cls .pick_vector_field_type (attrs )
314309 else :
315- if type not in cls . FIELD_TYPE_MAP :
310+ if type not in FIELD_TYPE_MAP :
316311 raise ValueError (f"Unknown field type: { type } " )
317- field_class = cls . FIELD_TYPE_MAP [type ] # type: ignore
312+ field_class = FIELD_TYPE_MAP [type ] # type: ignore
318313
319- return field_class (name = name , path = path , attrs = attrs ) # type: ignore
314+ return field_class .model_validate (
315+ {
316+ "name" : name ,
317+ "path" : path ,
318+ "attrs" : attrs ,
319+ }
320+ )
0 commit comments