Skip to content

Commit 7e16e32

Browse files
feat: Safer poll timeout (#1876)
1 parent a001869 commit 7e16e32

File tree

3 files changed

+243
-12
lines changed

3 files changed

+243
-12
lines changed

changelog/1876.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`poll` now takes `PollTimeout` replacing `libc::c_int`.

src/poll.rs

Lines changed: 239 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
//! Wait for events to trigger on specific file descriptors
22
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd};
3+
use std::time::Duration;
34

45
use crate::errno::Errno;
56
use crate::Result;
6-
77
/// This is a wrapper around `libc::pollfd`.
88
///
99
/// It's meant to be used as an argument to the [`poll`](fn.poll.html) and
@@ -27,13 +27,13 @@ impl<'fd> PollFd<'fd> {
2727
/// ```no_run
2828
/// # use std::os::unix::io::{AsFd, AsRawFd, FromRawFd};
2929
/// # use nix::{
30-
/// # poll::{PollFd, PollFlags, poll},
30+
/// # poll::{PollTimeout, PollFd, PollFlags, poll},
3131
/// # unistd::{pipe, read}
3232
/// # };
3333
/// let (r, w) = pipe().unwrap();
3434
/// let pfd = PollFd::new(r.as_fd(), PollFlags::POLLIN);
3535
/// let mut fds = [pfd];
36-
/// poll(&mut fds, -1).unwrap();
36+
/// poll(&mut fds, PollTimeout::NONE).unwrap();
3737
/// let mut buf = [0u8; 80];
3838
/// read(r.as_raw_fd(), &mut buf[..]);
3939
/// ```
@@ -175,6 +175,229 @@ libc_bitflags! {
175175
}
176176
}
177177

