Skip to content

Commit cb25e43

Browse files
author
Shayne Fletcher
committed
[hyperactor_mesh]: comm: handle undeliverable messages in casts
Pull Request resolved: #361 provide `CommActor` with a custom `handle_undeliverable_message` implementation that routes delivery failures back to the original cast sender. ghstack-source-id: 293073439 Differential Revision: [D77398378](https://our.internmc.facebook.com/intern/diff/D77398378/)
1 parent 3ad9b5d commit cb25e43

File tree

6 files changed

+104
-22
lines changed

6 files changed

+104
-22
lines changed

hyperactor/src/mailbox.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,11 @@ impl MessageEnvelope {
248248
&self.dest
249249
}
250250

251+
/// The message headers.
252+
pub fn headers(&self) -> &Attrs {
253+
&self.headers
254+
}
255+
251256
/// Tells whether this is a signal message.
252257
pub fn is_signal(&self) -> bool {
253258
self.dest.index() == Signal::port()
@@ -2131,8 +2136,15 @@ impl MailboxSender for WeakMailboxRouter {
21312136
}
21322137
}
21332138

2134-
/// A serializable [`MailboxRouter`]. It keeps a serializable address book so that
2135-
/// the mailbox sender can be recovered.
2139+
/// A dynamic mailbox router that supports remote delivery.
2140+
///
2141+
/// `DialMailboxRouter` maintains a runtime address book mapping
2142+
/// references to `ChannelAddr`s. It holds a cache of active
2143+
/// connections and forwards messages to the appropriate
2144+
/// `MailboxClient`.
2145+
///
2146+
/// Messages sent to unknown destinations are routed to the `default`
2147+
/// sender, if present.
21362148
#[derive(Debug, Clone)]
21372149
pub struct DialMailboxRouter {
21382150
address_book: Arc<RwLock<BTreeMap<Reference, ChannelAddr>>>,

hyperactor_mesh/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ tracing-subscriber = { version = "0.3.19", features = ["chrono", "env-filter", "
5454
dir-diff = "0.3"
5555
maplit = "1.0"
5656
timed_test = { version = "0.0.0", path = "../timed_test" }
57-
tracing-test = { version = "0.2.3", features = ["no-env-filter"] }
5857

5958
[lints]
6059
rust = { unexpected_cfgs = { check-cfg = ["cfg(fbcode_build)"], level = "warn" } }

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,12 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> {
232232
for ref slice in sel {
233233
for rank in slice.iter() {
234234
let mut headers = Attrs::new();
235-
set_cast_info_on_headers(&mut headers, rank, self.shape().clone());
235+
set_cast_info_on_headers(
236+
&mut headers,
237+
rank,
238+
self.shape().clone(),
239+
self.proc_mesh.client().actor_id().clone(),
240+
);
236241
self.ranks[rank]
237242
.send_with_headers(self.proc_mesh.client(), headers, message.clone())
238243
.map_err(|err| CastError::MailboxSenderError(rank, err))?;
@@ -782,17 +787,17 @@ mod tests {
782787
let actor_mesh: RootActorMesh<TestActor> = mesh.spawn("test", &()).await.unwrap();
783788
let actor_ref = actor_mesh.get(0).unwrap();
784789
let mut headers = Attrs::new();
785-
set_cast_info_on_headers(&mut headers, 0, Shape::unity());
790+
set_cast_info_on_headers(&mut headers, 0, Shape::unity(), mesh.client().actor_id().clone());
786791
actor_ref.send_with_headers(mesh.client(), headers.clone(), GetRank(true, reply_port.clone())).unwrap();
787792
assert_eq!(0, reply_port_receiver.recv().await.unwrap());
788793

789-
set_cast_info_on_headers(&mut headers, 1, Shape::unity());
794+
set_cast_info_on_headers(&mut headers, 1, Shape::unity(), mesh.client().actor_id().clone());
790795
actor_ref.port()
791796
.send_with_headers(mesh.client(), headers.clone(), GetRank(true, reply_port.clone()))
792797
.unwrap();
793798
assert_eq!(1, reply_port_receiver.recv().await.unwrap());
794799

795-
set_cast_info_on_headers(&mut headers, 2, Shape::unity());
800+
set_cast_info_on_headers(&mut headers, 2, Shape::unity(), mesh.client().actor_id().clone());
796801
actor_ref.actor_id()
797802
.port_id(GetRank::port())
798803
.send_with_headers(
@@ -888,12 +893,8 @@ mod tests {
888893
);
889894
}
890895

891-
// The intent is to emulate the behaviors of the Python
892-
// interaction of T225230867 "process hangs when i send
893-
// messages to a dead actor".
894-
#[tracing_test::traced_test]
895896
#[tokio::test]
896-
async fn test_behaviors_on_actor_error() {
897+
async fn test_cast_failure() {
897898
use crate::alloc::ProcStopReason;
898899
use crate::proc_mesh::ProcEvent;
899900
use crate::sel;
@@ -908,6 +909,10 @@ mod tests {
908909

909910
let stop = alloc.stopper();
910911
let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
912+
let mut undeliverable_rx = mesh
913+
.client_undeliverable_receiver()
914+
.take()
915+
.expect("client_undeliverable_receiver should be available");
911916
let mut events = mesh.events().unwrap();
912917

913918
let actor_mesh = mesh
@@ -930,14 +935,22 @@ mod tests {
930935
ProcEvent::Crashed(0, reason) if reason.contains("intentional error!")
931936
);
932937

933-
// Uncomment this to cause an infinite hang.
934-
/*
935-
let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
936-
actor_mesh
937-
.cast(sel!(*), GetRank(false, reply_handle.bind()))
938-
.unwrap();
939-
let rank = reply_receiver.recv().await.unwrap();
940-
*/
938+
// Cast the message.
939+
let (reply_handle, _) = actor_mesh.open_port();
940+
actor_mesh
941+
.cast(sel!(*), GetRank(false, reply_handle.bind()))
942+
.unwrap();
943+
944+
// The message will be returned.
945+
let Undeliverable(msg) = undeliverable_rx.recv().await.unwrap();
946+
assert_eq!(
947+
msg.sender(),
948+
&ActorId(
949+
ProcId(actor_mesh.world_id().clone(), 0),
950+
"comm".to_owned(),
951+
0
952+
)
953+
);
941954

942955
// Stop the mesh.
943956
stop();

hyperactor_mesh/src/comm/mod.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
use crate::comm::multicast::CAST_ORIGINATING_SENDER;
910
pub mod multicast;
1011

1112
use std::cmp::Ordering;
@@ -20,9 +21,13 @@ use hyperactor::ActorRef;
2021
use hyperactor::Handler;
2122
use hyperactor::Instance;
2223
use hyperactor::Named;
24+
use hyperactor::PortRef;
2325
use hyperactor::WorldId;
2426
use hyperactor::attrs::Attrs;
2527
use hyperactor::data::Serialized;
28+
use hyperactor::mailbox::DeliveryError;
29+
use hyperactor::mailbox::Undeliverable;
30+
use hyperactor::mailbox::UndeliverableMessageError;
2631
use hyperactor::reference::UnboundPort;
2732
use ndslice::Slice;
2833
use ndslice::selection::routing::RoutingFrame;
@@ -156,6 +161,46 @@ impl Actor for CommActor {
156161
mode: Default::default(),
157162
})
158163
}
164+
165+
// This is an override of the default actor behavior.
166+
async fn handle_undeliverable_message(
167+
&mut self,
168+
this: &Instance<Self>,
169+
undelivered: hyperactor::mailbox::Undeliverable<hyperactor::mailbox::MessageEnvelope>,
170+
) -> Result<(), anyhow::Error> {
171+
let Undeliverable(mut message_envelope) = undelivered;
172+
173+
// 1. Case delivery failure at a "forwarding" step.
174+
if let Ok(ForwardMessage { message, .. }) =
175+
message_envelope.deserialized::<ForwardMessage>()
176+
{
177+
let sender = message.sender();
178+
let return_port = PortRef::attest_message_port(sender);
179+
return_port
180+
.send(this, Undeliverable(message_envelope.clone()))
181+
.map_err(|err| {
182+
message_envelope
183+
.try_set_error(DeliveryError::BrokenLink(format!("send failure: {err}")));
184+
UndeliverableMessageError::return_failure(&message_envelope)
185+
})?;
186+
return Ok(());
187+
}
188+
189+
// 2. Case delivery failure at a "deliver here" step.
190+
if let Some(sender) = message_envelope.headers().get(CAST_ORIGINATING_SENDER) {
191+
let return_port = PortRef::attest_message_port(sender);
192+
return_port
193+
.send(this, Undeliverable(message_envelope.clone()))
194+
.map_err(|err| {
195+
message_envelope
196+
.try_set_error(DeliveryError::BrokenLink(format!("send failure: {err}")));
197+
UndeliverableMessageError::return_failure(&message_envelope)
198+
})?;
199+
return Ok(());
200+
}
201+
202+
unreachable!()
203+
}
159204
}
160205

161206
impl CommActor {
@@ -203,6 +248,7 @@ impl CommActor {
203248
&mut headers,
204249
mode.self_rank(this.self_id()),
205250
message.shape().clone(),
251+
message.sender().clone(),
206252
);
207253
// TODO(pzhang) split reply ports so children can reply to this comm
208254
// actor instead of parent.

hyperactor_mesh/src/comm/multicast.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ impl CastMessageEnvelope {
9191
}
9292
}
9393

94+
pub(crate) fn sender(&self) -> &ActorId {
95+
&self.sender
96+
}
97+
9498
pub(crate) fn dest_port(&self) -> &DestinationPort {
9599
&self.dest_port
96100
}
@@ -179,11 +183,14 @@ declare_attrs! {
179183
/// Used inside headers to store the shape of the
180184
/// actor mesh that a message was cast to.
181185
attr CAST_SHAPE: Shape;
186+
/// Used inside headers to store the originating sender of a cast.
187+
pub attr CAST_ORIGINATING_SENDER: ActorId;
182188
}
183189

184-
pub fn set_cast_info_on_headers(headers: &mut Attrs, rank: usize, shape: Shape) {
190+
pub fn set_cast_info_on_headers(headers: &mut Attrs, rank: usize, shape: Shape, sender: ActorId) {
185191
headers.set(CAST_RANK, rank);
186192
headers.set(CAST_SHAPE, shape);
193+
headers.set(CAST_ORIGINATING_SENDER, sender);
187194
}
188195

189196
pub fn get_cast_info_from_headers(headers: &Attrs) -> Option<(usize, Shape)> {

monarch_hyperactor/src/mailbox.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,12 @@ impl PyMailbox {
133133
) -> PyResult<()> {
134134
let port_id = dest.inner.port_id(PythonMessage::port());
135135
let mut headers = Attrs::new();
136-
set_cast_info_on_headers(&mut headers, rank, shape.inner.clone());
136+
set_cast_info_on_headers(
137+
&mut headers,
138+
rank,
139+
shape.inner.clone(),
140+
self.inner.actor_id().clone(),
141+
);
137142
let message = Serialized::serialize(message).map_err(|err| {
138143
PyRuntimeError::new_err(format!(
139144
"failed to serialize message ({:?}) to Serialized: {}",

0 commit comments

Comments
 (0)