Skip to content

Commit 8fdb27d

Browse files
feat: support parallel rails execution (#1234)
1 parent cd14a07 commit 8fdb27d

File tree

12 files changed

+602
-52
lines changed

12 files changed

+602
-52
lines changed

nemoguardrails/colang/runtime.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,21 @@ def __init__(self, config: RailsConfig, verbose: bool = False):
4343
name="run_output_rails_in_parallel_streaming",
4444
)
4545

46+
if hasattr(self, "_run_flows_in_parallel"):
47+
self.action_dispatcher.register_action(
48+
self._run_flows_in_parallel, name="run_flows_in_parallel"
49+
)
50+
51+
if hasattr(self, "_run_input_rails_in_parallel"):
52+
self.action_dispatcher.register_action(
53+
self._run_input_rails_in_parallel, name="run_input_rails_in_parallel"
54+
)
55+
56+
if hasattr(self, "_run_output_rails_in_parallel"):
57+
self.action_dispatcher.register_action(
58+
self._run_output_rails_in_parallel, name="run_output_rails_in_parallel"
59+
)
60+
4661
# The list of additional parameters that can be passed to the actions.
4762
self.registered_action_params: dict = {}
4863

nemoguardrails/colang/v1_0/runtime/flows.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,10 @@ def compute_next_state(state: State, event: dict) -> State:
455455
# Next, we try to start new flows
456456
for flow_config in state.flow_configs.values():
457457
# We don't allow subflow to start on their own
458-
if flow_config.is_subflow:
458+
# Unless there's an explicit start_flow event
459+
if flow_config.is_subflow and (
460+
event["type"] != "start_flow" or flow_config.id != event["flow_id"]
461+
):
459462
continue
460463

461464
# If the flow can't be started multiple times in parallel and
@@ -468,12 +471,22 @@ def compute_next_state(state: State, event: dict) -> State:
468471
# We try to slide first, just in case a flow starts with sliding logic
469472
start_head = slide(new_state, flow_config, 0)
470473

471-
# If the first element matches the current event, we start a new flow
472-
if _is_match(flow_config.elements[start_head], event):
474+
# If the first element matches the current event,
475+
# or, if the flow is explicitly started by a `start_flow` event,
476+
# we start a new flow
477+
_is_start_match = _is_match(flow_config.elements[start_head], event)
478+
if _is_start_match or (
479+
event["type"] == "start_flow" and flow_config.id == event["flow_id"]
480+
):
473481
flow_uid = new_uuid()
474482
flow_state = FlowState(
475-
uid=flow_uid, flow_id=flow_config.id, head=start_head + 1
483+
uid=flow_uid,
484+
flow_id=flow_config.id,
485+
# When we have a match, we skip the element that was matched and move the head to the next one
486+
head=start_head + (1 if _is_start_match else 0),
476487
)
488+
if params := event.get("params"):
489+
new_state.context_updates.update(params)
477490
new_state.flow_states.append(flow_state)
478491

479492
_slide_with_subflows(new_state, flow_state)

nemoguardrails/colang/v1_0/runtime/runtime.py

Lines changed: 224 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import asyncio
1616
import inspect
1717
import logging
18+
import uuid
1819
from textwrap import indent
1920
from time import time
2021
from typing import Any, Dict, List, Optional, Tuple
@@ -24,6 +25,7 @@
2425
from langchain.chains.base import Chain
2526

