Skip to content

Commit 4e88981

Browse files
pydantic facelift to v2
1 parent f41581a commit 4e88981

File tree

12 files changed

+73
-71
lines changed

12 files changed

+73
-71
lines changed

redisvl/extensions/llmcache/schema.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Dict, List, Optional
22

3-
from pydantic.v1 import BaseModel, Field, root_validator, validator
3+
from pydantic import BaseModel, Field, field_validator, model_validator
44

55
from redisvl.extensions.constants import (
66
CACHE_VECTOR_FIELD_NAME,
@@ -34,22 +34,23 @@ class CacheEntry(BaseModel):
3434
filters: Optional[Dict[str, Any]] = Field(default=None)
3535
"""Optional filter data stored on the cache entry for customizing retrieval"""
3636

37-
@root_validator(pre=True)
37+
@model_validator(mode="before")
3838
@classmethod
3939
def generate_id(cls, values):
4040
# Ensure entry_id is set
4141
if not values.get("entry_id"):
4242
values["entry_id"] = hashify(values["prompt"], values.get("filters"))
4343
return values
4444

45-
@validator("metadata")
45+
@field_validator("metadata")
46+
@classmethod
4647
def non_empty_metadata(cls, v):
4748
if v is not None and not isinstance(v, dict):
4849
raise TypeError("Metadata must be a dictionary.")
4950
return v
5051

5152
def to_dict(self, dtype: str) -> Dict:
52-
data = self.dict(exclude_none=True)
53+
data = self.model_dump(exclude_none=True)
5354
data["prompt_vector"] = array_to_buffer(self.prompt_vector, dtype)
5455
if self.metadata is not None:
5556
data["metadata"] = serialize(self.metadata)
@@ -79,18 +80,18 @@ class CacheHit(BaseModel):
7980
filters: Optional[Dict[str, Any]] = Field(default=None)
8081
"""Optional filter data stored on the cache entry for customizing retrieval"""
8182

82-
@root_validator(pre=True)
83+
@model_validator(mode="before")
8384
@classmethod
8485
def validate_cache_hit(cls, values):
8586
# Deserialize metadata if necessary
8687
if "metadata" in values and isinstance(values["metadata"], str):
8788
values["metadata"] = deserialize(values["metadata"])
8889

8990
# Separate filters from other fields
90-
known_fields = set(cls.__fields__.keys())
91+
known_fields = set(cls.model_fields.keys())
9192
filters = {k: v for k, v in values.items() if k not in known_fields}
9293

93-
# Add filters to values
94+
# Add filters to valuesgiy s
9495
if filters:
9596
values["filters"] = filters
9697

@@ -101,7 +102,7 @@ def validate_cache_hit(cls, values):
101102
return values
102103

103104
def to_dict(self) -> Dict:
104-
data = self.dict(exclude_none=True)
105+
data = self.model_dump(exclude_none=True)
105106
if self.filters:
106107
data.update(self.filters)
107108
del data["filters"]

redisvl/extensions/router/schema.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22
from typing import Dict, List, Optional
33

4-
from pydantic.v1 import BaseModel, Field, validator
4+
from pydantic import BaseModel, Field, field_validator
55

66
from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
77
from redisvl.schema import IndexSchema
@@ -19,21 +19,24 @@ class Route(BaseModel):
1919
distance_threshold: float = Field(default=0.5)
2020
"""Distance threshold for matching the route."""
2121

22-
@validator("name")
22+
@field_validator("name")
23+
@classmethod
2324
def name_must_not_be_empty(cls, v):
2425
if not v or not v.strip():
2526
raise ValueError("Route name must not be empty")
2627
return v
2728

28-
@validator("references")
29+
@field_validator("references")
30+
@classmethod
2931
def references_must_not_be_empty(cls, v):
3032
if not v:
3133
raise ValueError("References must not be empty")
3234
if any(not ref.strip() for ref in v):
3335
raise ValueError("All references must be non-empty strings")
3436
return v
3537

36-
@validator("distance_threshold")
38+
@field_validator("distance_threshold")
39+
@classmethod
3740
def distance_threshold_must_be_positive(cls, v):
3841
if v is not None and v <= 0:
3942
raise ValueError("Route distance threshold must be greater than zero")
@@ -79,7 +82,8 @@ class RoutingConfig(BaseModel):
7982
description="Global distance threshold is deprecated all distance_thresholds now apply at route level.",
8083
)
8184

82-
@validator("max_k")
85+
@field_validator("max_k")
86+
@classmethod
8387
def max_k_must_be_positive(cls, v):
8488
if v <= 0:
8589
raise ValueError("max_k must be a positive integer")

redisvl/extensions/router/semantic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import redis.commands.search.reducers as reducers
55
import yaml
6-
from pydantic.v1 import BaseModel, Field, PrivateAttr
6+
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
77
from redis import Redis
88
from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer
99
from redis.exceptions import ResponseError
@@ -44,8 +44,7 @@ class SemanticRouter(BaseModel):
4444

