66"""
77
88from enum import Enum
9- from typing import Any , Dict , Optional , Tuple , Type , Union
9+ from typing import Any , Dict , Literal , Optional , Tuple , Type , Union
1010
11- from pydantic . v1 import BaseModel , Field , validator
11+ from pydantic import BaseModel , Field , field_validator
1212from redis .commands .search .field import Field as RedisField
1313from redis .commands .search .field import GeoField as RedisGeoField
1414from redis .commands .search .field import NumericField as RedisNumericField
1515from redis .commands .search .field import TagField as RedisTagField
1616from redis .commands .search .field import TextField as RedisTextField
1717from redis .commands .search .field import VectorField as RedisVectorField
1818
19- ### Attribute Enums ###
20-
2119
2220class VectorDistanceMetric (str , Enum ):
2321 COSINE = "COSINE"
@@ -99,7 +97,7 @@ class BaseVectorFieldAttributes(BaseModel):
9997 initial_cap : Optional [int ] = None
10098 """Initial vector capacity in the index affecting memory allocation size of the index"""
10199
102- @validator ("algorithm" , "datatype" , "distance_metric" , pre = True )
100+ @field_validator ("algorithm" , "datatype" , "distance_metric" , mode = "before" )
103101 @classmethod
104102 def uppercase_strings (cls , v ):
105103 """Validate that provided values are cast to uppercase"""
@@ -121,9 +119,7 @@ def field_data(self) -> Dict[str, Any]:
121119class FlatVectorFieldAttributes (BaseVectorFieldAttributes ):
122120 """FLAT vector field attributes"""
123121
124- algorithm : VectorIndexAlgorithm = Field (
125- default = VectorIndexAlgorithm .FLAT , const = True
126- )
122+ algorithm : Literal [VectorIndexAlgorithm .FLAT ] = VectorIndexAlgorithm .FLAT
127123 """The indexing algorithm for the vector field"""
128124 block_size : Optional [int ] = None
129125 """Block size to hold amount of vectors in a contiguous array. This is useful when the index is dynamic with respect to addition and deletion"""
@@ -132,9 +128,7 @@ class FlatVectorFieldAttributes(BaseVectorFieldAttributes):
132128class HNSWVectorFieldAttributes (BaseVectorFieldAttributes ):
133129 """HNSW vector field attributes"""
134130
135- algorithm : VectorIndexAlgorithm = Field (
136- default = VectorIndexAlgorithm .HNSW , const = True
137- )
131+ algorithm : Literal [VectorIndexAlgorithm .HNSW ] = VectorIndexAlgorithm .HNSW
138132 """The indexing algorithm for the vector field"""
139133 m : int = Field (default = 16 )
140134 """Number of max outgoing edges for each graph node in each layer"""
@@ -173,7 +167,7 @@ def as_redis_field(self) -> RedisField:
173167class TextField (BaseField ):
174168 """Text field supporting a full text search index"""
175169
176- type : str = Field ( default = "text" , const = True )
170+ type : Literal [ "text" ] = "text"
177171 attrs : TextFieldAttributes = Field (default_factory = TextFieldAttributes )
178172
179173 def as_redis_field (self ) -> RedisField :
@@ -191,7 +185,7 @@ def as_redis_field(self) -> RedisField:
191185class TagField (BaseField ):
192186 """Tag field for simple boolean-style filtering"""
193187
194- type : str = Field ( default = "tag" , const = True )
188+ type : Literal [ "tag" ] = "tag"
195189 attrs : TagFieldAttributes = Field (default_factory = TagFieldAttributes )
196190
197191 def as_redis_field (self ) -> RedisField :
@@ -208,7 +202,7 @@ def as_redis_field(self) -> RedisField:
208202class NumericField (BaseField ):
209203 """Numeric field for numeric range filtering"""
210204
211- type : str = Field ( default = "numeric" , const = True )
205+ type : Literal [ "numeric" ] = "numeric"
212206 attrs : NumericFieldAttributes = Field (default_factory = NumericFieldAttributes )
213207
214208 def as_redis_field (self ) -> RedisField :
@@ -223,7 +217,7 @@ def as_redis_field(self) -> RedisField:
223217class GeoField (BaseField ):
224218 """Geo field with a geo-spatial index for location based search"""
225219
226- type : str = Field ( default = "geo" , const = True )
220+ type : Literal [ "geo" ] = "geo"
227221 attrs : GeoFieldAttributes = Field (default_factory = GeoFieldAttributes )
228222
229223 def as_redis_field (self ) -> RedisField :
@@ -238,7 +232,7 @@ def as_redis_field(self) -> RedisField:
238232class FlatVectorField (BaseField ):
239233 "Vector field with a FLAT index (brute force nearest neighbors search)"
240234
241- type : str = Field ( default = "vector" , const = True )
235+ type : Literal [ "vector" ] = "vector"
242236 attrs : FlatVectorFieldAttributes
243237
244238 def as_redis_field (self ) -> RedisField :
@@ -253,7 +247,7 @@ def as_redis_field(self) -> RedisField:
253247class HNSWVectorField (BaseField ):
254248 """Vector field with an HNSW index (approximate nearest neighbors search)"""
255249
256- type : str = Field ( default = "vector" , const = True )
250+ type : Literal [ "vector" ] = "vector"
257251 attrs : HNSWVectorFieldAttributes
258252
259253 def as_redis_field (self ) -> RedisField :
@@ -271,20 +265,21 @@ def as_redis_field(self) -> RedisField:
271265 return RedisVectorField (name , self .attrs .algorithm , field_data , as_name = as_name )
272266
273267
274- class FieldFactory :
275- """Factory class to create fields from client data and kwargs."""
268+ FIELD_TYPE_MAP = {
269+ "tag" : TagField ,
270+ "text" : TextField ,
271+ "numeric" : NumericField ,
272+ "geo" : GeoField ,
273+ }
276274
277- FIELD_TYPE_MAP = {
278- "tag" : TagField ,
279- "text" : TextField ,
280- "numeric" : NumericField ,
281- "geo" : GeoField ,
282- }
275+ VECTOR_FIELD_TYPE_MAP = {
276+ "flat" : FlatVectorField ,
277+ "hnsw" : HNSWVectorField ,
278+ }
283279
284- VECTOR_FIELD_TYPE_MAP = {
285- "flat" : FlatVectorField ,
286- "hnsw" : HNSWVectorField ,
287- }
280+
281+ class FieldFactory :
282+ """Factory class to create fields from client data and kwargs."""
288283
289284 @classmethod
290285 def pick_vector_field_type (cls , attrs : Dict [str , Any ]) -> Type [BaseField ]:
@@ -296,10 +291,10 @@ def pick_vector_field_type(cls, attrs: Dict[str, Any]) -> Type[BaseField]:
296291 raise ValueError ("Must provide dims param for the vector field." )
297292
298293 algorithm = attrs ["algorithm" ].lower ()
299- if algorithm not in cls . VECTOR_FIELD_TYPE_MAP :
294+ if algorithm not in VECTOR_FIELD_TYPE_MAP :
300295 raise ValueError (f"Unknown vector field algorithm: { algorithm } " )
301296
302- return cls . VECTOR_FIELD_TYPE_MAP [algorithm ] # type: ignore
297+ return VECTOR_FIELD_TYPE_MAP [algorithm ] # type: ignore
303298
304299 @classmethod
305300 def create_field (
@@ -314,8 +309,14 @@ def create_field(
314309 if type == "vector" :
315310 field_class = cls .pick_vector_field_type (attrs )
316311 else :
317- if type not in cls . FIELD_TYPE_MAP :
312+ if type not in FIELD_TYPE_MAP :
318313 raise ValueError (f"Unknown field type: { type } " )
319- field_class = cls . FIELD_TYPE_MAP [type ] # type: ignore
314+ field_class = FIELD_TYPE_MAP [type ] # type: ignore
320315
321- return field_class (name = name , path = path , attrs = attrs ) # type: ignore
316+ return field_class .model_validate (
317+ {
318+ "name" : name ,
319+ "path" : path ,
320+ "attrs" : attrs ,
321+ }
322+ )
0 commit comments