Skip to content

feat: parallel rails execution #1234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
Open
15 changes: 15 additions & 0 deletions nemoguardrails/colang/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ def __init__(self, config: RailsConfig, verbose: bool = False):
import_paths=list(config.imported_paths.values()),
)

if hasattr(self, "_run_flows_in_parallel"):
self.action_dispatcher.register_action(
self._run_flows_in_parallel, name="run_flows_in_parallel"
)

if hasattr(self, "_run_input_rails_in_parallel"):
self.action_dispatcher.register_action(
self._run_input_rails_in_parallel, name="run_input_rails_in_parallel"
)

if hasattr(self, "_run_output_rails_in_parallel"):
self.action_dispatcher.register_action(
self._run_output_rails_in_parallel, name="run_output_rails_in_parallel"
)

# The list of additional parameters that can be passed to the actions.
self.registered_action_params: dict = {}

Expand Down
21 changes: 17 additions & 4 deletions nemoguardrails/colang/v1_0/runtime/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,10 @@ def compute_next_state(state: State, event: dict) -> State:
# Next, we try to start new flows
for flow_config in state.flow_configs.values():
# We don't allow subflow to start on their own
if flow_config.is_subflow:
# Unless there's an explicit start_flow event
if flow_config.is_subflow and (
event["type"] != "start_flow" or flow_config.id != event["flow_id"]
):
continue

# If the flow can't be started multiple times in parallel and
Expand All @@ -468,12 +471,22 @@ def compute_next_state(state: State, event: dict) -> State:
# We try to slide first, just in case a flow starts with sliding logic
start_head = slide(new_state, flow_config, 0)

# If the first element matches the current event, we start a new flow
if _is_match(flow_config.elements[start_head], event):
# If the first element matches the current event,
# or, if the flow is explicitly started by a `start_flow` event,
# we start a new flow
_is_start_match = _is_match(flow_config.elements[start_head], event)
if _is_start_match or (
event["type"] == "start_flow" and flow_config.id == event["flow_id"]
):
flow_uid = new_uuid()
flow_state = FlowState(
uid=flow_uid, flow_id=flow_config.id, head=start_head + 1
uid=flow_uid,
flow_id=flow_config.id,
# When we have a match, we skip the element that was matched and move the head to the next one
head=start_head + (1 if _is_start_match else 0),
)
if params := event.get("params"):
new_state.context_updates.update(params)
new_state.flow_states.append(flow_state)

_slide_with_subflows(new_state, flow_state)
Expand Down
241 changes: 226 additions & 15 deletions nemoguardrails/colang/v1_0/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import inspect
import logging
import uuid
Expand All @@ -25,10 +25,13 @@
from langchain.chains.base import Chain

