|
| 1 | +import os |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +from redisvl.extensions.cache.embeddings import EmbeddingsCache |
1 | 5 | from redisvl.utils.vectorize.base import BaseVectorizer, Vectorizers
|
2 | 6 | from redisvl.utils.vectorize.text.azureopenai import AzureOpenAITextVectorizer
|
3 | 7 | from redisvl.utils.vectorize.text.bedrock import BedrockTextVectorizer
|
|
23 | 27 | ]
|
24 | 28 |
|
25 | 29 |
|
26 |
| -def vectorizer_from_dict(vectorizer: dict) -> BaseVectorizer: |
| 30 | +def vectorizer_from_dict( |
| 31 | + vectorizer: dict, |
| 32 | + cache: dict = {}, |
| 33 | + cache_folder=os.getenv("SENTENCE_TRANSFORMERS_HOME"), |
| 34 | +) -> BaseVectorizer: |
27 | 35 | vectorizer_type = Vectorizers(vectorizer["type"])
|
28 | 36 | model = vectorizer["model"]
|
| 37 | + |
| 38 | + args = {"model": model} |
| 39 | + if cache: |
| 40 | + emb_cache = EmbeddingsCache(**cache) |
| 41 | + args["cache"] = emb_cache |
| 42 | + |
29 | 43 | if vectorizer_type == Vectorizers.cohere:
|
30 |
| - return CohereTextVectorizer(model=model) |
| 44 | + return CohereTextVectorizer(**args) |
31 | 45 | elif vectorizer_type == Vectorizers.openai:
|
32 |
| - return OpenAITextVectorizer(model=model) |
| 46 | + return OpenAITextVectorizer(**args) |
33 | 47 | elif vectorizer_type == Vectorizers.azure_openai:
|
34 |
| - return AzureOpenAITextVectorizer(model=model) |
| 48 | + return AzureOpenAITextVectorizer(**args) |
35 | 49 | elif vectorizer_type == Vectorizers.hf:
|
36 |
| - return HFTextVectorizer(model=model) |
| 50 | + return HFTextVectorizer(**args) |
37 | 51 | elif vectorizer_type == Vectorizers.mistral:
|
38 |
| - return MistralAITextVectorizer(model=model) |
| 52 | + return MistralAITextVectorizer(**args) |
39 | 53 | elif vectorizer_type == Vectorizers.vertexai:
|
40 |
| - return VertexAITextVectorizer(model=model) |
| 54 | + return VertexAITextVectorizer(**args) |
41 | 55 | elif vectorizer_type == Vectorizers.voyageai:
|
42 |
| - return VoyageAITextVectorizer(model=model) |
| 56 | + return VoyageAITextVectorizer(**args) |
43 | 57 | else:
|
44 | 58 | raise ValueError(f"Unsupported vectorizer type: {vectorizer_type}")
|
0 commit comments