@@ -16,11 +16,9 @@ use async_trait::async_trait;
16
16
use hyperactor:: Actor ;
17
17
use hyperactor:: ActorHandle ;
18
18
use hyperactor:: ActorId ;
19
- use hyperactor:: HandleClient ;
20
19
use hyperactor:: Handler ;
21
20
use hyperactor:: Instance ;
22
21
use hyperactor:: Named ;
23
- use hyperactor:: forward;
24
22
use hyperactor:: message:: Bind ;
25
23
use hyperactor:: message:: Bindings ;
26
24
use hyperactor:: message:: IndexedErasedUnbound ;
@@ -40,7 +38,10 @@ use serde::Deserialize;
40
38
use serde:: Serialize ;
41
39
use serde_bytes:: ByteBuf ;
42
40
use tokio:: sync:: Mutex ;
41
+ use tokio:: sync:: mpsc:: UnboundedReceiver ;
42
+ use tokio:: sync:: mpsc:: UnboundedSender ;
43
43
use tokio:: sync:: oneshot;
44
+ use tracing:: Instrument ;
44
45
45
46
use crate :: mailbox:: EitherPortRef ;
46
47
use crate :: mailbox:: PyMailbox ;
@@ -262,6 +263,13 @@ impl PythonActorHandle {
262
263
}
263
264
}
264
265
266
+ #[ derive( Debug ) ]
267
+ enum PanicWatcher {
268
+ ForwardTo ( UnboundedReceiver < anyhow:: Result < ( ) > > ) ,
269
+ HandlerActor ( ActorHandle < PythonActorPanicWatcher > ) ,
270
+ None ,
271
+ }
272
+
265
273
/// An actor for which message handlers are implemented in Python.
266
274
#[ derive( Debug ) ]
267
275
#[ hyperactor:: export(
@@ -280,6 +288,8 @@ pub(super) struct PythonActor {
280
288
/// Stores a reference to the Python event loop to run Python coroutines on.
281
289
/// We give each PythonActor its own even loop in its own thread.
282
290
task_locals : pyo3_async_runtimes:: TaskLocals ,
291
+ panic_watcher : PanicWatcher ,
292
+ panic_sender : UnboundedSender < anyhow:: Result < ( ) > > ,
283
293
}
284
294
285
295
#[ async_trait]
@@ -312,10 +322,29 @@ impl Actor for PythonActor {
312
322
} ) ;
313
323
rx. recv ( ) . unwrap ( )
314
324
} ) ;
315
-
316
- Ok ( Self { actor, task_locals } )
325
+ let ( tx, rx) = tokio:: sync:: mpsc:: unbounded_channel ( ) ;
326
+ Ok ( Self {
327
+ actor,
328
+ task_locals,
329
+ panic_watcher : PanicWatcher :: ForwardTo ( rx) ,
330
+ panic_sender : tx,
331
+ } )
317
332
} ) ?)
318
333
}
334
+
335
+ async fn init ( & mut self , this : & Instance < Self > ) -> anyhow:: Result < ( ) > {
336
+ self . panic_watcher = PanicWatcher :: HandlerActor (
337
+ match std:: mem:: replace ( & mut self . panic_watcher , PanicWatcher :: None ) {
338
+ PanicWatcher :: ForwardTo ( chan) => PythonActorPanicWatcher :: spawn ( this, chan) . await ?,
339
+ PanicWatcher :: HandlerActor ( actor) => {
340
+ tracing:: warn!( "init called twice" ) ;
341
+ actor
342
+ }
343
+ PanicWatcher :: None => unreachable ! ( "init called while in an invalid state" ) ,
344
+ } ,
345
+ ) ;
346
+ Ok ( ( ) )
347
+ }
319
348
}
320
349
321
350
// [Panics in async endpoints]
@@ -365,6 +394,49 @@ impl PanicFlag {
365
394
}
366
395
}
367
396
397
+ #[ derive( Debug ) ]
398
+ struct PythonActorPanicWatcher {
399
+ panic_rx : UnboundedReceiver < anyhow:: Result < ( ) > > ,
400
+ }
401
+
402
+ #[ async_trait]
403
+ impl Actor for PythonActorPanicWatcher {
404
+ type Params = UnboundedReceiver < anyhow:: Result < ( ) > > ;
405
+
406
+ async fn new ( panic_rx : UnboundedReceiver < anyhow:: Result < ( ) > > ) -> Result < Self , anyhow:: Error > {
407
+ Ok ( Self { panic_rx } )
408
+ }
409
+
410
+ async fn init ( & mut self , this : & Instance < Self > ) -> Result < ( ) , anyhow:: Error > {
411
+ this. handle ( ) . send ( HandlePanic { } ) ?;
412
+ Ok ( ( ) )
413
+ }
414
+ }
415
+
416
+ #[ derive( Debug ) ]
417
+ struct HandlePanic { }
418
+
419
+ #[ async_trait]
420
+ impl Handler < HandlePanic > for PythonActorPanicWatcher {
421
+ async fn handle ( & mut self , this : & Instance < Self > , _message : HandlePanic ) -> anyhow:: Result < ( ) > {
422
+ match self . panic_rx . recv ( ) . await {
423
+ Some ( Ok ( _) ) => {
424
+ // async endpoint executed successfully.
425
+ // run again
426
+ this. handle ( ) . send ( HandlePanic { } ) ?;
427
+ }
428
+ Some ( Err ( err) ) => {
429
+ tracing:: error!( "caught error in async endpoint {}" , err) ;
430
+ return Err ( err) ;
431
+ }
432
+ None => {
433
+ tracing:: warn!( "panic forwarding channel was closed unexpectidly" )
434
+ }
435
+ }
436
+ Ok ( ( ) )
437
+ }
438
+ }
439
+
368
440
#[ async_trait]
369
441
impl Handler < PythonMessage > for PythonActor {
370
442
async fn handle (
@@ -400,8 +472,18 @@ impl Handler<PythonMessage> for PythonActor {
400
472
} ) ?;
401
473
402
474
// Spawn a child actor to await the Python handler method.
403
- let handler = AsyncEndpointTask :: spawn ( this, ( ) ) . await ?;
404
- handler. run ( this, PythonTask :: new ( future) , receiver) . await ?;
475
+ tokio:: spawn (
476
+ handle_async_endpoint_panic (
477
+ self . panic_sender . clone ( ) ,
478
+ PythonTask :: new ( future) ,
479
+ receiver,
480
+ )
481
+ . instrument (
482
+ tracing:: info_span!( "py_panic_handler" )
483
+ . follows_from ( tracing:: Span :: current ( ) . id ( ) )
484
+ . clone ( ) ,
485
+ ) ,
486
+ ) ;
405
487
Ok ( ( ) )
406
488
}
407
489
}
@@ -448,8 +530,18 @@ impl Handler<Cast<PythonMessage>> for PythonActor {
448
530
} ) ?;
449
531
450
532
// Spawn a child actor to await the Python handler method.
451
- let handler = AsyncEndpointTask :: spawn ( this, ( ) ) . await ?;
452
- handler. run ( this, PythonTask :: new ( future) , receiver) . await ?;
533
+ tokio:: spawn (
534
+ handle_async_endpoint_panic (
535
+ self . panic_sender . clone ( ) ,
536
+ PythonTask :: new ( future) ,
537
+ receiver,
538
+ )
539
+ . instrument (
540
+ tracing:: info_span!( "py_panic_handler" )
541
+ . follows_from ( tracing:: Span :: current ( ) . id ( ) )
542
+ . clone ( ) ,
543
+ ) ,
544
+ ) ;
453
545
Ok ( ( ) )
454
546
}
455
547
}
@@ -481,77 +573,45 @@ impl fmt::Debug for PythonTask {
481
573
}
482
574
}
483
575
484
- /// An ['Actor'] used to monitor the result of an async endpoint. We use an
485
- /// actor so that:
486
- /// - Actually waiting on the async endpoint can happen concurrently with other endpoints.
487
- /// - Any uncaught errors in the async endpoint will get propagated as a supervision event.
488
- #[ derive( Debug ) ]
489
- struct AsyncEndpointTask { }
490
-
491
- /// An invocation of an async endpoint on a [`PythonActor`].
492
- #[ derive( Handler , HandleClient , Debug ) ]
493
- enum AsyncEndpointInvocation {
494
- Run ( PythonTask , oneshot:: Receiver < PyObject > ) ,
495
- }
496
-
497
- #[ async_trait]
498
- impl Actor for AsyncEndpointTask {
499
- type Params = ( ) ;
500
-
501
- async fn new ( _params : Self :: Params ) -> anyhow:: Result < Self > {
502
- Ok ( Self { } )
503
- }
504
- }
505
-
506
- #[ async_trait]
507
- #[ forward( AsyncEndpointInvocation ) ]
508
- impl AsyncEndpointInvocationHandler for AsyncEndpointTask {
509
- async fn run (
510
- & mut self ,
511
- this : & Instance < Self > ,
512
- task : PythonTask ,
513
- side_channel : oneshot:: Receiver < PyObject > ,
514
- ) -> anyhow:: Result < ( ) > {
515
- // Drive our PythonTask to completion, but listen on the side channel
516
- // and raise an error if we hear anything there.
517
-
518
- let err_or_never = async {
519
- // The side channel will resolve with a value if a panic occured during
520
- // processing of the async endpoint, see [Panics in async endpoints].
521
- match side_channel. await {
522
- Ok ( value) => Python :: with_gil ( |py| -> Result < ( ) , SerializablePyErr > {
523
- let err: PyErr = value
524
- . downcast_bound :: < PyBaseException > ( py)
525
- . unwrap ( )
526
- . clone ( )
527
- . into ( ) ;
528
- Err ( SerializablePyErr :: from ( py, & err) )
529
- } ) ,
530
- // An Err means that the sender has been dropped without sending.
531
- // That's okay, it just means that the Python task has completed.
532
- // In that case, just never resolve this future. We expect the other
533
- // branch of the select to finish eventually.
534
- Err ( _) => pending ( ) . await ,
535
- }
536
- } ;
537
- let future = task. take ( ) . await ;
538
- let result: Result < ( ) , SerializablePyErr > = tokio:: select! {
539
- result = future => {
540
- match result {
541
- Ok ( _) => Ok ( ( ) ) ,
542
- Err ( e) => Err ( e. into( ) ) ,
543
- }
544
- } ,
545
- result = err_or_never => {
546
- result
576
+ async fn handle_async_endpoint_panic (
577
+ panic_sender : UnboundedSender < anyhow:: Result < ( ) > > ,
578
+ task : PythonTask ,
579
+ side_channel : oneshot:: Receiver < PyObject > ,
580
+ ) {
581
+ let err_or_never = async {
582
+ // The side channel will resolve with a value if a panic occured during
583
+ // processing of the async endpoint, see [Panics in async endpoints].
584
+ match side_channel. await {
585
+ Ok ( value) => Python :: with_gil ( |py| -> anyhow:: Result < ( ) > {
586
+ let err: PyErr = value
587
+ . downcast_bound :: < PyBaseException > ( py)
588
+ . unwrap ( )
589
+ . clone ( )
590
+ . into ( ) ;
591
+ Err ( err. into ( ) )
592
+ } ) ,
593
+ // An Err means that the sender has been dropped without sending.
594
+ // That's okay, it just means that the Python task has completed.
595
+ // In that case, just never resolve this future. We expect the other
596
+ // branch of the select to finish eventually.
597
+ Err ( _) => pending ( ) . await ,
598
+ }
599
+ } ;
600
+ let future = task. take ( ) . await ;
601
+ let result: anyhow:: Result < ( ) > = tokio:: select! {
602
+ result = future => {
603
+ match result {
604
+ Ok ( _) => Ok ( ( ) ) ,
605
+ Err ( e) => Err ( e. into( ) ) ,
547
606
}
548
- } ;
549
- result?;
550
-
551
- // Stop this actor now that its job is done.
552
- this. stop ( ) ?;
553
- Ok ( ( ) )
554
- }
607
+ } ,
608
+ result = err_or_never => {
609
+ result
610
+ }
611
+ } ;
612
+ panic_sender
613
+ . send ( result)
614
+ . expect ( "Unable to send panic message" ) ;
555
615
}
556
616
557
617
pub fn register_python_bindings ( hyperactor_mod : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
0 commit comments