from nemoguardrails.actions.actions import ActionResult
from nemoguardrails.actions.core import create_event
from nemoguardrails.colang import parse_colang_file
from nemoguardrails.colang.runtime import Runtime
from nemoguardrails.colang.v1_0.runtime.flows import (
FlowConfig,
_get_flow_params,
_normalize_flow_id,
compute_context,
compute_next_steps,
)
Expand Down Expand Up @@ -167,7 +170,7 @@ async def generate_events(
next_events = await self._process_start_action(events)

# If we need to start a flow, we parse the content and register it.
elif last_event["type"] == "start_flow":
elif last_event["type"] == "start_flow" and last_event.get("flow_body"):
next_events = await self._process_start_flow(
events, processing_log=processing_log
)
Expand All @@ -187,18 +190,30 @@ async def generate_events(
new_events.extend(next_events)

for event in next_events:
processing_log.append(
{"type": "event", "timestamp": time(), "data": event}
)
if event["type"] != "EventHistoryUpdate":
processing_log.append(
{"type": "event", "timestamp": time(), "data": event}
)

# If the next event is a listen, we stop the processing.
if next_events[-1]["type"] == "Listen":
break

# As a safety measure, we stop the processing if we have too many events.
if len(new_events) > 100:
if len(new_events) > 300:
raise Exception("Too many events.")

# Unpack and insert events in event history update event if available
temp_events = []
for event in new_events:
if event["type"] == "EventHistoryUpdate":
temp_events.extend(
[e for e in event["data"]["events"] if e["type"] != "Listen"]
)
else:
temp_events.append(event)
new_events = temp_events

return new_events

async def _compute_next_steps(
Expand Down Expand Up @@ -259,6 +274,210 @@ def _internal_error_action_result(message: str):
]
)

async def _run_flows_in_parallel(
self,
flows: List[str],
events: List[dict],
pre_events: Optional[List[dict]] = None,
post_events: Optional[List[dict]] = None,
) -> ActionResult:
"""
Run flows in parallel.

Running flows in parallel is done by triggering a separate event loop with a `start_flow` event for each flow, in the context of the current event loop.

Args:
flows (List[str]): The list of flow names to run in parallel.
events (List[dict]): The current events.
pre_events (List[dict], optional): Events to be added before starting each flow.
post_events (List[dict], optional): Events to be added after finishing each flow.
"""

if pre_events is not None and len(pre_events) != len(flows):
raise ValueError("Number of pre-events must match number of flows.")
if post_events is not None and len(post_events) != len(flows):
raise ValueError("Number of post-events must match number of flows.")

unique_flow_ids = {} # Keep track of unique flow IDs order
task_results: Dict[str, List] = {} # Store results keyed by flow_id
task_processing_logs: dict = {} # Store resulting processing logs for each flow

# Wrapper function to help reverse map the task result to the flow ID
async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
result = await func(*args, **kwargs)
if post_event:
result.append(post_event)
args[1].append(
{"type": "event", "timestamp": time(), "data": post_event}
)
return flow_uid, result

# Create a task for each flow but don't await them yet
tasks = []
for index, flow_name in enumerate(flows):
# Copy the events to avoid modifying the original list
_events = events.copy()

flow_params = _get_flow_params(flow_name)
flow_id = _normalize_flow_id(flow_name)

if flow_params:
_events.append(
{"type": "start_flow", "flow_id": flow_id, "params": flow_params}
)
else:
_events.append({"type": "start_flow", "flow_id": flow_id})

# Generate a unique flow ID
flow_uid = f"{flow_id}:{str(uuid.uuid4())}"

# Initialize task results and processing logs for this flow
task_results[flow_uid] = []
task_processing_logs[flow_uid] = []

# Add pre-event if provided
if pre_events:
task_results[flow_uid].append(pre_events[index])
task_processing_logs[flow_uid].append(
{"type": "event", "timestamp": time(), "data": pre_events[index]}
)

task = asyncio.create_task(
task_call_helper(
flow_uid,
post_events[index] if post_events else None,
self.generate_events,
_events,
task_processing_logs[flow_uid],
)
)
tasks.append(task)
unique_flow_ids[flow_uid] = task

stopped_task_results: List[dict] = []

# Process tasks as they complete using as_completed
try:
for future in asyncio.as_completed(tasks):
try:
(flow_id, result) = await future

# Check if this rail requested to stop
has_stop = any(
event["type"] == "BotIntent" and event["intent"] == "stop"
for event in result
)

# If this flow had a stop event
if has_stop:
stopped_task_results = task_results[flow_id] + result

# Cancel all remaining tasks
for pending_task in tasks:
# Don't include results and processing logs for cancelled or stopped tasks
if (
pending_task != unique_flow_ids[flow_id]
and not pending_task.done()
):
# Cancel the task if it is not done
pending_task.cancel()
# Find the flow_uid for this task and remove it from the dict
for k, v in list(unique_flow_ids.items()):
if v == pending_task:
del unique_flow_ids[k]
break
del unique_flow_ids[flow_id]
break
else:
# Store the result for this specific flow
task_results[flow_id].extend(result)

except asyncio.exceptions.CancelledError:
pass

except Exception as e:
log.error(f"Error in parallel rail execution: {str(e)}")
raise

context_updates: dict = {}
processing_log = processing_log_var.get()

finished_task_processing_logs: List[dict] = [] # Collect all results in order
finished_task_results: List[dict] = [] # Collect all results in order

# Compose results in original flow order of all completed tasks
for flow_id in unique_flow_ids:
result = task_results[flow_id]

# Extract context updates
for event in result:
if event["type"] == "ContextUpdate":
context_updates = {**context_updates, **event["data"]}

finished_task_results.extend(result)
finished_task_processing_logs.extend(task_processing_logs[flow_id])

if processing_log:
for plog in finished_task_processing_logs:
# Filter out "Listen" and "start_flow" events from task processing log
if plog["type"] == "event" and (
plog["data"]["type"] == "Listen"
or plog["data"]["type"] == "start_flow"
):
continue
processing_log.append(plog)

# We pack all events into a single event to add it to the event history.
history_events = new_event_dict(
"EventHistoryUpdate",
data={"events": finished_task_results},
)

return ActionResult(
events=[history_events] + stopped_task_results,
context_updates=context_updates,
)

async def _run_input_rails_in_parallel(
self, flows: List[str], events: List[dict]
) -> ActionResult:
"""Run the input rails in parallel."""
pre_events = [
(await create_event({"_type": "StartInputRail", "flow_id": flow})).events[0]
for flow in flows
]
post_events = [
(
await create_event({"_type": "InputRailFinished", "flow_id": flow})
).events[0]
for flow in flows
]

return await self._run_flows_in_parallel(
flows=flows, events=events, pre_events=pre_events, post_events=post_events
)

async def _run_output_rails_in_parallel(
self, flows: List[str], events: List[dict]
) -> ActionResult:
"""Run the output rails in parallel."""
pre_events = [
(await create_event({"_type": "StartOutputRail", "flow_id": flow})).events[
0
]
for flow in flows
]
post_events = [
(
await create_event({"_type": "OutputRailFinished", "flow_id": flow})
).events[0]
for flow in flows
]

return await self._run_flows_in_parallel(
flows=flows, events=events, pre_events=pre_events, post_events=post_events
)

async def _process_start_action(self, events: List[dict]) -> List[dict]:
"""
Start the specified action, wait for it to finish, and post back the result.
Expand Down Expand Up @@ -387,15 +606,7 @@ async def _process_start_action(self, events: List[dict]) -> List[dict]:
next_steps = []

if context_updates:
# We check if at least one key changed
changes = False
for k, v in context_updates.items():
if context.get(k) != v:
changes = True
break

if changes:
next_steps.append(new_event_dict("ContextUpdate", data=context_updates))
next_steps.append(new_event_dict("ContextUpdate", data=context_updates))

next_steps.append(
new_event_dict(
Expand Down
7 changes: 6 additions & 1 deletion nemoguardrails/logging/processing_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:
generation_log = GenerationLog()

# The list of actions to ignore during the processing.
ignored_actions = ["create_event"]
ignored_actions = [
"create_event",
"run_input_rails_in_parallel",
"run_output_rails_in_parallel",
"run_flows_in_parallel",
]
ignored_flows = [
"process user input",
"run input rails",
Expand Down
10 changes: 10 additions & 0 deletions nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,11 @@ class CoreConfig(BaseModel):
class InputRails(BaseModel):
"""Configuration of input rails."""

parallel: Optional[bool] = Field(
default=False,
description="If True, the input rails are executed in parallel.",
)

flows: List[str] = Field(
default_factory=list,
description="The names of all the flows that implement input rails.",
Expand Down Expand Up @@ -454,6 +459,11 @@ class OutputRailsStreamingConfig(BaseModel):
class OutputRails(BaseModel):
"""Configuration of output rails."""

parallel: Optional[bool] = Field(
default=False,
description="If True, the output rails are executed in parallel.",
)

flows: List[str] = Field(
default_factory=list,
description="The names of all the flows that implement output rails.",
Expand Down
Loading