178+
/// Timeout argument for [`poll`].
179+
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
180+
pub struct PollTimeout(i32);
181+
182+
impl PollTimeout {
183+
/// Blocks indefinitely.
184+
///
185+
/// > Specifying a negative value in timeout means an infinite timeout.
186+
pub const NONE: Self = Self(-1);
187+
/// Returns immediately.
188+
///
189+
/// > Specifying a timeout of zero causes poll() to return immediately, even if no file
190+
/// > descriptors are ready.
191+
pub const ZERO: Self = Self(0);
192+
/// Blocks for at most [`std::i32::MAX`] milliseconds.
193+
pub const MAX: Self = Self(i32::MAX);
194+
/// Returns if `self` equals [`PollTimeout::NONE`].
195+
pub fn is_none(&self) -> bool {
196+
// > Specifying a negative value in timeout means an infinite timeout.
197+
*self <= Self::NONE
198+
}
199+
/// Returns if `self` does not equal [`PollTimeout::NONE`].
200+
pub fn is_some(&self) -> bool {
201+
!self.is_none()
202+
}
203+
/// Returns the timeout in milliseconds if there is some, otherwise returns `None`.
204+
pub fn as_millis(&self) -> Option<u32> {
205+
self.is_some().then_some(u32::try_from(self.0).unwrap())
206+
}
207+
/// Returns the timeout as a `Duration` if there is some, otherwise returns `None`.
208+
pub fn timeout(&self) -> Option<Duration> {
209+
self.as_millis()
210+
.map(|x| Duration::from_millis(u64::from(x)))
211+
}
212+
}
213+
214+
/// Error type for integer conversions into `PollTimeout`.
215+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
216+
pub enum PollTimeoutTryFromError {
217+
/// Passing a value less than -1 is invalid on some systems, see
218+
/// <https://man.freebsd.org/cgi/man.cgi?poll#end>.
219+
TooNegative,
220+
/// Passing a value greater than `i32::MAX` is invalid.
221+
TooPositive,
222+
}
223+
224+
impl std::fmt::Display for PollTimeoutTryFromError {
225+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226+
match self {
227+
Self::TooNegative => write!(f, "Passed a negative timeout less than -1."),
228+
Self::TooPositive => write!(f, "Passed a positive timeout greater than `i32::MAX` milliseconds.")
229+
}
230+
}
231+
}
232+
233+
impl std::error::Error for PollTimeoutTryFromError {}
234+
235+
impl<T: Into<PollTimeout>> From<Option<T>> for PollTimeout {
236+
fn from(x: Option<T>) -> Self {
237+
x.map_or(Self::NONE, |x| x.into())
238+
}
239+
}
240+
impl TryFrom<Duration> for PollTimeout {
241+
type Error = PollTimeoutTryFromError;
242+
fn try_from(x: Duration) -> std::result::Result<Self, Self::Error> {
243+
Ok(Self(
244+
i32::try_from(x.as_millis())
245+
.map_err(|_| PollTimeoutTryFromError::TooPositive)?,
246+
))
247+
}
248+
}
249+
impl TryFrom<u128> for PollTimeout {
250+
type Error = PollTimeoutTryFromError;
251+
fn try_from(x: u128) -> std::result::Result<Self, Self::Error> {
252+
Ok(Self(
253+
i32::try_from(x)
254+
.map_err(|_| PollTimeoutTryFromError::TooPositive)?,
255+
))
256+
}
257+
}
258+
impl TryFrom<u64> for PollTimeout {
259+
type Error = PollTimeoutTryFromError;
260+
fn try_from(x: u64) -> std::result::Result<Self, Self::Error> {
261+
Ok(Self(
262+
i32::try_from(x)
263+
.map_err(|_| PollTimeoutTryFromError::TooPositive)?,
264+
))
265+
}
266+
}
267+
impl TryFrom<u32> for PollTimeout {
268+
type Error = PollTimeoutTryFromError;
269+
fn try_from(x: u32) -> std::result::Result<Self, Self::Error> {
270+
Ok(Self(
271+
i32::try_from(x)
272+
.map_err(|_| PollTimeoutTryFromError::TooPositive)?,
273+
))
274+
}
275+
}
276+
impl From<u16> for PollTimeout {
277+
fn from(x: u16) -> Self {
278+
Self(i32::from(x))
279+
}
280+
}
281+
impl From<u8> for PollTimeout {
282+
fn from(x: u8) -> Self {
283+
Self(i32::from(x))
284+
}
285+
}
286+
impl TryFrom<i128> for PollTimeout {
287+
type Error = PollTimeoutTryFromError;
288+
fn try_from(x: i128) -> std::result::Result<Self, Self::Error> {
289+
match x {
290+
..=-2 => Err(PollTimeoutTryFromError::TooNegative),
291+
-1.. => Ok(Self(
292+
i32::try_from(x)
293+
.map_err(|_| PollTimeoutTryFromError::TooPositive)?,
294+
)),
295+
}
296+
}
297+
}
298+
impl TryFrom<i64> for PollTimeout {
299+
type Error = PollTimeoutTryFromError;
300+
fn try_from(x: i64) -> std::result::Result<Self, Self::Error> {
301+
match x {
302+
..=-2 => Err(PollTimeoutTryFromError::TooNegative),
303+
-1.. => Ok(Self(
304+
i32::try_from(x)
305+
.map_err(|_| PollTimeoutTryFromError::TooPositive)?,
306+
)),
307+
}
308+
}
309+
}
310+
impl TryFrom<i32> for PollTimeout {
311+
type Error = PollTimeoutTryFromError;
312+
fn try_from(x: i32) -> std::result::Result<Self, Self::Error> {
313+
match x {
314+
..=-2 => Err(PollTimeoutTryFromError::TooNegative),
315+
-1.. => Ok(Self(x)),
316+
}
317+
}
318+
}
319+
impl TryFrom<i16> for PollTimeout {
320+
type Error = PollTimeoutTryFromError;
321+
fn try_from(x: i16) -> std::result::Result<Self, Self::Error> {
322+
match x {
323+
..=-2 => Err(PollTimeoutTryFromError::TooNegative),
324+
-1.. => Ok(Self(i32::from(x))),
325+
}
326+
}
327+
}
328+
impl TryFrom<i8> for PollTimeout {
329+
type Error = PollTimeoutTryFromError;
330+
fn try_from(x: i8) -> std::result::Result<Self, Self::Error> {
331+
match x {
332+
..=-2 => Err(PollTimeoutTryFromError::TooNegative),
333+
-1.. => Ok(Self(i32::from(x))),
334+
}
335+
}
336+
}
337+
impl TryFrom<PollTimeout> for Duration {
338+
type Error = ();
339+
fn try_from(x: PollTimeout) -> std::result::Result<Self, ()> {
340+
x.timeout().ok_or(())
341+
}
342+
}
343+
impl TryFrom<PollTimeout> for u128 {
344+
type Error = <Self as TryFrom<i32>>::Error;
345+
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
346+
Self::try_from(x.0)
347+
}
348+
}
349+
impl TryFrom<PollTimeout> for u64 {
350+
type Error = <Self as TryFrom<i32>>::Error;
351+
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
352+
Self::try_from(x.0)
353+
}
354+
}
355+
impl TryFrom<PollTimeout> for u32 {
356+
type Error = <Self as TryFrom<i32>>::Error;
357+
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
358+
Self::try_from(x.0)
359+
}
360+
}
361+
impl TryFrom<PollTimeout> for u16 {
362+
type Error = <Self as TryFrom<i32>>::Error;
363+
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
364+
Self::try_from(x.0)
365+
}
366+
}
367+
impl TryFrom<PollTimeout> for u8 {
368+
type Error = <Self as TryFrom<i32>>::Error;
369+
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
370+
Self::try_from(x.0)
371+
}
372+
}
373+
impl From<PollTimeout> for i128 {
374+
fn from(x: PollTimeout) -> Self {
375+
Self::from(x.0)
376+
}
377+
}
378+
impl From<PollTimeout> for i64 {
379+
fn from(x: PollTimeout) -> Self {
380+
Self::from(x.0)
381+
}
382+
}
383+
impl From<PollTimeout> for i32 {
384+
fn from(x: PollTimeout) -> Self {
385+
x.0
386+
}
387+
}
388+
impl TryFrom<PollTimeout> for i16 {
389+
type Error = <Self as TryFrom<i32>>::Error;
390+
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
391+
Self::try_from(x.0)
392+
}
393+
}
394+
impl TryFrom<PollTimeout> for i8 {
395+
type Error = <Self as TryFrom<i32>>::Error;
396+
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
397+
Self::try_from(x.0)
398+
}
399+
}
400+
178401
/// `poll` waits for one of a set of file descriptors to become ready to perform I/O.
179402
/// ([`poll(2)`](https://pubs.opengroup.org/onlinepubs/9699919799/functions/poll.html))
180403
///
@@ -191,13 +414,20 @@ libc_bitflags! {
191414
///
192415
/// Note that the timeout interval will be rounded up to the system clock
193416
/// granularity, and kernel scheduling delays mean that the blocking
194-
/// interval may overrun by a small amount. Specifying a negative value
195-
/// in timeout means an infinite timeout. Specifying a timeout of zero
196-
/// causes `poll()` to return immediately, even if no file descriptors are
197-
/// ready.
198-
pub fn poll(fds: &mut [PollFd], timeout: libc::c_int) -> Result<libc::c_int> {
417+
/// interval may overrun by a small amount. Specifying a [`PollTimeout::NONE`]
418+
/// in timeout means an infinite timeout. Specifying a timeout of
419+
/// [`PollTimeout::ZERO`] causes `poll()` to return immediately, even if no file
420+
/// descriptors are ready.
421+
pub fn poll<T: Into<PollTimeout>>(
422+
fds: &mut [PollFd],
423+
timeout: T,
424+
) -> Result<libc::c_int> {
199425
let res = unsafe {
200-
libc::poll(fds.as_mut_ptr().cast(), fds.len() as libc::nfds_t, timeout)
426+
libc::poll(
427+
fds.as_mut_ptr().cast(),
428+
fds.len() as libc::nfds_t,
429+
i32::from(timeout.into()),
430+
)
201431
};
202432

203433
Errno::result(res)

test/test_poll.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use nix::{
22
errno::Errno,
3-
poll::{poll, PollFd, PollFlags},
3+
poll::{poll, PollFd, PollFlags, PollTimeout},
44
unistd::{pipe, write},
55
};
66
use std::os::unix::io::{AsFd, BorrowedFd};
@@ -23,14 +23,14 @@ fn test_poll() {
2323
let mut fds = [PollFd::new(r.as_fd(), PollFlags::POLLIN)];
2424

2525
// Poll an idle pipe. Should timeout
26-
let nfds = loop_while_eintr!(poll(&mut fds, 100));
26+
let nfds = loop_while_eintr!(poll(&mut fds, PollTimeout::from(100u8)));
2727
assert_eq!(nfds, 0);
2828
assert!(!fds[0].revents().unwrap().contains(PollFlags::POLLIN));
2929

3030
write(&w, b".").unwrap();
3131

3232
// Poll a readable pipe. Should return an event.
33-
let nfds = poll(&mut fds, 100).unwrap();
33+
let nfds = poll(&mut fds, PollTimeout::from(100u8)).unwrap();
3434
assert_eq!(nfds, 1);
3535
assert!(fds[0].revents().unwrap().contains(PollFlags::POLLIN));
3636
}

0 commit comments

Comments
 (0)