Skip to content

Don't spawn a new actor for every python message #340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 138 additions & 78 deletions monarch_hyperactor/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@ use async_trait::async_trait;
use hyperactor::Actor;
use hyperactor::ActorHandle;
use hyperactor::ActorId;
use hyperactor::HandleClient;
use hyperactor::Handler;
use hyperactor::Instance;
use hyperactor::Named;
use hyperactor::forward;
use hyperactor::message::Bind;
use hyperactor::message::Bindings;
use hyperactor::message::IndexedErasedUnbound;
Expand All @@ -40,7 +38,10 @@ use serde::Deserialize;
use serde::Serialize;
use serde_bytes::ByteBuf;
use tokio::sync::Mutex;
use tokio::sync::mpsc::UnboundedReceiver;
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::oneshot;
use tracing::Instrument;

use crate::mailbox::EitherPortRef;
use crate::mailbox::PyMailbox;
Expand Down Expand Up @@ -262,6 +263,13 @@ impl PythonActorHandle {
}
}

#[derive(Debug)]
enum PanicWatcher {
ForwardTo(UnboundedReceiver<anyhow::Result<()>>),
HandlerActor(ActorHandle<PythonActorPanicWatcher>),
None,
}

/// An actor for which message handlers are implemented in Python.
#[derive(Debug)]
#[hyperactor::export(
Expand All @@ -280,6 +288,8 @@ pub(super) struct PythonActor {
/// Stores a reference to the Python event loop to run Python coroutines on.
/// We give each PythonActor its own even loop in its own thread.
task_locals: pyo3_async_runtimes::TaskLocals,
panic_watcher: PanicWatcher,
panic_sender: UnboundedSender<anyhow::Result<()>>,
}

#[async_trait]
Expand Down Expand Up @@ -312,10 +322,29 @@ impl Actor for PythonActor {
});
rx.recv().unwrap()
});

Ok(Self { actor, task_locals })
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
Ok(Self {
actor,
task_locals,
panic_watcher: PanicWatcher::ForwardTo(rx),
panic_sender: tx,
})
})?)
}

async fn init(&mut self, this: &Instance<Self>) -> anyhow::Result<()> {
self.panic_watcher = PanicWatcher::HandlerActor(
match std::mem::replace(&mut self.panic_watcher, PanicWatcher::None) {
PanicWatcher::ForwardTo(chan) => PythonActorPanicWatcher::spawn(this, chan).await?,
PanicWatcher::HandlerActor(actor) => {
tracing::warn!("init called twice");
actor
}
PanicWatcher::None => unreachable!("init called while in an invalid state"),
},
);
Ok(())
}
}

// [Panics in async endpoints]
Expand Down Expand Up @@ -365,6 +394,49 @@ impl PanicFlag {
}
}

#[derive(Debug)]
struct PythonActorPanicWatcher {
panic_rx: UnboundedReceiver<anyhow::Result<()>>,
}

#[async_trait]
impl Actor for PythonActorPanicWatcher {
type Params = UnboundedReceiver<anyhow::Result<()>>;

async fn new(panic_rx: UnboundedReceiver<anyhow::Result<()>>) -> Result<Self, anyhow::Error> {
Ok(Self { panic_rx })
}

async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
this.handle().send(HandlePanic {})?;
Ok(())
}
}

#[derive(Debug)]
struct HandlePanic {}

#[async_trait]
impl Handler<HandlePanic> for PythonActorPanicWatcher {
async fn handle(&mut self, this: &Instance<Self>, _message: HandlePanic) -> anyhow::Result<()> {
match self.panic_rx.recv().await {
Some(Ok(_)) => {
// async endpoint executed successfully.
// run again
this.handle().send(HandlePanic {})?;
}
Some(Err(err)) => {
tracing::error!("caught error in async endpoint {}", err);
return Err(err);
}
None => {
tracing::warn!("panic forwarding channel was closed unexpectidly")
}
}
Ok(())
}
}

