Skip to content

Commit 2c88cc3

Browse files
committed
Revert "Revert "Merge NetlinkPayload::{Ack,Error}""
This reverts commit 16300f5. Signed-off-by: Gris Ge <[email protected]>
1 parent 0e486c3 commit 2c88cc3

File tree

3 files changed

+104
-22
lines changed

3 files changed

+104
-22
lines changed

src/error.rs

+69-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// SPDX-License-Identifier: MIT
22

3-
use std::{fmt, io, mem::size_of};
3+
use std::{fmt, io, mem::size_of, num::NonZeroI32};
44

55
use byteorder::{ByteOrder, NativeEndian};
66
use netlink_packet_utils::DecodeError;
@@ -46,10 +46,14 @@ impl<T: AsRef<[u8]>> ErrorBuffer<T> {
4646
}
4747
}
4848

49-
/// Return the error code
50-
pub fn code(&self) -> i32 {
49+
/// Return the error code.
50+
///
51+
/// Returns `None` when there is no error to report (the message is an ACK),
52+
/// or a `Some(e)` if there is a non-zero error code `e` to report (the
53+
/// message is a NACK).
54+
pub fn code(&self) -> Option<NonZeroI32> {
5155
let data = self.buffer.as_ref();
52-
NativeEndian::read_i32(&data[CODE])
56+
NonZeroI32::new(NativeEndian::read_i32(&data[CODE]))
5357
}
5458
}
5559

@@ -77,22 +81,36 @@ impl<T: AsRef<[u8]> + AsMut<[u8]>> ErrorBuffer<T> {
7781
}
7882
}
7983

84+
/// An `NLMSG_ERROR` message.
85+
///
86+
/// Per [RFC 3549 section 2.3.2.2], this message carries the return code for a
87+
/// request which will indicate either success (an ACK) or failure (a NACK).
88+
///
89+
/// [RFC 3549 section 2.3.2.2]: https://datatracker.ietf.org/doc/html/rfc3549#section-2.3.2.2
8090
#[derive(Debug, Default, Clone, PartialEq, Eq)]
8191
#[non_exhaustive]
8292
pub struct ErrorMessage {
83-
pub code: i32,
93+
/// The error code.
94+
///
95+
/// Holds `None` when there is no error to report (the message is an ACK),
96+
/// or a `Some(e)` if there is a non-zero error code `e` to report (the
97+
/// message is a NACK).
98+
///
99+
/// See [Netlink message types] for details.
100+
///
101+
/// [Netlink message types]: https://kernel.org/doc/html/next/userspace-api/netlink/intro.html#netlink-message-types
102+
pub code: Option<NonZeroI32>,
103+
/// The original request's header.
84104
pub header: Vec<u8>,
85105
}
86106

87-
pub type AckMessage = ErrorMessage;
88-
89107
impl Emitable for ErrorMessage {
90108
fn buffer_len(&self) -> usize {
91109
size_of::<i32>() + self.header.len()
92110
}
93111
fn emit(&self, buffer: &mut [u8]) {
94112
let mut buffer = ErrorBuffer::new(buffer);
95-
buffer.set_code(self.code);
113+
buffer.set_code(self.raw_code());
96114
buffer.payload_mut().copy_from_slice(&self.header)
97115
}
98116
}
@@ -119,13 +137,18 @@ impl<'buffer, T: AsRef<[u8]> + 'buffer> Parseable<ErrorBuffer<&'buffer T>>
119137
}
120138

121139
impl ErrorMessage {
140+
/// Returns the raw error code.
141+
pub fn raw_code(&self) -> i32 {
142+
self.code.map_or(0, NonZeroI32::get)
143+
}
144+
122145
/// According to [`netlink(7)`](https://linux.die.net/man/7/netlink)
123146
/// the `NLMSG_ERROR` return Negative errno or 0 for acknowledgements.
124147
///
125148
/// convert into [`std::io::Error`](https://doc.rust-lang.org/std/io/struct.Error.html)
126149
/// using the absolute value from errno code
127150
pub fn to_io(&self) -> io::Error {
128-
io::Error::from_raw_os_error(self.code.abs())
151+
io::Error::from_raw_os_error(self.raw_code().abs())
129152
}
130153
}
131154

