Skip to content

Commit e7dbdbe

Browse files
committed
Initial basic cmsg support for unix
Fixes #313
1 parent f9c1aef commit e7dbdbe

File tree

6 files changed

+540
-21
lines changed

6 files changed

+540
-21
lines changed

src/cmsg.rs

Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
use crate::sys;
2+
use std::borrow::Borrow;
3+
use std::convert::TryInto as _;
4+
use std::io::IoSlice;
5+
use std::iter::FromIterator;
6+
7+
#[derive(Debug, Clone)]
8+
struct MsgHdrWalker<B> {
9+
buffer: B,
10+
position: Option<usize>,
11+
}
12+
13+
impl<B: AsRef<[u8]>> MsgHdrWalker<B> {
14+
fn next_ptr(&mut self) -> Option<*const libc::cmsghdr> {
15+
// Build a msghdr so we can use the functionality in libc.
16+
let mut msghdr: libc::msghdr = unsafe { std::mem::zeroed() };
17+
let buffer = self.buffer.as_ref();
18+
// SAFETY: We're giving msghdr a mutable pointer to comply with the C
19+
// API. We'll only allow mutation of `cmsghdr`, however if `B` is
20+
// AsMut<[u8]>.
21+
msghdr.msg_control = buffer.as_ptr() as *mut _;
22+
msghdr.msg_controllen = buffer.len().try_into().expect("buffer is too long");
23+
24+
let nxt_hdr = if let Some(position) = self.position {
25+
if position >= buffer.len() {
26+
return None;
27+
}
28+
let cur_hdr = &buffer[position] as *const u8 as *const _;
29+
// Safety: msghdr is a valid pointer and cur_hdr is not null.
30+
unsafe { libc::CMSG_NXTHDR(&msghdr, cur_hdr) }
31+
} else {
32+
// Safety: msghdr is a valid pointer.
33+
unsafe { libc::CMSG_FIRSTHDR(&msghdr) }
34+
};
35+
36+
if nxt_hdr.is_null() {
37+
self.position = Some(buffer.len());
38+
return None;
39+
}
40+
41+
// SAFETY: nxt_hdr always points to data within the buffer, they must be
42+
// part of the same allocation.
43+
let distance = unsafe { (nxt_hdr as *const u8).offset_from(buffer.as_ptr()) };
44+
// nxt_hdr is always ahead of the buffer and not null if we're here,
45+
// meaning the distance is always positive.
46+
self.position = Some(distance.try_into().unwrap());
47+
Some(nxt_hdr)
48+
}
49+
50+
fn next(&mut self) -> Option<(&libc::cmsghdr, &[u8])> {
51+
self.next_ptr().map(|cmsghdr| {
52+
// SAFETY: cmsghdr is a valid pointer given to us by `next_ptr`.
53+
let data = unsafe { libc::CMSG_DATA(cmsghdr) };
54+
let cmsghdr = unsafe { &*cmsghdr };
55+
// SAFETY: data points to buffer and is controlled by control
56+
// message length.
57+
let data = unsafe {
58+
std::slice::from_raw_parts(
59+
data,
60+
(cmsghdr.cmsg_len as usize)
61+
.saturating_sub(std::mem::size_of::<libc::cmsghdr>()),
62+
)
63+
};
64+
(cmsghdr, data)
65+
})
66+
}
67+
}
68+
69+
impl<B: AsRef<[u8]> + AsMut<[u8]>> MsgHdrWalker<B> {
70+
fn next_mut(&mut self) -> Option<(&mut libc::cmsghdr, &mut [u8])> {
71+
match self.next_ptr() {
72+
Some(cmsghdr) => {
73+
// SAFETY: cmsghdr is a valid pointer given to us by `next_ptr`.
74+
let data = unsafe { libc::CMSG_DATA(cmsghdr) };
75+
// SAFETY: The mutable pointer is safe because we're not going to
76+
// vend any concurrent access to the same memory region and B is
77+
// AsMut<[u8]> guaranteeing we have exclusive access to the buffer.
78+
let cmsghdr = cmsghdr as *mut libc::cmsghdr;
79+
let cmsghdr = unsafe { &mut *cmsghdr };
80+
81+
// We'll always yield the entirety of the rest of the buffer.
82+
let distance = unsafe { data.offset_from(self.buffer.as_ref().as_ptr()) };
83+
// The data pointer is always part of the buffer, can't be before
84+
// it.
85+
let distance: usize = distance.try_into().unwrap();
86+
Some((cmsghdr, &mut self.buffer.as_mut()[distance..]))
87+
}
88+
None => None,
89+
}
90+
}
91+
}
92+
93+
/// A wrapper around a buffer that can be used to write ancillary control
94+
/// messages.
95+
#[derive(Debug)]
96+
pub struct CmsgWriter<B> {
97+
walker: MsgHdrWalker<B>,
98+
last_push: usize,
99+
}
100+
101+
impl<B: AsMut<[u8]> + AsRef<[u8]>> CmsgWriter<B> {
102+
/// Creates a new [`CmsgBuffer`] backed by the bytes in `buffer`.
103+
pub fn new(buffer: B) -> Self {
104+
Self {
105+
walker: MsgHdrWalker {
106+
buffer,
107+
position: None,
108+
},
109+
last_push: 0,
110+
}
111+
}
112+
113+
/// Pushes a new control message `m` to the buffer.
114+
///
115+
/// # Panics
116+
///
117+
/// Panics if the contained buffer does not have enough space to fit `m`.
118+
pub fn push(&mut self, m: &Cmsg) {
119+
let (cmsg_level, cmsg_type, size) = m.level_type_size();
120+
let (nxt_hdr, data) = self
121+
.walker
122+
.next_mut()
123+
.unwrap_or_else(|| panic!("can't fit message {:?}", m));
124+
// Safety: All values are passed by copy.
125+
let cmsg_len = unsafe { libc::CMSG_LEN(size) }.try_into().unwrap();
126+
*nxt_hdr = libc::cmsghdr {
127+
cmsg_len,
128+
cmsg_level,
129+
cmsg_type,
130+
};
131+
m.write(&mut data[..size as usize]);
132+
// Always store the space required for the last push because the walker
133+
// maintains its position cursor at the currently written option, we
134+
// must always add the space for the last control message when returning
135+
// the consolidated buffer.
136+
self.last_push = unsafe { libc::CMSG_SPACE(size) } as usize;
137+
}
138+
}
139+
140+
impl<B: AsMut<[u8]> + AsRef<[u8]>> Extend<Cmsg> for CmsgWriter<B> {
141+
fn extend<T: IntoIterator<Item = Cmsg>>(&mut self, iter: T) {
142+
for cmsg in iter {
143+
self.push(&cmsg)
144+
}
145+
}
146+
}
147+
148+
impl<C: Borrow<Cmsg>> FromIterator<C> for CmsgWriter<Vec<u8>> {
149+
fn from_iter<T: IntoIterator<Item = C>>(iter: T) -> Self {
150+
let mut buff = CmsgWriter::new(vec![]);
151+
for cmsg in iter {
152+
let cmsg = cmsg.borrow();
153+
buff.walker
154+
.buffer
155+
.resize(buff.walker.buffer.len() + cmsg.space(), 0);
156+
buff.push(&cmsg)
157+
}
158+
buff
159+
}
160+
}
161+
162+
impl<B: AsRef<[u8]>> CmsgWriter<B> {
163+
pub(crate) fn io_slice(&self) -> IoSlice<'_> {
164+
IoSlice::new(self.buffer())
165+
}
166+
167+
pub(crate) fn buffer(&self) -> &[u8] {
168+
if let Some(position) = self.walker.position {
169+
&self.walker.buffer.as_ref()[..position + self.last_push]
170+
} else {
171+
&[]
172+
}
173+
}
174+
}
175+
176+
/// An iterator over received control messages.
177+
#[derive(Debug, Clone)]
178+
pub struct CmsgIter<'a> {
179+
walker: MsgHdrWalker<&'a [u8]>,
180+
}
181+
182+
impl<'a> CmsgIter<'a> {
183+
pub(crate) fn new(buffer: &'a [u8]) -> Self {
184+
Self {
185+
walker: MsgHdrWalker {
186+
buffer,
187+
position: None,
188+
},
189+
}
190+
}
191+
}
192+
193+
impl<'a> Iterator for CmsgIter<'a> {
194+
type Item = Cmsg;
195+
196+
fn next(&mut self) -> Option<Self::Item> {
197+
self.walker.next().map(
198+
|(
199+
libc::cmsghdr {
200+
cmsg_len: _,
201+
cmsg_level,
202+
cmsg_type,
203+
},
204+
data,
205+
)| Cmsg::from_raw(*cmsg_level, *cmsg_type, data),
206+
)
207+
}
208+
}
209+
210+
/// An unknown control message.
211+
#[derive(Debug, Eq, PartialEq)]
212+
pub struct UnknownCmsg {
213+
cmsg_level: libc::c_int,
214+
cmsg_type: libc::c_int,
215+
}
216+
217+
/// Control messages.
218+
#[derive(Debug, Eq, PartialEq)]
219+
pub enum Cmsg {
220+
/// The `IP_TTL` control message.
221+
IpTtl(u8),
222+
/// The `IPV6_PKTINFO` control message.
223+
Ipv6PktInfo {
224+
/// The address the packet is destined to/received from. Equivalent to
225+
/// `in6_pktinfo.ipi6_addr`.
226+
addr: std::net::Ipv6Addr,
227+
/// The interface index the packet is destined to/received from.
228+
/// Equivalent to `in6_pktinfo.ipi6_ifindex`.
229+
ifindex: u32,
230+
},
231+
/// An unrecognized control message.
232+
Unknown(UnknownCmsg),
233+
}
234+
235+
impl Cmsg {
236+
/// Returns the amount of buffer space required to hold this option.
237+
pub fn space(&self) -> usize {
238+
let (_, _, size) = self.level_type_size();
239+
// Safety: All values are passed by copy.
240+
let size = unsafe { libc::CMSG_SPACE(size) };
241+
size as usize
242+
}
243+
244+
fn level_type_size(&self) -> (libc::c_int, libc::c_int, libc::c_uint) {
245+
match self {
246+
Cmsg::IpTtl(_) => (
247+
libc::IPPROTO_IP,
248+
libc::IP_TTL,
249+
// TTL is encoded as a u32.
250+
std::mem::size_of::<u32>() as libc::c_uint,
251+
),
252+
Cmsg::Ipv6PktInfo { .. } => (
253+
libc::IPPROTO_IPV6,
254+
libc::IPV6_PKTINFO,
255+
std::mem::size_of::<libc::in6_pktinfo>() as libc::c_uint,
256+
),
257+
Cmsg::Unknown(UnknownCmsg {
258+
cmsg_level,
259+
cmsg_type,
260+
}) => (*cmsg_level, *cmsg_type, 0),
261+
}
262+
}
263+
264+
fn write(&self, buffer: &mut [u8]) {
265+
match self {
266+
Cmsg::IpTtl(ttl) => {
267+
let value: u32 = (*ttl).into();
268+
let value = value.to_ne_bytes();
269+
(&mut buffer[..value.len()]).copy_from_slice(&value[..]);
270+
}
271+
Cmsg::Ipv6PktInfo { addr, ifindex } => {
272+
let pktinfo = libc::in6_pktinfo {
273+
ipi6_addr: sys::to_in6_addr(addr),
274+
ipi6_ifindex: *ifindex,
275+
};
276+
let size = std::mem::size_of::<libc::in6_pktinfo>();
277+
assert_eq!(buffer.len(), size);
278+
// Safety: `pktinfo` is valid for reads for its size in bytes.
279+
// `buffer` is valid for write for the same length, as
280+
// guaranteed by the assertion above. Copy unit is byte, so
281+
// alignment is okay. The two regions do not overlap.
282+
unsafe {
283+
std::ptr::copy_nonoverlapping(
284+
&pktinfo as *const libc::in6_pktinfo as *const _,
285+
buffer.as_mut_ptr(),
286+
size,
287+
)
288+
}
289+
}
290+
Cmsg::Unknown(_) => {
291+
// NOTE: We don't actually allow users of the public API
292+
// serialize unknown control messages, but we use this code path
293+
// for testing.
294+
}
295+
}
296+
}
297+
298+
fn from_raw(cmsg_level: libc::c_int, cmsg_type: libc::c_int, bytes: &[u8]) -> Self {
299+
match (cmsg_level, cmsg_type) {
300+
(libc::IPPROTO_IP, libc::IP_TTL) => {
301+
assert!(bytes.len() >= std::mem::size_of::<u32>(), "{:?}", bytes);
302+
Cmsg::IpTtl(bytes[0])
303+
}
304+
(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => {
305+
let mut pktinfo = unsafe { std::mem::zeroed::<libc::in6_pktinfo>() };
306+
let size = std::mem::size_of::<libc::in6_pktinfo>();
307+
assert!(bytes.len() >= size, "{:?}", bytes);
308+
// Safety: `pktinfo` is valid for writes for its size in bytes.
309+
// `buffer` is valid for read for the same length, as
310+
// guaranteed by the assertion above. Copy unit is byte, so
311+
// alignment is okay. The two regions do not overlap.
312+
unsafe {
313+
std::ptr::copy_nonoverlapping(
314+
bytes.as_ptr(),
315+
&mut pktinfo as *mut libc::in6_pktinfo as *mut _,
316+
size,
317+
)
318+
}
319+
Cmsg::Ipv6PktInfo {
320+
addr: sys::from_in6_addr(pktinfo.ipi6_addr),
321+
ifindex: pktinfo.ipi6_ifindex,
322+
}
323+
}
324+
(cmsg_level, cmsg_type) => Cmsg::Unknown(UnknownCmsg {
325+
cmsg_level,
326+
cmsg_type,
327+
}),
328+
}
329+
}
330+
}
331+
332+
#[cfg(test)]
333+
mod tests {
334+
use super::*;
335+
336+
#[test]
337+
fn ser_deser() {
338+
let cmsgs = [
339+
Cmsg::IpTtl(2),
340+
Cmsg::Ipv6PktInfo {
341+
addr: std::net::Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8),
342+
ifindex: 13,
343+
},
344+
Cmsg::Unknown(UnknownCmsg {
345+
cmsg_level: 12345678,
346+
cmsg_type: 87654321,
347+
}),
348+
];
349+
let buffer: CmsgWriter<_> = cmsgs.iter().collect();
350+
let deser = CmsgIter::new(buffer.buffer()).collect::<Vec<_>>();
351+
assert_eq!(&cmsgs[..], &deser[..]);
352+
}
353+
354+
#[test]
355+
#[should_panic]
356+
fn ser_insufficient_space_panics() {
357+
let mut buffer = CmsgWriter::new([0; 3]);
358+
buffer.push(&Cmsg::IpTtl(2));
359+
}
360+
361+
#[test]
362+
fn empty_deser() {
363+
assert_eq!(CmsgIter::new(&[]).next(), None);
364+
}
365+
}

src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ macro_rules! from {
115115
};
116116
}
117117

118+
#[cfg(unix)]
119+
mod cmsg;
118120
mod sockaddr;
119121
mod socket;
120122
mod sockref;
@@ -141,6 +143,9 @@ pub use sockref::SockRef;
141143
)))]
142144
pub use socket::InterfaceIndexOrAddress;
143145

146+
#[cfg(unix)]
147+
pub use cmsg::{Cmsg, CmsgIter, CmsgWriter};
148+
144149
/// Specification of the communication domain for a socket.
145150
///
146151
/// This is a newtype wrapper around an integer which provides a nicer API in

0 commit comments

Comments
 (0)