Skip to content

Commit 734bba1

Browse files
authored
Merge pull request #1558 from weaviate/202502/nvidia-patches
Add `nvidia-reranker` factory function
2 parents 84571eb + d570dc9 commit 734bba1

File tree

2 files changed

+80
-3
lines changed

2 files changed

+80
-3
lines changed

test/collection/test_config.py

+42
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,48 @@ def test_config_with_generative(
10631063
"reranker-cohere": {},
10641064
},
10651065
),
1066+
(
1067+
Configure.Reranker.voyageai(),
1068+
{
1069+
"reranker-voyageai": {},
1070+
},
1071+
),
1072+
(
1073+
Configure.Reranker.voyageai(model="rerank-lite-1"),
1074+
{
1075+
"reranker-voyageai": {"model": "rerank-lite-1"},
1076+
},
1077+
),
1078+
(
1079+
Configure.Reranker.jinaai(),
1080+
{
1081+
"reranker-jinaai": {},
1082+
},
1083+
),
1084+
(
1085+
Configure.Reranker.jinaai(model="jina-reranker-v2-base-multilingual"),
1086+
{
1087+
"reranker-jinaai": {"model": "jina-reranker-v2-base-multilingual"},
1088+
},
1089+
),
1090+
(
1091+
Configure.Reranker.nvidia(),
1092+
{
1093+
"reranker-nvidia": {},
1094+
},
1095+
),
1096+
(
1097+
Configure.Reranker.nvidia(
1098+
model="nvidia/llama-3.2-nv-rerankqa-1b-v2",
1099+
base_url="https://integrate.api.nvidia.com/v1",
1100+
),
1101+
{
1102+
"reranker-nvidia": {
1103+
"model": "nvidia/llama-3.2-nv-rerankqa-1b-v2",
1104+
"baseURL": "https://integrate.api.nvidia.com/v1",
1105+
},
1106+
},
1107+
),
10661108
(
10671109
Configure.Reranker.transformers(),
10681110
{

weaviate/collections/classes/config.py

+38-3
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,16 @@ class Rerankers(str, BaseEnum):
215215
Weaviate module backed by VoyageAI reranking models.
216216
`JINAAI`
217217
Weaviate module backed by JinaAI reranking models.
218+
`NVIDIA`
219+
Weaviate module backed by NVIDIA reranking models.
218220
"""
219221

220222
NONE = "none"
221223
COHERE = "reranker-cohere"
222224
TRANSFORMERS = "reranker-transformers"
223225
VOYAGEAI = "reranker-voyageai"
224226
JINAAI = "reranker-jinaai"
227+
NVIDIA = "reranker-nvidia"
225228

226229

227230
class StopwordsPreset(str, BaseEnum):
@@ -650,6 +653,20 @@ class _RerankerVoyageAIConfig(_RerankerProvider):
650653
model: Optional[Union[RerankerVoyageAIModel, str]] = Field(default=None)
651654

652655

656+
class _RerankerNvidiaConfig(_RerankerProvider):
657+
reranker: Union[Rerankers, _EnumLikeStr] = Field(
658+
default=Rerankers.NVIDIA, frozen=True, exclude=True
659+
)
660+
model: Optional[str] = Field(default=None)
661+
baseURL: Optional[AnyHttpUrl]
662+
663+
def _to_dict(self) -> Dict[str, Any]:
664+
ret_dict = super()._to_dict()
665+
if self.baseURL is not None:
666+
ret_dict["baseURL"] = self.baseURL.unicode_string()
667+
return ret_dict
668+
669+
653670
class _Generative:
654671
"""Use this factory class to create the correct object for the `generative_config` argument in the `collections.create()` method.
655672
@@ -1120,7 +1137,7 @@ def cohere(
11201137
) -> _RerankerProvider:
11211138
"""Create a `_RerankerCohereConfig` object for use when reranking using the `reranker-cohere` module.
11221139
1123-
See the [documentation](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/reranker-cohere)
1140+
See the [documentation](https://weaviate.io/developers/weaviate/model-providers/cohere/reranker)
11241141
for detailed usage.
11251142
11261143
Arguments:
@@ -1135,7 +1152,7 @@ def jinaai(
11351152
) -> _RerankerProvider:
11361153
"""Create a `_RerankerJinaAIConfig` object for use when reranking using the `reranker-jinaai` module.
11371154
1138-
See the [documentation](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/reranker-jinaai)
1155+
See the [documentation](https://weaviate.io/developers/weaviate/model-providers/jinaai/reranker)
11391156
for detailed usage.
11401157
11411158
Arguments:
@@ -1150,7 +1167,7 @@ def voyageai(
11501167
) -> _RerankerProvider:
11511168
"""Create a `_RerankerVoyageAIConfig` object for use when reranking using the `reranker-voyageai` module.
11521169
1153-
See the [documentation](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/reranker-voyageai)
1170+
See the [documentation](https://weaviate.io/developers/weaviate/model-providers/voyageai/reranker)
11541171
for detailed usage.
11551172
11561173
Arguments:
@@ -1159,6 +1176,24 @@ def voyageai(
11591176
"""
11601177
return _RerankerVoyageAIConfig(model=model)
11611178

1179+
@staticmethod
1180+
def nvidia(
1181+
model: Optional[str] = None,
1182+
base_url: Optional[AnyHttpUrl] = None,
1183+
) -> _RerankerProvider:
1184+
"""Create a `_RerankerNvidiaConfig` object for use when reranking using the `reranker-nvidia` module.
1185+
1186+
See the [documentation](https://weaviate.io/developers/weaviate/model-providers/nvidia/reranker)
1187+
for detailed usage.
1188+
1189+
Arguments:
1190+
`model`
1191+
The model to use. Defaults to `None`, which uses the server-defined default
1192+
`baseurl`
1193+
The base URL to send the reranker requests to. Defaults to `None`, which uses the server-defined default.
1194+
"""
1195+
return _RerankerNvidiaConfig(model=model, baseURL=base_url)
1196+
11621197

11631198
class _CollectionConfigCreateBase(_ConfigCreateModel):
11641199
description: Optional[str] = Field(default=None)

0 commit comments

Comments
 (0)