@@ -149,7 +172,7 @@ mod tests {
149172
fn into_io_error() {
150173
let io_err = io::Error::from_raw_os_error(95);
151174
let err_msg = ErrorMessage {
152-
code: -95,
175+
code: NonZeroI32::new(-95),
153176
header: vec![],
154177
};
155178

@@ -158,4 +181,40 @@ mod tests {
158181
assert_eq!(err_msg.to_string(), io_err.to_string());
159182
assert_eq!(to_io.raw_os_error(), io_err.raw_os_error());
160183
}
184+
185+
#[test]
186+
fn parse_ack() {
187+
let bytes = vec![0, 0, 0, 0];
188+
let msg = ErrorBuffer::new_checked(&bytes)
189+
.and_then(|buf| ErrorMessage::parse(&buf))
190+
.expect("failed to parse NLMSG_ERROR");
191+
assert_eq!(
192+
ErrorMessage {
193+
code: None,
194+
header: Vec::new()
195+
},
196+
msg
197+
);
198+
assert_eq!(msg.raw_code(), 0);
199+
}
200+
201+
#[test]
202+
fn parse_nack() {
203+
// SAFETY: value is non-zero.
204+
const ERROR_CODE: NonZeroI32 =
205+
unsafe { NonZeroI32::new_unchecked(-1234) };
206+
let mut bytes = vec![0, 0, 0, 0];
207+
NativeEndian::write_i32(&mut bytes, ERROR_CODE.get());
208+
let msg = ErrorBuffer::new_checked(&bytes)
209+
.and_then(|buf| ErrorMessage::parse(&buf))
210+
.expect("failed to parse NLMSG_ERROR");
211+
assert_eq!(
212+
ErrorMessage {
213+
code: Some(ERROR_CODE),
214+
header: Vec::new()
215+
},
216+
msg
217+
);
218+
assert_eq!(msg.raw_code(), ERROR_CODE.get());
219+
}
161220
}

src/message.rs

+33-9
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use netlink_packet_utils::DecodeError;
77

88
use crate::{
99
payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN},
10-
AckMessage, DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage,
10+
DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage,
1111
NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkPayload,
1212
NetlinkSerializable, Parseable,
1313
};
@@ -101,11 +101,7 @@ where
101101
let msg = ErrorBuffer::new_checked(&bytes)
102102
.and_then(|buf| ErrorMessage::parse(&buf))
103103
.context("failed to parse NLMSG_ERROR")?;
104-
if msg.code >= 0 {
105-
Ack(msg as AckMessage)
106-
} else {
107-
Error(msg)
108-
}
104+
Error(msg)
109105
}
110106
NLMSG_NOOP => Noop,
111107
NLMSG_DONE => {
@@ -138,7 +134,6 @@ where
138134
Done(ref msg) => msg.buffer_len(),
139135
Overrun(ref bytes) => bytes.len(),
140136
Error(ref msg) => msg.buffer_len(),
141-
Ack(ref msg) => msg.buffer_len(),
142137
InnerMessage(ref msg) => msg.buffer_len(),
143138
};
144139

@@ -157,7 +152,6 @@ where
157152
Done(ref msg) => msg.emit(buffer),
158153
Overrun(ref bytes) => buffer.copy_from_slice(bytes),
159154
Error(ref msg) => msg.emit(buffer),
160-
Ack(ref msg) => msg.emit(buffer),
161155
InnerMessage(ref msg) => msg.serialize(buffer),
162156
}
163157
}
@@ -179,7 +173,7 @@ where
179173
mod tests {
180174
use super::*;
181175

182-
use std::{convert::Infallible, mem::size_of};
176+
use std::{convert::Infallible, mem::size_of, num::NonZeroI32};
183177

184178
#[derive(Clone, Debug, Default, PartialEq)]
185179
struct FakeNetlinkInnerMessage;
@@ -240,4 +234,34 @@ mod tests {
240234
let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
241235
assert_eq!(got, want);
242236
}
237+
238+
#[test]
239+
fn test_error() {
240+
// SAFETY: value is non-zero.
241+
const ERROR_CODE: NonZeroI32 =
242+
unsafe { NonZeroI32::new_unchecked(-8765) };
243+
244+
let header = NetlinkHeader::default();
245+
let error_msg = ErrorMessage {
246+
code: Some(ERROR_CODE),
247+
header: vec![],
248+
};
249+
let mut want = NetlinkMessage::new(
250+
header,
251+
NetlinkPayload::<FakeNetlinkInnerMessage>::Error(error_msg.clone()),
252+
);
253+
want.finalize();
254+
255+
let len = want.buffer_len();
256+
assert_eq!(len, header.buffer_len() + error_msg.buffer_len());
257+
258+
let mut buf = vec![1; len];
259+
want.emit(&mut buf);
260+
261+
let error_buf = ErrorBuffer::new(&buf[header.buffer_len()..]);
262+
assert_eq!(error_buf.code(), error_msg.code);
263+
264+
let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
265+
assert_eq!(got, want);
266+
}
243267
}

src/payload.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
use std::fmt::Debug;
44

5-
use crate::{AckMessage, DoneMessage, ErrorMessage, NetlinkSerializable};
5+
use crate::{DoneMessage, ErrorMessage, NetlinkSerializable};
66

77
/// The message is ignored.
88
pub const NLMSG_NOOP: u16 = 1;
@@ -20,7 +20,6 @@ pub const NLMSG_ALIGNTO: u16 = 4;
2020
pub enum NetlinkPayload<I> {
2121
Done(DoneMessage),
2222
Error(ErrorMessage),
23-
Ack(AckMessage),
2423
Noop,
2524
Overrun(Vec<u8>),
2625
InnerMessage(I),
@@ -33,7 +32,7 @@ where
3332
pub fn message_type(&self) -> u16 {
3433
match self {
3534
NetlinkPayload::Done(_) => NLMSG_DONE,
36-
NetlinkPayload::Error(_) | NetlinkPayload::Ack(_) => NLMSG_ERROR,
35+
NetlinkPayload::Error(_) => NLMSG_ERROR,
3736
NetlinkPayload::Noop => NLMSG_NOOP,
3837
NetlinkPayload::Overrun(_) => NLMSG_OVERRUN,
3938
NetlinkPayload::InnerMessage(message) => message.message_type(),

0 commit comments

Comments
 (0)