@@ -8,10 +8,20 @@ use std::{
88 sync:: atomic:: AtomicBool ,
99} ;
1010
11+ //// Active message protocol.
12+ /// Active message protocol is a mechanism for sending and receiving messages
13+ /// between processes in a distributed system.
14+ /// It allows a process to send a message to another process, which can then
15+ /// handle the message and perform some action based on its contents.
16+ /// Active messages are typically used in high-performance computing (HPC)
17+ /// applications, where low-latency communication is critical.
1118#[ derive( Debug , PartialEq , Eq ) ]
1219pub enum AmDataType {
20+ /// Eager message
1321 Eager ,
22+ /// Data message
1423 Data ,
24+ /// Rendezvous message
1525 Rndv ,
1626}
1727
@@ -88,6 +98,7 @@ impl RawMsg {
8898 }
8999}
90100
101+ /// Active message message.
91102pub struct AmMsg < ' a > {
92103 worker : & ' a Worker ,
93104 msg : RawMsg ,
@@ -98,35 +109,44 @@ impl<'a> AmMsg<'a> {
98109 AmMsg { worker, msg }
99110 }
100111
112+ /// Get the message ID.
101113 #[ inline]
102114 pub fn id ( & self ) -> u16 {
103115 self . msg . id
104116 }
105117
118+ /// Get the message header.
106119 #[ inline]
107120 pub fn header ( & self ) -> & [ u8 ] {
108121 self . msg . header . as_ref ( )
109122 }
110123
124+ /// Get the message header length.
111125 #[ inline]
112126 pub fn contains_data ( & self ) -> bool {
113127 self . data_type ( ) . is_some ( )
114128 }
115129
130+ /// Get the message data type.
116131 pub fn data_type ( & self ) -> Option < AmDataType > {
117132 self . msg . data . as_ref ( ) . map ( |data| data. data_type ( ) )
118133 }
119134
135+ /// Get the message data.
136+ /// Returns `None` if the message doesn't contain data.
120137 #[ inline]
121138 pub fn get_data ( & self ) -> Option < & [ u8 ] > {
122139 self . msg . data . as_ref ( ) . and_then ( |data| data. data ( ) )
123140 }
124141
142+ /// Get the message data length.
143+ /// Returns `0` if the message doesn't contain data.
125144 #[ inline]
126145 pub fn data_len ( & self ) -> usize {
127146 self . msg . data . as_ref ( ) . map_or ( 0 , |data| data. len ( ) )
128147 }
129148
149+ /// Receive the message data.
130150 pub async fn recv_data ( & mut self ) -> Result < Vec < u8 > , Error > {
131151 match self . msg . data . take ( ) {
132152 None => Ok ( Vec :: new ( ) ) ,
@@ -144,6 +164,12 @@ impl<'a> AmMsg<'a> {
144164 }
145165 }
146166
167+ /// Receive the message data.
168+ /// Returns `0` if the message doesn't contain data.
169+ /// Returns the number of bytes received.
170+ /// # Safety
171+ /// User needs to ensure that the buffer is large enough to hold the data.
172+ /// Otherwise, it will cause memory corruption.
147173 pub async fn recv_data_single ( & mut self , buf : & mut [ u8 ] ) -> Result < usize , Error > {
148174 if !self . contains_data ( ) {
149175 Ok ( 0 )
@@ -153,6 +179,7 @@ impl<'a> AmMsg<'a> {
153179 }
154180 }
155181
182+ /// Receive the message data.
156183 pub async fn recv_data_vectored ( & mut self , iov : & [ IoSliceMut < ' _ > ] ) -> Result < usize , Error > {
157184 let data = self . msg . data . take ( ) ;
158185 if let Some ( data) = data {
@@ -192,7 +219,7 @@ impl<'a> AmMsg<'a> {
192219 unsafe extern "C" fn callback (
193220 request : * mut c_void ,
194221 status : ucs_status_t ,
195- _length : u64 ,
222+ _length : usize ,
196223 _data : * mut c_void ,
197224 ) {
198225 // todo: handle error & fix real data length
@@ -255,6 +282,7 @@ impl<'a> AmMsg<'a> {
255282 }
256283 }
257284
285+ /// Check if the message needs a reply.
258286 #[ inline]
259287 pub fn need_reply ( & self ) -> bool {
260288 self . msg . attr & ( ucp_am_recv_attr_t:: UCP_AM_RECV_ATTR_FIELD_REPLY_EP as u64 ) != 0
@@ -309,6 +337,7 @@ impl<'a> Drop for AmMsg<'a> {
309337 }
310338}
311339
340+ /// Active message stream.
312341#[ derive( Clone ) ]
313342pub struct AmStream < ' a > {
314343 worker : & ' a Worker ,
@@ -383,9 +412,9 @@ impl Worker {
383412 unsafe extern "C" fn callback (
384413 arg : * mut c_void ,
385414 header : * const c_void ,
386- header_len : u64 ,
415+ header_len : usize ,
387416 data : * mut c_void ,
388- data_len : u64 ,
417+ data_len : usize ,
389418 param : * const ucp_am_recv_param_t ,
390419 ) -> ucs_status_t {
391420 let handler = & * ( arg as * const AmStreamInner ) ;
@@ -442,7 +471,9 @@ impl Worker {
442471 }
443472}
444473
474+ /// Active message endpoint.
445475impl Endpoint {
476+ /// Send active message.
446477 pub async fn am_send (
447478 & self ,
448479 id : u32 ,
@@ -456,6 +487,7 @@ impl Endpoint {
456487 . await
457488 }
458489
490+ /// Send active message.
459491 pub async fn am_send_vectorized (
460492 & self ,
461493 id : u32 ,
@@ -469,8 +501,11 @@ impl Endpoint {
469501 }
470502}
471503
504+ /// Active message protocol
472505pub enum AmProto {
506+ /// Eager protocol
473507 Eager ,
508+ /// Rendezvous protocol
474509 Rndv ,
475510}
476511
@@ -601,7 +636,7 @@ mod tests {
601636 header. as_slice( ) ,
602637 data. as_slice( ) ,
603638 true ,
604- Some ( AmProto :: Eager ) ,
639+ Some ( AmProto :: Rndv ) ,
605640 )
606641 . await ;
607642 assert!( result. is_ok( ) ) ;
@@ -627,7 +662,10 @@ mod tests {
627662 tokio:: join!(
628663 async {
629664 // send reply
630- let result = unsafe { msg. reply( 12 , & header, & data, false , None ) . await } ;
665+ let result = unsafe {
666+ msg. reply( 12 , & header, & data, false , Some ( AmProto :: Rndv ) )
667+ . await
668+ } ;
631669 assert!( result. is_ok( ) ) ;
632670 } ,
633671 async {
0 commit comments