Skip to content

Commit 2721e43

Browse files
refactor vectorizers to work with new pydantic base model
1 parent 86b145c commit 2721e43

File tree

15 files changed

+151
-106
lines changed

15 files changed

+151
-106
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __init__(
125125

126126
# Create semantic cache schema and index
127127
schema = SemanticCacheIndexSchema.from_params(
128-
name, prefix, vectorizer.dims, vectorizer.dtype
128+
name, prefix, vectorizer.dims, vectorizer.dtype # type: ignore
129129
)
130130
schema = self._modify_schema(schema, filterable_fields)
131131
self._index = SearchIndex(schema=schema)

redisvl/extensions/router.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from pydantic import root_validator
2+
3+
4+
class SemanticRouter(BaseModel):
5+
# existing fields...
6+
vectorizer: Optional[HFTextVectorizer] = None
7+
dtype: str = "float32"
8+
9+
@root_validator
10+
def check_vectorizer_dtype(cls, values):
11+
router_dtype = values.get("dtype")
12+
vectorizer = values.get("vectorizer")
13+
if vectorizer is not None and vectorizer.dtype != router_dtype:
14+
raise ValueError(
15+
f"Mismatched vectorizer dtype: {vectorizer.dtype} does not match router dtype: {router_dtype}"
16+
)
17+
return values

redisvl/extensions/router/semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _initialize_index(
108108
):
109109
"""Initialize the search index and handle Redis connection."""
110110
schema = SemanticRouterIndexSchema.from_params(
111-
self.name, self.vectorizer.dims, self.vectorizer.dtype
111+
self.name, self.vectorizer.dims, self.vectorizer.dtype # type: ignore
112112
)
113113
self._index = SearchIndex(schema=schema)
114114

redisvl/extensions/session_manager/semantic_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
self.set_distance_threshold(distance_threshold)
9595

9696
schema = SemanticSessionIndexSchema.from_params(
97-
name, prefix, self._vectorizer.dims, vectorizer.dtype
97+
name, prefix, vectorizer.dims, vectorizer.dtype # type: ignore
9898
)
9999

100100
self._index = SearchIndex(schema=schema)

redisvl/utils/vectorize/__init__.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,18 @@ def vectorizer_from_dict(vectorizer: dict) -> BaseVectorizer:
2727
vectorizer_type = Vectorizers(vectorizer["type"])
2828
model = vectorizer["model"]
2929
if vectorizer_type == Vectorizers.cohere:
30-
return CohereTextVectorizer(model)
30+
return CohereTextVectorizer(model=model)
3131
elif vectorizer_type == Vectorizers.openai:
32-
return OpenAITextVectorizer(model)
32+
return OpenAITextVectorizer(model=model)
3333
elif vectorizer_type == Vectorizers.azure_openai:
34-
return AzureOpenAITextVectorizer(model)
34+
return AzureOpenAITextVectorizer(model=model)
3535
elif vectorizer_type == Vectorizers.hf:
36-
return HFTextVectorizer(model)
36+
return HFTextVectorizer(model=model)
3737
elif vectorizer_type == Vectorizers.mistral:
38-
return MistralAITextVectorizer(model)
38+
return MistralAITextVectorizer(model=model)
3939
elif vectorizer_type == Vectorizers.vertexai:
40-
return VertexAITextVectorizer(model)
40+
return VertexAITextVectorizer(model=model)
4141
elif vectorizer_type == Vectorizers.voyageai:
42-
return VoyageAITextVectorizer(model)
42+
return VoyageAITextVectorizer(model=model)
43+
else:
44+
raise ValueError(f"Unsupported vectorizer type: {vectorizer_type}")

redisvl/utils/vectorize/base.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from enum import Enum
33
from typing import Callable, List, Optional
44

5-
from pydantic.v1 import BaseModel, Field, validator
5+
from pydantic import BaseModel, Field, field_validator
66

77
from redisvl.redis.utils import array_to_buffer
88
from redisvl.schema.fields import VectorDataType
@@ -19,16 +19,19 @@ class Vectorizers(Enum):
1919

2020

2121
class BaseVectorizer(BaseModel, ABC):
22+
"""Base vectorizer interface."""
23+
2224
model: str
23-
dims: int
24-
dtype: str = Field(default="float32")
25+
dtype: str = "float32"
26+
dims: Optional[int] = None
2527

2628
@property
2729
def type(self) -> str:
2830
return "base"
2931

30-
@validator("dtype")
31-
def check_dtype(dtype):
32+
@field_validator("dtype")
33+
@classmethod
34+
def check_dtype(cls, dtype):
3235
try:
3336
VectorDataType(dtype.upper())
3437
except ValueError:
@@ -37,7 +40,7 @@ def check_dtype(dtype):
3740
)
3841
return dtype
3942

