Skip to content

Commit a0e6fa5

Browse files
refactor vectorizers to work with new pydantic base model
1 parent 3ab7e19 commit a0e6fa5

File tree

15 files changed

+151
-106
lines changed

15 files changed

+151
-106
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __init__(
125125

126126
# Create semantic cache schema and index
127127
schema = SemanticCacheIndexSchema.from_params(
128-
name, prefix, vectorizer.dims, vectorizer.dtype
128+
name, prefix, vectorizer.dims, vectorizer.dtype # type: ignore
129129
)
130130
schema = self._modify_schema(schema, filterable_fields)
131131
self._index = SearchIndex(schema=schema)

redisvl/extensions/router.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from pydantic import root_validator
2+
3+
4+
class SemanticRouter(BaseModel):
5+
# existing fields...
6+
vectorizer: Optional[HFTextVectorizer] = None
7+
dtype: str = "float32"
8+
9+
@root_validator
10+
def check_vectorizer_dtype(cls, values):
11+
router_dtype = values.get("dtype")
12+
vectorizer = values.get("vectorizer")
13+
if vectorizer is not None and vectorizer.dtype != router_dtype:
14+
raise ValueError(
15+
f"Mismatched vectorizer dtype: {vectorizer.dtype} does not match router dtype: {router_dtype}"
16+
)
17+
return values

redisvl/extensions/router/semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _initialize_index(
108108
):
109109
"""Initialize the search index and handle Redis connection."""
110110
schema = SemanticRouterIndexSchema.from_params(
111-
self.name, self.vectorizer.dims, self.vectorizer.dtype
111+
self.name, self.vectorizer.dims, self.vectorizer.dtype # type: ignore
112112
)
113113
self._index = SearchIndex(schema=schema)
114114

redisvl/extensions/session_manager/semantic_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
self.set_distance_threshold(distance_threshold)
9595

9696
schema = SemanticSessionIndexSchema.from_params(
97-
name, prefix, self._vectorizer.dims, vectorizer.dtype
97+
name, prefix, vectorizer.dims, vectorizer.dtype # type: ignore
9898
)
9999

100100
self._index = SearchIndex(schema=schema)

redisvl/utils/vectorize/__init__.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,18 @@ def vectorizer_from_dict(vectorizer: dict) -> BaseVectorizer:
2727
vectorizer_type = Vectorizers(vectorizer["type"])
2828
model = vectorizer["model"]
2929
if vectorizer_type == Vectorizers.cohere:
30-
return CohereTextVectorizer(model)
30+
return CohereTextVectorizer(model=model)
3131
elif vectorizer_type == Vectorizers.openai:
32-
return OpenAITextVectorizer(model)
32+
return OpenAITextVectorizer(model=model)
3333
elif vectorizer_type == Vectorizers.azure_openai:
34-
return AzureOpenAITextVectorizer(model)
34+
return AzureOpenAITextVectorizer(model=model)
3535
elif vectorizer_type == Vectorizers.hf:
36-
return HFTextVectorizer(model)
36+
return HFTextVectorizer(model=model)
3737
elif vectorizer_type == Vectorizers.mistral:
38-
return MistralAITextVectorizer(model)
38+
return MistralAITextVectorizer(model=model)
3939
elif vectorizer_type == Vectorizers.vertexai:
40-
return VertexAITextVectorizer(model)
40+
return VertexAITextVectorizer(model=model)
4141
elif vectorizer_type == Vectorizers.voyageai:
42-
return VoyageAITextVectorizer(model)
42+
return VoyageAITextVectorizer(model=model)
43+
else:
44+
raise ValueError(f"Unsupported vectorizer type: {vectorizer_type}")

redisvl/utils/vectorize/base.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from enum import Enum
33
from typing import Callable, List, Optional
44

5-
from pydantic.v1 import BaseModel, Field, validator
5+
from pydantic import BaseModel, Field, field_validator
66

77
from redisvl.redis.utils import array_to_buffer
88
from redisvl.schema.fields import VectorDataType
@@ -19,16 +19,19 @@ class Vectorizers(Enum):
1919

2020

2121
class BaseVectorizer(BaseModel, ABC):
22+
"""Base vectorizer interface."""
23+
2224
model: str
23-
dims: int
24-
dtype: str = Field(default="float32")
25+
dtype: str = "float32"
26+
dims: Optional[int] = None
2527

2628
@property
2729
def type(self) -> str:
2830
return "base"
2931

30-
@validator("dtype")
31-
def check_dtype(dtype):
32+
@field_validator("dtype")
33+
@classmethod
34+
def check_dtype(cls, dtype):
3235
try:
3336
VectorDataType(dtype.upper())
3437
except ValueError:
@@ -37,7 +40,7 @@ def check_dtype(dtype):
3740
)
3841
return dtype
3942

