diff --git a/src/merkle/partial_mt/mod.rs b/src/merkle/partial_mt/mod.rs index ef87516..10f7231 100644 --- a/src/merkle/partial_mt/mod.rs +++ b/src/merkle/partial_mt/mod.rs @@ -118,7 +118,7 @@ impl PartialMerkleTree { // fill layers without nodes with empty vector for depth in 0..max_depth { - layers.entry(depth).or_insert(vec![]); + layers.entry(depth).or_default(); } let mut layer_iter = layers.into_values().rev(); @@ -370,7 +370,6 @@ impl PartialMerkleTree { return Ok(old_value); } - let mut node_index = node_index; let mut value = value.into(); for _ in 0..node_index.depth() { let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist"); diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs index b8dd52f..5379d20 100644 --- a/src/merkle/tiered_smt/mod.rs +++ b/src/merkle/tiered_smt/mod.rs @@ -2,7 +2,7 @@ use super::{ BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, StarkField, Vec, Word, }; -use core::cmp; +use core::{cmp, ops::Deref}; mod nodes; use nodes::NodeStore; @@ -148,32 +148,36 @@ impl TieredSmt { return self.remove_leaf_node(key); } - // insert the value into the value store, and if nothing has changed, return - let (old_value, is_update) = match self.values.insert(key, value) { - Some(old_value) => { - if old_value == value { - return old_value; - } - (old_value, true) + // insert the value into the value store, and if the key was already in the store, update + // it with the new value + if let Some(old_value) = self.values.insert(key, value) { + if old_value != value { + // if the new value is different from the old value, determine the location of + // the leaf node for this key, build the node, and update the root + let (index, leaf_exists) = self.nodes.get_leaf_index(&key); + debug_assert!(leaf_exists); + let node = self.build_leaf_node(index, key, value); + self.root = self.nodes.update_leaf_node(index, node); } - None => (Self::EMPTY_VALUE, false), + return old_value; }; - // determine the index for the value node; this index could have 3 different meanings: - // - it points to a root of an empty subtree (excluding depth = 64); in this case, we can - // replace the node with the value node immediately. - // - it points to a node at the bottom tier (i.e., depth = 64); in this case, we need to - // process bottom-tier insertion which will be handled by insert_leaf_node(). - // - it points to an existing leaf node; this node could be a node with the same key or a - // different key with a common prefix; in the latter case, we'll need to move the leaf - // to a lower tier - let (index, leaf_exists) = self.nodes.get_insert_location(&key); - debug_assert!(!is_update || leaf_exists); - - // if the returned index points to a leaf, and this leaf is for a different key (i.e., we - // are not updating a value for an existing key), we need to replace this leaf with a tree - // containing leaves for both the old and the new key-value pairs - if leaf_exists && !is_update { + // determine the location for the leaf node; this index could have 3 different meanings: + // - it points to a root of an empty subtree or an empty node at depth 64; in this case, + // we can replace the node with the value node immediately. + // - it points to an existing leaf at the bottom tier (i.e., depth = 64); in this case, + // we need to process update the bottom leaf. + // - it points to an existing leaf node for a different key with the same prefix (same + // key case was handled above); in this case, we need to move the leaf to a lower tier + let (index, leaf_exists) = self.nodes.get_leaf_index(&key); + + self.root = if leaf_exists && index.depth() == Self::MAX_DEPTH { + // returned index points to a leaf at the bottom tier + let node = self.build_leaf_node(index, key, value); + self.nodes.update_leaf_node(index, node) + } else if leaf_exists { + // returned index pointes to a leaf for a different key with the same prefix + // get the key-value pair for the key with the same prefix; since the key-value // pair has already been inserted into the value store, we need to filter it out // when looking for the other key-value pair @@ -183,12 +187,12 @@ impl TieredSmt { .expect("other key-value pair not found"); // determine how far down the tree should we move the leaves - let common_prefix_len = get_common_prefix_tier(&key, other_key); + let common_prefix_len = get_common_prefix_tier_depth(&key, other_key); let depth = cmp::min(common_prefix_len + Self::TIER_SIZE, Self::MAX_DEPTH); // compute node locations for new and existing key-value paris - let new_index = key_to_index(&key, depth); - let other_index = key_to_index(other_key, depth); + let new_index = LeafNodeIndex::from_key(&key, depth); + let other_index = LeafNodeIndex::from_key(other_key, depth); // compute node values for the new and existing key-value pairs let new_node = self.build_leaf_node(new_index, key, value); @@ -196,19 +200,17 @@ impl TieredSmt { // replace the leaf located at index with a subtree containing nodes for new and // existing key-value paris - self.root = self.nodes.replace_leaf_with_subtree( + self.nodes.replace_leaf_with_subtree( index, [(new_index, new_node), (other_index, other_node)], - ); + ) } else { - // if the returned index points to an empty subtree, or a leaf with the same key (i.e., - // we are performing an update), or a leaf is at the bottom tier, compute its node - // value and do a simple insert + // returned index points to an empty subtree or an empty leaf at the bottom tier let node = self.build_leaf_node(index, key, value); - self.root = self.nodes.insert_leaf_node(index, node); - } + self.nodes.insert_leaf_node(index, node) + }; - old_value + Self::EMPTY_VALUE } // ITERATORS @@ -235,7 +237,7 @@ impl TieredSmt { self.nodes.upper_leaves().map(|(index, node)| { let key_prefix = index_to_prefix(index); let (key, value) = self.values.get_first(key_prefix).expect("upper leaf not found"); - debug_assert_eq!(key_to_index(key, index.depth()), *index); + debug_assert_eq!(*index, LeafNodeIndex::from_key(key, index.depth()).into()); (*node, *key, *value) }) } @@ -269,8 +271,8 @@ impl TieredSmt { }; // determine the location of the leaf holding the key-value pair to be removed - let (index, leaf_exists) = self.nodes.get_insert_location(&key); - debug_assert!(index.depth() == Self::MAX_DEPTH || leaf_exists); + let (index, leaf_exists) = self.nodes.get_leaf_index(&key); + debug_assert!(leaf_exists); // if the leaf is at the bottom tier and after removing the key-value pair from it, the // leaf is still not empty, just recompute its hash and update the leaf node. @@ -286,7 +288,7 @@ impl TieredSmt { // higher tier, we need to move the sibling to a higher tier if let Some((sib_key, sib_val, new_sib_index)) = self.values.get_lone_sibling(index) { // determine the current index of the sibling node - let sib_index = key_to_index(sib_key, index.depth()); + let sib_index = LeafNodeIndex::from_key(sib_key, index.depth()); debug_assert!(sib_index.depth() > new_sib_index.depth()); // compute node value for the new location of the sibling leaf and replace the subtree @@ -309,9 +311,8 @@ impl TieredSmt { /// the value store, however, for depths 16, 32, and 48, the node is computed directly from /// the passed-in values (for depth 64, the value store is queried to get all the key-value /// pairs located at the specified index). - fn build_leaf_node(&self, index: NodeIndex, key: RpoDigest, value: Word) -> RpoDigest { + fn build_leaf_node(&self, index: LeafNodeIndex, key: RpoDigest, value: Word) -> RpoDigest { let depth = index.depth(); - debug_assert!(Self::TIER_DEPTHS.contains(&depth)); // insert the key into index-key map and compute the new value of the node if index.depth() == Self::MAX_DEPTH { @@ -337,6 +338,71 @@ impl Default for TieredSmt { } } +// LEAF NODE INDEX +// ================================================================================================ +/// A wrapper around [NodeIndex] to provide type-safe references to nodes at depths 16, 32, 48, and +/// 64. +#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub struct LeafNodeIndex(NodeIndex); + +impl LeafNodeIndex { + /// Returns a new [LeafNodeIndex] instantiated from the provided [NodeIndex]. + /// + /// In debug mode, panics if index depth is not 16, 32, 48, or 64. + pub fn new(index: NodeIndex) -> Self { + // check if the depth is 16, 32, 48, or 64; this works because for a valid depth, + // depth - 16, can be 0, 16, 32, or 48 - i.e., the value is either 0 or any of the 4th + // or 5th bits are set. We can test for this by computing a bitwise AND with a value + // which has all but the 4th and 5th bits set (which is !48). + debug_assert_eq!(((index.depth() - 16) & !48), 0, "invalid tier depth {}", index.depth()); + Self(index) + } + + /// Returns a new [LeafNodeIndex] instantiated from the specified key inserted at the specified + /// depth. + /// + /// The value for the key is computed by taking n most significant bits from the most significant + /// element of the key, where n is the specified depth. + pub fn from_key(key: &RpoDigest, depth: u8) -> Self { + let mse = get_key_prefix(key); + Self::new(NodeIndex::new_unchecked(depth, mse >> (TieredSmt::MAX_DEPTH - depth))) + } + + /// Returns a new [LeafNodeIndex] instantiated for testing purposes. + #[cfg(test)] + pub fn make(depth: u8, value: u64) -> Self { + Self::new(NodeIndex::make(depth, value)) + } + + /// Traverses towards the root until the specified depth is reached. + /// + /// The new depth must be a valid tier depth - i.e., 16, 32, 48, or 64. + pub fn move_up_to(&mut self, depth: u8) { + debug_assert_eq!(((depth - 16) & !48), 0, "invalid tier depth: {depth}"); + self.0.move_up_to(depth); + } +} + +impl Deref for LeafNodeIndex { + type Target = NodeIndex; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for LeafNodeIndex { + fn from(value: NodeIndex) -> Self { + Self::new(value) + } +} + +impl From for NodeIndex { + fn from(value: LeafNodeIndex) -> Self { + value.0 + } +} + // HELPER FUNCTIONS // ================================================================================================ @@ -351,19 +417,6 @@ fn index_to_prefix(index: &NodeIndex) -> u64 { index.value() << (TieredSmt::MAX_DEPTH - index.depth()) } -/// Returns index for the specified key inserted at the specified depth. -/// -/// The value for the key is computed by taking n most significant bits from the most significant -/// element of the key, where n is the specified depth. -fn key_to_index(key: &RpoDigest, depth: u8) -> NodeIndex { - let mse = get_key_prefix(key); - let value = match depth { - 16 | 32 | 48 | 64 => mse >> ((TieredSmt::MAX_DEPTH - depth) as u32), - _ => unreachable!("invalid depth: {depth}"), - }; - NodeIndex::new_unchecked(depth, value) -} - /// Returns tiered common prefix length between the most significant elements of the provided keys. /// /// Specifically: @@ -372,36 +425,13 @@ fn key_to_index(key: &RpoDigest, depth: u8) -> NodeIndex { /// - returns 32 if the common prefix is between 32 and 47 bits. /// - returns 16 if the common prefix is between 16 and 31 bits. /// - returns 0 if the common prefix is fewer than 16 bits. -fn get_common_prefix_tier(key1: &RpoDigest, key2: &RpoDigest) -> u8 { +fn get_common_prefix_tier_depth(key1: &RpoDigest, key2: &RpoDigest) -> u8 { let e1 = get_key_prefix(key1); let e2 = get_key_prefix(key2); let ex = (e1 ^ e2).leading_zeros() as u8; (ex / 16) * 16 } -/// Returns a tier for the specified index. -/// -/// The tiers are defined as follows: -/// - Tier 0: depth 0 through 16 (inclusive). -/// - Tier 1: depth 17 through 32 (inclusive). -/// - Tier 2: depth 33 through 48 (inclusive). -/// - Tier 3: depth 49 through 64 (inclusive). -const fn get_index_tier(index: &NodeIndex) -> usize { - debug_assert!(index.depth() <= TieredSmt::MAX_DEPTH, "invalid depth"); - match index.depth() { - 0..=16 => 0, - 17..=32 => 1, - 33..=48 => 2, - _ => 3, - } -} - -/// Returns true if the specified index is an index for an leaf node (i.e., the depth is 16, 32, -/// 48, or 64). -const fn is_leaf_node(index: &NodeIndex) -> bool { - matches!(index.depth(), 16 | 32 | 48 | 64) -} - /// Computes node value for leaves at tiers 16, 32, or 48. /// /// Node value is computed as: hash(key || value, domain = depth). @@ -413,7 +443,10 @@ pub fn hash_upper_leaf(key: RpoDigest, value: Word, depth: u8) -> RpoDigest { /// Computes node value for leaves at the bottom tier (depth 64). /// -/// Node value is computed as: hash([key_0, value_0, ..., key_n, value_n, domain=64]). +/// Node value is computed as: hash([key_0, value_0, ..., key_n, value_n], domain=64). +/// +/// TODO: when hashing in domain is implemented for `hash_elements()`, combine this function with +/// `hash_upper_leaf()` function. pub fn hash_bottom_leaf(values: &[(RpoDigest, Word)]) -> RpoDigest { let mut elements = Vec::with_capacity(values.len() * 8); for (key, val) in values.iter() { diff --git a/src/merkle/tiered_smt/nodes.rs b/src/merkle/tiered_smt/nodes.rs index 42bad5e..7135c6c 100644 --- a/src/merkle/tiered_smt/nodes.rs +++ b/src/merkle/tiered_smt/nodes.rs @@ -1,6 +1,6 @@ use super::{ - get_index_tier, get_key_prefix, is_leaf_node, BTreeMap, BTreeSet, EmptySubtreeRoots, - InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, + BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, LeafNodeIndex, MerkleError, MerklePath, + NodeIndex, Rpo256, RpoDigest, Vec, }; // CONSTANTS @@ -21,7 +21,8 @@ const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH; /// A store of nodes for a Tiered Sparse Merkle tree. /// /// The store contains information about all nodes as well as information about which of the nodes -/// represent leaf nodes in a Tiered Sparse Merkle tree. +/// represent leaf nodes in a Tiered Sparse Merkle tree. In the current implementation, [BTreeSet]s +/// are used to determine the position of the leaves in the tree. #[derive(Debug, Clone, PartialEq, Eq)] pub struct NodeStore { nodes: BTreeMap, @@ -88,14 +89,13 @@ impl NodeStore { /// Returns an index at which a leaf node for the specified key should be inserted. /// /// The second value in the returned tuple is set to true if the node at the returned index - /// is already a leaf node, excluding leaves at the bottom tier (i.e., if the leaf is at the - /// bottom tier, false is returned). - pub fn get_insert_location(&self, key: &RpoDigest) -> (NodeIndex, bool) { + /// is already a leaf node. + pub fn get_leaf_index(&self, key: &RpoDigest) -> (LeafNodeIndex, bool) { // traverse the tree from the root down checking nodes at tiers 16, 32, and 48. Return if // a node at any of the tiers is either a leaf or a root of an empty subtree. - let mse = get_key_prefix(key); - for depth in (TIER_DEPTHS[0]..MAX_DEPTH).step_by(TIER_SIZE as usize) { - let index = NodeIndex::new_unchecked(depth, mse >> (MAX_DEPTH - depth)); + const NUM_UPPER_TIERS: usize = TIER_DEPTHS.len() - 1; + for &tier_depth in TIER_DEPTHS[..NUM_UPPER_TIERS].iter() { + let index = LeafNodeIndex::from_key(key, tier_depth); if self.upper_leaves.contains(&index) { return (index, true); } else if !self.nodes.contains_key(&index) { @@ -105,8 +105,8 @@ impl NodeStore { // if we got here, that means all of the nodes checked so far are internal nodes, and // the new node would need to be inserted in the bottom tier. - let index = NodeIndex::new_unchecked(MAX_DEPTH, mse); - (index, false) + let index = LeafNodeIndex::from_key(key, MAX_DEPTH); + (index, self.bottom_leaves.contains(&index.value())) } // ITERATORS @@ -118,7 +118,7 @@ impl NodeStore { /// The iterator order is unspecified. pub fn inner_nodes(&self) -> impl Iterator + '_ { self.nodes.iter().filter_map(|(index, node)| { - if !is_leaf_node(index) { + if self.is_internal_node(index) { Some(InnerNodeInfo { value: *node, left: self.get_node_unchecked(&index.left_child()), @@ -152,20 +152,26 @@ impl NodeStore { /// at the specified indexes. Recomputes and returns the new root. pub fn replace_leaf_with_subtree( &mut self, - leaf_index: NodeIndex, - subtree_leaves: [(NodeIndex, RpoDigest); 2], + leaf_index: LeafNodeIndex, + subtree_leaves: [(LeafNodeIndex, RpoDigest); 2], ) -> RpoDigest { - debug_assert!(is_leaf_node(&leaf_index)); - debug_assert!(is_leaf_node(&subtree_leaves[0].0)); - debug_assert!(is_leaf_node(&subtree_leaves[1].0)); + debug_assert!(self.is_non_empty_leaf(&leaf_index)); debug_assert!(!is_empty_root(&subtree_leaves[0].1)); debug_assert!(!is_empty_root(&subtree_leaves[1].1)); debug_assert_eq!(subtree_leaves[0].0.depth(), subtree_leaves[1].0.depth()); debug_assert!(leaf_index.depth() < subtree_leaves[0].0.depth()); self.upper_leaves.remove(&leaf_index); - self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1); - self.insert_leaf_node(subtree_leaves[1].0, subtree_leaves[1].1) + + if subtree_leaves[0].0 == subtree_leaves[1].0 { + // if the subtree is for a single node at depth 64, we only need to insert one node + debug_assert_eq!(subtree_leaves[0].0.depth(), MAX_DEPTH); + debug_assert_eq!(subtree_leaves[0].1, subtree_leaves[1].1); + self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1) + } else { + self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1); + self.insert_leaf_node(subtree_leaves[1].0, subtree_leaves[1].1) + } } /// Replaces a subtree containing the retained and the removed leaf nodes, with a leaf node @@ -175,14 +181,14 @@ impl NodeStore { /// moving the node at the `retained_leaf` index up to the tier specified by `new_depth`. pub fn replace_subtree_with_leaf( &mut self, - removed_leaf: NodeIndex, - retained_leaf: NodeIndex, + removed_leaf: LeafNodeIndex, + retained_leaf: LeafNodeIndex, new_depth: u8, node: RpoDigest, ) -> RpoDigest { debug_assert!(!is_empty_root(&node)); - debug_assert!(self.is_leaf(&removed_leaf)); - debug_assert!(self.is_leaf(&retained_leaf)); + debug_assert!(self.is_non_empty_leaf(&removed_leaf)); + debug_assert!(self.is_non_empty_leaf(&retained_leaf)); debug_assert_eq!(removed_leaf.depth(), retained_leaf.depth()); debug_assert!(removed_leaf.depth() > new_depth); @@ -202,7 +208,6 @@ impl NodeStore { // compute the index of the common root for retained and removed leaves let mut new_index = retained_leaf; new_index.move_up_to(new_depth); - debug_assert!(is_leaf_node(&new_index)); // insert the node at the root index self.insert_leaf_node(new_index, node) @@ -211,19 +216,21 @@ impl NodeStore { /// Inserts the specified node at the specified index; recomputes and returns the new root /// of the Tiered Sparse Merkle tree. /// - /// This method assumes that node is a non-empty value. - pub fn insert_leaf_node(&mut self, mut index: NodeIndex, mut node: RpoDigest) -> RpoDigest { - debug_assert!(is_leaf_node(&index)); + /// This method assumes that the provided node is a non-empty value, and that there is no node + /// at the specified index. + pub fn insert_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest { debug_assert!(!is_empty_root(&node)); + debug_assert_eq!(self.nodes.get(&index), None); // mark the node as the leaf if index.depth() == MAX_DEPTH { self.bottom_leaves.insert(index.value()); } else { - self.upper_leaves.insert(index); + self.upper_leaves.insert(index.into()); }; // insert the node and update the path from the node to the root + let mut index: NodeIndex = index.into(); for _ in 0..index.depth() { self.nodes.insert(index, node); let sibling = self.get_node_unchecked(&index.sibling()); @@ -240,8 +247,8 @@ impl NodeStore { /// returns the new root of the Tiered Sparse Merkle tree. /// /// This method can accept `node` as either an empty or a non-empty value. - pub fn update_leaf_node(&mut self, mut index: NodeIndex, mut node: RpoDigest) -> RpoDigest { - debug_assert!(self.is_leaf(&index)); + pub fn update_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest { + debug_assert!(self.is_non_empty_leaf(&index)); // if the value we are updating the node to is a root of an empty tree, clear the leaf // flag for this node @@ -256,6 +263,7 @@ impl NodeStore { } // update the path from the node to the root + let mut index: NodeIndex = index.into(); for _ in 0..index.depth() { if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] { self.nodes.remove(&index); @@ -275,8 +283,8 @@ impl NodeStore { /// Replaces the leaf node at the specified index with a root of an empty subtree; recomputes /// and returns the new root of the Tiered Sparse Merkle tree. - pub fn clear_leaf_node(&mut self, index: NodeIndex) -> RpoDigest { - debug_assert!(self.is_leaf(&index)); + pub fn clear_leaf_node(&mut self, index: LeafNodeIndex) -> RpoDigest { + debug_assert!(self.is_non_empty_leaf(&index)); let node = EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize]; self.update_leaf_node(index, node) } @@ -285,8 +293,7 @@ impl NodeStore { // -------------------------------------------------------------------------------------------- /// Returns true if the node at the specified index is a leaf node. - fn is_leaf(&self, index: &NodeIndex) -> bool { - debug_assert!(is_leaf_node(index)); + fn is_non_empty_leaf(&self, index: &LeafNodeIndex) -> bool { if index.depth() == MAX_DEPTH { self.bottom_leaves.contains(&index.value()) } else { @@ -294,6 +301,16 @@ impl NodeStore { } } + /// Returns true if the node at the specified index is an internal node - i.e., there is + /// no leaf at that node and the node does not belong to the bottom tier. + fn is_internal_node(&self, index: &NodeIndex) -> bool { + if index.depth() == MAX_DEPTH { + false + } else { + !self.upper_leaves.contains(index) + } + } + /// Checks if the specified index is valid in the context of this Merkle tree. /// /// # Errors @@ -309,7 +326,7 @@ impl NodeStore { } else { // make sure that there are no leaf nodes in the ancestors of the index; since leaf // nodes can live at specific depth, we just need to check these depths. - let tier = get_index_tier(&index); + let tier = ((index.depth() - 1) / TIER_SIZE) as usize; let mut tier_index = index; for &depth in TIER_DEPTHS[..tier].iter().rev() { tier_index.move_up_to(depth); @@ -335,12 +352,13 @@ impl NodeStore { } /// Removes a sequence of nodes starting at the specified index and traversing the - /// tree up to the specified depth. + /// tree up to the specified depth. The node at the `end_depth` is also removed. /// /// This method does not update any other nodes and does not recompute the tree root. - fn remove_branch(&mut self, mut index: NodeIndex, end_depth: u8) { + fn remove_branch(&mut self, index: LeafNodeIndex, end_depth: u8) { + let mut index: NodeIndex = index.into(); assert!(index.depth() > end_depth); - for _ in 0..(index.depth() - end_depth) { + for _ in 0..(index.depth() - end_depth + 1) { self.nodes.remove(&index); index.move_up() } diff --git a/src/merkle/tiered_smt/tests.rs b/src/merkle/tiered_smt/tests.rs index 845e76e..e459c90 100644 --- a/src/merkle/tiered_smt/tests.rs +++ b/src/merkle/tiered_smt/tests.rs @@ -509,9 +509,26 @@ fn tsmt_bottom_tier() { actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node))); // make sure leaves are returned correctly - let mut leaves = smt.bottom_leaves(); + let smt_clone = smt.clone(); + let mut leaves = smt_clone.bottom_leaves(); assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a)]))); assert_eq!(leaves.next(), None); + + // --- update a leaf at the bottom tier ------------------------------------------------------- + + let val_a2 = [Felt::new(3); WORD_SIZE]; + assert_eq!(smt.insert(key_a, val_a2), val_a); + + let leaf_node = build_bottom_leaf_node(&[key_b, key_a], &[val_b, val_a2]); + store.set_node(tree_root, index, leaf_node).unwrap(); + + let expected_nodes = get_non_empty_nodes(&store); + let actual_nodes = smt.inner_nodes().collect::>(); + actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node))); + + let mut leaves = smt.bottom_leaves(); + assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a2)]))); + assert_eq!(leaves.next(), None); } #[test] diff --git a/src/merkle/tiered_smt/values.rs b/src/merkle/tiered_smt/values.rs index eca8e5d..ec2a465 100644 --- a/src/merkle/tiered_smt/values.rs +++ b/src/merkle/tiered_smt/values.rs @@ -1,4 +1,4 @@ -use super::{get_key_prefix, is_leaf_node, BTreeMap, NodeIndex, RpoDigest, StarkField, Vec, Word}; +use super::{get_key_prefix, BTreeMap, LeafNodeIndex, RpoDigest, StarkField, Vec, Word}; use crate::utils::vec; use core::{ cmp::{Ord, Ordering}, @@ -23,7 +23,8 @@ const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH; /// the values are the corresponding key-value pairs (or a list of key-value pairs if more that /// a single key-value pair shares the same 64-bit prefix). /// -/// The store supports lookup by the full key as well as by the 64-bit key prefix. +/// The store supports lookup by the full key (i.e. [RpoDigest]) as well as by the 64-bit key +/// prefix. #[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct ValueStore { values: BTreeMap, @@ -76,26 +77,29 @@ impl ValueStore { /// /// This method assumes that the key-value pair for the specified index has already been /// removed from the store. - pub fn get_lone_sibling(&self, index: NodeIndex) -> Option<(&RpoDigest, &Word, NodeIndex)> { - debug_assert!(is_leaf_node(&index)); - + pub fn get_lone_sibling( + &self, + index: LeafNodeIndex, + ) -> Option<(&RpoDigest, &Word, LeafNodeIndex)> { // iterate over tiers from top to bottom, looking at the tiers which are strictly above // the depth of the index. This implies that only tiers at depth 32 and 48 will be // considered. For each tier, check if the parent of the index at the higher tier - // contains a single node. - for &tier in TIER_DEPTHS.iter().filter(|&t| index.depth() > *t) { + // contains a single node. The fist tier (depth 16) is excluded because we cannot move + // nodes at depth 16 to a higher tier. This implies that nodes at the first tier will + // never have "lone siblings". + for &tier_depth in TIER_DEPTHS.iter().filter(|&t| index.depth() > *t) { // compute the index of the root at a higher tier let mut parent_index = index; - parent_index.move_up_to(tier); + parent_index.move_up_to(tier_depth); // find the lone sibling, if any; we need to handle the "last node" at a given tier // separately specify the bounds for the search correctly. - let start_prefix = parent_index.value() << (MAX_DEPTH - tier); - let sibling = if start_prefix.leading_ones() as u8 == tier { + let start_prefix = parent_index.value() << (MAX_DEPTH - tier_depth); + let sibling = if start_prefix.leading_ones() as u8 == tier_depth { let mut iter = self.range(start_prefix..); iter.next().filter(|_| iter.next().is_none()) } else { - let end_prefix = (parent_index.value() + 1) << (MAX_DEPTH - tier); + let end_prefix = (parent_index.value() + 1) << (MAX_DEPTH - tier_depth); let mut iter = self.range(start_prefix..end_prefix); iter.next().filter(|_| iter.next().is_none()) }; @@ -346,12 +350,8 @@ fn cmp_digests(d1: &RpoDigest, d2: &RpoDigest) -> Ordering { #[cfg(test)] mod tests { - - use super::{RpoDigest, ValueStore}; - use crate::{ - merkle::{tiered_smt::values::StoreEntry, NodeIndex}, - Felt, ONE, WORD_SIZE, ZERO, - }; + use super::{LeafNodeIndex, RpoDigest, StoreEntry, ValueStore}; + use crate::{Felt, ONE, WORD_SIZE, ZERO}; #[test] fn test_insert() { @@ -569,17 +569,17 @@ mod tests { store.insert(key_b, value_b); // check sibling node for `a` - let index = NodeIndex::make(32, 0b_10101010_10101010_00011111_11111110); - let parent_index = NodeIndex::make(16, 0b_10101010_10101010); + let index = LeafNodeIndex::make(32, 0b_10101010_10101010_00011111_11111110); + let parent_index = LeafNodeIndex::make(16, 0b_10101010_10101010); assert_eq!(store.get_lone_sibling(index), Some((&key_a, &value_a, parent_index))); // check sibling node for `b` - let index = NodeIndex::make(32, 0b_11111111_11111111_00011111_11111111); - let parent_index = NodeIndex::make(16, 0b_11111111_11111111); + let index = LeafNodeIndex::make(32, 0b_11111111_11111111_00011111_11111111); + let parent_index = LeafNodeIndex::make(16, 0b_11111111_11111111); assert_eq!(store.get_lone_sibling(index), Some((&key_b, &value_b, parent_index))); // check some other sibling for some other index - let index = NodeIndex::make(32, 0b_11101010_10101010); + let index = LeafNodeIndex::make(32, 0b_11101010_10101010); assert_eq!(store.get_lone_sibling(index), None); } }