Skip to content

Commit fc56bb2

Browse files
authored
fix: use better types in HAMT hashing (#1865)
Using a u8 here means we can statically guarantee that no bitfield access will panic due to a bounds error.
1 parent f3b1c72 commit fc56bb2

File tree

3 files changed

+21
-19
lines changed

3 files changed

+21
-19
lines changed

ipld/hamt/src/bitfield.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,20 @@ impl Default for Bitfield {
6969
}
7070

7171
impl Bitfield {
72-
pub fn clear_bit(&mut self, idx: u32) {
72+
pub fn clear_bit(&mut self, idx: u8) {
7373
let ai = idx / 64;
7474
let bi = idx % 64;
7575
self.0[ai as usize] &= u64::MAX - (1 << bi);
7676
}
7777

78-
pub fn test_bit(&self, idx: u32) -> bool {
78+
pub fn test_bit(&self, idx: u8) -> bool {
7979
let ai = idx / 64;
8080
let bi = idx % 64;
8181

8282
self.0[ai as usize] & (1 << bi) != 0
8383
}
8484

85-
pub fn set_bit(&mut self, idx: u32) {
85+
pub fn set_bit(&mut self, idx: u8) {
8686
let ai = idx / 64;
8787
let bi = idx % 64;
8888

@@ -106,14 +106,14 @@ impl Bitfield {
106106
Bitfield([0, 0, 0, 0])
107107
}
108108

109-
pub fn set_bits_le(self, bit: u32) -> Self {
109+
pub fn set_bits_le(self, bit: u8) -> Self {
110110
if bit == 0 {
111111
return self;
112112
}
113113
self.set_bits_leq(bit - 1)
114114
}
115115

116-
pub fn set_bits_leq(mut self, bit: u32) -> Self {
116+
pub fn set_bits_leq(mut self, bit: u8) -> Self {
117117
if bit < 64 {
118118
self.0[0] = set_bits_leq(self.0[0], bit);
119119
} else if bit < 128 {
@@ -135,7 +135,7 @@ impl Bitfield {
135135
}
136136

137137
#[inline]
138-
fn set_bits_leq(v: u64, bit: u32) -> u64 {
138+
fn set_bits_leq(v: u64, bit: u8) -> u64 {
139139
(v as u128 | ((1u128 << (1 + bit)) - 1)) as u64
140140
}
141141

ipld/hamt/src/hash_bits.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ pub struct HashBits<'a> {
1616
pub consumed: u32,
1717
}
1818

19+
// n must be less than 8
1920
#[inline]
20-
pub(crate) fn mkmask(n: u32) -> u32 {
21-
((1u64 << n) - 1) as u32
21+
pub(crate) fn mkmask(n: u32) -> u8 {
22+
((1u16 << n) - 1) as u8
2223
}
2324

2425
impl<'a> HashBits<'a> {
@@ -36,7 +37,7 @@ impl<'a> HashBits<'a> {
3637

3738
/// Returns next `i` bits of the hash and returns the value as an integer and returns
3839
/// Error when maximum depth is reached
39-
pub fn next(&mut self, i: u32) -> Result<u32, Error> {
40+
pub fn next(&mut self, i: u32) -> Result<u8, Error> {
4041
if i > 8 || i == 0 {
4142
return Err(Error::InvalidHashBitLen);
4243
}
@@ -49,11 +50,12 @@ impl<'a> HashBits<'a> {
4950
Ok(self.next_bits(std::cmp::min(i, maxi)))
5051
}
5152

52-
fn next_bits(&mut self, i: u32) -> u32 {
53+
// `i` must be between 1 and 8, inclusive.
54+
fn next_bits(&mut self, i: u32) -> u8 {
5355
let curbi = self.consumed / 8;
5456
let leftb = 8 - (self.consumed % 8);
5557

56-
let curb = self.b[curbi as usize] as u32;
58+
let curb = self.b[curbi as usize];
5759
match i.cmp(&leftb) {
5860
Ordering::Equal => {
5961
// bits to consume is equal to the bits remaining in the currently indexed byte
@@ -71,11 +73,11 @@ impl<'a> HashBits<'a> {
7173
}
7274
Ordering::Greater => {
7375
// Consumes remaining bits and remaining bits from a recursive call
74-
let mut out = (mkmask(leftb) & curb) as u64;
76+
let mut out = mkmask(leftb) & curb;
7577
out <<= i - leftb;
7678
self.consumed += leftb;
77-
out += self.next_bits(i - leftb) as u64;
78-
out as u32
79+
out += self.next_bits(i - leftb);
80+
out
7981
}
8082
}
8183
}

ipld/hamt/src/node.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -478,18 +478,18 @@ where
478478
Ok(())
479479
}
480480

481-
fn rm_child(&mut self, i: usize, idx: u32) -> Pointer<K, V, H, Ver> {
481+
fn rm_child(&mut self, i: usize, idx: u8) -> Pointer<K, V, H, Ver> {
482482
self.bitfield.clear_bit(idx);
483483
self.pointers.remove(i)
484484
}
485485

486-
fn insert_child(&mut self, idx: u32, key: K, value: V) {
486+
fn insert_child(&mut self, idx: u8, key: K, value: V) {
487487
let i = self.index_for_bit_pos(idx);
488488
self.bitfield.set_bit(idx);
489489
self.pointers.insert(i, Pointer::from_key_value(key, value))
490490
}
491491

492-
fn insert_child_dirty(&mut self, idx: u32, node: Box<Node<K, V, H, Ver>>) {
492+
fn insert_child_dirty(&mut self, idx: u8, node: Box<Node<K, V, H, Ver>>) {
493493
let i = self.index_for_bit_pos(idx);
494494
self.bitfield.set_bit(idx);
495495
self.pointers.insert(i, Pointer::Dirty(node))
@@ -517,9 +517,9 @@ where
517517
}
518518

519519
impl<K, V, H, Ver> Node<K, V, H, Ver> {
520-
pub(crate) fn index_for_bit_pos(&self, bp: u32) -> usize {
520+
pub(crate) fn index_for_bit_pos(&self, bp: u8) -> usize {
521521
let mask = Bitfield::zero().set_bits_le(bp);
522-
assert_eq!(mask.count_ones(), bp as usize);
522+
debug_assert_eq!(mask.count_ones(), bp as usize);
523523
mask.and(&self.bitfield).count_ones()
524524
}
525525
}

0 commit comments

Comments
 (0)