Skip to content

Graceful shutdown for proc_mesh #352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions hyperactor/src/proc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,11 +517,9 @@ impl Proc {
.next()
}

// Iterating over a proc's root actors signaling each to stop.
// Return the root actor IDs and status observers.
async fn destroy(
&mut self,
) -> Result<HashMap<ActorId, watch::Receiver<ActorStatus>>, anyhow::Error> {
/// Iterating over a proc's root actors signaling each to stop.
/// Return the root actor IDs and status observers.
pub fn destroy(&self) -> Result<HashMap<ActorId, watch::Receiver<ActorStatus>>, anyhow::Error> {
tracing::debug!("{}: proc stopping", self.proc_id());

let mut statuses = HashMap::new();
Expand Down Expand Up @@ -558,7 +556,7 @@ impl Proc {
timeout: Duration,
skip_waiting: Option<&ActorId>,
) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
let mut statuses = self.destroy().await?;
let mut statuses = self.destroy()?;
let waits: Vec<_> = statuses
.iter_mut()
.filter(|(actor_id, _)| Some(*actor_id) != skip_waiting)
Expand Down
38 changes: 38 additions & 0 deletions hyperactor_mesh/src/proc_mesh/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@ impl ProcMesh {
/// An event stream of proc events. Each ProcMesh can produce only one such
/// stream, returning None after the first call.
pub fn events(&mut self) -> Option<ProcEvents> {
let (stop_alloc_tx, stop_alloc_rx) = tokio::sync::mpsc::unbounded_channel::<()>();

self.event_state.take().map(|event_state| ProcEvents {
event_state,
ranks: self
Expand All @@ -422,6 +424,8 @@ impl ProcMesh {
.enumerate()
.map(|(rank, (proc_id, _))| (proc_id.clone(), rank))
.collect(),
stop_alloc_tx,
stop_alloc_rx,
})
}
pub fn shape(&self) -> &Shape {
Expand Down Expand Up @@ -457,6 +461,8 @@ impl fmt::Display for ProcEvent {
pub struct ProcEvents {
event_state: EventState,
ranks: HashMap<ProcId, usize>,
stop_alloc_tx: tokio::sync::mpsc::UnboundedSender<()>,
stop_alloc_rx: tokio::sync::mpsc::UnboundedReceiver<()>,
}

impl ProcEvents {
Expand Down Expand Up @@ -491,9 +497,19 @@ impl ProcEvents {
};
break Some(ProcEvent::Crashed(*rank, actor_status.to_string()))
}
Some(_) = self.stop_alloc_rx.recv() => {
if let Err(err) = self.event_state.alloc.stop_and_wait().await {
tracing::error!("failed to stop alloc: {}", err);
}
break None;
}
}
}
}

pub fn stop_alloc_tx(&self) -> &tokio::sync::mpsc::UnboundedSender<()> {
&self.stop_alloc_tx
}
}

/// Spawns from shared ([`Arc`]) proc meshes, providing [`ActorMesh`]es with
Expand Down Expand Up @@ -697,4 +713,26 @@ mod tests {

assert!(events.next().await.is_none());
}

#[tracing_test::traced_test]
#[tokio::test]
async fn test_proc_mesh_stop() {
let alloc_spec = AllocSpec {
shape: shape! { replica = 4 },
constraints: Default::default(),
};
let alloc = LocalAllocator.allocate(alloc_spec).await.unwrap();
let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap();

let _ = proc_mesh.spawn::<TestActor>("foo", &()).await.unwrap();
let _ = proc_mesh.spawn::<TestActor>("bar", &()).await.unwrap();

let mut proc_state = proc_mesh.events().unwrap();
let stop_sender = proc_state.stop_alloc_tx();
stop_sender.send(()).unwrap();

while (proc_state.next().await).is_some() {}

assert!(logs_contain("4 actors stopped"));
}
}
40 changes: 35 additions & 5 deletions monarch_hyperactor/src/proc_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ use crate::shape::PyShape;
pub struct PyProcMesh {
pub inner: Arc<ProcMesh>,
keepalive: Keepalive,
stop_signal_sender: tokio::sync::mpsc::UnboundedSender<()>,
stop_observer: tokio::sync::watch::Receiver<bool>,
}
fn allocate_proc_mesh<'py>(py: Python<'py>, alloc: &PyAlloc) -> PyResult<Bound<'py, PyAny>> {
let alloc = match alloc.take() {
Expand Down Expand Up @@ -75,19 +77,25 @@ impl PyProcMesh {
/// Create a new [`PyProcMesh`] with a monitor that crashes the
/// process on any proc failure.
fn monitored(mut proc_mesh: ProcMesh, world_id: WorldId) -> Self {
let monitor = tokio::spawn(Self::monitor_proc_mesh(
proc_mesh.events().unwrap(),
world_id,
));
let events = proc_mesh.events().unwrap();
let stop_signal_sender = events.stop_alloc_tx().clone();
let (stopped_sender, stop_observer) = tokio::sync::watch::channel(false);
let monitor = tokio::spawn(Self::monitor_proc_mesh(events, world_id, stopped_sender));
Self {
inner: Arc::new(proc_mesh),
keepalive: Keepalive::new(monitor),
stop_signal_sender,
stop_observer,
}
}

/// Monitor the proc mesh for crashes. If a proc crashes, we print the reason
/// to stderr and exit with code 1.
async fn monitor_proc_mesh(mut events: ProcEvents, world_id: WorldId) {
async fn monitor_proc_mesh(
mut events: ProcEvents,
world_id: WorldId,
stopped_sender: tokio::sync::watch::Sender<bool>,
) {
while let Some(event) = events.next().await {
match event {
// A graceful stop should not be cause for alarm, but
Expand All @@ -99,6 +107,7 @@ impl PyProcMesh {
}
}
}
let _ = stopped_sender.send(true);
}
}

Expand Down Expand Up @@ -173,6 +182,27 @@ impl PyProcMesh {
}
}

fn stop(&mut self) -> PyResult<()> {
self.stop_signal_sender.send(()).map_err(|err| {
PyException::new_err(format!("Failed to send stop signal to alloc: {}", err))
})?;
self.inner.client_proc().destroy().map_err(|err| {
PyException::new_err(format!("Failed to destroy client proc: {}", err))
})?;
Ok(())
}

fn wait_for_stop<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let mut stop_observer = self.stop_observer.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
stop_observer
.wait_for(|stopped| *stopped)
.await
.map_err(|err| PyException::new_err(format!("Failed to wait for stop: {}", err)))?;
Ok(())
})
}

