Skip to content

Commit 11aa195

Browse files
committed
detect async
1 parent 24e9bc1 commit 11aa195

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

src/cleanlab_codex/validator.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,13 @@ def __init__(
119119
raise ValueError(error_msg)
120120

121121
def validate(self, query: str, context: str, response: str, prompt=None, form_prompt=None):
122-
return asyncio.run(self.validate_async(query, context, response, prompt, form_prompt))
123-
122+
try:
123+
return asyncio.run(self.validate_async(query, context, response, prompt, form_prompt))
124+
except RuntimeError:
125+
# If inside an event loop, use `asyncio.create_task()`
126+
loop = asyncio.get_event_loop()
127+
return loop.create_task(self.validate_async(query, context, response, prompt, form_prompt))
128+
124129
async def validate_async(
125130
self,
126131
query: str,
@@ -142,10 +147,11 @@ async def validate_async(
142147
- 'is_bad_response': True if the response is flagged as potentially bad (when True, a lookup in Codex is performed), False otherwise.
143148
- Additional keys: Various keys from a [`ThresholdedTrustworthyRAGScore`](/cleanlab_codex/types/validator/#class-thresholdedtrustworthyragscore) dictionary, with raw scores from [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) for each evaluation metric. `is_bad` indicating whether the score is below the threshold.
144149
"""
145-
expert_task = asyncio.create_task(self.remediate_async(query))
146-
scores, is_bad_response = await self.detect(query, context, response, prompt, form_prompt)
150+
detect_task = self.detect(query, context, response, prompt, form_prompt)
151+
remediate_task = self.remediate_async(query)
152+
scores, is_bad_response = await detect_task
147153
if is_bad_response:
148-
expert_answer, maybe_entry = await expert_task
154+
expert_answer, maybe_entry = await remediate_task
149155
if expert_answer == None:
150156
self._project._sdk_client.projects.entries.add_question(
151157
self._project._id, question=query).model_dump()
@@ -158,7 +164,7 @@ async def validate_async(
158164
**scores,
159165
}
160166

161-
def detect(
167+
async def detect(
162168
self,
163169
query: str,
164170
context: str,
@@ -180,7 +186,7 @@ def detect(
180186
- bool: True if the response is determined to be bad based on the evaluation scores
181187
and configured thresholds, False otherwise.
182188
"""
183-
scores = self._tlm_rag.score(
189+
scores = await self._tlm_rag.score(
184190
response=response,
185191
query=query,
186192
context=context,

0 commit comments

Comments
 (0)