#[async_trait]
impl Handler<PythonMessage> for PythonActor {
async fn handle(
Expand Down Expand Up @@ -400,8 +472,18 @@ impl Handler<PythonMessage> for PythonActor {
})?;

// Spawn a child actor to await the Python handler method.
let handler = AsyncEndpointTask::spawn(this, ()).await?;
handler.run(this, PythonTask::new(future), receiver).await?;
tokio::spawn(
handle_async_endpoint_panic(
self.panic_sender.clone(),
PythonTask::new(future),
receiver,
)
.instrument(
tracing::info_span!("py_panic_handler")
.follows_from(tracing::Span::current().id())
.clone(),
),
);
Ok(())
}
}
Expand Down Expand Up @@ -448,8 +530,18 @@ impl Handler<Cast<PythonMessage>> for PythonActor {
})?;

// Spawn a child actor to await the Python handler method.
let handler = AsyncEndpointTask::spawn(this, ()).await?;
handler.run(this, PythonTask::new(future), receiver).await?;
tokio::spawn(
handle_async_endpoint_panic(
self.panic_sender.clone(),
PythonTask::new(future),
receiver,
)
.instrument(
tracing::info_span!("py_panic_handler")
.follows_from(tracing::Span::current().id())
.clone(),
),
);
Ok(())
}
}
Expand Down Expand Up @@ -481,77 +573,45 @@ impl fmt::Debug for PythonTask {
}
}

/// An ['Actor'] used to monitor the result of an async endpoint. We use an
/// actor so that:
/// - Actually waiting on the async endpoint can happen concurrently with other endpoints.
/// - Any uncaught errors in the async endpoint will get propagated as a supervision event.
#[derive(Debug)]
struct AsyncEndpointTask {}

/// An invocation of an async endpoint on a [`PythonActor`].
#[derive(Handler, HandleClient, Debug)]
enum AsyncEndpointInvocation {
Run(PythonTask, oneshot::Receiver<PyObject>),
}

#[async_trait]
impl Actor for AsyncEndpointTask {
type Params = ();

async fn new(_params: Self::Params) -> anyhow::Result<Self> {
Ok(Self {})
}
}

#[async_trait]
#[forward(AsyncEndpointInvocation)]
impl AsyncEndpointInvocationHandler for AsyncEndpointTask {
async fn run(
&mut self,
this: &Instance<Self>,
task: PythonTask,
side_channel: oneshot::Receiver<PyObject>,
) -> anyhow::Result<()> {
// Drive our PythonTask to completion, but listen on the side channel
// and raise an error if we hear anything there.

let err_or_never = async {
// The side channel will resolve with a value if a panic occured during
// processing of the async endpoint, see [Panics in async endpoints].
match side_channel.await {
Ok(value) => Python::with_gil(|py| -> Result<(), SerializablePyErr> {
let err: PyErr = value
.downcast_bound::<PyBaseException>(py)
.unwrap()
.clone()
.into();
Err(SerializablePyErr::from(py, &err))
}),
// An Err means that the sender has been dropped without sending.
// That's okay, it just means that the Python task has completed.
// In that case, just never resolve this future. We expect the other
// branch of the select to finish eventually.
Err(_) => pending().await,
}
};
let future = task.take().await;
let result: Result<(), SerializablePyErr> = tokio::select! {
result = future => {
match result {
Ok(_) => Ok(()),
Err(e) => Err(e.into()),
}
},
result = err_or_never => {
result
async fn handle_async_endpoint_panic(
panic_sender: UnboundedSender<anyhow::Result<()>>,
task: PythonTask,
side_channel: oneshot::Receiver<PyObject>,
) {
let err_or_never = async {
// The side channel will resolve with a value if a panic occured during
// processing of the async endpoint, see [Panics in async endpoints].
match side_channel.await {
Ok(value) => Python::with_gil(|py| -> anyhow::Result<()> {
let err: PyErr = value
.downcast_bound::<PyBaseException>(py)
.unwrap()
.clone()
.into();
Err(err.into())
}),
// An Err means that the sender has been dropped without sending.
// That's okay, it just means that the Python task has completed.
// In that case, just never resolve this future. We expect the other
// branch of the select to finish eventually.
Err(_) => pending().await,
}
};
let future = task.take().await;
let result: anyhow::Result<()> = tokio::select! {
result = future => {
match result {
Ok(_) => Ok(()),
Err(e) => Err(e.into()),
}
};
result?;

// Stop this actor now that its job is done.
this.stop()?;
Ok(())
}
},
result = err_or_never => {
result
}
};
panic_sender
.send(result)
.expect("Unable to send panic message");
}

pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
Expand Down
Loading