40-
@validator("dims")
43+
@field_validator("dims")
4144
@classmethod
4245
def check_dims(cls, value):
4346
"""Ensures the dims are a positive integer."""

redisvl/utils/vectorize/text/azureopenai.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from typing import Any, Callable, Dict, List, Optional
33

4-
from pydantic.v1 import PrivateAttr
4+
from pydantic import PrivateAttr
55
from tenacity import retry, stop_after_attempt, wait_random_exponential
66
from tenacity.retry import retry_if_not_exception_type
77

@@ -56,6 +56,7 @@ def __init__(
5656
model: str = "text-embedding-ada-002",
5757
api_config: Optional[Dict] = None,
5858
dtype: str = "float32",
59+
**kwargs,
5960
):
6061
"""Initialize the AzureOpenAI vectorizer.
6162
@@ -75,10 +76,13 @@ def __init__(
7576
ValueError: If the AzureOpenAI API key, version, or endpoint are not provided.
7677
ValueError: If an invalid dtype is provided.
7778
"""
78-
self._initialize_clients(api_config)
79-
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)
79+
super().__init__(model=model, dtype=dtype)
80+
# Init client
81+
self._initialize_clients(api_config, **kwargs)
82+
# Set model dimensions
83+
self.dims = self._set_model_dims()
8084

