diff --git a/crates/rlp/src/decode.rs b/crates/rlp/src/decode.rs index e668fce..7ee1a94 100644 --- a/crates/rlp/src/decode.rs +++ b/crates/rlp/src/decode.rs @@ -1,12 +1,93 @@ -use crate::{Error, Header, Result}; -use bytes::{Bytes, BytesMut}; +use crate::{header::advance_unchecked, Error, Header, Result}; +use bytes::{Buf, Bytes, BytesMut}; use core::marker::{PhantomData, PhantomPinned}; +/// The expected type of an RLP header during deserialization. This is used by +/// the [`Decodable`] trait to enforce header correctness during decoding. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Expectation { + /// Expect a list. + List, + /// Expect a bytestring. + Bytestring, + /// No expectation. The type has no header, or the header type is data-dependent. + None, +} + +impl Expectation { + /// Checks if the header matches the expectation. + pub fn check(&self, header: &Header) -> Result<()> { + match self { + Self::List => { + if !header.list { + return Err(Error::UnexpectedString); + } + } + Self::Bytestring => { + if header.list { + return Err(Error::UnexpectedList); + } + } + _ => {} + } + Ok(()) + } +} + /// A type that can be decoded from an RLP blob. pub trait Decodable: Sized { - /// Decodes the blob into the appropriate type. `buf` must be advanced past - /// the decoded object. - fn decode(buf: &mut &[u8]) -> Result; + /// Returns the expected header type for this type. Used by + /// [`Decodable::decode`] to check header correctness during decoding. If + /// the RLP type is unknown or data-dependent, or if the data is a + /// single-byte type return [`Expectation::None`]. + fn expected() -> Expectation; + + /// Decode the fields of this type from the blob. + /// + /// After this function returns the `buf` MUST be empty. + fn decode_fields(buf: &mut &[u8]) -> Result; + + /// Decodes the blob into the appropriate type. + fn decode(buf: &mut &[u8]) -> Result { + let header = Header::decode(buf)?; + if header.payload_length > buf.len() { + return Err(Error::InputTooShort); + } + Self::expected().check(&header)?; + let t = Self::decode_fields(buf)?; + + Ok(t) + } + + /// Decode the blob into the appropriate type, ensuring no trailing bytes + /// remain. + fn decode_exact(buf: &mut &[u8]) -> Result { + let copy = &mut &**buf; + + // Determine what the appropriate region of the header is + let header = Header::decode(copy)?; + let inner_deser_len = header.length_with_payload(); + + // Check that the buffer is exact size + if inner_deser_len > buf.len() { + return Err(Error::InputTooShort); + } + if inner_deser_len < buf.len() { + return Err(Error::UnexpectedLength); + } + + // Deserialize using only the appropriate region of the buffer + let inner_deser = &mut &buf[..inner_deser_len]; + let t = Self::decode(inner_deser)?; + if inner_deser.len() != 0 { + // decoding failed to consume the buffer + return Err(Error::UnexpectedLength); + } + + // SAFETY: checked above + unsafe { advance_unchecked(buf, inner_deser_len) }; + Ok(t) + } } /// An active RLP decoder, with a specific slice of a payload. @@ -34,58 +115,178 @@ impl<'a> Rlp<'a> { } impl Decodable for PhantomData { + #[inline] + fn expected() -> Expectation { + Expectation::None + } + + #[inline] + fn decode_fields(_buf: &mut &[u8]) -> Result { + Ok(Self) + } + + #[inline] fn decode(_buf: &mut &[u8]) -> Result { Ok(Self) } + + #[inline] + fn decode_exact(_buf: &mut &[u8]) -> Result { + Ok(Self) + } } impl Decodable for PhantomPinned { + #[inline] + fn expected() -> Expectation { + Expectation::None + } + + #[inline] + fn decode_fields(_buf: &mut &[u8]) -> Result { + Ok(Self) + } + + #[inline] fn decode(_buf: &mut &[u8]) -> Result { Ok(Self) } + + #[inline] + fn decode_exact(_buf: &mut &[u8]) -> Result { + Ok(Self) + } } impl Decodable for bool { #[inline] - fn decode(buf: &mut &[u8]) -> Result { + fn expected() -> Expectation { + Expectation::None + } + + #[inline] + fn decode_fields(buf: &mut &[u8]) -> Result { Ok(match u8::decode(buf)? { 0 => false, 1 => true, _ => return Err(Error::Custom("invalid bool value, must be 0 or 1")), }) } + + #[inline] + fn decode(buf: &mut &[u8]) -> Result { + Self::decode_fields(buf) + } + + #[inline] + fn decode_exact(buf: &mut &[u8]) -> Result { + if buf.len() != 1 { + return Err(Error::UnexpectedLength); + } + Self::decode_fields(buf) + } } impl Decodable for [u8; N] { #[inline] - fn decode(from: &mut &[u8]) -> Result { - let bytes = Header::decode_bytes(from, false)?; - Self::try_from(bytes).map_err(|_| Error::UnexpectedLength) + fn expected() -> Expectation { + Expectation::Bytestring + } + + #[inline] + fn decode_fields(buf: &mut &[u8]) -> Result { + let mut arr = [0; N]; + arr.copy_from_slice(*buf); + *buf = &[]; + Ok(arr) + } + + fn decode(buf: &mut &[u8]) -> Result { + let header = Header::decode(buf)?; + if header.payload_length != N { + return Err(Error::UnexpectedLength); + } + if buf.len() < N { + return Err(Error::InputTooShort); + } + Self::expected().check(&header)?; + let t = Self::decode_fields(buf)?; + + Ok(t) } } -macro_rules! decode_integer { +macro_rules! uint_impl { ($($t:ty),+ $(,)?) => {$( impl Decodable for $t { + #[inline] + fn expected() -> Expectation { + Expectation::None + } + + #[inline] + fn decode_fields(buf: &mut &[u8]) -> Result { + let first = buf.first().copied().ok_or(Error::InputTooShort)?; + match first { + 0 => return Err(Error::LeadingZero), + 1..crate::EMPTY_STRING_CODE => { + buf.advance(1); + return Ok(first as $t) + }, + crate::EMPTY_STRING_CODE => { + buf.advance(1); + return Ok(0) + }, + _ => { + let bytes = Header::decode_bytes(buf, false)?; + static_left_pad(bytes).map(<$t>::from_be_bytes) + } + } + + } + #[inline] fn decode(buf: &mut &[u8]) -> Result { - let bytes = Header::decode_bytes(buf, false)?; - static_left_pad(bytes).map(<$t>::from_be_bytes) + Self::decode_fields(buf) + } + + #[inline] + fn decode_exact(buf: &mut &[u8]) -> Result { + let res = Self::decode_fields(buf); + if !buf.is_empty() { + return Err(Error::UnexpectedLength); + } + res } } )+}; } -decode_integer!(u8, u16, u32, u64, usize, u128); +uint_impl!(u8, u16, u32, u64, usize, u128); impl Decodable for Bytes { #[inline] - fn decode(buf: &mut &[u8]) -> Result { - Header::decode_bytes(buf, false).map(|x| Self::from(x.to_vec())) + fn expected() -> Expectation { + Expectation::Bytestring + } + + #[inline] + fn decode_fields(buf: &mut &[u8]) -> Result { + Ok(buf.copy_to_bytes(buf.len())) } } impl Decodable for BytesMut { + #[inline] + fn expected() -> Expectation { + Expectation::Bytestring + } + + #[inline] + fn decode_fields(buf: &mut &[u8]) -> Result { + Ok(buf.copy_to_bytes(buf.len()).into()) + } + #[inline] fn decode(buf: &mut &[u8]) -> Result { Header::decode_bytes(buf, false).map(Self::from) @@ -94,19 +295,49 @@ impl Decodable for BytesMut { impl Decodable for alloc::string::String { #[inline] + fn expected() -> Expectation { + Expectation::Bytestring + } + + #[inline] + fn decode_fields(buf: &mut &[u8]) -> Result { + let res = core::str::from_utf8(buf) + .map_err(|_| Error::Custom("invalid utf8 string")) + .map(Into::into); + *buf = &[]; + res + } + fn decode(buf: &mut &[u8]) -> Result { - Header::decode_str(buf).map(Into::into) + let header = Header::decode(buf)?; + if header.payload_length == 0 { + if buf.is_empty() { + return Ok(Self::new()); + } else { + return Err(Error::UnexpectedLength); + } + } + if header.payload_length > buf.len() { + return Err(Error::InputTooShort); + } + Self::expected().check(&header)?; + let t = Self::decode_fields(buf)?; + + Ok(t) } } impl Decodable for alloc::vec::Vec { #[inline] - fn decode(buf: &mut &[u8]) -> Result { - let mut bytes = Header::decode_bytes(buf, true)?; + fn expected() -> Expectation { + Expectation::List + } + + #[inline] + fn decode_fields(buf: &mut &[u8]) -> Result { let mut vec = Self::new(); - let payload_view = &mut bytes; - while !payload_view.is_empty() { - vec.push(T::decode(payload_view)?); + while !buf.is_empty() { + vec.push(T::decode(buf)?); } Ok(vec) } @@ -116,6 +347,16 @@ macro_rules! wrap_impl { ($($(#[$attr:meta])* [$($gen:tt)*] <$t:ty>::$new:ident($t2:ty)),+ $(,)?) => {$( $(#[$attr])* impl<$($gen)*> Decodable for $t { + #[inline] + fn expected() -> Expectation { + <$t2 as Decodable>::expected() + } + + #[inline] + fn decode_fields(buf: &mut &[u8]) -> Result { + <$t2 as Decodable>::decode_fields(buf).map(<$t>::$new) + } + #[inline] fn decode(buf: &mut &[u8]) -> Result { <$t2 as Decodable>::decode(buf).map(<$t>::$new) @@ -136,6 +377,16 @@ impl Decodable for alloc::borrow::Cow<'_, T> where T::Owned: Decodable, { + #[inline] + fn expected() -> Expectation { + T::Owned::expected() + } + + #[inline] + fn decode_fields(buf: &mut &[u8]) -> Result { + T::Owned::decode_fields(buf).map(Self::Owned) + } + #[inline] fn decode(buf: &mut &[u8]) -> Result { T::Owned::decode(buf).map(Self::Owned) @@ -148,31 +399,48 @@ mod std_impl { use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; impl Decodable for IpAddr { - fn decode(buf: &mut &[u8]) -> Result { - let bytes = Header::decode_bytes(buf, false)?; - match bytes.len() { - 4 => Ok(Self::V4(Ipv4Addr::from(slice_to_array::<4>(bytes).expect("infallible")))), - 16 => { - Ok(Self::V6(Ipv6Addr::from(slice_to_array::<16>(bytes).expect("infallible")))) - } - _ => Err(Error::UnexpectedLength), + #[inline] + fn expected() -> Expectation { + Expectation::Bytestring + } + + #[inline] + fn decode_fields(buf: &mut &[u8]) -> Result { + if buf.len() == 4 { + Ipv4Addr::decode_fields(buf).map(Self::V4) + } else if buf.len() == 16 { + Ipv6Addr::decode_fields(buf).map(Self::V6) + } else { + Err(Error::UnexpectedLength) } } } impl Decodable for Ipv4Addr { #[inline] - fn decode(buf: &mut &[u8]) -> Result { - let bytes = Header::decode_bytes(buf, false)?; - slice_to_array::<4>(bytes).map(Self::from) + fn expected() -> Expectation { + Expectation::Bytestring + } + + #[inline] + fn decode_fields(buf: &mut &[u8]) -> Result { + let res = slice_to_array::<4>(buf).map(Self::from); + buf.advance(4); + res } } impl Decodable for Ipv6Addr { #[inline] - fn decode(buf: &mut &[u8]) -> Result { - let bytes = Header::decode_bytes(buf, false)?; - slice_to_array::<16>(bytes).map(Self::from) + fn expected() -> Expectation { + Expectation::Bytestring + } + + #[inline] + fn decode_fields(buf: &mut &[u8]) -> Result { + let res = slice_to_array::<16>(buf).map(Self::from); + buf.advance(16); + res } } } @@ -184,16 +452,7 @@ mod std_impl { /// Returns an error if the encoding is invalid or if data remains after decoding the RLP item. #[inline] pub fn decode_exact(bytes: impl AsRef<[u8]>) -> Result { - let mut buf = bytes.as_ref(); - let out = T::decode(&mut buf)?; - - // check if there are any remaining bytes after decoding - if !buf.is_empty() { - // TODO: introduce a new variant TrailingBytes to better distinguish this error - return Err(Error::UnexpectedLength); - } - - Ok(out) + T::decode_exact(&mut bytes.as_ref()) } /// Left-pads a slice to a statically known size array. @@ -387,6 +646,8 @@ mod tests { ); } + decode_exact::(vec![0x80]).unwrap(); + check_decode_exact::("".into()); check_decode_exact::("test1234".into()); check_decode_exact::>(vec![]); diff --git a/crates/rlp/src/encode.rs b/crates/rlp/src/encode.rs index 0fcb0a5..7bfe38a 100644 --- a/crates/rlp/src/encode.rs +++ b/crates/rlp/src/encode.rs @@ -13,19 +13,36 @@ use arrayvec::ArrayVec; /// A type that can be encoded via RLP. pub trait Encodable { + /// Encode the inner fields into the `out` buffer, without a header. + /// + /// This is a low-level function that should generally not be called + /// directly. Use [`encode`] instead. + fn encode_fields(&self, out: &mut dyn BufMut); + + /// Returns the length of the encoded fields in bytes. + fn encoded_fields_length(&self) -> usize; + + /// Returns `true` if the type is a string. + fn is_string(&self) -> bool; + + /// Creates a header for the encoding. For types for which a header is + /// unnecessary, + fn header(&self) -> Option
{ + Some(Header { list: !self.is_string(), payload_length: self.encoded_fields_length() }) + } + /// Encodes the type into the `out` buffer. - fn encode(&self, out: &mut dyn BufMut); + fn encode(&self, out: &mut dyn BufMut) { + self.header().map(|h| h.encode(out)); + self.encode_fields(out); + } /// Returns the length of the encoding of this type in bytes. - /// - /// The default implementation computes this by encoding the type. - /// When possible, we recommender implementers override this with a - /// specialized implementation. #[inline] fn length(&self) -> usize { - let mut out = Vec::new(); - self.encode(&mut out); - out.len() + self.header() + .map(|h| h.length_with_payload()) + .unwrap_or_else(|| self.encoded_fields_length()) } } @@ -75,52 +92,123 @@ pub(crate) use to_be_bytes_trimmed; impl Encodable for [u8] { #[inline] - fn length(&self) -> usize { - let mut len = self.len(); - if len != 1 || self[0] >= EMPTY_STRING_CODE { - len += length_of_length(len); - } - len + fn encode_fields(&self, out: &mut dyn BufMut) { + out.put_slice(self); } #[inline] - fn encode(&self, out: &mut dyn BufMut) { - if self.len() != 1 || self[0] >= EMPTY_STRING_CODE { - Header { list: false, payload_length: self.len() }.encode(out); + fn encoded_fields_length(&self) -> usize { + self.len() + } + + #[inline] + fn is_string(&self) -> bool { + true + } + + #[inline] + fn header(&self) -> Option
{ + (self.len() != 1 || self[0] >= EMPTY_STRING_CODE) + .then(|| Header { list: false, payload_length: self.encoded_fields_length() }) + } + + #[inline] + fn length(&self) -> usize { + const ESC: usize = EMPTY_STRING_CODE as usize; + match self.len() { + ..ESC => 1, + ESC.. => { + self.header().expect("encoding rules enforce header presence").length_with_payload() + } } - out.put_slice(self); } } impl Encodable for PhantomData { #[inline] - fn length(&self) -> usize { + fn encode_fields(&self, _out: &mut dyn BufMut) {} + + #[inline] + fn encoded_fields_length(&self) -> usize { 0 } + #[inline] + fn is_string(&self) -> bool { + true + } + + #[inline] + fn header(&self) -> Option
{ + None + } + #[inline] fn encode(&self, _out: &mut dyn BufMut) {} + + #[inline] + fn length(&self) -> usize { + 0 + } } impl Encodable for PhantomPinned { #[inline] - fn length(&self) -> usize { + fn encode_fields(&self, _out: &mut dyn BufMut) {} + + #[inline] + fn encoded_fields_length(&self) -> usize { 0 } + #[inline] + fn is_string(&self) -> bool { + true + } + + #[inline] + fn header(&self) -> Option
{ + None + } + #[inline] fn encode(&self, _out: &mut dyn BufMut) {} + + #[inline] + fn length(&self) -> usize { + 0 + } } impl Encodable for [u8; N] { + #[inline] + fn encode_fields(&self, out: &mut dyn BufMut) { + self.as_slice().encode_fields(out); + } + + #[inline] + fn encoded_fields_length(&self) -> usize { + self.as_slice().encoded_fields_length() + } + + #[inline] + fn is_string(&self) -> bool { + true + } + + #[inline] + fn header(&self) -> Option
{ + self.as_slice().header() + } + #[inline] fn length(&self) -> usize { - self[..].length() + self.as_slice().length() } #[inline] fn encode(&self, out: &mut dyn BufMut) { - self[..].encode(out); + self.as_slice().encode(out); } } @@ -130,17 +218,57 @@ unsafe impl MaxEncodedLenAssoc for [u8; N] { impl Encodable for str { #[inline] - fn length(&self) -> usize { - self.as_bytes().length() + fn encode_fields(&self, out: &mut dyn BufMut) { + self.as_bytes().encode_fields(out); + } + + #[inline] + fn encoded_fields_length(&self) -> usize { + self.as_bytes().encoded_fields_length() + } + + #[inline] + fn is_string(&self) -> bool { + true + } + + #[inline] + fn header(&self) -> Option
{ + self.as_bytes().header() } #[inline] fn encode(&self, out: &mut dyn BufMut) { - self.as_bytes().encode(out) + self.as_bytes().encode(out); + } + + #[inline] + fn length(&self) -> usize { + self.as_bytes().length() } } impl Encodable for bool { + #[inline] + fn encode_fields(&self, out: &mut dyn BufMut) { + out.put_u8(if *self { 1 } else { EMPTY_STRING_CODE }); + } + + #[inline] + fn encoded_fields_length(&self) -> usize { + 1 + } + + #[inline] + fn is_string(&self) -> bool { + true + } + + #[inline] + fn header(&self) -> Option
{ + None + } + #[inline] fn length(&self) -> usize { // a `bool` is always `< EMPTY_STRING_CODE` @@ -149,8 +277,7 @@ impl Encodable for bool { #[inline] fn encode(&self, out: &mut dyn BufMut) { - // inlined `(*self as u8).encode(out)` - out.put_u8(if *self { 1 } else { EMPTY_STRING_CODE }); + self.encode_fields(out); } } @@ -159,28 +286,53 @@ impl_max_encoded_len!(bool, ::LEN); macro_rules! uint_impl { ($($t:ty),+ $(,)?) => {$( impl Encodable for $t { + #[inline] - fn length(&self) -> usize { - let x = *self; - if x < EMPTY_STRING_CODE as $t { - 1 - } else { - 1 + (<$t>::BITS as usize / 8) - (x.leading_zeros() as usize / 8) + fn encode_fields(&self, out: &mut dyn BufMut) { + const ESC: $t = EMPTY_STRING_CODE as $t; + match self { + 0 => out.put_u8(EMPTY_STRING_CODE), + 1..ESC => out.put_u8(*self as u8), + ESC.. => { + let be; + let be = to_be_bytes_trimmed!(be, *self); + out.put_slice(be); + } } } #[inline] - fn encode(&self, out: &mut dyn BufMut) { - let x = *self; - if x == 0 { - out.put_u8(EMPTY_STRING_CODE); - } else if x < EMPTY_STRING_CODE as $t { - out.put_u8(x as u8); - } else { - let be; - let be = to_be_bytes_trimmed!(be, x); - out.put_u8(EMPTY_STRING_CODE + be.len() as u8); - out.put_slice(be); + fn encoded_fields_length(&self) -> usize { + const ESC: $t = EMPTY_STRING_CODE as $t; + match self { + 0..ESC => 1, + ESC.. => (<$t>::BITS as usize / 8) - (self.leading_zeros() as usize / 8) + } + } + + + #[inline] + fn is_string(&self) -> bool { + true + } + + #[inline] + fn header(&self) -> Option
{ + const ESC: $t = EMPTY_STRING_CODE as $t; + match self { + 0..ESC => None, + ESC.. => { + Some(Header { list: false, payload_length: self.encoded_fields_length() }) + } + } + } + + #[inline] + fn length(&self) -> usize { + const ESC: $t = EMPTY_STRING_CODE as $t; + match self { + 0..ESC => 1, + ESC.. => self.header().expect("header presence enforced by encoding rules").length_with_payload() } } } @@ -196,13 +348,22 @@ uint_impl!(u8, u16, u32, u64, usize, u128); impl Encodable for Vec { #[inline] - fn length(&self) -> usize { - list_length(self) + fn encode_fields(&self, out: &mut dyn BufMut) { + self.iter().for_each(|t| t.encode(out)); + } + + fn encoded_fields_length(&self) -> usize { + self.iter().map(Encodable::length).sum() + } + + fn is_string(&self) -> bool { + false } #[inline] fn encode(&self, out: &mut dyn BufMut) { - encode_list(self, out) + self.header().expect("lists always have headers").encode(out); + self.encode_fields(out); } } @@ -210,6 +371,27 @@ macro_rules! deref_impl { ($($(#[$attr:meta])* [$($gen:tt)*] $t:ty),+ $(,)?) => {$( $(#[$attr])* impl<$($gen)*> Encodable for $t { + + #[inline] + fn encode_fields(&self, out: &mut dyn BufMut) { + (**self).encode_fields(out) + } + + #[inline] + fn encoded_fields_length(&self) -> usize { + (**self).encoded_fields_length() + } + + #[inline] + fn is_string(&self) -> bool { + (**self).is_string() + } + + #[inline] + fn header(&self) -> Option
{ + (**self).header() + } + #[inline] fn length(&self) -> usize { (**self).length() @@ -243,6 +425,34 @@ mod std_support { use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; impl Encodable for IpAddr { + fn encode_fields(&self, out: &mut dyn BufMut) { + match self { + Self::V4(ip) => ip.encode_fields(out), + Self::V6(ip) => ip.encode_fields(out), + } + } + + fn encoded_fields_length(&self) -> usize { + match self { + Self::V4(ip) => ip.encoded_fields_length(), + Self::V6(ip) => ip.encoded_fields_length(), + } + } + + fn is_string(&self) -> bool { + match self { + Self::V4(ip) => ip.is_string(), + Self::V6(ip) => ip.is_string(), + } + } + + fn header(&self) -> Option
{ + match self { + Self::V4(ip) => ip.header(), + Self::V6(ip) => ip.header(), + } + } + #[inline] fn length(&self) -> usize { match self { @@ -261,6 +471,26 @@ mod std_support { } impl Encodable for Ipv4Addr { + #[inline] + fn encode_fields(&self, out: &mut dyn BufMut) { + self.octets().encode_fields(out) + } + + #[inline] + fn encoded_fields_length(&self) -> usize { + self.octets().encoded_fields_length() + } + + #[inline] + fn is_string(&self) -> bool { + self.octets().is_string() + } + + #[inline] + fn header(&self) -> Option
{ + self.octets().header() + } + #[inline] fn length(&self) -> usize { self.octets().length() @@ -273,6 +503,26 @@ mod std_support { } impl Encodable for Ipv6Addr { + #[inline] + fn encode_fields(&self, out: &mut dyn BufMut) { + self.octets().encode_fields(out) + } + + #[inline] + fn encoded_fields_length(&self) -> usize { + self.octets().encoded_fields_length() + } + + #[inline] + fn is_string(&self) -> bool { + self.octets().is_string() + } + + #[inline] + fn header(&self) -> Option
{ + self.octets().header() + } + #[inline] fn length(&self) -> usize { self.octets().length() @@ -468,7 +718,10 @@ mod tests { macro_rules! uint_rlp_test { ($fixtures:expr) => { for (input, output) in $fixtures { - assert_eq!(encode(input), output, "encode({input})"); + let encoded = encode(input); + assert_eq!(input.length(), encoded.len(), "length({input})"); + assert_eq!(encoded, output, "encode({input})"); + #[cfg(feature = "arrayvec")] assert_eq!(&encode_fixed_size(&input)[..], output, "encode_fixed_size({input})"); } diff --git a/crates/rlp/src/header.rs b/crates/rlp/src/header.rs index 65f2c97..ebb8777 100644 --- a/crates/rlp/src/header.rs +++ b/crates/rlp/src/header.rs @@ -186,7 +186,7 @@ fn get_next_byte(buf: &[u8]) -> Result { /// Same as `let (bytes, rest) = buf.split_at(cnt); *buf = rest; bytes`. #[inline(always)] -unsafe fn advance_unchecked<'a>(buf: &mut &'a [u8], cnt: usize) -> &'a [u8] { +pub(crate) unsafe fn advance_unchecked<'a>(buf: &mut &'a [u8], cnt: usize) -> &'a [u8] { if buf.remaining() < cnt { unreachable_unchecked() } diff --git a/crates/rlp/src/lib.rs b/crates/rlp/src/lib.rs index 32cc0b9..10b4901 100644 --- a/crates/rlp/src/lib.rs +++ b/crates/rlp/src/lib.rs @@ -11,7 +11,7 @@ extern crate alloc; mod decode; -pub use decode::{decode_exact, Decodable, Rlp}; +pub use decode::{decode_exact, Decodable, Expectation, Rlp}; mod error; pub use error::{Error, Result}; diff --git a/examples/enum.rs b/examples/enum.rs index 4227543..aabaf0b 100644 --- a/examples/enum.rs +++ b/examples/enum.rs @@ -1,38 +1,68 @@ //! This example demonstrates how to encode and decode an enum using //! `alloy_rlp`. -use alloy_rlp::{encode, encode_list, Decodable, Encodable, Error, Header}; +use alloy_rlp::{encode, Decodable, Encodable, Error, Expectation}; use bytes::BufMut; #[derive(Debug, PartialEq)] enum FooBar { + // Treated as an individual value Foo(u64), + // Treated as a list of values Bar(u16, u64), } impl Encodable for FooBar { - fn encode(&self, out: &mut dyn BufMut) { + fn encode_fields(&self, out: &mut dyn BufMut) { match self { + // This is an individual value, so we just pass through Self::Foo(x) => { - let enc: [&dyn Encodable; 2] = [&0u8, x]; - encode_list::<_, dyn Encodable>(&enc, out); + x.encode_fields(out); } + // This is a list of values, so we need to encode each entry as its + // own item Self::Bar(x, y) => { - let enc: [&dyn Encodable; 3] = [&1u8, x, y]; - encode_list::<_, dyn Encodable>(&enc, out); + x.encode(out); + y.encode(out); } } } + + fn encoded_fields_length(&self) -> usize { + match self { + // This is an individual value, so we just pass through + Self::Foo(x) => x.encoded_fields_length(), + // This is a list of values, so we need to sum the lengths of the fields + Self::Bar(x, y) => x.length() + y.length(), + } + } + + fn is_string(&self) -> bool { + match self { + // This is an individual value, so we just pass through + Self::Foo(inner) => inner.is_string(), + // This is a list because it's treated as a list of values + Self::Bar(_, _) => false, + } + } } impl Decodable for FooBar { - fn decode(data: &mut &[u8]) -> Result { - let mut payload = Header::decode_bytes(data, true)?; - match u8::decode(&mut payload)? { - 0 => Ok(Self::Foo(u64::decode(&mut payload)?)), - 1 => Ok(Self::Bar(u16::decode(&mut payload)?, u64::decode(&mut payload)?)), - _ => Err(Error::Custom("unknown type")), + fn expected() -> Expectation { + // No expectation, as the type is data-dependent + Expectation::None + } + + fn decode_fields(data: &mut &[u8]) -> Result { + // Simple strategy: if we correctly decode a `u64`, then we know it's + // a `Foo` variant. Otherwise, we assume it's a `Bar` variant. + let copy = &mut *data; + if let Ok(val) = u64::decode_fields(copy) { + *data = *copy; + return Ok(Self::Foo(val)); } + + Ok(Self::Bar(u16::decode_fields(data)?, u64::decode_fields(data)?)) } }