Skip to content

feat: implement parallel streaming output rails execution #1263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions nemoguardrails/colang/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def __init__(self, config: RailsConfig, verbose: bool = False):
import_paths=list(config.imported_paths.values()),
)

if hasattr(self, "_run_output_rails_in_parallel_streaming"):
self.action_dispatcher.register_action(
self._run_output_rails_in_parallel_streaming,
name="run_output_rails_in_parallel_streaming",
)

# The list of additional parameters that can be passed to the actions.
self.registered_action_params: dict = {}

Expand Down
94 changes: 90 additions & 4 deletions nemoguardrails/colang/v1_0/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import inspect
import logging
import uuid
from textwrap import indent
from time import time
from typing import Any, Dict, List, Optional, Tuple
Expand All @@ -25,10 +24,13 @@
from langchain.chains.base import Chain

from nemoguardrails.actions.actions import ActionResult
from nemoguardrails.actions.output_mapping import is_output_blocked
from nemoguardrails.colang import parse_colang_file
from nemoguardrails.colang.runtime import Runtime
from nemoguardrails.colang.v1_0.runtime.flows import (
FlowConfig,
_get_flow_params,
_normalize_flow_id,
compute_context,
compute_next_steps,
)
Expand Down Expand Up @@ -259,6 +261,89 @@ def _internal_error_action_result(message: str):
]
)

async def _run_output_rails_in_parallel_streaming(
self, flows_with_params: Dict[str, dict], events: List[dict]
) -> ActionResult:
"""Run the output rails in parallel for streaming chunks.

This is a streamlined version that avoids the full flow state management
which can cause issues with hide_prev_turn logic during streaming.

Args:
flows_with_params: Dictionary mapping flow_id to {"action_name": str, "params": dict}
events: The events list for context
"""
tasks = []

async def run_single_rail(flow_id: str, action_info: dict) -> tuple:
"""Run a single rail flow and return (flow_id, result)"""

try:
action_name = action_info["action_name"]
params = action_info["params"]

result_tuple = await self.action_dispatcher.execute_action(
action_name, params
)
result, status = result_tuple

if status != "success":
log.error(f"Action {action_name} failed with status: {status}")
return flow_id, False # Allow on failure

action_func = self.action_dispatcher.get_action(action_name)

# use the mapping to decide if the result indicates blocked content.
# True means blocked, False means allowed
result = is_output_blocked(result, action_func)

return flow_id, result

except Exception as e:
log.error(f"Error executing rail {flow_id}: {e}")
return flow_id, False # Allow on error

# create tasks for all flows
for flow_id, action_info in flows_with_params.items():
task = asyncio.create_task(run_single_rail(flow_id, action_info))
tasks.append(task)

stopped_events = []

try:
for future in asyncio.as_completed(tasks):
try:
flow_id, is_blocked = await future

# check if this rail blocked the content
if is_blocked:
# create stop events
stopped_events = [
{
"type": "BotIntent",
"intent": "stop",
"flow_id": flow_id,
}
]

# cancel remaining tasks
for pending_task in tasks:
if not pending_task.done():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we try and cancel a task that's already done, will that throw an exception? A task could finish in between the pending_task.done() and pending_task.cancel() calls. We need to be able to cancel a task that's already done and no Exceptions to be thrown

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we try and cancel a task that's already done, will that throw an exception?

calling cancel() on a task that is already done or cancelled will not throw an exception. It simply returns False and has no effect. There is no need to guard cancel() with a try/except for this reason.

but this is a point good we might want to document.

A task could finish in between the pending_task.done() and pending_task.cancel() calls

it shows a race condition which is the expected behavior, but we can log a warning:

For example

  1. rails A, B, C are running in parallel
  2. rail A completes and says "blocked"
  3. while we're processing A's result rail B also completes and says "blocked"
  4. We cancel rail C (which is still pending)
  5. We only report that A blocked, not B