2627
from nemoguardrails.actions.actions import ActionResult
28+
from nemoguardrails.actions.core import create_event
2729
from nemoguardrails.actions.output_mapping import is_output_blocked
2830
from nemoguardrails.colang import parse_colang_file
2931
from nemoguardrails.colang.runtime import Runtime
@@ -169,7 +171,7 @@ async def generate_events(
169171
next_events = await self._process_start_action(events)
170172

171173
# If we need to start a flow, we parse the content and register it.
172-
elif last_event["type"] == "start_flow":
174+
elif last_event["type"] == "start_flow" and last_event.get("flow_body"):
173175
next_events = await self._process_start_flow(
174176
events, processing_log=processing_log
175177
)
@@ -189,18 +191,30 @@ async def generate_events(
189191
new_events.extend(next_events)
190192

191193
for event in next_events:
192-
processing_log.append(
193-
{"type": "event", "timestamp": time(), "data": event}
194-
)
194+
if event["type"] != "EventHistoryUpdate":
195+
processing_log.append(
196+
{"type": "event", "timestamp": time(), "data": event}
197+
)
195198

196199
# If the next event is a listen, we stop the processing.
197200
if next_events[-1]["type"] == "Listen":
198201
break
199202

200203
# As a safety measure, we stop the processing if we have too many events.
201-
if len(new_events) > 100:
204+
if len(new_events) > 300:
202205
raise Exception("Too many events.")
203206

207+
# Unpack and insert events in event history update event if available
208+
temp_events = []
209+
for event in new_events:
210+
if event["type"] == "EventHistoryUpdate":
211+
temp_events.extend(
212+
[e for e in event["data"]["events"] if e["type"] != "Listen"]
213+
)
214+
else:
215+
temp_events.append(event)
216+
new_events = temp_events
217+
204218
return new_events
205219

206220
async def _compute_next_steps(
@@ -261,6 +275,210 @@ def _internal_error_action_result(message: str):
261275
]
262276
)
263277

278+
async def _run_flows_in_parallel(
279+
self,
280+
flows: List[str],
281+
events: List[dict],
282+
pre_events: Optional[List[dict]] = None,
283+
post_events: Optional[List[dict]] = None,
284+
) -> ActionResult:
285+
"""
286+
Run flows in parallel.
287+
288+
Running flows in parallel is done by triggering a separate event loop with a `start_flow` event for each flow, in the context of the current event loop.
289+
290+
Args:
291+
flows (List[str]): The list of flow names to run in parallel.
292+
events (List[dict]): The current events.
293+
pre_events (List[dict], optional): Events to be added before starting each flow.
294+
post_events (List[dict], optional): Events to be added after finishing each flow.
295+
"""
296+
297+
if pre_events is not None and len(pre_events) != len(flows):
298+
raise ValueError("Number of pre-events must match number of flows.")
299+
if post_events is not None and len(post_events) != len(flows):
300+
raise ValueError("Number of post-events must match number of flows.")
301+
302+
unique_flow_ids = {} # Keep track of unique flow IDs order
303+
task_results: Dict[str, List] = {} # Store results keyed by flow_id
304+
task_processing_logs: dict = {} # Store resulting processing logs for each flow
305+
306+
# Wrapper function to help reverse map the task result to the flow ID
307+
async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
308+
result = await func(*args, **kwargs)
309+
if post_event:
310+
result.append(post_event)
311+
args[1].append(
312+
{"type": "event", "timestamp": time(), "data": post_event}
313+
)
314+
return flow_uid, result
315+
316+
# Create a task for each flow but don't await them yet
317+
tasks = []
318+
for index, flow_name in enumerate(flows):
319+
# Copy the events to avoid modifying the original list
320+
_events = events.copy()
321+
322+
flow_params = _get_flow_params(flow_name)
323+
flow_id = _normalize_flow_id(flow_name)
324+
325+
if flow_params:
326+
_events.append(
327+
{"type": "start_flow", "flow_id": flow_id, "params": flow_params}
328+
)
329+
else:
330+
_events.append({"type": "start_flow", "flow_id": flow_id})
331+
332+
# Generate a unique flow ID
333+
flow_uid = f"{flow_id}:{str(uuid.uuid4())}"
334+
335+
# Initialize task results and processing logs for this flow
336+
task_results[flow_uid] = []
337+
task_processing_logs[flow_uid] = []
338+
339+
# Add pre-event if provided
340+
if pre_events:
341+
task_results[flow_uid].append(pre_events[index])
342+
task_processing_logs[flow_uid].append(
343+
{"type": "event", "timestamp": time(), "data": pre_events[index]}
344+
)
345+
346+
task = asyncio.create_task(
347+
task_call_helper(
348+
flow_uid,
349+
post_events[index] if post_events else None,
350+
self.generate_events,
351+
_events,
352+
task_processing_logs[flow_uid],
353+
)
354+
)
355+
tasks.append(task)
356+
unique_flow_ids[flow_uid] = task
357+
358+
stopped_task_results: List[dict] = []
359+
360+
# Process tasks as they complete using as_completed
361+
try:
362+
for future in asyncio.as_completed(tasks):
363+
try:
364+
(flow_id, result) = await future
365+
366+
# Check if this rail requested to stop
367+
has_stop = any(
368+
event["type"] == "BotIntent" and event["intent"] == "stop"
369+
for event in result
370+
)
371+
372+
# If this flow had a stop event
373+
if has_stop:
374+
stopped_task_results = task_results[flow_id] + result
375+
376+
# Cancel all remaining tasks
377+
for pending_task in tasks:
378+
# Don't include results and processing logs for cancelled or stopped tasks
379+
if (
380+
pending_task != unique_flow_ids[flow_id]
381+
and not pending_task.done()
382+
):
383+
# Cancel the task if it is not done
384+
pending_task.cancel()
385+
# Find the flow_uid for this task and remove it from the dict
386+
for k, v in list(unique_flow_ids.items()):
387+
if v == pending_task:
388+
del unique_flow_ids[k]
389+
break
390+
del unique_flow_ids[flow_id]
391+
break
392+
else:
393+
# Store the result for this specific flow
394+
task_results[flow_id].extend(result)
395+
396+
except asyncio.exceptions.CancelledError:
397+
pass
398+
399+
except Exception as e:
400+
log.error(f"Error in parallel rail execution: {str(e)}")
401+
raise
402+
403+
context_updates: dict = {}
404+
processing_log = processing_log_var.get()
405+
406+
finished_task_processing_logs: List[dict] = [] # Collect all results in order
407+
finished_task_results: List[dict] = [] # Collect all results in order
408+
409+
# Compose results in original flow order of all completed tasks
410+
for flow_id in unique_flow_ids:
411+
result = task_results[flow_id]
412+
413+
# Extract context updates
414+
for event in result:
415+
if event["type"] == "ContextUpdate":
416+
context_updates = {**context_updates, **event["data"]}
417+
418+
finished_task_results.extend(result)
419+
finished_task_processing_logs.extend(task_processing_logs[flow_id])
420+
421+
if processing_log:
422+
for plog in finished_task_processing_logs:
423+
# Filter out "Listen" and "start_flow" events from task processing log
424+
if plog["type"] == "event" and (
425+
plog["data"]["type"] == "Listen"
426+
or plog["data"]["type"] == "start_flow"
427+
):
428+
continue
429+
processing_log.append(plog)
430+
431+
# We pack all events into a single event to add it to the event history.
432+
history_events = new_event_dict(
433+
"EventHistoryUpdate",
434+
data={"events": finished_task_results},
435+
)
436+
437+
return ActionResult(
438+
events=[history_events] + stopped_task_results,
439+
context_updates=context_updates,
440+
)
441+
442+
async def _run_input_rails_in_parallel(
443+
self, flows: List[str], events: List[dict]
444+
) -> ActionResult:
445+
"""Run the input rails in parallel."""
446+
pre_events = [
447+
(await create_event({"_type": "StartInputRail", "flow_id": flow})).events[0]
448+
for flow in flows
449+
]
450+
post_events = [
451+
(
452+
await create_event({"_type": "InputRailFinished", "flow_id": flow})
453+
).events[0]
454+
for flow in flows
455+
]
456+
457+
return await self._run_flows_in_parallel(
458+
flows=flows, events=events, pre_events=pre_events, post_events=post_events
459+
)
460+
461+
async def _run_output_rails_in_parallel(
462+
self, flows: List[str], events: List[dict]
463+
) -> ActionResult:
464+
"""Run the output rails in parallel."""
465+
pre_events = [
466+
(await create_event({"_type": "StartOutputRail", "flow_id": flow})).events[
467+
0
468+
]
469+
for flow in flows
470+
]
471+
post_events = [
472+
(
473+
await create_event({"_type": "OutputRailFinished", "flow_id": flow})
474+
).events[0]
475+
for flow in flows
476+
]
477+
478+
return await self._run_flows_in_parallel(
479+
flows=flows, events=events, pre_events=pre_events, post_events=post_events
480+
)
481+
264482
async def _run_output_rails_in_parallel_streaming(
265483
self, flows_with_params: Dict[str, dict], events: List[dict]
266484
) -> ActionResult:
@@ -472,15 +690,7 @@ async def _process_start_action(self, events: List[dict]) -> List[dict]:
472690
next_steps = []
473691

474692
if context_updates:
475-
# We check if at least one key changed
476-
changes = False
477-
for k, v in context_updates.items():
478-
if context.get(k) != v:
479-
changes = True
480-
break
481-
482-
if changes:
483-
next_steps.append(new_event_dict("ContextUpdate", data=context_updates))
693+
next_steps.append(new_event_dict("ContextUpdate", data=context_updates))
484694

485695
next_steps.append(
486696
new_event_dict(

nemoguardrails/logging/processing_log.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:
3636
generation_log = GenerationLog()
3737

3838
# The list of actions to ignore during the processing.
39-
ignored_actions = ["create_event"]
39+
ignored_actions = [
40+
"create_event",
41+
"run_input_rails_in_parallel",
42+
"run_output_rails_in_parallel",
43+
"run_flows_in_parallel",
44+
]
4045
ignored_flows = [
4146
"process user input",
4247
"run input rails",

nemoguardrails/rails/llm/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,11 @@ class CoreConfig(BaseModel):
425425
class InputRails(BaseModel):
426426
"""Configuration of input rails."""
427427

428+
parallel: Optional[bool] = Field(
429+
default=False,
430+
description="If True, the input rails are executed in parallel.",
431+
)
432+
428433
flows: List[str] = Field(
429434
default_factory=list,
430435
description="The names of all the flows that implement input rails.",

0 commit comments

Comments
 (0)