Skip to content

Commit 7537d4e

Browse files
eliothedemanfacebook-github-bot
authored andcommitted
Don't spawn a new actor for every python message (#340)
Summary: Pull Request resolved: #340 Differential Revision: D77171101
1 parent 0843625 commit 7537d4e

File tree

1 file changed

+138
-78
lines changed

1 file changed

+138
-78
lines changed

monarch_hyperactor/src/actor.rs

Lines changed: 138 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@ use async_trait::async_trait;
1616
use hyperactor::Actor;
1717
use hyperactor::ActorHandle;
1818
use hyperactor::ActorId;
19-
use hyperactor::HandleClient;
2019
use hyperactor::Handler;
2120
use hyperactor::Instance;
2221
use hyperactor::Named;
23-
use hyperactor::forward;
2422
use hyperactor::message::Bind;
2523
use hyperactor::message::Bindings;
2624
use hyperactor::message::IndexedErasedUnbound;
@@ -40,7 +38,10 @@ use serde::Deserialize;
4038
use serde::Serialize;
4139
use serde_bytes::ByteBuf;
4240
use tokio::sync::Mutex;
41+
use tokio::sync::mpsc::UnboundedReceiver;
42+
use tokio::sync::mpsc::UnboundedSender;
4343
use tokio::sync::oneshot;
44+
use tracing::Instrument;
4445

4546
use crate::mailbox::EitherPortRef;
4647
use crate::mailbox::PyMailbox;
@@ -262,6 +263,13 @@ impl PythonActorHandle {
262263
}
263264
}
264265

266+
#[derive(Debug)]
267+
enum PanicWatcher {
268+
ForwardTo(UnboundedReceiver<anyhow::Result<()>>),
269+
HandlerActor(ActorHandle<PythonActorPanicWatcher>),
270+
None,
271+
}
272+
265273
/// An actor for which message handlers are implemented in Python.
266274
#[derive(Debug)]
267275
#[hyperactor::export(
@@ -280,6 +288,8 @@ pub(super) struct PythonActor {
280288
/// Stores a reference to the Python event loop to run Python coroutines on.
281289
/// We give each PythonActor its own even loop in its own thread.
282290
task_locals: pyo3_async_runtimes::TaskLocals,
291+
panic_watcher: PanicWatcher,
292+
panic_sender: UnboundedSender<anyhow::Result<()>>,
283293
}
284294

285295
#[async_trait]
@@ -312,10 +322,29 @@ impl Actor for PythonActor {
312322
});
313323
rx.recv().unwrap()
314324
});
315-
316-
Ok(Self { actor, task_locals })
325+
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
326+
Ok(Self {
327+
actor,
328+
task_locals,
329+
panic_watcher: PanicWatcher::ForwardTo(rx),
330+
panic_sender: tx,
331+
})
317332
})?)
318333
}
334+
335+
async fn init(&mut self, this: &Instance<Self>) -> anyhow::Result<()> {
336+
self.panic_watcher = PanicWatcher::HandlerActor(
337+
match std::mem::replace(&mut self.panic_watcher, PanicWatcher::None) {
338+
PanicWatcher::ForwardTo(chan) => PythonActorPanicWatcher::spawn(this, chan).await?,
339+
PanicWatcher::HandlerActor(actor) => {
340+
tracing::warn!("init called twice");
341+
actor
342+
}
343+
PanicWatcher::None => unreachable!("init called while in an invalid state"),
344+
},
345+
);
346+
Ok(())
347+
}
319348
}
320349

321350
// [Panics in async endpoints]
@@ -365,6 +394,49 @@ impl PanicFlag {
365394
}
366395
}
367396

397+
#[derive(Debug)]
398+
struct PythonActorPanicWatcher {
399+
panic_rx: UnboundedReceiver<anyhow::Result<()>>,
400+
}
401+
402+
#[async_trait]
403+
impl Actor for PythonActorPanicWatcher {
404+
type Params = UnboundedReceiver<anyhow::Result<()>>;
405+
406+
async fn new(panic_rx: UnboundedReceiver<anyhow::Result<()>>) -> Result<Self, anyhow::Error> {
407+
Ok(Self { panic_rx })
408+
}
409+
410+
async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
411+
this.handle().send(HandlePanic {})?;
412+
Ok(())
413+
}
414+
}
415+
416+
#[derive(Debug)]
417+
struct HandlePanic {}
418+
419+
#[async_trait]
420+
impl Handler<HandlePanic> for PythonActorPanicWatcher {
421+
async fn handle(&mut self, this: &Instance<Self>, _message: HandlePanic) -> anyhow::Result<()> {
422+
match self.panic_rx.recv().await {
423+
Some(Ok(_)) => {
424+
// async endpoint executed successfully.
425+
// run again
426+
this.handle().send(HandlePanic {})?;
427+
}
428+
Some(Err(err)) => {
429+
tracing::error!("caught error in async endpoint {}", err);
430+
return Err(err);
431+
}
432+
None => {
433+
tracing::warn!("panic forwarding channel was closed unexpectidly")
434+
}
435+
}
436+
Ok(())
437+
}
438+
}
439+
368440
#[async_trait]
369441
impl Handler<PythonMessage> for PythonActor {
370442
async fn handle(
@@ -400,8 +472,18 @@ impl Handler<PythonMessage> for PythonActor {
400472
})?;
401473

402474
// Spawn a child actor to await the Python handler method.
403-
let handler = AsyncEndpointTask::spawn(this, ()).await?;
404-
handler.run(this, PythonTask::new(future), receiver).await?;
475+
tokio::spawn(
476+
handle_async_endpoint_panic(
477+
self.panic_sender.clone(),
478+
PythonTask::new(future),
479+
receiver,
480+
)
481+
.instrument(
482+
tracing::info_span!("py_panic_handler")
483+
.follows_from(tracing::Span::current().id())
484+
.clone(),
485+
),
486+
);
405487
Ok(())
406488
}
407489
}
@@ -448,8 +530,18 @@ impl Handler<Cast<PythonMessage>> for PythonActor {
448530
})?;
449531