4545
_index: SearchIndex = PrivateAttr()
4646

47-
class Config:
48-
arbitrary_types_allowed = True
47+
model_config = ConfigDict(arbitrary_types_allowed=True)
4948

5049
@deprecated_argument("dtype", "vectorizer")
5150
def __init__(

redisvl/extensions/session_manager/schema.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Dict, List, Optional
22

3-
from pydantic.v1 import BaseModel, Field, root_validator
3+
from pydantic import BaseModel, ConfigDict, Field, model_validator
44

55
from redisvl.extensions.constants import (
66
CONTENT_FIELD_NAME,
@@ -33,11 +33,9 @@ class ChatMessage(BaseModel):
3333
"""An optional identifier for a tool call associated with the message."""
3434
vector_field: Optional[List[float]] = Field(default=None)
3535
"""The vector representation of the message content."""
36+
model_config = ConfigDict(arbitrary_types_allowed=True)
3637

37-
class Config:
38-
arbitrary_types_allowed = True
39-
40-
@root_validator(pre=True)
38+
@model_validator(mode="before")
4139
@classmethod
4240
def generate_id(cls, values):
4341
if TIMESTAMP_FIELD_NAME not in values:
@@ -49,7 +47,7 @@ def generate_id(cls, values):
4947
return values
5048

5149
def to_dict(self, dtype: Optional[str] = None) -> Dict:
52-
data = self.dict(exclude_none=True)
50+
data = self.model_dump(exclude_none=True)
5351

5452
# handle optional fields
5553
if SESSION_VECTOR_FIELD_NAME in data:

redisvl/index/storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
from typing import Any, Callable, Dict, Iterable, List, Optional
33

4-
from pydantic.v1 import BaseModel
4+
from pydantic import BaseModel
55
from redis import Redis
66
from redis.asyncio import Redis as AsyncRedis
77
from redis.commands.search.indexDefinition import IndexType

redisvl/schema/fields.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,16 @@
66
"""
77

88
from enum import Enum
9-
from typing import Any, Dict, Optional, Tuple, Type, Union
9+
from typing import Any, Dict, Literal, Optional, Tuple, Type, Union
1010

11-
from pydantic.v1 import BaseModel, Field, validator
11+
from pydantic import BaseModel, Field, field_validator
1212
from redis.commands.search.field import Field as RedisField
1313
from redis.commands.search.field import GeoField as RedisGeoField
1414
from redis.commands.search.field import NumericField as RedisNumericField
1515
from redis.commands.search.field import TagField as RedisTagField
1616
from redis.commands.search.field import TextField as RedisTextField
1717
from redis.commands.search.field import VectorField as RedisVectorField
1818

19-
### Attribute Enums ###
20-
2119

2220
class VectorDistanceMetric(str, Enum):
2321
COSINE = "COSINE"
@@ -97,7 +95,7 @@ class BaseVectorFieldAttributes(BaseModel):
9795
initial_cap: Optional[int] = None
9896
"""Initial vector capacity in the index affecting memory allocation size of the index"""
9997

100-
@validator("algorithm", "datatype", "distance_metric", pre=True)
98+
@field_validator("algorithm", "datatype", "distance_metric", mode="before")
10199
@classmethod
102100
def uppercase_strings(cls, v):
103101
"""Validate that provided values are cast to uppercase"""
@@ -119,9 +117,7 @@ def field_data(self) -> Dict[str, Any]:
119117
class FlatVectorFieldAttributes(BaseVectorFieldAttributes):
120118
"""FLAT vector field attributes"""
121119

122-
algorithm: VectorIndexAlgorithm = Field(
123-
default=VectorIndexAlgorithm.FLAT, const=True
124-
)
120+
algorithm: Literal[VectorIndexAlgorithm.FLAT] = VectorIndexAlgorithm.FLAT
125121
"""The indexing algorithm for the vector field"""
126122
block_size: Optional[int] = None
127123
"""Block size to hold amount of vectors in a contiguous array. This is useful when the index is dynamic with respect to addition and deletion"""
@@ -130,9 +126,7 @@ class FlatVectorFieldAttributes(BaseVectorFieldAttributes):
130126
class HNSWVectorFieldAttributes(BaseVectorFieldAttributes):
131127
"""HNSW vector field attributes"""
132128

133-
algorithm: VectorIndexAlgorithm = Field(
134-
default=VectorIndexAlgorithm.HNSW, const=True
135-
)
129+
algorithm: Literal[VectorIndexAlgorithm.HNSW] = VectorIndexAlgorithm.HNSW
136130
"""The indexing algorithm for the vector field"""
137131
m: int = Field(default=16)
138132
"""Number of max outgoing edges for each graph node in each layer"""
@@ -171,7 +165,7 @@ def as_redis_field(self) -> RedisField:
171165
class TextField(BaseField):
172166
"""Text field supporting a full text search index"""
173167

174-
type: str = Field(default="text", const=True)
168+
type: Literal["text"] = "text"
175169
attrs: TextFieldAttributes = Field(default_factory=TextFieldAttributes)
176170

177171
def as_redis_field(self) -> RedisField:
@@ -189,7 +183,7 @@ def as_redis_field(self) -> RedisField:
189183
class TagField(BaseField):
190184
"""Tag field for simple boolean-style filtering"""
191185

192-
type: str = Field(default="tag", const=True)
186+
type: Literal["tag"] = "tag"
193187
attrs: TagFieldAttributes = Field(default_factory=TagFieldAttributes)
194188

195189
def as_redis_field(self) -> RedisField:
@@ -206,7 +200,7 @@ def as_redis_field(self) -> RedisField:
206200
class NumericField(BaseField):
207201
"""Numeric field for numeric range filtering"""
208202

209-
type: str = Field(default="numeric", const=True)
203+
type: Literal["numeric"] = "numeric"
210204
attrs: NumericFieldAttributes = Field(default_factory=NumericFieldAttributes)
211205

212206
def as_redis_field(self) -> RedisField:
@@ -221,7 +215,7 @@ def as_redis_field(self) -> RedisField:
221215
class GeoField(BaseField):
222216
"""Geo field with a geo-spatial index for location based search"""
223217

224-
type: str = Field(default="geo", const=True)
218+
type: Literal["geo"] = "geo"
225219
attrs: GeoFieldAttributes = Field(default_factory=GeoFieldAttributes)
226220

227221
def as_redis_field(self) -> RedisField:
@@ -236,7 +230,7 @@ def as_redis_field(self) -> RedisField:
236230
class FlatVectorField(BaseField):
237231
"Vector field with a FLAT index (brute force nearest neighbors search)"
238232

239-
type: str = Field(default="vector", const=True)
233+
type: Literal["vector"] = "vector"
240234
attrs: FlatVectorFieldAttributes
241235

242236
def as_redis_field(self) -> RedisField:
@@ -251,7 +245,7 @@ def as_redis_field(self) -> RedisField:
251245
class HNSWVectorField(BaseField):
252246
"""Vector field with an HNSW index (approximate nearest neighbors search)"""
253247

254-
type: str = Field(default="vector", const=True)
248+
type: Literal["vector"] = "vector"
255249
attrs: HNSWVectorFieldAttributes
256250

257251
def as_redis_field(self) -> RedisField:
@@ -269,20 +263,21 @@ def as_redis_field(self) -> RedisField:
269263
return RedisVectorField(name, self.attrs.algorithm, field_data, as_name=as_name)
270264

271265

272-
class FieldFactory:
273-
"""Factory class to create fields from client data and kwargs."""
266+
FIELD_TYPE_MAP = {
267+
"tag": TagField,
268+
"text": TextField,
269+
"numeric": NumericField,
270+
"geo": GeoField,
271+
}
274272

275-
FIELD_TYPE_MAP = {
276-
"tag": TagField,
277-
"text": TextField,
278-
"numeric": NumericField,
279-
"geo": GeoField,
280-
}
273+
VECTOR_FIELD_TYPE_MAP = {
274+
"flat": FlatVectorField,
275+
"hnsw": HNSWVectorField,
276+
}
281277

282-
VECTOR_FIELD_TYPE_MAP = {
283-
"flat": FlatVectorField,
284-
"hnsw": HNSWVectorField,
285-
}
278+
279+
class FieldFactory:
280+
"""Factory class to create fields from client data and kwargs."""
286281

287282
@classmethod
288283
def pick_vector_field_type(cls, attrs: Dict[str, Any]) -> Type[BaseField]:
@@ -294,10 +289,10 @@ def pick_vector_field_type(cls, attrs: Dict[str, Any]) -> Type[BaseField]:
294289
raise ValueError("Must provide dims param for the vector field.")
295290

296291
algorithm = attrs["algorithm"].lower()
297-
if algorithm not in cls.VECTOR_FIELD_TYPE_MAP:
292+
if algorithm not in VECTOR_FIELD_TYPE_MAP:
298293
raise ValueError(f"Unknown vector field algorithm: {algorithm}")
299294

300-
return cls.VECTOR_FIELD_TYPE_MAP[algorithm] # type: ignore
295+
return VECTOR_FIELD_TYPE_MAP[algorithm] # type: ignore
301296

302297
@classmethod
303298
def create_field(
@@ -312,8 +307,14 @@ def create_field(
312307
if type == "vector":
313308
field_class = cls.pick_vector_field_type(attrs)
314309
else:
315-
if type not in cls.FIELD_TYPE_MAP:
310+
if type not in FIELD_TYPE_MAP:
316311
raise ValueError(f"Unknown field type: {type}")
317-
field_class = cls.FIELD_TYPE_MAP[type] # type: ignore
312+
field_class = FIELD_TYPE_MAP[type] # type: ignore
318313

319-
return field_class(name=name, path=path, attrs=attrs) # type: ignore
314+
return field_class.model_validate(
315+
{
316+
"name": name,
317+
"path": path,
318+
"attrs": attrs,
319+
}
320+
)

0 commit comments

Comments
 (0)