-
Notifications
You must be signed in to change notification settings - Fork 512
feat: implement parallel streaming output rails execution #1263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -66,7 +66,7 @@ | |||||||||||
from nemoguardrails.logging.verbose import set_verbose | ||||||||||||
from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop | ||||||||||||
from nemoguardrails.rails.llm.buffer import get_buffer_strategy | ||||||||||||
from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, Model, RailsConfig | ||||||||||||
from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, RailsConfig | ||||||||||||
from nemoguardrails.rails.llm.options import ( | ||||||||||||
GenerationLog, | ||||||||||||
GenerationOptions, | ||||||||||||
|
@@ -1351,6 +1351,32 @@ def _get_latest_user_message( | |||||||||||
return message | ||||||||||||
return {} | ||||||||||||
|
||||||||||||
def _prepare_context_for_parallel_rails( | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] Defining helper functions ( Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||
chunk_str: str, | ||||||||||||
prompt: Optional[str] = None, | ||||||||||||
messages: Optional[List[dict]] = None, | ||||||||||||
) -> dict: | ||||||||||||
"""Prepare context for parallel rails execution.""" | ||||||||||||
context_message = _get_last_context_message(messages) | ||||||||||||
user_message = prompt or _get_latest_user_message(messages) | ||||||||||||
|
||||||||||||
context = { | ||||||||||||
"user_message": user_message, | ||||||||||||
"bot_message": chunk_str, | ||||||||||||
} | ||||||||||||
|
||||||||||||
if context_message: | ||||||||||||
context.update(context_message["content"]) | ||||||||||||
Pouyanpi marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
|
||||||||||||
Comment on lines
+1369
to
+1370
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updating the context dict with a string (
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||
return context | ||||||||||||
|
||||||||||||
def _create_events_for_chunk(chunk_str: str, context: dict) -> List[dict]: | ||||||||||||
"""Create events for running output rails on a chunk.""" | ||||||||||||
return [ | ||||||||||||
{"type": "ContextUpdate", "data": context}, | ||||||||||||
{"type": "BotMessage", "text": chunk_str}, | ||||||||||||
] | ||||||||||||
|
||||||||||||
def _prepare_params( | ||||||||||||
flow_id: str, | ||||||||||||
action_name: str, | ||||||||||||
|
@@ -1404,6 +1430,8 @@ def _prepare_params( | |||||||||||
_get_action_details_from_flow_id, flows=self.config.flows | ||||||||||||
) | ||||||||||||
|
||||||||||||
parallel_mode = getattr(self.config.rails.output, "parallel", False) | ||||||||||||
|
||||||||||||
async for chunk_batch in buffer_strategy(streaming_handler): | ||||||||||||
user_output_chunks = chunk_batch.user_output_chunks | ||||||||||||
# format processing_context for output rails processing (needs full context) | ||||||||||||
|
@@ -1427,48 +1455,118 @@ def _prepare_params( | |||||||||||
for chunk in user_output_chunks: | ||||||||||||
yield chunk | ||||||||||||
|
||||||||||||
for flow_id in output_rails_flows_id: | ||||||||||||
action_name, action_params = get_action_details(flow_id) | ||||||||||||
if parallel_mode: | ||||||||||||
try: | ||||||||||||
context = _prepare_context_for_parallel_rails( | ||||||||||||
bot_response_chunk, prompt, messages | ||||||||||||
) | ||||||||||||
events = _create_events_for_chunk(bot_response_chunk, context) | ||||||||||||
|
||||||||||||
flows_with_params = {} | ||||||||||||
for flow_id in output_rails_flows_id: | ||||||||||||
action_name, action_params = get_action_details(flow_id) | ||||||||||||
params = _prepare_params( | ||||||||||||
flow_id=flow_id, | ||||||||||||
action_name=action_name, | ||||||||||||
bot_response_chunk=bot_response_chunk, | ||||||||||||
prompt=prompt, | ||||||||||||
messages=messages, | ||||||||||||
action_params=action_params, | ||||||||||||
) | ||||||||||||
flows_with_params[flow_id] = { | ||||||||||||
"action_name": action_name, | ||||||||||||
"params": params, | ||||||||||||
} | ||||||||||||
|
||||||||||||
result_tuple = await self.runtime.action_dispatcher.execute_action( | ||||||||||||
"run_output_rails_in_parallel_streaming", | ||||||||||||
{ | ||||||||||||
"flows_with_params": flows_with_params, | ||||||||||||
"events": events, | ||||||||||||
}, | ||||||||||||
) | ||||||||||||
|
||||||||||||
params = _prepare_params( | ||||||||||||
flow_id=flow_id, | ||||||||||||
action_name=action_name, | ||||||||||||
bot_response_chunk=bot_response_chunk, | ||||||||||||
prompt=prompt, | ||||||||||||
messages=messages, | ||||||||||||
action_params=action_params, | ||||||||||||
) | ||||||||||||
# ActionDispatcher.execute_action always returns (result, status) | ||||||||||||
result, status = result_tuple | ||||||||||||
|
||||||||||||
result = await self.runtime.action_dispatcher.execute_action( | ||||||||||||
action_name, params | ||||||||||||
) | ||||||||||||
if status != "success": | ||||||||||||
log.error( | ||||||||||||
f"Parallel rails execution failed with status: {status}" | ||||||||||||
) | ||||||||||||
# continue processing the chunk even if rails fail | ||||||||||||
pass | ||||||||||||
else: | ||||||||||||
# if there are any stop events, content was blocked | ||||||||||||
if result.events: | ||||||||||||
# extract the blocked flow from the first stop event | ||||||||||||
blocked_flow = result.events[0].get( | ||||||||||||
"flow_id", "output rails" | ||||||||||||
) | ||||||||||||
|
||||||||||||
reason = f"Blocked by {blocked_flow} rails." | ||||||||||||
error_data = { | ||||||||||||
"error": { | ||||||||||||
"message": reason, | ||||||||||||
"type": "guardrails_violation", | ||||||||||||
"param": blocked_flow, | ||||||||||||
"code": "content_blocked", | ||||||||||||
} | ||||||||||||
} | ||||||||||||
yield json.dumps(error_data) | ||||||||||||
return | ||||||||||||
|
||||||||||||
except Exception as e: | ||||||||||||
log.error(f"Error in parallel rail execution: {e}") | ||||||||||||
# don't block the stream for rail execution errors | ||||||||||||
# continue processing the chunk | ||||||||||||
pass | ||||||||||||
|
||||||||||||
# update explain info for parallel mode | ||||||||||||
self.explain_info = self._ensure_explain_info() | ||||||||||||
|
||||||||||||
action_func = self.runtime.action_dispatcher.get_action(action_name) | ||||||||||||
|
||||||||||||
# Use the mapping to decide if the result indicates blocked content. | ||||||||||||
if is_output_blocked(result, action_func): | ||||||||||||
reason = f"Blocked by {flow_id} rails." | ||||||||||||
|
||||||||||||
# return the error as a plain JSON string (not in SSE format) | ||||||||||||
# NOTE: When integrating with the OpenAI Python client, the server code should: | ||||||||||||
# 1. detect this JSON error object in the stream | ||||||||||||
# 2. terminate the stream | ||||||||||||
# 3. format the error following OpenAI's SSE format | ||||||||||||
# the OpenAI client will then properly raise an APIError with this error message | ||||||||||||
|
||||||||||||
error_data = { | ||||||||||||
"error": { | ||||||||||||
"message": reason, | ||||||||||||
"type": "guardrails_violation", | ||||||||||||
"param": flow_id, | ||||||||||||
"code": "content_blocked", | ||||||||||||
else: | ||||||||||||
for flow_id in output_rails_flows_id: | ||||||||||||
action_name, action_params = get_action_details(flow_id) | ||||||||||||
|
||||||||||||
params = _prepare_params( | ||||||||||||
flow_id=flow_id, | ||||||||||||
action_name=action_name, | ||||||||||||
bot_response_chunk=bot_response_chunk, | ||||||||||||
prompt=prompt, | ||||||||||||
messages=messages, | ||||||||||||
action_params=action_params, | ||||||||||||
) | ||||||||||||
|
||||||||||||
result = await self.runtime.action_dispatcher.execute_action( | ||||||||||||
action_name, params | ||||||||||||
) | ||||||||||||
self.explain_info = self._ensure_explain_info() | ||||||||||||
|
||||||||||||
action_func = self.runtime.action_dispatcher.get_action(action_name) | ||||||||||||
|
||||||||||||
# Use the mapping to decide if the result indicates blocked content. | ||||||||||||
if is_output_blocked(result, action_func): | ||||||||||||
reason = f"Blocked by {flow_id} rails." | ||||||||||||
|
||||||||||||
# return the error as a plain JSON string (not in SSE format) | ||||||||||||
# NOTE: When integrating with the OpenAI Python client, the server code should: | ||||||||||||
# 1. detect this JSON error object in the stream | ||||||||||||
# 2. terminate the stream | ||||||||||||
# 3. format the error following OpenAI's SSE format | ||||||||||||
# the OpenAI client will then properly raise an APIError with this error message | ||||||||||||
|
||||||||||||
error_data = { | ||||||||||||
"error": { | ||||||||||||
"message": reason, | ||||||||||||
"type": "guardrails_violation", | ||||||||||||
"param": flow_id, | ||||||||||||
"code": "content_blocked", | ||||||||||||
} | ||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
# return as plain JSON: the server should detect this JSON and convert it to an HTTP error | ||||||||||||
yield json.dumps(error_data) | ||||||||||||
return | ||||||||||||
# return as plain JSON: the server should detect this JSON and convert it to an HTTP error | ||||||||||||
yield json.dumps(error_data) | ||||||||||||
return | ||||||||||||
|
||||||||||||
if not stream_first: | ||||||||||||
# yield the individual chunks directly from the buffer strategy | ||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we try and cancel a task that's already done, will that throw an exception? A task could finish in between the
pending_task.done()
andpending_task.cancel()
calls. We need to be able to cancel a task that's already done and no Exceptions to be thrownThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
calling
cancel()
on a task that is already done or cancelled will not throw an exception. It simply returnsFalse
and has no effect. There is no need to guardcancel()
with a try/except for this reason.but this is a point good we might want to document.
it shows a race condition which is the expected behavior, but we can log a warning:
For example
based on our assumptions:
so for content modifying rails one should use sequential execution.
for read only validation rails parallel execution is fine.