|
24 | 24 | _chat_providers,
|
25 | 25 | _discover_langchain_community_chat_providers,
|
26 | 26 | _discover_langchain_community_llm_providers,
|
| 27 | + _discover_langchain_partner_chat_providers, |
27 | 28 | _llm_providers,
|
28 |
| - _parse_version, |
| 29 | + get_chat_provider_names, |
29 | 30 | get_community_chat_provider_names,
|
30 | 31 | get_llm_provider_names,
|
31 | 32 | )
|
|
133 | 134 | "yi",
|
134 | 135 | "you",
|
135 | 136 | ]
|
136 |
| -_CHAT_PROVIDERS_NAMES = [ |
| 137 | +_COMMUNITY_CHAT_PROVIDERS_NAMES = [ |
137 | 138 | "azure_openai",
|
138 | 139 | "bedrock",
|
139 | 140 | "anthropic",
|
|
195 | 196 | "llamacpp",
|
196 | 197 | "yi",
|
197 | 198 | ]
|
| 199 | + |
| 200 | +_PARTNER_CHAT_PROVIDERS_NAMES = { |
| 201 | + "anthropic", |
| 202 | + "azure_openai", |
| 203 | + "bedrock", |
| 204 | + "bedrock_converse", |
| 205 | + "cohere", |
| 206 | + "deepseek", |
| 207 | + "fireworks", |
| 208 | + "google_anthropic_vertex", |
| 209 | + "google_genai", |
| 210 | + "google_vertexai", |
| 211 | + "groq", |
| 212 | + "huggingface", |
| 213 | + "mistralai", |
| 214 | + "nim", |
| 215 | + "ollama", |
| 216 | + "openai", |
| 217 | + "together", |
| 218 | +} |
198 | 219 | # at some point we might care about certain providers
|
199 | 220 | CRITICAL_LLM_PROVIDERS = [
|
200 | 221 | "openai",
|
@@ -316,18 +337,40 @@ def test_provider_imports():
|
316 | 337 |
|
317 | 338 |
|
318 | 339 | def test_discover_langchain_community_chat_providers():
|
319 |
| - """Test that the function correctly discovers LangChain chat providers.""" |
| 340 | + """Test that the function correctly discovers LangChain community chat providers.""" |
| 341 | + |
320 | 342 | providers = _discover_langchain_community_chat_providers()
|
321 | 343 | chat_provider_names = get_community_chat_provider_names()
|
322 | 344 | assert set(chat_provider_names) == set(
|
323 | 345 | providers.keys()
|
324 | 346 | ), "it seems that we are registering a provider that is not in the LC community chat provider"
|
325 |
| - assert _CHAT_PROVIDERS_NAMES == list(providers.keys()), ( |
| 347 | + assert _COMMUNITY_CHAT_PROVIDERS_NAMES == list(providers.keys()), ( |
326 | 348 | "LangChain chat community providers may have changed. "
|
327 | 349 | "please investigate and update the test if necessary."
|
328 | 350 | )
|
329 | 351 |
|
330 | 352 |
|
| 353 | +def test_dicsover_partner_chat_providers(): |
| 354 | + """Test that the function correctly discovers LangChain partner chat providers.""" |
| 355 | + |
| 356 | + partner_chat_providers = _discover_langchain_partner_chat_providers() |
| 357 | + assert _PARTNER_CHAT_PROVIDERS_NAMES.issubset(partner_chat_providers), ( |
| 358 | + "LangChain partner chat providers may have changed. Update " |
| 359 | + "_PARTNER_CHAT_PROVIDERS_NAMES to include all expected providers." |
| 360 | + ) |
| 361 | + chat_providers = get_chat_provider_names() |
| 362 | + |
| 363 | + assert partner_chat_providers.issubset( |
| 364 | + chat_providers |
| 365 | + ), "partner chat providers are not a subset of the list of chat providers" |
| 366 | + |
| 367 | + if not partner_chat_providers == _PARTNER_CHAT_PROVIDERS_NAMES: |
| 368 | + warnings.warn( |
| 369 | + "LangChain partner chat providers may have changed. Update " |
| 370 | + "_PARTNER_CHAT_PROVIDERS_NAMES to include all expected providers." |
| 371 | + ) |
| 372 | + |
| 373 | + |
331 | 374 | def test_discover_langchain_community_llm_providers():
|
332 | 375 | providers = _discover_langchain_community_llm_providers()
|
333 | 376 | llm_provider_names = get_llm_provider_names()
|
|
0 commit comments