81-
def _initialize_clients(self, api_config: Optional[Dict]):
85+
def _initialize_clients(self, api_config: Optional[Dict], **kwargs):
8286
"""
8387
Setup the OpenAI clients using the provided API key or an
8488
environment variable.
@@ -140,21 +144,19 @@ def _initialize_clients(self, api_config: Optional[Dict]):
140144
api_version=api_version,
141145
azure_endpoint=azure_endpoint,
142146
**api_config,
147+
**kwargs,
143148
)
144149
self._aclient = AsyncAzureOpenAI(
145150
api_key=api_key,
146151
api_version=api_version,
147152
azure_endpoint=azure_endpoint,
148153
**api_config,
154+
**kwargs,
149155
)
150156

151-
def _set_model_dims(self, model) -> int:
157+
def _set_model_dims(self) -> int:
152158
try:
153-
embedding = (
154-
self._client.embeddings.create(input=["dimension test"], model=model)
155-
.data[0]
156-
.embedding
157-
)
159+
embedding = self.embed("dimension check")
158160
except (KeyError, IndexError) as ke:
159161
raise ValueError(f"Unexpected response from the AzureOpenAI API: {str(ke)}")
160162
except Exception as e: # pylint: disable=broad-except

redisvl/utils/vectorize/text/bedrock.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from typing import Any, Callable, Dict, List, Optional
44

5-
from pydantic.v1 import PrivateAttr
5+
from pydantic import PrivateAttr
66
from tenacity import retry, stop_after_attempt, wait_random_exponential
77
from tenacity.retry import retry_if_not_exception_type
88

@@ -50,6 +50,7 @@ def __init__(
5050
model: str = "amazon.titan-embed-text-v2:0",
5151
api_config: Optional[Dict[str, str]] = None,
5252
dtype: str = "float32",
53+
**kwargs,
5354
) -> None:
5455
"""Initialize the AWS Bedrock Vectorizer.
5556
@@ -67,6 +68,17 @@ def __init__(
6768
ImportError: If boto3 is not installed.
6869
ValueError: If an invalid dtype is provided.
6970
"""
71+
super().__init__(model=model, dtype=dtype)
72+
# Init client
73+
self._initialize_client(api_config, **kwargs)
74+
# Set model dimensions after init
75+
self.dims = self._set_model_dims()
76+
77+
def _initialize_client(self, api_config: Optional[Dict], **kwargs):
78+
"""
79+
Setup the Bedrock client using the provided API keys or
80+
environment variables.
81+
"""
7082
try:
7183
import boto3 # type: ignore
7284
except ImportError:
@@ -97,21 +109,18 @@ def __init__(
97109
aws_access_key_id=aws_access_key_id,
98110
aws_secret_access_key=aws_secret_access_key,
99111
region_name=aws_region,
112+
**kwargs,
100113
)
101114

102-
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)
103-
104-
def _set_model_dims(self, model: str) -> int:
105-
"""Initialize model and determine embedding dimensions."""
115+
def _set_model_dims(self) -> int:
106116
try:
107-
response = self._client.invoke_model(
108-
modelId=model, body=json.dumps({"inputText": "dimension test"})
109-
)
110-
response_body = json.loads(response["body"].read())
111-
embedding = response_body["embedding"]
112-
return len(embedding)
113-
except Exception as e:
114-
raise ValueError(f"Error initializing Bedrock model: {str(e)}")
117+
embedding = self.embed("dimension check")
118+
except (KeyError, IndexError) as ke:
119+
raise ValueError(f"Unexpected response from the OpenAI API: {str(ke)}")
120+
except Exception as e: # pylint: disable=broad-except
121+
# fall back (TODO get more specific)
122+
raise ValueError(f"Error setting embedding model dimensions: {str(e)}")
123+
return len(embedding)
115124

116125
@retry(
117126
wait=wait_random_exponential(min=1, max=60),

redisvl/utils/vectorize/text/cohere.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from typing import Any, Callable, Dict, List, Optional
33

4-
from pydantic.v1 import PrivateAttr
4+
from pydantic import PrivateAttr
55
from tenacity import retry, stop_after_attempt, wait_random_exponential
66
from tenacity.retry import retry_if_not_exception_type
77

@@ -51,6 +51,7 @@ def __init__(
5151
model: str = "embed-english-v3.0",
5252
api_config: Optional[Dict] = None,
5353
dtype: str = "float32",
54+
**kwargs,
5455
):
5556
"""Initialize the Cohere vectorizer.
5657
@@ -69,24 +70,29 @@ def __init__(
6970
ValueError: If the API key is not provided.
7071
ValueError: If an invalid dtype is provided.
7172
"""
72-
self._initialize_client(api_config)
73-
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)
73+
super().__init__(model=model, dtype=dtype)
74+
# Init client
75+
self._initialize_client(api_config, **kwargs)
76+
# Set model dimensions after init
77+
self.dims = self._set_model_dims()
7478

75-
def _initialize_client(self, api_config: Optional[Dict]):
79+
def _initialize_client(self, api_config: Optional[Dict], **kwargs):
7680
"""
7781
Setup the Cohere clients using the provided API key or an
7882
environment variable.
7983
"""
84+
if api_config is None:
85+
api_config = {}
86+
8087
# Dynamic import of the cohere module
8188
try:
82-
from cohere import AsyncClient, Client
89+
from cohere import Client
8390
except ImportError:
8491
raise ImportError(
8592
"Cohere vectorizer requires the cohere library. \
8693
Please install with `pip install cohere`"
8794
)
8895

89-
# Fetch the API key from api_config or environment variable
9096
api_key = (
9197
api_config.get("api_key") if api_config else os.getenv("COHERE_API_KEY")
9298
)
@@ -95,15 +101,11 @@ def _initialize_client(self, api_config: Optional[Dict]):
95101
"Cohere API key is required. "
96102
"Provide it in api_config or set the COHERE_API_KEY environment variable."
97103
)
98-
self._client = Client(api_key=api_key, client_name="redisvl")
104+
self._client = Client(api_key=api_key, client_name="redisvl", **kwargs)
99105

100-
def _set_model_dims(self, model) -> int:
106+
def _set_model_dims(self) -> int:
101107
try:
102-
embedding = self._client.embed(
103-
texts=["dimension test"],
104-
model=model,
105-
input_type="search_document",
106-
).embeddings[0]
108+
embedding = self.embed("dimension check", input_type="search_document")
107109
except (KeyError, IndexError) as ke:
108110
raise ValueError(f"Unexpected response from the Cohere API: {str(ke)}")
109111
except Exception as e: # pylint: disable=broad-except

redisvl/utils/vectorize/text/custom.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Callable, List, Optional
22

3-
from pydantic.v1 import PrivateAttr
3+
from pydantic import PrivateAttr
44

55
from redisvl.utils.vectorize.base import BaseVectorizer
66

@@ -113,17 +113,16 @@ def __init__(
113113
Raises:
114114
ValueError: if embedding validation fails.
115115
"""
116+
super().__init__(model=self.type, dtype=dtype)
117+
116118
# Store user-provided callables
117119
self._embed = embed
118120
self._embed_many = embed_many
119121
self._aembed = aembed
120122
self._aembed_many = aembed_many
121123

122-
# Manually validate sync methods to discover dimension
123-
dims = self._validate_sync_callables()
124-
125-
# Initialize the base class now that we know the dimension
126-
super().__init__(model=self.type, dims=dims, dtype=dtype)
124+
# Set dims
125+
self.dims = self._validate_sync_callables()
127126

128127
@property
129128
def type(self) -> str:

0 commit comments

Comments
 (0)