@@ -234,6 +234,35 @@ def run_mteb_rerank(cross_encoder, tasks, languages):
234
234
return main_score
235
235
236
236
237
+ def mteb_test_rerank_models_hf (hf_runner , model_name , hf_model_callback = None ):
238
+ with hf_runner (model_name , is_cross_encoder = True ,
239
+ dtype = "float32" ) as hf_model :
240
+
241
+ original_predict = hf_model .predict
242
+
243
+ def _predict (
244
+ sentences : list [tuple [str , str ,
245
+ Optional [str ]]], # query, corpus, prompt
246
+ * args ,
247
+ ** kwargs ,
248
+ ):
249
+ # vllm and st both remove the prompt, fair comparison.
250
+ prompts = [(s [0 ], s [1 ]) for s in sentences ]
251
+ return original_predict (prompts , * args , ** kwargs , batch_size = 8 )
252
+
253
+ hf_model .predict = _predict
254
+ hf_model .original_predict = original_predict
255
+
256
+ if hf_model_callback is not None :
257
+ hf_model_callback (hf_model )
258
+
259
+ st_main_score = run_mteb_rerank (hf_model ,
260
+ tasks = MTEB_RERANK_TASKS ,
261
+ languages = MTEB_RERANK_LANGS )
262
+ st_dtype = next (hf_model .model .model .parameters ()).dtype
263
+ return st_main_score , st_dtype
264
+
265
+
237
266
def mteb_test_rerank_models (hf_runner ,
238
267
vllm_runner ,
239
268
model_info : RerankModelInfo ,
@@ -261,31 +290,8 @@ def mteb_test_rerank_models(hf_runner,
261
290
languages = MTEB_RERANK_LANGS )
262
291
vllm_dtype = vllm_model .model .llm_engine .model_config .dtype
263
292
264
- with hf_runner (model_info .name , is_cross_encoder = True ,
265
- dtype = "float32" ) as hf_model :
266
-
267
- original_predict = hf_model .predict
268
-
269
- def _predict (
270
- sentences : list [tuple [str , str ,
271
- Optional [str ]]], # query, corpus, prompt
272
- * args ,
273
- ** kwargs ,
274
- ):
275
- # vllm and st both remove the prompt, fair comparison.
276
- prompts = [(s [0 ], s [1 ]) for s in sentences ]
277
- return original_predict (prompts , * args , ** kwargs , batch_size = 8 )
278
-
279
- hf_model .predict = _predict
280
- hf_model .original_predict = original_predict
281
-
282
- if hf_model_callback is not None :
283
- hf_model_callback (hf_model )
284
-
285
- st_main_score = run_mteb_rerank (hf_model ,
286
- tasks = MTEB_RERANK_TASKS ,
287
- languages = MTEB_RERANK_LANGS )
288
- st_dtype = next (hf_model .model .model .parameters ()).dtype
293
+ st_main_score , st_dtype = mteb_test_rerank_models_hf (
294
+ hf_runner , model_info .name , hf_model_callback )
289
295
290
296
print ("VLLM:" , vllm_dtype , vllm_main_score )
291
297
print ("SentenceTransformers:" , st_dtype , st_main_score )
0 commit comments