fn __repr__(&self) -> PyResult<String> {
Ok(format!("<ProcMesh {}>", self.inner))
}
Expand Down
16 changes: 16 additions & 0 deletions python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,20 @@ class ProcMesh:
"""
...

def stop(self) -> None:
"""
Signals the `alloc` to stop all `Proc`s and the
`client`'s `Proc` to stop.
This returns immediately after the signal is sent.

Call `await wait_for_stop()` to wait until all the `Proc`s have completed stopping.
"""
...

async def wait_for_stop(self) -> None:
"""
Wait for all `Proc`s in the `alloc` and the `client`'s `Proc` to stop.
"""
...

def __repr__(self) -> str: ...
7 changes: 7 additions & 0 deletions python/monarch/proc_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Any,
cast,
Dict,
Generator,
List,
Optional,
Sequence,
Expand Down Expand Up @@ -214,6 +215,12 @@ async def sync_workspace(self) -> None:
)
await self._rsync_mesh_client.sync_workspace()

def stop(self) -> None:
self._proc_mesh.stop()

def __await__(self) -> Generator[None, None, None]:
return self._proc_mesh.wait_for_stop().__await__()


async def local_proc_mesh_nonblocking(
*, gpus: Optional[int] = None, hosts: int = 1
Expand Down
Loading