450532
// Spawn a child actor to await the Python handler method.
451-
let handler = AsyncEndpointTask::spawn(this, ()).await?;
452-
handler.run(this, PythonTask::new(future), receiver).await?;
533+
tokio::spawn(
534+
handle_async_endpoint_panic(
535+
self.panic_sender.clone(),
536+
PythonTask::new(future),
537+
receiver,
538+
)
539+
.instrument(
540+
tracing::info_span!("py_panic_handler")
541+
.follows_from(tracing::Span::current().id())
542+
.clone(),
543+
),
544+
);
453545
Ok(())
454546
}
455547
}
@@ -481,77 +573,45 @@ impl fmt::Debug for PythonTask {
481573
}
482574
}
483575

484-
/// An ['Actor'] used to monitor the result of an async endpoint. We use an
485-
/// actor so that:
486-
/// - Actually waiting on the async endpoint can happen concurrently with other endpoints.
487-
/// - Any uncaught errors in the async endpoint will get propagated as a supervision event.
488-
#[derive(Debug)]
489-
struct AsyncEndpointTask {}
490-
491-
/// An invocation of an async endpoint on a [`PythonActor`].
492-
#[derive(Handler, HandleClient, Debug)]
493-
enum AsyncEndpointInvocation {
494-
Run(PythonTask, oneshot::Receiver<PyObject>),
495-
}
496-
497-
#[async_trait]
498-
impl Actor for AsyncEndpointTask {
499-
type Params = ();
500-
501-
async fn new(_params: Self::Params) -> anyhow::Result<Self> {
502-
Ok(Self {})
503-
}
504-
}
505-
506-
#[async_trait]
507-
#[forward(AsyncEndpointInvocation)]
508-
impl AsyncEndpointInvocationHandler for AsyncEndpointTask {
509-
async fn run(
510-
&mut self,
511-
this: &Instance<Self>,
512-
task: PythonTask,
513-
side_channel: oneshot::Receiver<PyObject>,
514-
) -> anyhow::Result<()> {
515-
// Drive our PythonTask to completion, but listen on the side channel
516-
// and raise an error if we hear anything there.
517-
518-
let err_or_never = async {
519-
// The side channel will resolve with a value if a panic occured during
520-
// processing of the async endpoint, see [Panics in async endpoints].
521-
match side_channel.await {
522-
Ok(value) => Python::with_gil(|py| -> Result<(), SerializablePyErr> {
523-
let err: PyErr = value
524-
.downcast_bound::<PyBaseException>(py)
525-
.unwrap()
526-
.clone()
527-
.into();
528-
Err(SerializablePyErr::from(py, &err))
529-
}),
530-
// An Err means that the sender has been dropped without sending.
531-
// That's okay, it just means that the Python task has completed.
532-
// In that case, just never resolve this future. We expect the other
533-
// branch of the select to finish eventually.
534-
Err(_) => pending().await,
535-
}
536-
};
537-
let future = task.take().await;
538-
let result: Result<(), SerializablePyErr> = tokio::select! {
539-
result = future => {
540-
match result {
541-
Ok(_) => Ok(()),
542-
Err(e) => Err(e.into()),
543-
}
544-
},
545-
result = err_or_never => {
546-
result
576+
async fn handle_async_endpoint_panic(
577+
panic_sender: UnboundedSender<anyhow::Result<()>>,
578+
task: PythonTask,
579+
side_channel: oneshot::Receiver<PyObject>,
580+
) {
581+
let err_or_never = async {
582+
// The side channel will resolve with a value if a panic occured during
583+
// processing of the async endpoint, see [Panics in async endpoints].
584+
match side_channel.await {
585+
Ok(value) => Python::with_gil(|py| -> anyhow::Result<()> {
586+
let err: PyErr = value
587+
.downcast_bound::<PyBaseException>(py)
588+
.unwrap()
589+
.clone()
590+
.into();
591+
Err(err.into())
592+
}),
593+
// An Err means that the sender has been dropped without sending.
594+
// That's okay, it just means that the Python task has completed.
595+
// In that case, just never resolve this future. We expect the other
596+
// branch of the select to finish eventually.
597+
Err(_) => pending().await,
598+
}
599+
};
600+
let future = task.take().await;
601+
let result: anyhow::Result<()> = tokio::select! {
602+
result = future => {
603+
match result {
604+
Ok(_) => Ok(()),
605+
Err(e) => Err(e.into()),
547606
}
548-
};
549-
result?;
550-
551-
// Stop this actor now that its job is done.
552-
this.stop()?;
553-
Ok(())
554-
}
607+
},
608+
result = err_or_never => {
609+
result
610+
}
611+
};
612+
panic_sender
613+
.send(result)
614+
.expect("Unable to send panic message");
555615
}
556616

557617
pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {

0 commit comments

Comments
 (0)