based on our assumptions:

  • not a bug: multiple rails might detect violations, but we only need to know that at least one violation occurred
  • expected behavior: first violation detected stops the stream
  • correct outcome: content is blocked regardless of which rail detected it first

so for content modifying rails one should use sequential execution.
for read only validation rails parallel execution is fine.

pending_task.cancel()
break

except asyncio.CancelledError:
pass
except Exception as e:
log.error(f"Error in parallel rail task: {e}")
continue

except Exception as e:
log.error(f"Error in parallel rail execution: {e}")
return ActionResult(events=[])

return ActionResult(events=stopped_events)

async def _process_start_action(self, events: List[dict]) -> List[dict]:
"""
Start the specified action, wait for it to finish, and post back the result.
Expand Down Expand Up @@ -458,8 +543,9 @@ async def _get_action_resp(
)

resp = await resp.json()
result, status = resp.get("result", result), resp.get(
"status", status
result, status = (
resp.get("result", result),
resp.get("status", status),
)
except Exception as e:
log.info(f"Exception {e} while making request to {action_name}")
Expand Down
5 changes: 5 additions & 0 deletions nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,11 @@ class OutputRailsStreamingConfig(BaseModel):
class OutputRails(BaseModel):
"""Configuration of output rails."""

parallel: Optional[bool] = Field(
default=False,
description="If True, the output rails are executed in parallel.",
)

flows: List[str] = Field(
default_factory=list,
description="The names of all the flows that implement output rails.",
Expand Down
172 changes: 135 additions & 37 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from nemoguardrails.logging.verbose import set_verbose
from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop
from nemoguardrails.rails.llm.buffer import get_buffer_strategy
from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, Model, RailsConfig
from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, RailsConfig
from nemoguardrails.rails.llm.options import (
GenerationLog,
GenerationOptions,
Expand Down Expand Up @@ -1351,6 +1351,32 @@ def _get_latest_user_message(
return message
return {}

def _prepare_context_for_parallel_rails(
Copy link
Preview

Copilot AI Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Defining helper functions (_prepare_context_for_parallel_rails, _create_events_for_chunk) inside a method can make the code harder to navigate; consider extracting them to module-level helpers.

Copilot uses AI. Check for mistakes.

chunk_str: str,
prompt: Optional[str] = None,
messages: Optional[List[dict]] = None,
) -> dict:
"""Prepare context for parallel rails execution."""
context_message = _get_last_context_message(messages)
user_message = prompt or _get_latest_user_message(messages)

context = {
"user_message": user_message,
"bot_message": chunk_str,
}

if context_message:
context.update(context_message["content"])

Comment on lines +1369 to +1370
Copy link
Preview

Copilot AI Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating the context dict with a string (context_message["content"]) will iterate over its characters; instead assign the content under a specific key or ensure the updated value is a mapping.

Suggested change
context.update(context_message["content"])
if isinstance(context_message.get("content"), dict):
context.update(context_message["content"])
else:
log.warning("context_message['content'] is not a dictionary and will be ignored.")

Copilot uses AI. Check for mistakes.

return context

def _create_events_for_chunk(chunk_str: str, context: dict) -> List[dict]:
"""Create events for running output rails on a chunk."""
return [
{"type": "ContextUpdate", "data": context},
{"type": "BotMessage", "text": chunk_str},
]

def _prepare_params(
flow_id: str,
action_name: str,
Expand Down Expand Up @@ -1404,6 +1430,8 @@ def _prepare_params(
_get_action_details_from_flow_id, flows=self.config.flows
)

parallel_mode = getattr(self.config.rails.output, "parallel", False)

async for chunk_batch in buffer_strategy(streaming_handler):
user_output_chunks = chunk_batch.user_output_chunks
# format processing_context for output rails processing (needs full context)
Expand All @@ -1427,48 +1455,118 @@ def _prepare_params(
for chunk in user_output_chunks:
yield chunk

for flow_id in output_rails_flows_id:
action_name, action_params = get_action_details(flow_id)
if parallel_mode:
try:
context = _prepare_context_for_parallel_rails(
bot_response_chunk, prompt, messages
)
events = _create_events_for_chunk(bot_response_chunk, context)

flows_with_params = {}
for flow_id in output_rails_flows_id:
action_name, action_params = get_action_details(flow_id)
params = _prepare_params(
flow_id=flow_id,
action_name=action_name,
bot_response_chunk=bot_response_chunk,
prompt=prompt,
messages=messages,
action_params=action_params,
)
flows_with_params[flow_id] = {
"action_name": action_name,
"params": params,
}

result_tuple = await self.runtime.action_dispatcher.execute_action(
"run_output_rails_in_parallel_streaming",
{
"flows_with_params": flows_with_params,
"events": events,
},
)

params = _prepare_params(
flow_id=flow_id,
action_name=action_name,
bot_response_chunk=bot_response_chunk,
prompt=prompt,
messages=messages,
action_params=action_params,
)
# ActionDispatcher.execute_action always returns (result, status)
result, status = result_tuple

result = await self.runtime.action_dispatcher.execute_action(
action_name, params
)
if status != "success":
log.error(
f"Parallel rails execution failed with status: {status}"
)
# continue processing the chunk even if rails fail
pass
else:
# if there are any stop events, content was blocked
if result.events:
# extract the blocked flow from the first stop event
blocked_flow = result.events[0].get(
"flow_id", "output rails"
)

reason = f"Blocked by {blocked_flow} rails."
error_data = {
"error": {
"message": reason,
"type": "guardrails_violation",
"param": blocked_flow,
"code": "content_blocked",
}
}
yield json.dumps(error_data)
return

except Exception as e:
log.error(f"Error in parallel rail execution: {e}")
# don't block the stream for rail execution errors
# continue processing the chunk
pass

# update explain info for parallel mode
self.explain_info = self._ensure_explain_info()

action_func = self.runtime.action_dispatcher.get_action(action_name)

# Use the mapping to decide if the result indicates blocked content.
if is_output_blocked(result, action_func):
reason = f"Blocked by {flow_id} rails."

# return the error as a plain JSON string (not in SSE format)
# NOTE: When integrating with the OpenAI Python client, the server code should:
# 1. detect this JSON error object in the stream
# 2. terminate the stream
# 3. format the error following OpenAI's SSE format
# the OpenAI client will then properly raise an APIError with this error message

error_data = {
"error": {
"message": reason,
"type": "guardrails_violation",
"param": flow_id,
"code": "content_blocked",
else:
for flow_id in output_rails_flows_id:
action_name, action_params = get_action_details(flow_id)

params = _prepare_params(
flow_id=flow_id,
action_name=action_name,
bot_response_chunk=bot_response_chunk,
prompt=prompt,
messages=messages,
action_params=action_params,
)

result = await self.runtime.action_dispatcher.execute_action(
action_name, params
)
self.explain_info = self._ensure_explain_info()

action_func = self.runtime.action_dispatcher.get_action(action_name)

# Use the mapping to decide if the result indicates blocked content.
if is_output_blocked(result, action_func):
reason = f"Blocked by {flow_id} rails."

# return the error as a plain JSON string (not in SSE format)
# NOTE: When integrating with the OpenAI Python client, the server code should:
# 1. detect this JSON error object in the stream
# 2. terminate the stream
# 3. format the error following OpenAI's SSE format
# the OpenAI client will then properly raise an APIError with this error message

error_data = {
"error": {
"message": reason,
"type": "guardrails_violation",
"param": flow_id,
"code": "content_blocked",
}
}
}

# return as plain JSON: the server should detect this JSON and convert it to an HTTP error
yield json.dumps(error_data)
return
# return as plain JSON: the server should detect this JSON and convert it to an HTTP error
yield json.dumps(error_data)
return

if not stream_first:
# yield the individual chunks directly from the buffer strategy
Expand Down
Loading