40-
@validator("dims")
43+
@field_validator("dims")
4144
@classmethod
4245
def check_dims(cls, value):
4346
"""Ensures the dims are a positive integer."""

redisvl/utils/vectorize/text/azureopenai.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from typing import Any, Callable, Dict, List, Optional
33

4-
from pydantic.v1 import PrivateAttr
4+
from pydantic import PrivateAttr
55
from tenacity import retry, stop_after_attempt, wait_random_exponential
66
from tenacity.retry import retry_if_not_exception_type
77

@@ -57,6 +57,7 @@ def __init__(
5757
model: str = "text-embedding-ada-002",
5858
api_config: Optional[Dict] = None,
5959
dtype: str = "float32",
60+
**kwargs,
6061
):
6162
"""Initialize the AzureOpenAI vectorizer.
6263
@@ -76,10 +77,13 @@ def __init__(
7677
ValueError: If the AzureOpenAI API key, version, or endpoint are not provided.
7778
ValueError: If an invalid dtype is provided.
7879
"""
79-
self._initialize_clients(api_config)
80-
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)
80+
super().__init__(model=model, dtype=dtype)
81+
# Init client
82+
self._initialize_clients(api_config, **kwargs)
83+
# Set model dimensions
84+
self.dims = self._set_model_dims()
8185

