15
15
import asyncio
16
16
import inspect
17
17
import logging
18
+ import uuid
18
19
from textwrap import indent
19
20
from time import time
20
21
from typing import Any , Dict , List , Optional , Tuple
24
25
from langchain .chains .base import Chain
25
26
26
27
from nemoguardrails .actions .actions import ActionResult
28
+ from nemoguardrails .actions .core import create_event
27
29
from nemoguardrails .actions .output_mapping import is_output_blocked
28
30
from nemoguardrails .colang import parse_colang_file
29
31
from nemoguardrails .colang .runtime import Runtime
@@ -169,7 +171,7 @@ async def generate_events(
169
171
next_events = await self ._process_start_action (events )
170
172
171
173
# If we need to start a flow, we parse the content and register it.
172
- elif last_event ["type" ] == "start_flow" :
174
+ elif last_event ["type" ] == "start_flow" and last_event . get ( "flow_body" ) :
173
175
next_events = await self ._process_start_flow (
174
176
events , processing_log = processing_log
175
177
)
@@ -189,18 +191,30 @@ async def generate_events(
189
191
new_events .extend (next_events )
190
192
191
193
for event in next_events :
192
- processing_log .append (
193
- {"type" : "event" , "timestamp" : time (), "data" : event }
194
- )
194
+ if event ["type" ] != "EventHistoryUpdate" :
195
+ processing_log .append (
196
+ {"type" : "event" , "timestamp" : time (), "data" : event }
197
+ )
195
198
196
199
# If the next event is a listen, we stop the processing.
197
200
if next_events [- 1 ]["type" ] == "Listen" :
198
201
break
199
202
200
203
# As a safety measure, we stop the processing if we have too many events.
201
- if len (new_events ) > 100 :
204
+ if len (new_events ) > 300 :
202
205
raise Exception ("Too many events." )
203
206
207
+ # Unpack and insert events in event history update event if available
208
+ temp_events = []
209
+ for event in new_events :
210
+ if event ["type" ] == "EventHistoryUpdate" :
211
+ temp_events .extend (
212
+ [e for e in event ["data" ]["events" ] if e ["type" ] != "Listen" ]
213
+ )
214
+ else :
215
+ temp_events .append (event )
216
+ new_events = temp_events
217
+
204
218
return new_events
205
219
206
220
async def _compute_next_steps (
@@ -261,6 +275,210 @@ def _internal_error_action_result(message: str):
261
275
]
262
276
)
263
277
278
+ async def _run_flows_in_parallel (
279
+ self ,
280
+ flows : List [str ],
281
+ events : List [dict ],
282
+ pre_events : Optional [List [dict ]] = None ,
283
+ post_events : Optional [List [dict ]] = None ,
284
+ ) -> ActionResult :
285
+ """
286
+ Run flows in parallel.
287
+
288
+ Running flows in parallel is done by triggering a separate event loop with a `start_flow` event for each flow, in the context of the current event loop.
289
+
290
+ Args:
291
+ flows (List[str]): The list of flow names to run in parallel.
292
+ events (List[dict]): The current events.
293
+ pre_events (List[dict], optional): Events to be added before starting each flow.
294
+ post_events (List[dict], optional): Events to be added after finishing each flow.
295
+ """
296
+
297
+ if pre_events is not None and len (pre_events ) != len (flows ):
298
+ raise ValueError ("Number of pre-events must match number of flows." )
299
+ if post_events is not None and len (post_events ) != len (flows ):
300
+ raise ValueError ("Number of post-events must match number of flows." )
301
+
302
+ unique_flow_ids = {} # Keep track of unique flow IDs order
303
+ task_results : Dict [str , List ] = {} # Store results keyed by flow_id
304
+ task_processing_logs : dict = {} # Store resulting processing logs for each flow
305
+
306
+ # Wrapper function to help reverse map the task result to the flow ID
307
+ async def task_call_helper (flow_uid , post_event , func , * args , ** kwargs ):
308
+ result = await func (* args , ** kwargs )
309
+ if post_event :
310
+ result .append (post_event )
311
+ args [1 ].append (
312
+ {"type" : "event" , "timestamp" : time (), "data" : post_event }
313
+ )
314
+ return flow_uid , result
315
+
316
+ # Create a task for each flow but don't await them yet
317
+ tasks = []
318
+ for index , flow_name in enumerate (flows ):
319
+ # Copy the events to avoid modifying the original list
320
+ _events = events .copy ()
321
+
322
+ flow_params = _get_flow_params (flow_name )
323
+ flow_id = _normalize_flow_id (flow_name )
324
+
325
+ if flow_params :
326
+ _events .append (
327
+ {"type" : "start_flow" , "flow_id" : flow_id , "params" : flow_params }
328
+ )
329
+ else :
330
+ _events .append ({"type" : "start_flow" , "flow_id" : flow_id })
331
+
332
+ # Generate a unique flow ID
333
+ flow_uid = f"{ flow_id } :{ str (uuid .uuid4 ())} "
334
+
335
+ # Initialize task results and processing logs for this flow
336
+ task_results [flow_uid ] = []
337
+ task_processing_logs [flow_uid ] = []
338
+
339
+ # Add pre-event if provided
340
+ if pre_events :
341
+ task_results [flow_uid ].append (pre_events [index ])
342
+ task_processing_logs [flow_uid ].append (
343
+ {"type" : "event" , "timestamp" : time (), "data" : pre_events [index ]}
344
+ )
345
+
346
+ task = asyncio .create_task (
347
+ task_call_helper (
348
+ flow_uid ,
349
+ post_events [index ] if post_events else None ,
350
+ self .generate_events ,
351
+ _events ,
352
+ task_processing_logs [flow_uid ],
353
+ )
354
+ )
355
+ tasks .append (task )
356
+ unique_flow_ids [flow_uid ] = task
357
+
358
+ stopped_task_results : List [dict ] = []
359
+
360
+ # Process tasks as they complete using as_completed
361
+ try :
362
+ for future in asyncio .as_completed (tasks ):
363
+ try :
364
+ (flow_id , result ) = await future
365
+
366
+ # Check if this rail requested to stop
367
+ has_stop = any (
368
+ event ["type" ] == "BotIntent" and event ["intent" ] == "stop"
369
+ for event in result
370
+ )
371
+
372
+ # If this flow had a stop event
373
+ if has_stop :
374
+ stopped_task_results = task_results [flow_id ] + result
375
+
376
+ # Cancel all remaining tasks
377
+ for pending_task in tasks :
378
+ # Don't include results and processing logs for cancelled or stopped tasks
379
+ if (
380
+ pending_task != unique_flow_ids [flow_id ]
381
+ and not pending_task .done ()
382
+ ):
383
+ # Cancel the task if it is not done
384
+ pending_task .cancel ()
385
+ # Find the flow_uid for this task and remove it from the dict
386
+ for k , v in list (unique_flow_ids .items ()):
387
+ if v == pending_task :
388
+ del unique_flow_ids [k ]
389
+ break
390
+ del unique_flow_ids [flow_id ]
391
+ break
392
+ else :
393
+ # Store the result for this specific flow
394
+ task_results [flow_id ].extend (result )
395
+
396
+ except asyncio .exceptions .CancelledError :
397
+ pass
398
+
399
+ except Exception as e :
400
+ log .error (f"Error in parallel rail execution: { str (e )} " )
401
+ raise
402
+
403
+ context_updates : dict = {}
404
+ processing_log = processing_log_var .get ()
405
+
406
+ finished_task_processing_logs : List [dict ] = [] # Collect all results in order
407
+ finished_task_results : List [dict ] = [] # Collect all results in order
408
+
409
+ # Compose results in original flow order of all completed tasks
410
+ for flow_id in unique_flow_ids :
411
+ result = task_results [flow_id ]
412
+
413
+ # Extract context updates
414
+ for event in result :
415
+ if event ["type" ] == "ContextUpdate" :
416
+ context_updates = {** context_updates , ** event ["data" ]}
417
+
418
+ finished_task_results .extend (result )
419
+ finished_task_processing_logs .extend (task_processing_logs [flow_id ])
420
+
421
+ if processing_log :
422
+ for plog in finished_task_processing_logs :
423
+ # Filter out "Listen" and "start_flow" events from task processing log
424
+ if plog ["type" ] == "event" and (
425
+ plog ["data" ]["type" ] == "Listen"
426
+ or plog ["data" ]["type" ] == "start_flow"
427
+ ):
428
+ continue
429
+ processing_log .append (plog )
430
+
431
+ # We pack all events into a single event to add it to the event history.
432
+ history_events = new_event_dict (
433
+ "EventHistoryUpdate" ,
434
+ data = {"events" : finished_task_results },
435
+ )
436
+
437
+ return ActionResult (
438
+ events = [history_events ] + stopped_task_results ,
439
+ context_updates = context_updates ,
440
+ )
441
+
442
+ async def _run_input_rails_in_parallel (
443
+ self , flows : List [str ], events : List [dict ]
444
+ ) -> ActionResult :
445
+ """Run the input rails in parallel."""
446
+ pre_events = [
447
+ (await create_event ({"_type" : "StartInputRail" , "flow_id" : flow })).events [0 ]
448
+ for flow in flows
449
+ ]
450
+ post_events = [
451
+ (
452
+ await create_event ({"_type" : "InputRailFinished" , "flow_id" : flow })
453
+ ).events [0 ]
454
+ for flow in flows
455
+ ]
456
+
457
+ return await self ._run_flows_in_parallel (
458
+ flows = flows , events = events , pre_events = pre_events , post_events = post_events
459
+ )
460
+
461
+ async def _run_output_rails_in_parallel (
462
+ self , flows : List [str ], events : List [dict ]
463
+ ) -> ActionResult :
464
+ """Run the output rails in parallel."""
465
+ pre_events = [
466
+ (await create_event ({"_type" : "StartOutputRail" , "flow_id" : flow })).events [
467
+ 0
468
+ ]
469
+ for flow in flows
470
+ ]
471
+ post_events = [
472
+ (
473
+ await create_event ({"_type" : "OutputRailFinished" , "flow_id" : flow })
474
+ ).events [0 ]
475
+ for flow in flows
476
+ ]
477
+
478
+ return await self ._run_flows_in_parallel (
479
+ flows = flows , events = events , pre_events = pre_events , post_events = post_events
480
+ )
481
+
264
482
async def _run_output_rails_in_parallel_streaming (
265
483
self , flows_with_params : Dict [str , dict ], events : List [dict ]
266
484
) -> ActionResult :
@@ -472,15 +690,7 @@ async def _process_start_action(self, events: List[dict]) -> List[dict]:
472
690
next_steps = []
473
691
474
692
if context_updates :
475
- # We check if at least one key changed
476
- changes = False
477
- for k , v in context_updates .items ():
478
- if context .get (k ) != v :
479
- changes = True
480
- break
481
-
482
- if changes :
483
- next_steps .append (new_event_dict ("ContextUpdate" , data = context_updates ))
693
+ next_steps .append (new_event_dict ("ContextUpdate" , data = context_updates ))
484
694
485
695
next_steps .append (
486
696
new_event_dict (
0 commit comments