11import os
22from typing import Any , Callable , Dict , List , Optional
33
4- from pydantic . v1 import PrivateAttr
4+ from pydantic import PrivateAttr
55from tenacity import retry , stop_after_attempt , wait_random_exponential
66from tenacity .retry import retry_if_not_exception_type
77
@@ -52,6 +52,7 @@ def __init__(
5252 model : str = "embed-english-v3.0" ,
5353 api_config : Optional [Dict ] = None ,
5454 dtype : str = "float32" ,
55+ ** kwargs ,
5556 ):
5657 """Initialize the Cohere vectorizer.
5758
@@ -70,24 +71,29 @@ def __init__(
7071 ValueError: If the API key is not provided.
7172 ValueError: If an invalid dtype is provided.
7273 """
73- self ._initialize_client (api_config )
74- super ().__init__ (model = model , dims = self ._set_model_dims (model ), dtype = dtype )
74+ super ().__init__ (model = model , dtype = dtype )
75+ # Init client
76+ self ._initialize_client (api_config , ** kwargs )
77+ # Set model dimensions after init
78+ self .dims = self ._set_model_dims ()
7579
76- def _initialize_client (self , api_config : Optional [Dict ]):
80+ def _initialize_client (self , api_config : Optional [Dict ], ** kwargs ):
7781 """
7882 Setup the Cohere clients using the provided API key or an
7983 environment variable.
8084 """
85+ if api_config is None :
86+ api_config = {}
87+
8188 # Dynamic import of the cohere module
8289 try :
83- from cohere import AsyncClient , Client
90+ from cohere import Client
8491 except ImportError :
8592 raise ImportError (
8693 "Cohere vectorizer requires the cohere library. \
8794 Please install with `pip install cohere`"
8895 )
8996
90- # Fetch the API key from api_config or environment variable
9197 api_key = (
9298 api_config .get ("api_key" ) if api_config else os .getenv ("COHERE_API_KEY" )
9399 )
@@ -96,15 +102,11 @@ def _initialize_client(self, api_config: Optional[Dict]):
96102 "Cohere API key is required. "
97103 "Provide it in api_config or set the COHERE_API_KEY environment variable."
98104 )
99- self ._client = Client (api_key = api_key , client_name = "redisvl" )
105+ self ._client = Client (api_key = api_key , client_name = "redisvl" , ** kwargs )
100106
101- def _set_model_dims (self , model ) -> int :
107+ def _set_model_dims (self ) -> int :
102108 try :
103- embedding = self ._client .embed (
104- texts = ["dimension test" ],
105- model = model ,
106- input_type = "search_document" ,
107- ).embeddings [0 ]
109+ embedding = self .embed ("dimension check" , input_type = "search_document" )
108110 except (KeyError , IndexError ) as ke :
109111 raise ValueError (f"Unexpected response from the Cohere API: { str (ke )} " )
110112 except Exception as e : # pylint: disable=broad-except
0 commit comments