Skip to content

Add sendmmsg #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -697,3 +697,93 @@ impl<'name, 'bufs, 'control> fmt::Debug for MsgHdrMut<'name, 'bufs, 'control> {
"MsgHdrMut".fmt(fmt)
}
}

/// Wraps `mmsghdr` on Unix for a `sendmmsg(2)` system call.
///
/// Also see [`MsgHdr`] for the variant used by `sendmsg(2)`.
#[repr(transparent)]
#[cfg(any(target_os = "linux", target_os = "android",))]
pub struct MMsgHdr<'addr, 'bufs, 'control> {
inner: sys::mmsghdr,
#[allow(clippy::type_complexity)]
_lifetimes: PhantomData<(&'addr SockAddr, &'bufs IoSlice<'bufs>, &'control [u8])>,
}

#[cfg(any(target_os = "linux", target_os = "android",))]
impl<'addr, 'bufs, 'control> MMsgHdr<'addr, 'bufs, 'control> {
/// Create a new `MMsgHdr` from `MsgHdr` and with the `msg_len` set to zero.
pub fn new(msg: MsgHdr<'_, '_, '_>) -> Self {
Self {
inner: sys::mmsghdr {
msg_hdr: msg.inner,
msg_len: 0,
},
_lifetimes: PhantomData,
}
}

/// Number of bytes transmitted.
pub fn transmitted_bytes(&self) -> u32 {
self.inner.msg_len
}
}

#[cfg(any(target_os = "linux", target_os = "android",))]
impl<'addr, 'bufs, 'control> fmt::Debug for MMsgHdr<'addr, 'bufs, 'control> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&format!("MMsgHdr({})", self.transmitted_bytes()))
}
}

/// Wraps `mmsghdr` on Unix for a `recvmmsg(2)` system call.
///
/// Also see [`MsgHdrMut`] for the variant used by `recvmsg(2)`.
#[repr(transparent)]
#[cfg(any(target_os = "linux", target_os = "android",))]
pub struct MMsgHdrMut<'addr, 'bufs, 'control> {
inner: sys::mmsghdr,
#[allow(clippy::type_complexity)]
_lifetimes: PhantomData<(
&'addr mut SockAddr,
&'bufs mut MaybeUninitSlice<'bufs>,
&'control mut [u8],
)>,
}

#[cfg(any(target_os = "linux", target_os = "android",))]
impl<'addr, 'bufs, 'control> MMsgHdrMut<'addr, 'bufs, 'control> {
/// Create a new `MMsgHdrMut` from `MsgHdrMut` and with the `msg_len` set to zero.
pub fn new(msg: MsgHdrMut<'_, '_, '_>) -> Self {
Self {
inner: sys::mmsghdr {
msg_hdr: msg.inner,
msg_len: 0,
},
_lifetimes: PhantomData,
}
}

/// Number of received bytes.
pub fn recieved_bytes(&self) -> u32 {
self.inner.msg_len
}

/// Returns the flags of the message.
pub fn flags(&self) -> RecvFlags {
sys::msghdr_flags(&self.inner.msg_hdr)
}

/// Gets the length of the control buffer.
///
/// Can be used to determine how much, if any, of the control buffer was filled by `recvmsg`.
pub fn control_len(&self) -> usize {
sys::msghdr_control_len(&self.inner.msg_hdr)
}
}

#[cfg(any(target_os = "linux", target_os = "android",))]
impl<'addr, 'bufs, 'control> fmt::Debug for MMsgHdrMut<'addr, 'bufs, 'control> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&format!("MMsgHdrMut({})", self.recieved_bytes()))
}
}
22 changes: 22 additions & 0 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type};
#[cfg(not(target_os = "redox"))]
use crate::{MaybeUninitSlice, MsgHdr, RecvFlags};

#[cfg(all(feature = "all", any(target_os = "android", target_os = "linux",)))]
use crate::{MMsgHdr, MMsgHdrMut};

/// Owned wrapper around a system socket.
///
/// This type simply wraps an instance of a file descriptor (`c_int`) on Unix
Expand Down Expand Up @@ -634,6 +637,18 @@ impl Socket {
sys::recvmsg(self.as_raw(), msg, flags)
}

