Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
62f785b
Optimise FixedVector Decode
michaelsproul Aug 28, 2025
b729e99
Add encoding benchmarks
michaelsproul Sep 4, 2025
2f52ad0
u8 encoding benchmark
michaelsproul Sep 4, 2025
56fd2f8
Update benchmarks to include encoding, optimise VariableList
michaelsproul Sep 4, 2025
31aee18
Remove path patch
michaelsproul Sep 4, 2025
b042216
Remove unnecessary pub
michaelsproul Sep 4, 2025
37e9308
Fix trailing bytes bug and add more tests
michaelsproul Sep 8, 2025
2eecf7a
Remove junk
michaelsproul Sep 8, 2025
902121a
Merge remote-tracking branch 'origin/main' into optimise-decode
michaelsproul Sep 8, 2025
6bca389
More tests for oversize FixedVector
michaelsproul Sep 8, 2025
7f2505b
Remove temp ByteVector from benches
michaelsproul Sep 8, 2025
da45efa
Fix tests
michaelsproul Sep 8, 2025
d87e180
Merge remote-tracking branch 'origin/main' into optimise-decode
michaelsproul Sep 8, 2025
7c89206
Merge branch 'main' into optimise-decode
michaelsproul Oct 22, 2025
8784c47
Merge remote-tracking branch 'origin/main' into optimise-decode
michaelsproul Oct 28, 2025
f022441
Update comments
michaelsproul Oct 28, 2025
29d5205
Fix Clippy
michaelsproul Oct 28, 2025
7635d55
Test bool on unsafe codepath
michaelsproul Oct 28, 2025
50e6061
Add test demonstrating bool UB
michaelsproul Oct 29, 2025
72b60cc
Fix UB by using TypeId
michaelsproul Oct 29, 2025
c2f7140
Fix test name
michaelsproul Nov 10, 2025
c19c971
Merge branch 'main' into optimise-decode
michaelsproul Nov 10, 2025
59e79b5
Merge branch 'main' into optimise-decode
michaelsproul Nov 11, 2025
dad52cb
Merge branch 'main' into optimise-decode
michaelsproul Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 86 additions & 9 deletions src/fixed_vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use crate::tree_hash::vec_tree_hash_root;
use crate::Error;
use serde::Deserialize;
use serde_derive::Serialize;
use std::any::TypeId;
use std::marker::PhantomData;
use std::mem;
use std::ops::{Deref, DerefMut, Index, IndexMut};
use std::slice::SliceIndex;
use tree_hash::Hash256;
Expand Down Expand Up @@ -283,7 +285,7 @@ impl<T, N: Unsigned> ssz::TryFromIter<T> for FixedVector<T, N> {

impl<T, N: Unsigned> ssz::Decode for FixedVector<T, N>
where
T: ssz::Decode,
T: ssz::Decode + 'static,
{
fn is_ssz_fixed_len() -> bool {
T::is_ssz_fixed_len()
Expand All @@ -305,6 +307,24 @@ where
len: 0,
expected: 1,
})
} else if TypeId::of::<T>() == TypeId::of::<u8>() {
if bytes.len() != fixed_len {
return Err(ssz::DecodeError::BytesInvalid(format!(
"FixedVector of {} items has {} items",
fixed_len,
bytes.len(),
)));
}

// Safety: We've verified T is u8, so Vec<T> *is* Vec<u8>.
let vec_u8 = bytes.to_vec();
let vec_t = unsafe { mem::transmute::<Vec<u8>, Vec<T>>(vec_u8) };
Self::new(vec_t).map_err(|e| {
ssz::DecodeError::BytesInvalid(format!(
"Wrong number of FixedVector elements: {:?}",
e
))
})
} else if T::is_ssz_fixed_len() {
let num_items = bytes
.len()
Expand All @@ -314,17 +334,24 @@ where
if num_items != fixed_len {
return Err(ssz::DecodeError::BytesInvalid(format!(
"FixedVector of {} items has {} items",
num_items, fixed_len
fixed_len, num_items
)));
}

let vec = bytes.chunks(T::ssz_fixed_len()).try_fold(
Vec::with_capacity(num_items),
|mut vec, chunk| {
vec.push(T::from_ssz_bytes(chunk)?);
Ok(vec)
},
)?;
// Check that we have a whole number of items and that it is safe to use chunks_exact
if !bytes.len().is_multiple_of(T::ssz_fixed_len()) {
return Err(ssz::DecodeError::BytesInvalid(format!(
"FixedVector of {} items has {} bytes",
num_items,
bytes.len()
)));
}

let mut vec = Vec::with_capacity(num_items);
for chunk in bytes.chunks_exact(T::ssz_fixed_len()) {
vec.push(T::from_ssz_bytes(chunk)?);
}

Self::new(vec).map_err(|e| {
ssz::DecodeError::BytesInvalid(format!(
"Wrong number of FixedVector elements: {:?}",
Expand Down Expand Up @@ -479,6 +506,56 @@ mod test {
ssz_round_trip::<FixedVector<u16, U8>>(vec![0; 8].try_into().unwrap());
}

// Test byte decoding (we have a specialised code path with unsafe code that NEEDS coverage).
#[test]
fn ssz_round_trip_u8_len_1024() {
ssz_round_trip::<FixedVector<u8, U1024>>(vec![42; 1024].try_into().unwrap());
ssz_round_trip::<FixedVector<u8, U1024>>(vec![0; 1024].try_into().unwrap());
}

// bool is layout equivalent to u8 but must not use the same unsafe codepath because not all u8
// values are valid bools.
#[test]
fn ssz_round_trip_bool_len_1024() {
assert_eq!(mem::size_of::<bool>(), 1);
assert_eq!(mem::align_of::<bool>(), 1);
ssz_round_trip::<FixedVector<bool, U1024>>(vec![true; 1024].try_into().unwrap());
ssz_round_trip::<FixedVector<bool, U1024>>(vec![false; 1024].try_into().unwrap());
}

// Decoding a u8 vector as a vector of bools must fail, if we aren't careful we could trigger UB.
#[test]
fn ssz_u8_to_bool_len_1024() {
let list_u8 = FixedVector::<u8, U8>::new(vec![0, 1, 2, 3, 4, 5, 6, 7]).unwrap();
FixedVector::<bool, U8>::from_ssz_bytes(&list_u8.as_ssz_bytes()).unwrap_err();
}

#[test]
fn ssz_u8_len_1024_too_long() {
assert_eq!(
FixedVector::<u8, U1024>::from_ssz_bytes(&vec![42; 1025]).unwrap_err(),
ssz::DecodeError::BytesInvalid("FixedVector of 1024 items has 1025 items".into())
);
}

#[test]
fn ssz_u64_len_1024_too_long() {
assert_eq!(
FixedVector::<u64, U1024>::from_ssz_bytes(&vec![42; 8 * 1025]).unwrap_err(),
ssz::DecodeError::BytesInvalid("FixedVector of 1024 items has 1025 items".into())
);
}

// Decoding an input with invalid trailing bytes MUST fail.
#[test]
fn ssz_bytes_u32_trailing() {
let bytes = [1, 0, 0, 0, 2, 0, 0, 0, 1];
assert_eq!(
FixedVector::<u32, U2>::from_ssz_bytes(&bytes).unwrap_err(),
ssz::DecodeError::BytesInvalid("FixedVector of 2 items has 9 bytes".into())
);
}

#[test]
fn tree_hash_u8() {
let fixed: FixedVector<u8, U0> = FixedVector::try_from(vec![]).unwrap();
Expand Down
122 changes: 105 additions & 17 deletions src/variable_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use crate::tree_hash::vec_tree_hash_root;
use crate::Error;
use serde::Deserialize;
use serde_derive::Serialize;
use std::any::TypeId;
use std::marker::PhantomData;
use std::mem;
use std::ops::{Deref, DerefMut, Index, IndexMut};
use std::slice::SliceIndex;
use tree_hash::Hash256;
Expand Down Expand Up @@ -288,7 +290,7 @@ impl<T, N: Unsigned> ssz::TryFromIter<T> for VariableList<T, N> {

impl<T, N> ssz::Decode for VariableList<T, N>
where
T: ssz::Decode,
T: ssz::Decode + 'static,
N: Unsigned,
{
fn is_ssz_fixed_len() -> bool {
Expand All @@ -302,6 +304,26 @@ where
return Ok(Self::default());
}

if TypeId::of::<T>() == TypeId::of::<u8>() {
if bytes.len() > max_len {
return Err(ssz::DecodeError::BytesInvalid(format!(
"VariableList of {} items exceeds maximum of {}",
bytes.len(),
max_len
)));
}

// Safety: We've verified T is u8, so Vec<T> *is* Vec<u8>.
let vec_u8 = bytes.to_vec();
let vec_t = unsafe { mem::transmute::<Vec<u8>, Vec<T>>(vec_u8) };
return Self::new(vec_t).map_err(|e| {
ssz::DecodeError::BytesInvalid(format!(
"Wrong number of VariableList elements: {:?}",
e
))
});
}

if T::is_ssz_fixed_len() {
let num_items = bytes
.len()
Expand All @@ -315,20 +337,28 @@ where
)));
}

bytes.chunks(T::ssz_fixed_len()).try_fold(
Vec::with_capacity(num_items),
|mut vec, chunk| {
vec.push(T::from_ssz_bytes(chunk)?);
Ok(vec)
},
)
// Check that we have a whole number of items and that it is safe to use chunks_exact
if !bytes.len().is_multiple_of(T::ssz_fixed_len()) {
return Err(ssz::DecodeError::BytesInvalid(format!(
"VariableList of {} items has {} bytes",
num_items,
bytes.len()
)));
}

let mut vec = Vec::with_capacity(num_items);
for chunk in bytes.chunks_exact(T::ssz_fixed_len()) {
vec.push(T::from_ssz_bytes(chunk)?);
}
Self::new(vec).map_err(|e| {
ssz::DecodeError::BytesInvalid(format!(
"Wrong number of VariableList elements: {:?}",
e
))
})
} else {
ssz::decode_list_of_variable_length_items(bytes, Some(max_len))
}?
.try_into()
.map_err(|e| {
ssz::DecodeError::BytesInvalid(format!("VariableList::try_from failed: {e:?}"))
})
}
}
}

Expand Down Expand Up @@ -452,17 +482,60 @@ mod test {
assert_eq!(<VariableList<u16, U2> as Encode>::ssz_fixed_len(), 4);
}

fn round_trip<T: Encode + Decode + std::fmt::Debug + PartialEq>(item: T) {
fn ssz_round_trip<T: Encode + Decode + std::fmt::Debug + PartialEq>(item: T) {
let encoded = &item.as_ssz_bytes();
assert_eq!(item.ssz_bytes_len(), encoded.len());
assert_eq!(T::from_ssz_bytes(encoded), Ok(item));
}

#[test]
fn u16_len_8() {
round_trip::<VariableList<u16, U8>>(vec![42; 8].try_into().unwrap());
round_trip::<VariableList<u16, U8>>(vec![0; 8].try_into().unwrap());
round_trip::<VariableList<u16, U8>>(vec![].try_into().unwrap());
ssz_round_trip::<VariableList<u16, U8>>(vec![42; 8].try_into().unwrap());
ssz_round_trip::<VariableList<u16, U8>>(vec![0; 8].try_into().unwrap());
ssz_round_trip::<VariableList<u16, U8>>(vec![].try_into().unwrap());
}

#[test]
fn ssz_round_trip_u8_len_1024() {
ssz_round_trip::<VariableList<u8, U1024>>(vec![42; 1024].try_into().unwrap());
ssz_round_trip::<VariableList<u8, U1024>>(vec![0; 1024].try_into().unwrap());
}

// bool is layout equivalent to u8 but must not use the same unsafe codepath because not all u8
// values are valid bools.
#[test]
fn ssz_round_trip_bool_len_1024() {
assert_eq!(mem::size_of::<bool>(), 1);
assert_eq!(mem::align_of::<bool>(), 1);
ssz_round_trip::<VariableList<bool, U1024>>(vec![true; 1024].try_into().unwrap());
ssz_round_trip::<VariableList<bool, U1024>>(vec![false; 1024].try_into().unwrap());
}

// Decoding a u8 list as a list of bools must fail, if we aren't careful we could trigger UB.
#[test]
fn ssz_u8_to_bool_len_1024() {
let list_u8 = VariableList::<u8, U8>::new(vec![0, 1, 2, 3, 4, 5, 6, 7]).unwrap();
VariableList::<bool, U8>::from_ssz_bytes(&list_u8.as_ssz_bytes()).unwrap_err();
}

#[test]
fn ssz_u8_len_1024_too_long() {
assert_eq!(
VariableList::<u8, U1024>::from_ssz_bytes(&vec![42; 1025]).unwrap_err(),
ssz::DecodeError::BytesInvalid(
"VariableList of 1025 items exceeds maximum of 1024".into()
)
);
}

#[test]
fn ssz_u64_len_1024_too_long() {
assert_eq!(
VariableList::<u64, U1024>::from_ssz_bytes(&vec![42; 8 * 1025]).unwrap_err(),
ssz::DecodeError::BytesInvalid(
"VariableList of 1025 items exceeds maximum of 1024".into()
)
);
}

#[test]
Expand All @@ -473,6 +546,21 @@ mod test {
assert_eq!(VariableList::from_ssz_bytes(&[]).unwrap(), empty_list);
}

#[test]
fn ssz_bytes_u32_trailing() {
let bytes = [1, 0, 0, 0, 2, 0];
assert_eq!(
VariableList::<u32, U2>::from_ssz_bytes(&bytes).unwrap_err(),
ssz::DecodeError::BytesInvalid("VariableList of 1 items has 6 bytes".into())
);

let bytes = [1, 0, 0, 0, 2, 0, 0, 0, 3];
assert_eq!(
VariableList::<u32, U2>::from_ssz_bytes(&bytes).unwrap_err(),
ssz::DecodeError::BytesInvalid("VariableList of 2 items has 9 bytes".into())
);
}

fn root_with_length(bytes: &[u8], len: usize) -> Hash256 {
let root = merkle_root(bytes, 0);
tree_hash::mix_in_length(&root, len)
Expand Down
Loading