26
26
from concurrent .futures import ThreadPoolExecutor
27
27
from typing import Callable , Literal
28
28
29
+ from pydantic import BaseModel
29
30
from tqdm import tqdm
30
31
31
32
from lighteval .utils .imports import is_litellm_available , is_openai_available , is_vllm_available
33
+ from lighteval .utils .utils import as_list
32
34
33
35
34
36
logging .getLogger ("openai" ).setLevel (logging .ERROR )
35
37
logging .getLogger ("httpx" ).setLevel (logging .ERROR )
36
38
logger = logging .getLogger (__name__ )
37
39
38
40
41
+ DEFAULT_FORMAT = {"type" : "text" }
42
+
43
+
39
44
class JudgeLM :
40
45
"""
41
46
A class representing a judge for evaluating answers using either the OpenAI or Transformers library.
@@ -76,6 +81,7 @@ def __init__(
76
81
judge_backend : Literal ["litellm" , "openai" , "transformers" , "tgi" , "vllm" ],
77
82
url : str | None = None ,
78
83
api_key : str | None = None ,
84
+ response_format : BaseModel = None ,
79
85
):
80
86
self .model = model
81
87
self .template = templates
@@ -91,6 +97,8 @@ def __init__(
91
97
self .api_key = api_key
92
98
self .backend = judge_backend
93
99
100
+ self .response_format = response_format if not None else DEFAULT_FORMAT
101
+
94
102
def __lazy_load_client (self ):
95
103
match self .backend :
96
104
# Wether we use openai or TGI models, we go through the openai API
@@ -232,7 +240,7 @@ def __call_api(prompt):
232
240
233
241
def __call_api_parallel (self , prompts ):
234
242
results = []
235
- with ThreadPoolExecutor (100 ) as executor :
243
+ with ThreadPoolExecutor (10 ) as executor :
236
244
for entry in tqdm (executor .map (self .__call_api , prompts ), total = len (prompts )):
237
245
results .append (entry )
238
246
@@ -244,16 +252,34 @@ def __call_api_parallel(self, prompts):
244
252
def __call_api (self , prompt ):
245
253
for _ in range (self .API_MAX_RETRY ):
246
254
try :
247
- response = self .client .chat .completions .create (
255
+ # Base model
256
+ response = self .client .beta .chat .completions .parse (
248
257
model = self .model ,
249
- messages = prompt ,
250
- response_format = {"type" : "text" },
251
- max_tokens = 512 ,
258
+ messages = as_list (prompt ),
259
+ response_format = self .response_format ,
260
+ max_tokens = 4096 ,
261
+ temperature = 0.0 ,
252
262
n = 1 ,
253
263
)
254
- text = response .choices [0 ].message .content
255
- return text
264
+ answer = response .choices [0 ].message .parsed
265
+ return answer
266
+ except TypeError :
267
+ try :
268
+ # Finetune
269
+ response = self .client .chat .completions .create (
270
+ model = self .model ,
271
+ messages = as_list (prompt ),
272
+ response_format = self .response_format ,
273
+ max_tokens = 512 ,
274
+ n = 1 ,
275
+ )
276
+ text = response .choices [0 ].message .content
277
+ return text
278
+ except Exception as e :
279
+ logger .warning (f"{ type (e ), e } " )
280
+ time .sleep (self .API_RETRY_SLEEP )
256
281
except Exception as e :
257
282
logger .warning (f"{ type (e ), e } " )
258
283
time .sleep (self .API_RETRY_SLEEP )
284
+
259
285
raise Exception ("Failed to get response from the API" )
0 commit comments