82-
def _initialize_clients(self, api_config: Optional[Dict]):
86+
def _initialize_clients(self, api_config: Optional[Dict], **kwargs):
8387
"""
8488
Setup the OpenAI clients using the provided API key or an
8589
environment variable.
@@ -141,21 +145,19 @@ def _initialize_clients(self, api_config: Optional[Dict]):
141145
api_version=api_version,
142146
azure_endpoint=azure_endpoint,
143147
**api_config,
148+
**kwargs,
144149
)
145150
self._aclient = AsyncAzureOpenAI(
146151
api_key=api_key,
147152
api_version=api_version,
148153
azure_endpoint=azure_endpoint,
149154
**api_config,
155+
**kwargs,
150156
)
151157

152-
def _set_model_dims(self, model) -> int:
158+
def _set_model_dims(self) -> int:
153159
try:
154-
embedding = (
155-
self._client.embeddings.create(input=["dimension test"], model=model)
156-
.data[0]
157-
.embedding
158-
)
160+
embedding = self.embed("dimension check")
159161
except (KeyError, IndexError) as ke:
160162
raise ValueError(f"Unexpected response from the AzureOpenAI API: {str(ke)}")
161163
except Exception as e: # pylint: disable=broad-except

redisvl/utils/vectorize/text/bedrock.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from typing import Any, Callable, Dict, List, Optional
44

5-
from pydantic.v1 import PrivateAttr
5+
from pydantic import PrivateAttr
66
from tenacity import retry, stop_after_attempt, wait_random_exponential
77
from tenacity.retry import retry_if_not_exception_type
88

@@ -51,6 +51,7 @@ def __init__(
5151
model: str = "amazon.titan-embed-text-v2:0",
5252
api_config: Optional[Dict[str, str]] = None,
5353
dtype: str = "float32",
54+
**kwargs,
5455
) -> None:
5556
"""Initialize the AWS Bedrock Vectorizer.
5657
@@ -68,6 +69,17 @@ def __init__(
6869
ImportError: If boto3 is not installed.
6970
ValueError: If an invalid dtype is provided.
7071
"""
72+
super().__init__(model=model, dtype=dtype)
73+
# Init client
74+
self._initialize_client(api_config, **kwargs)
75+
# Set model dimensions after init
76+
self.dims = self._set_model_dims()
77+
78+
def _initialize_client(self, api_config: Optional[Dict], **kwargs):
79+
"""
80+
Setup the Bedrock client using the provided API keys or
81+
environment variables.
82+
"""
7183
try:
7284
import boto3 # type: ignore
7385
except ImportError:
@@ -98,21 +110,18 @@ def __init__(
98110
aws_access_key_id=aws_access_key_id,
99111
aws_secret_access_key=aws_secret_access_key,
100112
region_name=aws_region,
113+
**kwargs,
101114
)
102115

103-
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)
104-
105-
def _set_model_dims(self, model: str) -> int:
106-
"""Initialize model and determine embedding dimensions."""
116+
def _set_model_dims(self) -> int:
107117
try:
108-
response = self._client.invoke_model(
109-
modelId=model, body=json.dumps({"inputText": "dimension test"})
110-
)
111-
response_body = json.loads(response["body"].read())
112-
embedding = response_body["embedding"]
113-
return len(embedding)
114-
except Exception as e:
115-
raise ValueError(f"Error initializing Bedrock model: {str(e)}")
118+
embedding = self.embed("dimension check")
119+
except (KeyError, IndexError) as ke:
120+
raise ValueError(f"Unexpected response from the OpenAI API: {str(ke)}")
121+
except Exception as e: # pylint: disable=broad-except
122+
# fall back (TODO get more specific)
123+
raise ValueError(f"Error setting embedding model dimensions: {str(e)}")
124+
return len(embedding)
116125

117126
@retry(
118127
wait=wait_random_exponential(min=1, max=60),

redisvl/utils/vectorize/text/cohere.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from typing import Any, Callable, Dict, List, Optional
33

4-
from pydantic.v1 import PrivateAttr
4+
from pydantic import PrivateAttr
55
from tenacity import retry, stop_after_attempt, wait_random_exponential
66
from tenacity.retry import retry_if_not_exception_type
77

@@ -52,6 +52,7 @@ def __init__(
5252
model: str = "embed-english-v3.0",
5353
api_config: Optional[Dict] = None,
5454
dtype: str = "float32",
55+
**kwargs,
5556
):
5657
"""Initialize the Cohere vectorizer.
5758
@@ -70,24 +71,29 @@ def __init__(
7071
ValueError: If the API key is not provided.
7172
ValueError: If an invalid dtype is provided.
7273
"""
73-
self._initialize_client(api_config)
74-
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)
74+
super().__init__(model=model, dtype=dtype)
75+
# Init client
76+
self._initialize_client(api_config, **kwargs)
77+
# Set model dimensions after init
78+
self.dims = self._set_model_dims()
7579

76-
def _initialize_client(self, api_config: Optional[Dict]):
80+
def _initialize_client(self, api_config: Optional[Dict], **kwargs):
7781
"""
7882
Setup the Cohere clients using the provided API key or an
7983
environment variable.
8084
"""
85+
if api_config is None:
86+
api_config = {}
87+
8188
# Dynamic import of the cohere module
8289
try:
83-
from cohere import AsyncClient, Client
90+
from cohere import Client
8491
except ImportError:
8592
raise ImportError(
8693
"Cohere vectorizer requires the cohere library. \
8794
Please install with `pip install cohere`"
8895
)
8996

90-
# Fetch the API key from api_config or environment variable
9197
api_key = (
9298
api_config.get("api_key") if api_config else os.getenv("COHERE_API_KEY")
9399
)
@@ -96,15 +102,11 @@ def _initialize_client(self, api_config: Optional[Dict]):
96102
"Cohere API key is required. "
97103
"Provide it in api_config or set the COHERE_API_KEY environment variable."
98104
)
99-
self._client = Client(api_key=api_key, client_name="redisvl")
105+
self._client = Client(api_key=api_key, client_name="redisvl", **kwargs)
100106

101-
def _set_model_dims(self, model) -> int:
107+
def _set_model_dims(self) -> int:
102108
try:
103-
embedding = self._client.embed(
104-
texts=["dimension test"],
105-
model=model,
106-
input_type="search_document",
107-
).embeddings[0]
109+
embedding = self.embed("dimension check", input_type="search_document")
108110
except (KeyError, IndexError) as ke:
109111
raise ValueError(f"Unexpected response from the Cohere API: {str(ke)}")
110112
except Exception as e: # pylint: disable=broad-except

redisvl/utils/vectorize/text/custom.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Callable, List, Optional
22

3-
from pydantic.v1 import PrivateAttr
3+
from pydantic import PrivateAttr
44

55
from redisvl.utils.utils import deprecated_argument
66
from redisvl.utils.vectorize.base import BaseVectorizer
@@ -114,17 +114,16 @@ def __init__(
114114
Raises:
115115
ValueError: if embedding validation fails.
116116
"""
117+
super().__init__(model=self.type, dtype=dtype)
118+
117119
# Store user-provided callables
118120
self._embed = embed
119121
self._embed_many = embed_many
120122
self._aembed = aembed
121123
self._aembed_many = aembed_many
122124

123-
# Manually validate sync methods to discover dimension
124-
dims = self._validate_sync_callables()
125-
126-
# Initialize the base class now that we know the dimension
127-
super().__init__(model=self.type, dims=dims, dtype=dtype)
125+
# Set dims
126+
self.dims = self._validate_sync_callables()
128127

129128
@property
130129
def type(self) -> str:

0 commit comments

Comments
 (0)