diff --git a/mast/src/node.rs b/mast/src/node.rs index 07c3aab..f579db4 100644 --- a/mast/src/node.rs +++ b/mast/src/node.rs @@ -36,10 +36,22 @@ enum RefCountDiff { } impl Node { + pub(crate) fn new(key: &[u8], value: &[u8]) -> Self { + Self { + key: key.into(), + value: value.into(), + left: None, + right: None, + + ref_count: 0, + } + } + pub(crate) fn open( table: &'_ impl ReadableTable<&'static [u8], (u64, &'static [u8])>, hash: Hash, ) -> Option { + // TODO: make it Result instead! let existing = table.get(hash.as_bytes().as_slice()).unwrap(); existing.map(|existing| { @@ -53,35 +65,6 @@ impl Node { }) } - pub(crate) fn insert( - table: &mut Table<&[u8], (u64, &[u8])>, - key: &[u8], - value: &[u8], - left: Option, - right: Option, - ) -> Hash { - let node = Self { - key: key.into(), - value: value.into(), - left, - right, - - ref_count: 1, - }; - - let encoded = node.canonical_encode(); - let hash = hash(&encoded); - - table - .insert( - hash.as_bytes().as_slice(), - (node.ref_count, encoded.as_slice()), - ) - .unwrap(); - - hash - } - // === Getters === pub fn key(&self) -> &[u8] { @@ -100,98 +83,57 @@ impl Node { &self.right } - pub(crate) fn ref_count(&self) -> &u64 { - &self.ref_count + pub fn rank(&self) -> Hash { + hash(self.key()) } - // === Public Methods === - - pub fn rank(&self) -> Hash { - hash(&self.key) + pub(crate) fn ref_count(&self) -> &u64 { + &self.ref_count } /// Returns the hash of the node. pub fn hash(&self) -> Hash { - hash(&self.canonical_encode()) + let encoded = self.canonical_encode(); + hash(&encoded) } - /// Set the value and save the updated node. - pub(crate) fn set_value( - &mut self, - table: &mut Table<&[u8], (u64, &[u8])>, - value: &[u8], - ) -> Hash { + // === Private Methods === + + /// Set the value. + pub(crate) fn set_value(&mut self, value: &[u8]) -> &mut Self { self.value = value.into(); - self.save(table) + self } /// Set the left child, save the updated node, and return the new hash. - pub(crate) fn set_left_child( - &mut self, - table: &mut Table<&[u8], (u64, &[u8])>, - child: Option, - ) -> Hash { - self.set_child(table, Branch::Left, child) + pub(crate) fn set_left_child(&mut self, child: Option<&mut Node>) -> &mut Self { + self.set_child(Branch::Left, child) } /// Set the right child, save the updated node, and return the new hash. - pub(crate) fn set_right_child( - &mut self, - table: &mut Table<&[u8], (u64, &[u8])>, - child: Option, - ) -> Hash { - self.set_child(table, Branch::Right, child) + pub(crate) fn set_right_child(&mut self, child: Option<&mut Node>) -> &mut Self { + self.set_child(Branch::Right, child) } - // === Private Methods === - - pub fn decrement_ref_count(&self, table: &mut Table<&[u8], (u64, &[u8])>) { - self.update_ref_count(table, RefCountDiff::Decrement) - } - - fn set_child( - &mut self, - table: &mut Table<&[u8], (u64, &[u8])>, - branch: Branch, - child: Option, - ) -> Hash { + /// Set the child, update its ref_count, save the updated node and return it. + fn set_child(&mut self, branch: Branch, new_child: Option<&mut Node>) -> &mut Self { match branch { - Branch::Left => self.left = child, - Branch::Right => self.right = child, - } - - let encoded = self.canonical_encode(); - let hash = hash(&encoded); - - table - .insert( - hash.as_bytes().as_slice(), - (self.ref_count, encoded.as_slice()), - ) - .unwrap(); + Branch::Left => self.left = new_child.as_ref().map(|n| n.hash()), + Branch::Right => self.right = new_child.as_ref().map(|n| n.hash()), + }; - hash + self } - fn save(&mut self, table: &mut Table<&[u8], (u64, &[u8])>) -> Hash { - let encoded = self.canonical_encode(); - let hash = hash(&encoded); - - table - .insert( - hash.as_bytes().as_slice(), - (self.ref_count, encoded.as_slice()), - ) - .unwrap(); - - hash + pub(crate) fn increment_ref_count(&mut self) -> &mut Self { + self.update_ref_count(RefCountDiff::Increment) } - fn increment_ref_count(&self, table: &mut Table<&[u8], (u64, &[u8])>) { - self.update_ref_count(table, RefCountDiff::Increment) + pub(crate) fn decrement_ref_count(&mut self) -> &mut Self { + self.update_ref_count(RefCountDiff::Decrement) } - fn update_ref_count(&self, table: &mut Table<&[u8], (u64, &[u8])>, diff: RefCountDiff) { + fn update_ref_count(&mut self, diff: RefCountDiff) -> &mut Self { let ref_count = match diff { RefCountDiff::Increment => self.ref_count + 1, RefCountDiff::Decrement => { @@ -203,14 +145,23 @@ impl Node { } }; - let bytes = self.canonical_encode(); - let hash = hash(&bytes); + // We only updaet the ref count, and handle deletion elsewhere. + self.ref_count = ref_count; + self + } - match ref_count { - 0 => table.remove(hash.as_bytes().as_slice()), - _ => table.insert(hash.as_bytes().as_slice(), (ref_count, bytes.as_slice())), - } - .unwrap(); + pub(crate) fn save(&mut self, table: &mut Table<&[u8], (u64, &[u8])>) -> &mut Self { + // TODO: keep data in encoded in a bytes field. + let encoded = self.canonical_encode(); + + table + .insert( + hash(&encoded).as_bytes().as_slice(), + (self.ref_count, encoded.as_slice()), + ) + .unwrap(); + + self } fn canonical_encode(&self) -> Vec { @@ -232,8 +183,11 @@ impl Node { } } -pub(crate) fn rank(key: &[u8]) -> Hash { - hash(key) +pub(crate) fn hash(bytes: &[u8]) -> Hash { + let mut hasher = Hasher::new(); + hasher.update(bytes); + + hasher.finalize() } fn encode(bytes: &[u8], out: &mut Vec) { @@ -255,14 +209,7 @@ fn decode(bytes: &[u8]) -> (&[u8], &[u8]) { (value, rest) } -fn hash(bytes: &[u8]) -> Hash { - let mut hasher = Hasher::new(); - hasher.update(bytes); - - hasher.finalize() -} - -pub fn decode_node(data: (u64, &[u8])) -> Node { +fn decode_node(data: (u64, &[u8])) -> Node { let (ref_count, encoded_node) = data; let (key, rest) = decode(encoded_node); diff --git a/mast/src/operations/insert.rs b/mast/src/operations/insert.rs index 67ff10c..8a689e4 100644 --- a/mast/src/operations/insert.rs +++ b/mast/src/operations/insert.rs @@ -1,7 +1,6 @@ use std::cmp::Ordering; -use crate::node::{rank, Branch, Node}; -use blake3::Hash; +use crate::node::{hash, Branch, Node}; use redb::Table; // Watch this [video](https://youtu.be/NxRXhBur6Xs?si=GNwaUOfuGwr_tBKI&t=1763) for a good explanation of the unzipping algorithm. @@ -82,38 +81,86 @@ use redb::Table; // all then new nodes (in both the upper and lower paths) before comitting the write transaction. pub fn insert( - table: &'_ mut Table<&'static [u8], (u64, &'static [u8])>, - root: Option, + nodes_table: &'_ mut Table<&'static [u8], (u64, &'static [u8])>, + root: Option, key: &[u8], value: &[u8], -) -> Hash { - let mut path = binary_search_path(table, root, key); +) -> Node { + let mut path = binary_search_path(nodes_table, root, key); - let mut unzip_left_root: Option = None; - let mut unzip_right_root: Option = None; + let mut unzip_left_root: Option<&mut Node> = None; + let mut unzip_right_root: Option<&mut Node> = None; + // Unzip the lower path to get left and right children of the inserted node. for (node, branch) in path.unzip_path.iter_mut().rev() { + node.decrement_ref_count().save(nodes_table); + match branch { - Branch::Right => unzip_left_root = Some(node.set_right_child(table, unzip_left_root)), - Branch::Left => unzip_right_root = Some(node.set_left_child(table, unzip_right_root)), + Branch::Right => { + node.set_right_child(unzip_left_root) + .increment_ref_count() + .save(nodes_table); + + unzip_left_root = Some(node); + } + Branch::Left => { + node.set_left_child(unzip_right_root) + .increment_ref_count() + .save(nodes_table); + + unzip_right_root = Some(node); + } } } - let mut root = if let Some(mut existing) = path.existing { - existing.set_value(table, value) + let mut root = path.existing; + + if let Some(mut existing) = root { + if existing.value() == value { + // There is really nothing to update. Skip traversing upwards. + return path.upper_path.pop().map(|(n, _)| n).unwrap_or(existing); + } + + existing.decrement_ref_count().save(nodes_table); + + // Else, update the value and rehashe the node so that we can update the hashes upwards. + existing + .set_value(value) + .increment_ref_count() + .save(nodes_table); + + root = Some(existing) } else { - Node::insert(table, key, value, unzip_left_root, unzip_right_root) + // Insert the new node. + let mut node = Node::new(key, value); + + // TODO: we do hash the node twice here, can we do better? + node.set_left_child(unzip_left_root) + .set_right_child(unzip_right_root) + .increment_ref_count() + .save(nodes_table); + + root = Some(node); }; - for (node, branch) in path.upper_path.iter_mut().rev() { + let mut upper_path = path.upper_path; + + // Propagate the new hashes upwards if there are any nodes in the upper_path. + while let Some((mut node, branch)) = upper_path.pop() { + node.decrement_ref_count().save(nodes_table); + match branch { - Branch::Left => root = node.set_left_child(table, Some(root)), - Branch::Right => root = node.set_right_child(table, Some(root)), - } + Branch::Left => node.set_left_child(root.as_mut()), + Branch::Right => node.set_right_child(root.as_mut()), + }; + + node.increment_ref_count().save(nodes_table); + + root = Some(node); } - // Finally return the new root to be committed. - root + // Finally return the new root to be set to the root. + root.expect("Root should be set by now") } #[derive(Debug)] @@ -130,11 +177,11 @@ struct BinarySearchPath { /// /// If a match was found, the `lower_path` will be empty. fn binary_search_path( - table: &'_ mut Table<&'static [u8], (u64, &'static [u8])>, - root: Option, + table: &Table<&'static [u8], (u64, &'static [u8])>, + root: Option, key: &[u8], ) -> BinarySearchPath { - let rank = rank(key); + let rank = hash(key); let mut result = BinarySearchPath { upper_path: Default::default(), @@ -142,43 +189,182 @@ fn binary_search_path( unzip_path: Default::default(), }; - let mut previous_hash = root; - - while let Some(current_hash) = previous_hash { - let current_node = Node::open(table, current_hash).expect("Node not found!"); + let mut next = root; - // Decrement each node in the binary search path. - // if it doesn't change, we will increment it again later. - // - // It is important then to terminate the loop if we found an exact match, - // as lower nodes shouldn't change then. - current_node.decrement_ref_count(table); - - let path = if current_node.rank().as_bytes() > rank.as_bytes() { + while let Some(current) = next { + let path = if current.rank().as_bytes() > rank.as_bytes() { &mut result.upper_path } else { &mut result.unzip_path }; - match key.cmp(current_node.key()) { + match key.cmp(current.key()) { Ordering::Equal => { // We found exact match. terminate the search. - result.existing = Some(current_node); + result.existing = Some(current); return result; } Ordering::Less => { - previous_hash = *current_node.left(); + next = current.left().and_then(|n| Node::open(table, n)); - path.push((current_node, Branch::Left)); + path.push((current, Branch::Left)); } Ordering::Greater => { - previous_hash = *current_node.right(); + next = current.right().and_then(|n| Node::open(table, n)); - path.push((current_node, Branch::Right)); + path.push((current, Branch::Right)); } }; } result } + +#[cfg(test)] +mod test { + use crate::test::{test_operations, Entry, Operation}; + + #[test] + fn insert_single_entry() { + let case = ["A"]; + + let expected = case.map(|key| Entry { + key: key.as_bytes().to_vec(), + value: [b"v", key.as_bytes()].concat(), + }); + + test_operations( + &expected.clone().map(|e| (e, Operation::Insert)), + &expected, + Some("78fd7507ef338f1a5816ffd702394999680a9694a85f4b8af77795d9fdd5854d"), + ) + } + + #[test] + fn sorted_alphabets() { + let case = [ + "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", + "R", "S", "T", "U", "V", "W", "X", "Y", "Z", + ]; + + let expected = case.map(|key| Entry { + key: key.as_bytes().to_vec(), + value: [b"v", key.as_bytes()].concat(), + }); + + test_operations( + &expected.clone().map(|e| (e, Operation::Insert)), + &expected, + Some("02af3de6ed6368c5abc16f231a17d1140e7bfec483c8d0aa63af4ef744d29bc3"), + ); + } + + #[test] + fn reverse_alphabets() { + let mut case = [ + "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", + "R", "S", "T", "U", "V", "W", "X", "Y", "Z", + ]; + + case.reverse(); + + let expected = case.map(|key| Entry { + key: key.as_bytes().to_vec(), + value: [b"v", key.as_bytes()].concat(), + }); + + test_operations( + &expected.clone().map(|e| (e, Operation::Insert)), + &expected, + Some("02af3de6ed6368c5abc16f231a17d1140e7bfec483c8d0aa63af4ef744d29bc3"), + ) + } + + #[test] + fn unsorted() { + let case = ["D", "N", "P", "X", "A", "G", "C", "M", "H", "I", "J"]; + + let expected = case.map(|key| Entry { + key: key.as_bytes().to_vec(), + value: [b"v", key.as_bytes()].concat(), + }); + + test_operations( + &expected.clone().map(|e| (e, Operation::Insert)), + &expected, + Some("0957cc9b87c11cef6d88a95328cfd9043a3d6a99e9ba35ee5c9c47e53fb6d42b"), + ) + } + + #[test] + fn upsert_at_root() { + let case = ["X", "X"]; + + let mut i = 0; + + let entries = case.map(|key| { + i += 1; + Entry { + key: key.as_bytes().to_vec(), + value: i.to_string().into(), + } + }); + + test_operations( + &entries.clone().map(|e| (e, Operation::Insert)), + &entries[1..], + Some("4538b4de5e58f9be9d54541e69fab8c94c31553a1dec579227ef9b572d1c1dff"), + ) + } + + #[test] + fn upsert_deeper() { + // X has higher rank. + let case = ["X", "F", "F"]; + + let mut i = 0; + + let entries = case.map(|key| { + i += 1; + Entry { + key: key.as_bytes().to_vec(), + value: i.to_string().into(), + } + }); + + let mut expected = entries.to_vec(); + expected.sort_by(|a, b| a.key.cmp(&b.key)); + + test_operations( + &entries.clone().map(|e| (e, Operation::Insert)), + &expected[1..], + Some("c9f7aaefb18ec8569322b9621fc64f430a7389a790e0bf69ec0ad02879d6ce54"), + ) + } + + #[test] + fn upsert_root_with_children() { + // X has higher rank. + let case = ["F", "X", "X"]; + + let mut i = 0; + + let entries = case.map(|key| { + i += 1; + Entry { + key: key.as_bytes().to_vec(), + value: i.to_string().into(), + } + }); + + let mut expected = entries.to_vec(); + expected.remove(1); + + test_operations( + &entries.clone().map(|e| (e, Operation::Insert)), + &expected, + Some("02e26311f2b55bf6d4a7163399f99e17c975891a05af2f1e09bc969f8bf0f95d"), + ) + } +} diff --git a/mast/src/test.rs b/mast/src/test.rs index 2fa7dd0..0088cc5 100644 --- a/mast/src/test.rs +++ b/mast/src/test.rs @@ -7,162 +7,18 @@ use crate::Hash; use redb::backends::InMemoryBackend; use redb::Database; -#[test] -fn cases() { - let sorted_alphabets = [ - "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", - "S", "T", "U", "V", "W", "X", "Y", "Z", - ] - .map(|key| Entry { - key: key.as_bytes().to_vec(), - value: [b"v", key.as_bytes()].concat(), - }); - - let mut reverse_alphabets = sorted_alphabets.clone(); - reverse_alphabets.reverse(); - - let unsorted = ["D", "N", "P", "X", "A", "G", "C", "M", "H", "I", "J"].map(|key| Entry { - key: key.as_bytes().to_vec(), - value: [b"v", key.as_bytes()].concat(), - }); - - let single_entry = ["X"].map(|key| Entry { - key: key.as_bytes().to_vec(), - value: [b"v", key.as_bytes()].concat(), - }); - - let upsert_at_root = ["X", "X"] - .iter() - .enumerate() - .map(|(i, _)| { - ( - Entry { - key: b"X".to_vec(), - value: i.to_string().into(), - }, - Operation::Insert, - ) - }) - .collect::>(); - - // X has higher rank. - let upsert_deeper = ["X", "F", "F"] - .iter() - .enumerate() - .map(|(i, key)| { - ( - Entry { - key: key.as_bytes().to_vec(), - value: i.to_string().into(), - }, - Operation::Insert, - ) - }) - .collect::>(); - - let mut upsert_deeper_expected = upsert_deeper.clone(); - upsert_deeper_expected.remove(upsert_deeper.len() - 2); - - // X has higher rank. - let upsert_root_with_children = ["F", "X", "X"] - .iter() - .enumerate() - .map(|(i, key)| { - ( - Entry { - key: key.as_bytes().to_vec(), - value: i.to_string().into(), - }, - Operation::Insert, - ) - }) - .collect::>(); - - let mut upsert_root_with_children_expected = upsert_root_with_children.clone(); - upsert_root_with_children_expected.remove(upsert_root_with_children.len() - 2); - - let cases = [ - ( - "sorted alphabets", - sorted_alphabets - .clone() - .map(|e| (e, Operation::Insert)) - .to_vec(), - sorted_alphabets.to_vec(), - Some("02af3de6ed6368c5abc16f231a17d1140e7bfec483c8d0aa63af4ef744d29bc3"), - ), - ( - "reversed alphabets", - sorted_alphabets - .clone() - .map(|e| (e, Operation::Insert)) - .to_vec(), - sorted_alphabets.to_vec(), - Some("02af3de6ed6368c5abc16f231a17d1140e7bfec483c8d0aa63af4ef744d29bc3"), - ), - ( - "unsorted alphabets", - unsorted.clone().map(|e| (e, Operation::Insert)).to_vec(), - unsorted.to_vec(), - Some("0957cc9b87c11cef6d88a95328cfd9043a3d6a99e9ba35ee5c9c47e53fb6d42b"), - ), - ( - "Single insert", - single_entry - .clone() - .map(|e| (e, Operation::Insert)) - .to_vec(), - single_entry.to_vec(), - Some("b3e862d316e6f5caca72c8f91b7a15015b4f7f8f970c2731433aad793f7fe3e6"), - ), - ( - "upsert at root without children", - upsert_at_root.clone(), - upsert_at_root[1..] - .iter() - .map(|(e, _)| e.clone()) - .collect::>(), - Some("b1353174e730b9ff6850577357fd9ff608071bbab46ebe72c434133f5d4f0383"), - ), - ( - "upsert deeper", - upsert_deeper.to_vec(), - upsert_deeper_expected - .to_vec() - .iter() - .map(|(e, _)| e.clone()) - .collect::>(), - Some("58272c9e8c9e6b7266e4b60e45d55257b94e85561997f1706e0891ee542a8cd5"), - ), - ( - "upsert at root with children", - upsert_root_with_children.to_vec(), - upsert_root_with_children_expected - .to_vec() - .iter() - .map(|(e, _)| e.clone()) - .collect::>(), - Some("f46daf022dc852cd4e60a98a33de213f593e17bcd234d9abff7a178d8a5d0761"), - ), - ]; - - for case in cases { - test(case.0, &case.1, &case.2, case.3); - } -} - // === Helpers === #[derive(Clone, Debug)] -enum Operation { +pub enum Operation { Insert, Delete, } #[derive(Clone, PartialEq)] -struct Entry { - key: Vec, - value: Vec, +pub struct Entry { + pub(crate) key: Vec, + pub(crate) value: Vec, } impl std::fmt::Debug for Entry { @@ -171,7 +27,7 @@ impl std::fmt::Debug for Entry { } } -fn test(name: &str, input: &[(Entry, Operation)], expected: &[Entry], root_hash: Option<&str>) { +pub fn test_operations(input: &[(Entry, Operation)], expected: &[Entry], root_hash: Option<&str>) { let inmemory = InMemoryBackend::new(); let db = Database::builder() .create_with_backend(inmemory) @@ -184,18 +40,20 @@ fn test(name: &str, input: &[(Entry, Operation)], expected: &[Entry], root_hash: Operation::Insert => treap.insert(&entry.key, &entry.value), Operation::Delete => todo!(), } - println!( - "{:?} {:?}\n{}", - &entry.key, - &entry.value, - into_mermaid_graph(&treap) - ); } + // Uncomment to see the graph (only if values are utf8) + // println!("{}", into_mermaid_graph(&treap)); + let collected = treap .iter() .map(|n| { - assert_eq!(*n.ref_count(), 1_u64, "Node has wrong ref count"); + assert_eq!( + *n.ref_count(), + 1_u64, + "{}", + format!("Node has wrong ref count {:?}", n) + ); Entry { key: n.key().to_vec(), @@ -207,22 +65,13 @@ fn test(name: &str, input: &[(Entry, Operation)], expected: &[Entry], root_hash: let mut sorted = expected.to_vec(); sorted.sort_by(|a, b| a.key.cmp(&b.key)); - // println!("{}", into_mermaid_graph(&treap)); + verify_ranks(&treap); + + assert_eq!(collected, sorted, "{}", format!("Entries do not match")); if root_hash.is_some() { assert_root(&treap, root_hash.unwrap()); - } else { - dbg!(&treap.root_hash()); - - verify_ranks(&treap); } - - assert_eq!( - collected, - sorted, - "{}", - format!("Entries do not match at: \"{}\"", name) - ); } /// Verify that every node has higher rank than its children. diff --git a/mast/src/treap.rs b/mast/src/treap.rs index 917e6d0..f2db89c 100644 --- a/mast/src/treap.rs +++ b/mast/src/treap.rs @@ -65,12 +65,15 @@ impl<'treap> HashTreap<'treap> { let mut roots_table = write_txn.open_table(ROOTS_TABLE).unwrap(); let mut nodes_table = write_txn.open_table(NODES_TABLE).unwrap(); - let root = self.root_hash_inner(&roots_table); + let old_root = self + .root_hash_inner(&roots_table) + .and_then(|hash| Node::open(&nodes_table, hash)); - let new_root = crate::operations::insert::insert(&mut nodes_table, root, key, value); + let new_root = + crate::operations::insert::insert(&mut nodes_table, old_root, key, value); roots_table - .insert(self.name.as_bytes(), new_root.as_bytes().as_slice()) + .insert(self.name.as_bytes(), new_root.hash().as_bytes().as_slice()) .unwrap(); };