Skip to content

Commit c925b04

Browse files
committed
feat: add partner chat provider names
1 parent de8125d commit c925b04

File tree

2 files changed

+69
-5
lines changed

2 files changed

+69
-5
lines changed

nemoguardrails/llm/providers/providers.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import importlib
2626
import logging
2727
import warnings
28-
from typing import Dict, List, Type
28+
from typing import Dict, List, Set, Type
2929

3030
from langchain.chat_models.base import BaseChatModel
3131
from langchain_community import llms
@@ -82,6 +82,16 @@ def _discover_langchain_community_llm_providers():
8282
return type_to_cls_dict
8383

8484

85+
# this is needed as we perform the mapping in langchain_initializer.py
86+
_CUSTOM_CHAT_PROVIDERS = {"nim"}
87+
88+
89+
def _discover_langchain_partner_chat_providers() -> Set[str]:
90+
from langchain.chat_models.base import _SUPPORTED_PROVIDERS
91+
92+
return _SUPPORTED_PROVIDERS | _CUSTOM_CHAT_PROVIDERS
93+
94+
8595
def _discover_langchain_community_chat_providers():
8696
"""Creates a mapping from provider name to chat model class.
8797
The provider name is defined as the last segment of the module path.
@@ -159,6 +169,17 @@ def get_community_chat_provider_names() -> List[str]:
159169
return list(sorted(list(_chat_providers.keys())))
160170

161171

172+
def _get_all_chat_provider_names() -> List[str]:
173+
"""Consolidates all chat provider names."""
174+
175+
return list(_chat_providers.keys() | _discover_langchain_partner_chat_providers())
176+
177+
178+
def get_chat_provider_names() -> List[str]:
179+
"""Returns the list of supported chat providers."""
180+
return list(sorted(_get_all_chat_provider_names()))
181+
182+
162183
def _get_text_completion_provider(provider_name: str) -> Type[BaseLLM]:
163184
if provider_name not in _llm_providers:
164185
raise RuntimeError(f"Could not find LLM provider '{provider_name}'")

tests/llm_providers/test_version_compatibility.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
_chat_providers,
2525
_discover_langchain_community_chat_providers,
2626
_discover_langchain_community_llm_providers,
27+
_discover_langchain_partner_chat_providers,
2728
_llm_providers,
28-
_parse_version,
29+
get_chat_provider_names,
2930
get_community_chat_provider_names,
3031
get_llm_provider_names,
3132
)
@@ -133,7 +134,7 @@
133134
"yi",
134135
"you",
135136
]
136-
_CHAT_PROVIDERS_NAMES = [
137+
_COMMUNITY_CHAT_PROVIDERS_NAMES = [
137138
"azure_openai",
138139
"bedrock",
139140
"anthropic",
@@ -195,6 +196,26 @@
195196
"llamacpp",
196197
"yi",
197198
]
199+
200+
_PARTNER_CHAT_PROVIDERS_NAMES = {
201+
"anthropic",
202+
"azure_openai",
203+
"bedrock",
204+
"bedrock_converse",
205+
"cohere",
206+
"deepseek",
207+
"fireworks",
208+
"google_anthropic_vertex",
209+
"google_genai",
210+
"google_vertexai",
211+
"groq",
212+
"huggingface",
213+
"mistralai",
214+
"nim",
215+
"ollama",
216+
"openai",
217+
"together",
218+
}
198219
# at some point we might care about certain providers
199220
CRITICAL_LLM_PROVIDERS = [
200221
"openai",
@@ -316,18 +337,40 @@ def test_provider_imports():
316337

317338

318339
def test_discover_langchain_community_chat_providers():
319-
"""Test that the function correctly discovers LangChain chat providers."""
340+
"""Test that the function correctly discovers LangChain community chat providers."""
341+
320342
providers = _discover_langchain_community_chat_providers()
321343
chat_provider_names = get_community_chat_provider_names()
322344
assert set(chat_provider_names) == set(
323345
providers.keys()
324346
), "it seems that we are registering a provider that is not in the LC community chat provider"
325-
assert _CHAT_PROVIDERS_NAMES == list(providers.keys()), (
347+
assert _COMMUNITY_CHAT_PROVIDERS_NAMES == list(providers.keys()), (
326348
"LangChain chat community providers may have changed. "
327349
"please investigate and update the test if necessary."
328350
)
329351

330352

353+
def test_dicsover_partner_chat_providers():
354+
"""Test that the function correctly discovers LangChain partner chat providers."""
355+
356+
partner_chat_providers = _discover_langchain_partner_chat_providers()
357+
assert _PARTNER_CHAT_PROVIDERS_NAMES.issubset(partner_chat_providers), (
358+
"LangChain partner chat providers may have changed. Update "
359+
"_PARTNER_CHAT_PROVIDERS_NAMES to include all expected providers."
360+
)
361+
chat_providers = get_chat_provider_names()
362+
363+
assert partner_chat_providers.issubset(
364+
chat_providers
365+
), "partner chat providers are not a subset of the list of chat providers"
366+
367+
if not partner_chat_providers == _PARTNER_CHAT_PROVIDERS_NAMES:
368+
warnings.warn(
369+
"LangChain partner chat providers may have changed. Update "
370+
"_PARTNER_CHAT_PROVIDERS_NAMES to include all expected providers."
371+
)
372+
373+
331374
def test_discover_langchain_community_llm_providers():
332375
providers = _discover_langchain_community_llm_providers()
333376
llm_provider_names = get_llm_provider_names()

0 commit comments

Comments
 (0)