/// Receive multiple messages on the socket using a single system call.
#[doc = man_links!(unix: recvmmsg(2))]
#[cfg(all(feature = "all", any(target_os = "android", target_os = "linux",)))]
pub fn recvmmsg(
&self,
msgvec: &mut [MMsgHdrMut<'_, '_, '_>],
flags: c_int,
timeout: Option<Duration>,
) -> io::Result<usize> {
sys::recvmmsg(self.as_raw(), msgvec, flags, timeout)
}

/// Sends data on the socket to a connected peer.
///
/// This is typically used on TCP sockets or datagram sockets which have
Expand Down Expand Up @@ -735,6 +750,13 @@ impl Socket {
pub fn sendmsg(&self, msg: &MsgHdr<'_, '_, '_>, flags: sys::c_int) -> io::Result<usize> {
sys::sendmsg(self.as_raw(), msg, flags)
}

/// Send multiple messages on the socket using a single system call.
#[doc = man_links!(unix: sendmmsg(2))]
#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android",)))]
pub fn sendmmsg(&self, msgvec: &mut [MMsgHdr<'_, '_, '_>], flags: c_int) -> io::Result<usize> {
sys::sendmmsg(self.as_raw(), msgvec, flags)
}
}

/// Set `SOCK_CLOEXEC` and `NO_HANDLE_INHERIT` on the `ty`pe on platforms that
Expand Down
77 changes: 77 additions & 0 deletions src/sys/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type};
#[cfg(not(target_os = "redox"))]
use crate::{MsgHdr, MsgHdrMut, RecvFlags};

#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android",)))]
use crate::{MMsgHdr, MMsgHdrMut};

// Used in `MMsgHdr`.
#[cfg(any(target_os = "linux", target_os = "android",))]
pub(crate) use libc::mmsghdr;

pub(crate) use libc::c_int;

// Used in `Domain`.
Expand Down Expand Up @@ -1076,6 +1083,35 @@ pub(crate) fn recvmsg(
syscall!(recvmsg(fd, &mut msg.inner, flags)).map(|n| n as usize)
}

#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android",)))]
/// This emits all the messages in a single syscall
pub(crate) fn recvmmsg(
fd: Socket,
msgvec: &mut [MMsgHdrMut<'_, '_, '_>],
flags: c_int,
timeout: Option<Duration>,
) -> io::Result<usize> {
if cfg!(target_env = "musl") {
debug_assert!(flags >= 0, "socket flags must be non-negative");
}

let mut timeout = timeout.map(into_timespec);
let timeout_ptr = timeout
.as_mut()
.map(|t| t as *mut _)
.unwrap_or(ptr::null_mut());

syscall!(recvmmsg(
fd,
// SAFETY: `MMsgHdrMut` is `#[repr(transparent)]` and wraps a `libc::mmsghdr`
msgvec.as_mut_ptr() as *mut mmsghdr,
msgvec.len() as _,
flags as _,
timeout_ptr
))
.map(|n| n as usize)
}

pub(crate) fn send(fd: Socket, buf: &[u8], flags: c_int) -> io::Result<usize> {
syscall!(send(
fd,
Expand Down Expand Up @@ -1120,6 +1156,27 @@ pub(crate) fn sendmsg(fd: Socket, msg: &MsgHdr<'_, '_, '_>, flags: c_int) -> io:
syscall!(sendmsg(fd, &msg.inner, flags)).map(|n| n as usize)
}

#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android",)))]
/// This transmits all the messages in a single syscall
pub(crate) fn sendmmsg(
fd: Socket,
msgvec: &mut [MMsgHdr<'_, '_, '_>],
flags: c_int,
) -> io::Result<usize> {
if cfg!(target_env = "musl") {
debug_assert!(flags >= 0, "socket flags must be non-negative");
}

syscall!(sendmmsg(
fd,
// SAFETY: `MMsgHdr` is `#[repr(transparent)]` and wraps a `libc::mmsghdr`
msgvec.as_mut_ptr() as *mut mmsghdr,
msgvec.len() as _,
flags as _
))
.map(|n| n as usize)
}

/// Wrapper around `getsockopt` to deal with platform specific timeouts.
pub(crate) fn timeout_opt(fd: Socket, opt: c_int, val: c_int) -> io::Result<Option<Duration>> {
unsafe { getsockopt(fd, opt, val).map(from_timeval) }
Expand Down Expand Up @@ -1161,6 +1218,26 @@ fn into_timeval(duration: Option<Duration>) -> libc::timeval {
}
}

#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android",)))]
fn into_timespec(duration: Duration) -> libc::timespec {
// https://github.com/rust-lang/libc/issues/1848
#[cfg_attr(target_env = "musl", allow(deprecated))]
libc::timespec {
tv_sec: min(duration.as_secs(), libc::time_t::MAX as u64) as libc::time_t,
#[cfg(any(
all(target_arch = "x86_64", target_pointer_width = "32"),
target_pointer_width = "64"
))]
tv_nsec: duration.subsec_nanos() as i64,

#[cfg(not(any(
all(target_arch = "x86_64", target_pointer_width = "32"),
target_pointer_width = "64"
)))]
tv_nsec: duration.subsec_nanos().clamp(0, i32::MAX as u32) as i32,
}
}

#[cfg(all(
feature = "all",
not(any(target_os = "haiku", target_os = "openbsd", target_os = "vita"))
Expand Down
41 changes: 41 additions & 0 deletions tests/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,47 @@ fn sendmsg() {
assert_eq!(received, DATA.len());
}

#[test]
#[cfg(all(feature = "all", any(target_os = "linux", target_os = "android")))]
fn send_and_recv_batched_msgs() {
let (socket_a, socket_b) = udp_pair_unconnected();

const DATA: &[u8] = b"Hello, World!";

let addr_b = socket_b.local_addr().unwrap();
let mut batched_msgs = Vec::new();
let mut recv_batched_msgs = Vec::new();
for _ in 0..10 {
let bufs = &[IoSlice::new(DATA)];
batched_msgs.push(socket2::MMsgHdr::new(
socket2::MsgHdr::new().with_addr(&addr_b).with_buffers(bufs),
));

let mut buf = [MaybeUninit::new(0u8); DATA.len()];
let recv_bufs = MaybeUninitSlice::new(buf.as_mut_slice());
recv_batched_msgs.push(socket2::MMsgHdrMut::new(
socket2::MsgHdrMut::new().with_buffers(&mut [recv_bufs]),
));
}

let sent = socket_a.sendmmsg(batched_msgs.as_mut_slice(), 0).unwrap();

let mut sent_data = 0;
// Calculate transmitted length
for msg in batched_msgs.iter().take(sent) {
sent_data += msg.transmitted_bytes()
}
assert!(sent_data as usize == 10 * DATA.len());

let recvd = socket_b
.recvmmsg(recv_batched_msgs.as_mut_slice(), 0, None)
.unwrap();

assert!(recvd == sent);

assert!(recv_batched_msgs[0].recieved_bytes() == DATA.len().try_into().unwrap());
}

#[test]
#[cfg(not(any(target_os = "redox", target_os = "vita")))]
fn recv_vectored_truncated() {
Expand Down