11import os
22from typing import Any , Callable , Dict , List , Optional
33
4- from pydantic . v1 import PrivateAttr
4+ from pydantic import PrivateAttr
55from tenacity import retry , stop_after_attempt , wait_random_exponential
66from tenacity .retry import retry_if_not_exception_type
77
@@ -51,6 +51,7 @@ def __init__(
5151 model : str = "embed-english-v3.0" ,
5252 api_config : Optional [Dict ] = None ,
5353 dtype : str = "float32" ,
54+ ** kwargs ,
5455 ):
5556 """Initialize the Cohere vectorizer.
5657
@@ -69,24 +70,29 @@ def __init__(
6970 ValueError: If the API key is not provided.
7071 ValueError: If an invalid dtype is provided.
7172 """
72- self ._initialize_client (api_config )
73- super ().__init__ (model = model , dims = self ._set_model_dims (model ), dtype = dtype )
73+ super ().__init__ (model = model , dtype = dtype )
74+ # Init client
75+ self ._initialize_client (api_config , ** kwargs )
76+ # Set model dimensions after init
77+ self .dims = self ._set_model_dims ()
7478
75- def _initialize_client (self , api_config : Optional [Dict ]):
79+ def _initialize_client (self , api_config : Optional [Dict ], ** kwargs ):
7680 """
7781 Setup the Cohere clients using the provided API key or an
7882 environment variable.
7983 """
84+ if api_config is None :
85+ api_config = {}
86+
8087 # Dynamic import of the cohere module
8188 try :
82- from cohere import AsyncClient , Client
89+ from cohere import Client
8390 except ImportError :
8491 raise ImportError (
8592 "Cohere vectorizer requires the cohere library. \
8693 Please install with `pip install cohere`"
8794 )
8895
89- # Fetch the API key from api_config or environment variable
9096 api_key = (
9197 api_config .get ("api_key" ) if api_config else os .getenv ("COHERE_API_KEY" )
9298 )
@@ -95,15 +101,11 @@ def _initialize_client(self, api_config: Optional[Dict]):
95101 "Cohere API key is required. "
96102 "Provide it in api_config or set the COHERE_API_KEY environment variable."
97103 )
98- self ._client = Client (api_key = api_key , client_name = "redisvl" )
104+ self ._client = Client (api_key = api_key , client_name = "redisvl" , ** kwargs )
99105
100- def _set_model_dims (self , model ) -> int :
106+ def _set_model_dims (self ) -> int :
101107 try :
102- embedding = self ._client .embed (
103- texts = ["dimension test" ],
104- model = model ,
105- input_type = "search_document" ,
106- ).embeddings [0 ]
108+ embedding = self .embed ("dimension check" , input_type = "search_document" )
107109 except (KeyError , IndexError ) as ke :
108110 raise ValueError (f"Unexpected response from the Cohere API: { str (ke )} " )
109111 except Exception as e : # pylint: disable=broad-except
0 commit comments