Browse Source

fix: node type check in inner_nodes() iterator of TSMT

al-gkr-basic-workflow
Bobbin Threadbare 1 year ago
parent
commit
6810b5e3ab
5 changed files with 209 additions and 142 deletions
  1. +1
    -2
      src/merkle/partial_mt/mod.rs
  2. +112
    -79
      src/merkle/tiered_smt/mod.rs
  3. +56
    -38
      src/merkle/tiered_smt/nodes.rs
  4. +18
    -1
      src/merkle/tiered_smt/tests.rs
  5. +22
    -22
      src/merkle/tiered_smt/values.rs

+ 1
- 2
src/merkle/partial_mt/mod.rs

@ -118,7 +118,7 @@ impl PartialMerkleTree {
// fill layers without nodes with empty vector // fill layers without nodes with empty vector
for depth in 0..max_depth { for depth in 0..max_depth {
layers.entry(depth).or_insert(vec![]);
layers.entry(depth).or_default();
} }
let mut layer_iter = layers.into_values().rev(); let mut layer_iter = layers.into_values().rev();
@ -370,7 +370,6 @@ impl PartialMerkleTree {
return Ok(old_value); return Ok(old_value);
} }
let mut node_index = node_index;
let mut value = value.into(); let mut value = value.into();
for _ in 0..node_index.depth() { for _ in 0..node_index.depth() {
let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist"); let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist");

+ 112
- 79
src/merkle/tiered_smt/mod.rs

@ -2,7 +2,7 @@ use super::{
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex,
Rpo256, RpoDigest, StarkField, Vec, Word, Rpo256, RpoDigest, StarkField, Vec, Word,
}; };
use core::cmp;
use core::{cmp, ops::Deref};
mod nodes; mod nodes;
use nodes::NodeStore; use nodes::NodeStore;
@ -148,32 +148,36 @@ impl TieredSmt {
return self.remove_leaf_node(key); return self.remove_leaf_node(key);
} }
// insert the value into the value store, and if nothing has changed, return
let (old_value, is_update) = match self.values.insert(key, value) {
Some(old_value) => {
if old_value == value {
return old_value;
}
(old_value, true)
// insert the value into the value store, and if the key was already in the store, update
// it with the new value
if let Some(old_value) = self.values.insert(key, value) {
if old_value != value {
// if the new value is different from the old value, determine the location of
// the leaf node for this key, build the node, and update the root
let (index, leaf_exists) = self.nodes.get_leaf_index(&key);
debug_assert!(leaf_exists);
let node = self.build_leaf_node(index, key, value);
self.root = self.nodes.update_leaf_node(index, node);
} }
None => (Self::EMPTY_VALUE, false),
return old_value;
}; };
// determine the index for the value node; this index could have 3 different meanings:
// - it points to a root of an empty subtree (excluding depth = 64); in this case, we can
// replace the node with the value node immediately.
// - it points to a node at the bottom tier (i.e., depth = 64); in this case, we need to
// process bottom-tier insertion which will be handled by insert_leaf_node().
// - it points to an existing leaf node; this node could be a node with the same key or a
// different key with a common prefix; in the latter case, we'll need to move the leaf
// to a lower tier
let (index, leaf_exists) = self.nodes.get_insert_location(&key);
debug_assert!(!is_update || leaf_exists);
// if the returned index points to a leaf, and this leaf is for a different key (i.e., we
// are not updating a value for an existing key), we need to replace this leaf with a tree
// containing leaves for both the old and the new key-value pairs
if leaf_exists && !is_update {
// determine the location for the leaf node; this index could have 3 different meanings:
// - it points to a root of an empty subtree or an empty node at depth 64; in this case,
// we can replace the node with the value node immediately.
// - it points to an existing leaf at the bottom tier (i.e., depth = 64); in this case,
// we need to process update the bottom leaf.
// - it points to an existing leaf node for a different key with the same prefix (same
// key case was handled above); in this case, we need to move the leaf to a lower tier
let (index, leaf_exists) = self.nodes.get_leaf_index(&key);
self.root = if leaf_exists && index.depth() == Self::MAX_DEPTH {
// returned index points to a leaf at the bottom tier
let node = self.build_leaf_node(index, key, value);
self.nodes.update_leaf_node(index, node)
} else if leaf_exists {
// returned index pointes to a leaf for a different key with the same prefix
// get the key-value pair for the key with the same prefix; since the key-value // get the key-value pair for the key with the same prefix; since the key-value
// pair has already been inserted into the value store, we need to filter it out // pair has already been inserted into the value store, we need to filter it out
// when looking for the other key-value pair // when looking for the other key-value pair
@ -183,12 +187,12 @@ impl TieredSmt {
.expect("other key-value pair not found"); .expect("other key-value pair not found");
// determine how far down the tree should we move the leaves // determine how far down the tree should we move the leaves
let common_prefix_len = get_common_prefix_tier(&key, other_key);
let common_prefix_len = get_common_prefix_tier_depth(&key, other_key);
let depth = cmp::min(common_prefix_len + Self::TIER_SIZE, Self::MAX_DEPTH); let depth = cmp::min(common_prefix_len + Self::TIER_SIZE, Self::MAX_DEPTH);
// compute node locations for new and existing key-value paris // compute node locations for new and existing key-value paris
let new_index = key_to_index(&key, depth);
let other_index = key_to_index(other_key, depth);
let new_index = LeafNodeIndex::from_key(&key, depth);
let other_index = LeafNodeIndex::from_key(other_key, depth);
// compute node values for the new and existing key-value pairs // compute node values for the new and existing key-value pairs
let new_node = self.build_leaf_node(new_index, key, value); let new_node = self.build_leaf_node(new_index, key, value);
@ -196,19 +200,17 @@ impl TieredSmt {
// replace the leaf located at index with a subtree containing nodes for new and // replace the leaf located at index with a subtree containing nodes for new and
// existing key-value paris // existing key-value paris
self.root = self.nodes.replace_leaf_with_subtree(
self.nodes.replace_leaf_with_subtree(
index, index,
[(new_index, new_node), (other_index, other_node)], [(new_index, new_node), (other_index, other_node)],
);
)
} else { } else {
// if the returned index points to an empty subtree, or a leaf with the same key (i.e.,
// we are performing an update), or a leaf is at the bottom tier, compute its node
// value and do a simple insert
// returned index points to an empty subtree or an empty leaf at the bottom tier
let node = self.build_leaf_node(index, key, value); let node = self.build_leaf_node(index, key, value);
self.root = self.nodes.insert_leaf_node(index, node);
}
self.nodes.insert_leaf_node(index, node)
};
old_value
Self::EMPTY_VALUE
} }
// ITERATORS // ITERATORS
@ -235,7 +237,7 @@ impl TieredSmt {
self.nodes.upper_leaves().map(|(index, node)| { self.nodes.upper_leaves().map(|(index, node)| {
let key_prefix = index_to_prefix(index); let key_prefix = index_to_prefix(index);
let (key, value) = self.values.get_first(key_prefix).expect("upper leaf not found"); let (key, value) = self.values.get_first(key_prefix).expect("upper leaf not found");
debug_assert_eq!(key_to_index(key, index.depth()), *index);
debug_assert_eq!(*index, LeafNodeIndex::from_key(key, index.depth()).into());
(*node, *key, *value) (*node, *key, *value)
}) })
} }
@ -269,8 +271,8 @@ impl TieredSmt {
}; };
// determine the location of the leaf holding the key-value pair to be removed // determine the location of the leaf holding the key-value pair to be removed
let (index, leaf_exists) = self.nodes.get_insert_location(&key);
debug_assert!(index.depth() == Self::MAX_DEPTH || leaf_exists);
let (index, leaf_exists) = self.nodes.get_leaf_index(&key);
debug_assert!(leaf_exists);
// if the leaf is at the bottom tier and after removing the key-value pair from it, the // if the leaf is at the bottom tier and after removing the key-value pair from it, the
// leaf is still not empty, just recompute its hash and update the leaf node. // leaf is still not empty, just recompute its hash and update the leaf node.
@ -286,7 +288,7 @@ impl TieredSmt {
// higher tier, we need to move the sibling to a higher tier // higher tier, we need to move the sibling to a higher tier
if let Some((sib_key, sib_val, new_sib_index)) = self.values.get_lone_sibling(index) { if let Some((sib_key, sib_val, new_sib_index)) = self.values.get_lone_sibling(index) {
// determine the current index of the sibling node // determine the current index of the sibling node
let sib_index = key_to_index(sib_key, index.depth());
let sib_index = LeafNodeIndex::from_key(sib_key, index.depth());
debug_assert!(sib_index.depth() > new_sib_index.depth()); debug_assert!(sib_index.depth() > new_sib_index.depth());
// compute node value for the new location of the sibling leaf and replace the subtree // compute node value for the new location of the sibling leaf and replace the subtree
@ -309,9 +311,8 @@ impl TieredSmt {
/// the value store, however, for depths 16, 32, and 48, the node is computed directly from /// the value store, however, for depths 16, 32, and 48, the node is computed directly from
/// the passed-in values (for depth 64, the value store is queried to get all the key-value /// the passed-in values (for depth 64, the value store is queried to get all the key-value
/// pairs located at the specified index). /// pairs located at the specified index).
fn build_leaf_node(&self, index: NodeIndex, key: RpoDigest, value: Word) -> RpoDigest {
fn build_leaf_node(&self, index: LeafNodeIndex, key: RpoDigest, value: Word) -> RpoDigest {
let depth = index.depth(); let depth = index.depth();
debug_assert!(Self::TIER_DEPTHS.contains(&depth));
// insert the key into index-key map and compute the new value of the node // insert the key into index-key map and compute the new value of the node
if index.depth() == Self::MAX_DEPTH { if index.depth() == Self::MAX_DEPTH {
@ -337,6 +338,71 @@ impl Default for TieredSmt {
} }
} }
// LEAF NODE INDEX
// ================================================================================================
/// A wrapper around [NodeIndex] to provide type-safe references to nodes at depths 16, 32, 48, and
/// 64.
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub struct LeafNodeIndex(NodeIndex);
impl LeafNodeIndex {
/// Returns a new [LeafNodeIndex] instantiated from the provided [NodeIndex].
///
/// In debug mode, panics if index depth is not 16, 32, 48, or 64.
pub fn new(index: NodeIndex) -> Self {
// check if the depth is 16, 32, 48, or 64; this works because for a valid depth,
// depth - 16, can be 0, 16, 32, or 48 - i.e., the value is either 0 or any of the 4th
// or 5th bits are set. We can test for this by computing a bitwise AND with a value
// which has all but the 4th and 5th bits set (which is !48).
debug_assert_eq!(((index.depth() - 16) & !48), 0, "invalid tier depth {}", index.depth());
Self(index)
}
/// Returns a new [LeafNodeIndex] instantiated from the specified key inserted at the specified
/// depth.
///
/// The value for the key is computed by taking n most significant bits from the most significant
/// element of the key, where n is the specified depth.
pub fn from_key(key: &RpoDigest, depth: u8) -> Self {
let mse = get_key_prefix(key);
Self::new(NodeIndex::new_unchecked(depth, mse >> (TieredSmt::MAX_DEPTH - depth)))
}
/// Returns a new [LeafNodeIndex] instantiated for testing purposes.
#[cfg(test)]
pub fn make(depth: u8, value: u64) -> Self {
Self::new(NodeIndex::make(depth, value))
}
/// Traverses towards the root until the specified depth is reached.
///
/// The new depth must be a valid tier depth - i.e., 16, 32, 48, or 64.
pub fn move_up_to(&mut self, depth: u8) {
debug_assert_eq!(((depth - 16) & !48), 0, "invalid tier depth: {depth}");
self.0.move_up_to(depth);
}
}
impl Deref for LeafNodeIndex {
type Target = NodeIndex;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<NodeIndex> for LeafNodeIndex {
fn from(value: NodeIndex) -> Self {
Self::new(value)
}
}
impl From<LeafNodeIndex> for NodeIndex {
fn from(value: LeafNodeIndex) -> Self {
value.0
}
}
// HELPER FUNCTIONS // HELPER FUNCTIONS
// ================================================================================================ // ================================================================================================
@ -351,19 +417,6 @@ fn index_to_prefix(index: &NodeIndex) -> u64 {
index.value() << (TieredSmt::MAX_DEPTH - index.depth()) index.value() << (TieredSmt::MAX_DEPTH - index.depth())
} }
/// Returns index for the specified key inserted at the specified depth.
///
/// The value for the key is computed by taking n most significant bits from the most significant
/// element of the key, where n is the specified depth.
fn key_to_index(key: &RpoDigest, depth: u8) -> NodeIndex {
let mse = get_key_prefix(key);
let value = match depth {
16 | 32 | 48 | 64 => mse >> ((TieredSmt::MAX_DEPTH - depth) as u32),
_ => unreachable!("invalid depth: {depth}"),
};
NodeIndex::new_unchecked(depth, value)
}
/// Returns tiered common prefix length between the most significant elements of the provided keys. /// Returns tiered common prefix length between the most significant elements of the provided keys.
/// ///
/// Specifically: /// Specifically:
@ -372,36 +425,13 @@ fn key_to_index(key: &RpoDigest, depth: u8) -> NodeIndex {
/// - returns 32 if the common prefix is between 32 and 47 bits. /// - returns 32 if the common prefix is between 32 and 47 bits.
/// - returns 16 if the common prefix is between 16 and 31 bits. /// - returns 16 if the common prefix is between 16 and 31 bits.
/// - returns 0 if the common prefix is fewer than 16 bits. /// - returns 0 if the common prefix is fewer than 16 bits.
fn get_common_prefix_tier(key1: &RpoDigest, key2: &RpoDigest) -> u8 {
fn get_common_prefix_tier_depth(key1: &RpoDigest, key2: &RpoDigest) -> u8 {
let e1 = get_key_prefix(key1); let e1 = get_key_prefix(key1);
let e2 = get_key_prefix(key2); let e2 = get_key_prefix(key2);
let ex = (e1 ^ e2).leading_zeros() as u8; let ex = (e1 ^ e2).leading_zeros() as u8;
(ex / 16) * 16 (ex / 16) * 16
} }
/// Returns a tier for the specified index.
///
/// The tiers are defined as follows:
/// - Tier 0: depth 0 through 16 (inclusive).
/// - Tier 1: depth 17 through 32 (inclusive).
/// - Tier 2: depth 33 through 48 (inclusive).
/// - Tier 3: depth 49 through 64 (inclusive).
const fn get_index_tier(index: &NodeIndex) -> usize {
debug_assert!(index.depth() <= TieredSmt::MAX_DEPTH, "invalid depth");
match index.depth() {
0..=16 => 0,
17..=32 => 1,
33..=48 => 2,
_ => 3,
}
}
/// Returns true if the specified index is an index for an leaf node (i.e., the depth is 16, 32,
/// 48, or 64).
const fn is_leaf_node(index: &NodeIndex) -> bool {
matches!(index.depth(), 16 | 32 | 48 | 64)
}
/// Computes node value for leaves at tiers 16, 32, or 48. /// Computes node value for leaves at tiers 16, 32, or 48.
/// ///
/// Node value is computed as: hash(key || value, domain = depth). /// Node value is computed as: hash(key || value, domain = depth).
@ -413,7 +443,10 @@ pub fn hash_upper_leaf(key: RpoDigest, value: Word, depth: u8) -> RpoDigest {
/// Computes node value for leaves at the bottom tier (depth 64). /// Computes node value for leaves at the bottom tier (depth 64).
/// ///
/// Node value is computed as: hash([key_0, value_0, ..., key_n, value_n, domain=64]).
/// Node value is computed as: hash([key_0, value_0, ..., key_n, value_n], domain=64).
///
/// TODO: when hashing in domain is implemented for `hash_elements()`, combine this function with
/// `hash_upper_leaf()` function.
pub fn hash_bottom_leaf(values: &[(RpoDigest, Word)]) -> RpoDigest { pub fn hash_bottom_leaf(values: &[(RpoDigest, Word)]) -> RpoDigest {
let mut elements = Vec::with_capacity(values.len() * 8); let mut elements = Vec::with_capacity(values.len() * 8);
for (key, val) in values.iter() { for (key, val) in values.iter() {

+ 56
- 38
src/merkle/tiered_smt/nodes.rs

@ -1,6 +1,6 @@
use super::{ use super::{
get_index_tier, get_key_prefix, is_leaf_node, BTreeMap, BTreeSet, EmptySubtreeRoots,
InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec,
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, LeafNodeIndex, MerkleError, MerklePath,
NodeIndex, Rpo256, RpoDigest, Vec,
}; };
// CONSTANTS // CONSTANTS
@ -21,7 +21,8 @@ const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
/// A store of nodes for a Tiered Sparse Merkle tree. /// A store of nodes for a Tiered Sparse Merkle tree.
/// ///
/// The store contains information about all nodes as well as information about which of the nodes /// The store contains information about all nodes as well as information about which of the nodes
/// represent leaf nodes in a Tiered Sparse Merkle tree.
/// represent leaf nodes in a Tiered Sparse Merkle tree. In the current implementation, [BTreeSet]s
/// are used to determine the position of the leaves in the tree.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct NodeStore { pub struct NodeStore {
nodes: BTreeMap<NodeIndex, RpoDigest>, nodes: BTreeMap<NodeIndex, RpoDigest>,
@ -88,14 +89,13 @@ impl NodeStore {
/// Returns an index at which a leaf node for the specified key should be inserted. /// Returns an index at which a leaf node for the specified key should be inserted.
/// ///
/// The second value in the returned tuple is set to true if the node at the returned index /// The second value in the returned tuple is set to true if the node at the returned index
/// is already a leaf node, excluding leaves at the bottom tier (i.e., if the leaf is at the
/// bottom tier, false is returned).
pub fn get_insert_location(&self, key: &RpoDigest) -> (NodeIndex, bool) {
/// is already a leaf node.
pub fn get_leaf_index(&self, key: &RpoDigest) -> (LeafNodeIndex, bool) {
// traverse the tree from the root down checking nodes at tiers 16, 32, and 48. Return if // traverse the tree from the root down checking nodes at tiers 16, 32, and 48. Return if
// a node at any of the tiers is either a leaf or a root of an empty subtree. // a node at any of the tiers is either a leaf or a root of an empty subtree.
let mse = get_key_prefix(key);
for depth in (TIER_DEPTHS[0]..MAX_DEPTH).step_by(TIER_SIZE as usize) {
let index = NodeIndex::new_unchecked(depth, mse >> (MAX_DEPTH - depth));
const NUM_UPPER_TIERS: usize = TIER_DEPTHS.len() - 1;
for &tier_depth in TIER_DEPTHS[..NUM_UPPER_TIERS].iter() {
let index = LeafNodeIndex::from_key(key, tier_depth);
if self.upper_leaves.contains(&index) { if self.upper_leaves.contains(&index) {
return (index, true); return (index, true);
} else if !self.nodes.contains_key(&index) { } else if !self.nodes.contains_key(&index) {
@ -105,8 +105,8 @@ impl NodeStore {
// if we got here, that means all of the nodes checked so far are internal nodes, and // if we got here, that means all of the nodes checked so far are internal nodes, and
// the new node would need to be inserted in the bottom tier. // the new node would need to be inserted in the bottom tier.
let index = NodeIndex::new_unchecked(MAX_DEPTH, mse);
(index, false)
let index = LeafNodeIndex::from_key(key, MAX_DEPTH);
(index, self.bottom_leaves.contains(&index.value()))
} }
// ITERATORS // ITERATORS
@ -118,7 +118,7 @@ impl NodeStore {
/// The iterator order is unspecified. /// The iterator order is unspecified.
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ { pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
self.nodes.iter().filter_map(|(index, node)| { self.nodes.iter().filter_map(|(index, node)| {
if !is_leaf_node(index) {
if self.is_internal_node(index) {
Some(InnerNodeInfo { Some(InnerNodeInfo {
value: *node, value: *node,
left: self.get_node_unchecked(&index.left_child()), left: self.get_node_unchecked(&index.left_child()),
@ -152,20 +152,26 @@ impl NodeStore {
/// at the specified indexes. Recomputes and returns the new root. /// at the specified indexes. Recomputes and returns the new root.
pub fn replace_leaf_with_subtree( pub fn replace_leaf_with_subtree(
&mut self, &mut self,
leaf_index: NodeIndex,
subtree_leaves: [(NodeIndex, RpoDigest); 2],
leaf_index: LeafNodeIndex,
subtree_leaves: [(LeafNodeIndex, RpoDigest); 2],
) -> RpoDigest { ) -> RpoDigest {
debug_assert!(is_leaf_node(&leaf_index));
debug_assert!(is_leaf_node(&subtree_leaves[0].0));
debug_assert!(is_leaf_node(&subtree_leaves[1].0));
debug_assert!(self.is_non_empty_leaf(&leaf_index));
debug_assert!(!is_empty_root(&subtree_leaves[0].1)); debug_assert!(!is_empty_root(&subtree_leaves[0].1));
debug_assert!(!is_empty_root(&subtree_leaves[1].1)); debug_assert!(!is_empty_root(&subtree_leaves[1].1));
debug_assert_eq!(subtree_leaves[0].0.depth(), subtree_leaves[1].0.depth()); debug_assert_eq!(subtree_leaves[0].0.depth(), subtree_leaves[1].0.depth());
debug_assert!(leaf_index.depth() < subtree_leaves[0].0.depth()); debug_assert!(leaf_index.depth() < subtree_leaves[0].0.depth());
self.upper_leaves.remove(&leaf_index); self.upper_leaves.remove(&leaf_index);
self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1);
self.insert_leaf_node(subtree_leaves[1].0, subtree_leaves[1].1)
if subtree_leaves[0].0 == subtree_leaves[1].0 {
// if the subtree is for a single node at depth 64, we only need to insert one node
debug_assert_eq!(subtree_leaves[0].0.depth(), MAX_DEPTH);
debug_assert_eq!(subtree_leaves[0].1, subtree_leaves[1].1);
self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1)
} else {
self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1);
self.insert_leaf_node(subtree_leaves[1].0, subtree_leaves[1].1)
}
} }
/// Replaces a subtree containing the retained and the removed leaf nodes, with a leaf node /// Replaces a subtree containing the retained and the removed leaf nodes, with a leaf node
@ -175,14 +181,14 @@ impl NodeStore {
/// moving the node at the `retained_leaf` index up to the tier specified by `new_depth`. /// moving the node at the `retained_leaf` index up to the tier specified by `new_depth`.
pub fn replace_subtree_with_leaf( pub fn replace_subtree_with_leaf(
&mut self, &mut self,
removed_leaf: NodeIndex,
retained_leaf: NodeIndex,
removed_leaf: LeafNodeIndex,
retained_leaf: LeafNodeIndex,
new_depth: u8, new_depth: u8,
node: RpoDigest, node: RpoDigest,
) -> RpoDigest { ) -> RpoDigest {
debug_assert!(!is_empty_root(&node)); debug_assert!(!is_empty_root(&node));
debug_assert!(self.is_leaf(&removed_leaf));
debug_assert!(self.is_leaf(&retained_leaf));
debug_assert!(self.is_non_empty_leaf(&removed_leaf));
debug_assert!(self.is_non_empty_leaf(&retained_leaf));
debug_assert_eq!(removed_leaf.depth(), retained_leaf.depth()); debug_assert_eq!(removed_leaf.depth(), retained_leaf.depth());
debug_assert!(removed_leaf.depth() > new_depth); debug_assert!(removed_leaf.depth() > new_depth);
@ -202,7 +208,6 @@ impl NodeStore {
// compute the index of the common root for retained and removed leaves // compute the index of the common root for retained and removed leaves
let mut new_index = retained_leaf; let mut new_index = retained_leaf;
new_index.move_up_to(new_depth); new_index.move_up_to(new_depth);
debug_assert!(is_leaf_node(&new_index));
// insert the node at the root index // insert the node at the root index
self.insert_leaf_node(new_index, node) self.insert_leaf_node(new_index, node)
@ -211,19 +216,21 @@ impl NodeStore {
/// Inserts the specified node at the specified index; recomputes and returns the new root /// Inserts the specified node at the specified index; recomputes and returns the new root
/// of the Tiered Sparse Merkle tree. /// of the Tiered Sparse Merkle tree.
/// ///
/// This method assumes that node is a non-empty value.
pub fn insert_leaf_node(&mut self, mut index: NodeIndex, mut node: RpoDigest) -> RpoDigest {
debug_assert!(is_leaf_node(&index));
/// This method assumes that the provided node is a non-empty value, and that there is no node
/// at the specified index.
pub fn insert_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest {
debug_assert!(!is_empty_root(&node)); debug_assert!(!is_empty_root(&node));
debug_assert_eq!(self.nodes.get(&index), None);
// mark the node as the leaf // mark the node as the leaf
if index.depth() == MAX_DEPTH { if index.depth() == MAX_DEPTH {
self.bottom_leaves.insert(index.value()); self.bottom_leaves.insert(index.value());
} else { } else {
self.upper_leaves.insert(index);
self.upper_leaves.insert(index.into());
}; };
// insert the node and update the path from the node to the root // insert the node and update the path from the node to the root
let mut index: NodeIndex = index.into();
for _ in 0..index.depth() { for _ in 0..index.depth() {
self.nodes.insert(index, node); self.nodes.insert(index, node);
let sibling = self.get_node_unchecked(&index.sibling()); let sibling = self.get_node_unchecked(&index.sibling());
@ -240,8 +247,8 @@ impl NodeStore {
/// returns the new root of the Tiered Sparse Merkle tree. /// returns the new root of the Tiered Sparse Merkle tree.
/// ///
/// This method can accept `node` as either an empty or a non-empty value. /// This method can accept `node` as either an empty or a non-empty value.
pub fn update_leaf_node(&mut self, mut index: NodeIndex, mut node: RpoDigest) -> RpoDigest {
debug_assert!(self.is_leaf(&index));
pub fn update_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest {
debug_assert!(self.is_non_empty_leaf(&index));
// if the value we are updating the node to is a root of an empty tree, clear the leaf // if the value we are updating the node to is a root of an empty tree, clear the leaf
// flag for this node // flag for this node
@ -256,6 +263,7 @@ impl NodeStore {
} }
// update the path from the node to the root // update the path from the node to the root
let mut index: NodeIndex = index.into();
for _ in 0..index.depth() { for _ in 0..index.depth() {
if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] { if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] {
self.nodes.remove(&index); self.nodes.remove(&index);
@ -275,8 +283,8 @@ impl NodeStore {
/// Replaces the leaf node at the specified index with a root of an empty subtree; recomputes /// Replaces the leaf node at the specified index with a root of an empty subtree; recomputes
/// and returns the new root of the Tiered Sparse Merkle tree. /// and returns the new root of the Tiered Sparse Merkle tree.
pub fn clear_leaf_node(&mut self, index: NodeIndex) -> RpoDigest {
debug_assert!(self.is_leaf(&index));
pub fn clear_leaf_node(&mut self, index: LeafNodeIndex) -> RpoDigest {
debug_assert!(self.is_non_empty_leaf(&index));
let node = EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize]; let node = EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize];
self.update_leaf_node(index, node) self.update_leaf_node(index, node)
} }
@ -285,8 +293,7 @@ impl NodeStore {
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
/// Returns true if the node at the specified index is a leaf node. /// Returns true if the node at the specified index is a leaf node.
fn is_leaf(&self, index: &NodeIndex) -> bool {
debug_assert!(is_leaf_node(index));
fn is_non_empty_leaf(&self, index: &LeafNodeIndex) -> bool {
if index.depth() == MAX_DEPTH { if index.depth() == MAX_DEPTH {
self.bottom_leaves.contains(&index.value()) self.bottom_leaves.contains(&index.value())
} else { } else {
@ -294,6 +301,16 @@ impl NodeStore {
} }
} }
/// Returns true if the node at the specified index is an internal node - i.e., there is
/// no leaf at that node and the node does not belong to the bottom tier.
fn is_internal_node(&self, index: &NodeIndex) -> bool {
if index.depth() == MAX_DEPTH {
false
} else {
!self.upper_leaves.contains(index)
}
}
/// Checks if the specified index is valid in the context of this Merkle tree. /// Checks if the specified index is valid in the context of this Merkle tree.
/// ///
/// # Errors /// # Errors
@ -309,7 +326,7 @@ impl NodeStore {
} else { } else {
// make sure that there are no leaf nodes in the ancestors of the index; since leaf // make sure that there are no leaf nodes in the ancestors of the index; since leaf
// nodes can live at specific depth, we just need to check these depths. // nodes can live at specific depth, we just need to check these depths.
let tier = get_index_tier(&index);
let tier = ((index.depth() - 1) / TIER_SIZE) as usize;
let mut tier_index = index; let mut tier_index = index;
for &depth in TIER_DEPTHS[..tier].iter().rev() { for &depth in TIER_DEPTHS[..tier].iter().rev() {
tier_index.move_up_to(depth); tier_index.move_up_to(depth);
@ -335,12 +352,13 @@ impl NodeStore {
} }
/// Removes a sequence of nodes starting at the specified index and traversing the /// Removes a sequence of nodes starting at the specified index and traversing the
/// tree up to the specified depth.
/// tree up to the specified depth. The node at the `end_depth` is also removed.
/// ///
/// This method does not update any other nodes and does not recompute the tree root. /// This method does not update any other nodes and does not recompute the tree root.
fn remove_branch(&mut self, mut index: NodeIndex, end_depth: u8) {
fn remove_branch(&mut self, index: LeafNodeIndex, end_depth: u8) {
let mut index: NodeIndex = index.into();
assert!(index.depth() > end_depth); assert!(index.depth() > end_depth);
for _ in 0..(index.depth() - end_depth) {
for _ in 0..(index.depth() - end_depth + 1) {
self.nodes.remove(&index); self.nodes.remove(&index);
index.move_up() index.move_up()
} }

+ 18
- 1
src/merkle/tiered_smt/tests.rs

@ -509,9 +509,26 @@ fn tsmt_bottom_tier() {
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node))); actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
// make sure leaves are returned correctly // make sure leaves are returned correctly
let mut leaves = smt.bottom_leaves();
let smt_clone = smt.clone();
let mut leaves = smt_clone.bottom_leaves();
assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a)]))); assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a)])));
assert_eq!(leaves.next(), None); assert_eq!(leaves.next(), None);
// --- update a leaf at the bottom tier -------------------------------------------------------
let val_a2 = [Felt::new(3); WORD_SIZE];
assert_eq!(smt.insert(key_a, val_a2), val_a);
let leaf_node = build_bottom_leaf_node(&[key_b, key_a], &[val_b, val_a2]);
store.set_node(tree_root, index, leaf_node).unwrap();
let expected_nodes = get_non_empty_nodes(&store);
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
let mut leaves = smt.bottom_leaves();
assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a2)])));
assert_eq!(leaves.next(), None);
} }
#[test] #[test]

+ 22
- 22
src/merkle/tiered_smt/values.rs

@ -1,4 +1,4 @@
use super::{get_key_prefix, is_leaf_node, BTreeMap, NodeIndex, RpoDigest, StarkField, Vec, Word};
use super::{get_key_prefix, BTreeMap, LeafNodeIndex, RpoDigest, StarkField, Vec, Word};
use crate::utils::vec; use crate::utils::vec;
use core::{ use core::{
cmp::{Ord, Ordering}, cmp::{Ord, Ordering},
@ -23,7 +23,8 @@ const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
/// the values are the corresponding key-value pairs (or a list of key-value pairs if more that /// the values are the corresponding key-value pairs (or a list of key-value pairs if more that
/// a single key-value pair shares the same 64-bit prefix). /// a single key-value pair shares the same 64-bit prefix).
/// ///
/// The store supports lookup by the full key as well as by the 64-bit key prefix.
/// The store supports lookup by the full key (i.e. [RpoDigest]) as well as by the 64-bit key
/// prefix.
#[derive(Debug, Default, Clone, PartialEq, Eq)] #[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct ValueStore { pub struct ValueStore {
values: BTreeMap<u64, StoreEntry>, values: BTreeMap<u64, StoreEntry>,
@ -76,26 +77,29 @@ impl ValueStore {
/// ///
/// This method assumes that the key-value pair for the specified index has already been /// This method assumes that the key-value pair for the specified index has already been
/// removed from the store. /// removed from the store.
pub fn get_lone_sibling(&self, index: NodeIndex) -> Option<(&RpoDigest, &Word, NodeIndex)> {
debug_assert!(is_leaf_node(&index));
pub fn get_lone_sibling(
&self,
index: LeafNodeIndex,
) -> Option<(&RpoDigest, &Word, LeafNodeIndex)> {
// iterate over tiers from top to bottom, looking at the tiers which are strictly above // iterate over tiers from top to bottom, looking at the tiers which are strictly above
// the depth of the index. This implies that only tiers at depth 32 and 48 will be // the depth of the index. This implies that only tiers at depth 32 and 48 will be
// considered. For each tier, check if the parent of the index at the higher tier // considered. For each tier, check if the parent of the index at the higher tier
// contains a single node.
for &tier in TIER_DEPTHS.iter().filter(|&t| index.depth() > *t) {
// contains a single node. The fist tier (depth 16) is excluded because we cannot move
// nodes at depth 16 to a higher tier. This implies that nodes at the first tier will
// never have "lone siblings".
for &tier_depth in TIER_DEPTHS.iter().filter(|&t| index.depth() > *t) {
// compute the index of the root at a higher tier // compute the index of the root at a higher tier
let mut parent_index = index; let mut parent_index = index;
parent_index.move_up_to(tier);
parent_index.move_up_to(tier_depth);
// find the lone sibling, if any; we need to handle the "last node" at a given tier // find the lone sibling, if any; we need to handle the "last node" at a given tier
// separately specify the bounds for the search correctly. // separately specify the bounds for the search correctly.
let start_prefix = parent_index.value() << (MAX_DEPTH - tier);
let sibling = if start_prefix.leading_ones() as u8 == tier {
let start_prefix = parent_index.value() << (MAX_DEPTH - tier_depth);
let sibling = if start_prefix.leading_ones() as u8 == tier_depth {
let mut iter = self.range(start_prefix..); let mut iter = self.range(start_prefix..);
iter.next().filter(|_| iter.next().is_none()) iter.next().filter(|_| iter.next().is_none())
} else { } else {
let end_prefix = (parent_index.value() + 1) << (MAX_DEPTH - tier);
let end_prefix = (parent_index.value() + 1) << (MAX_DEPTH - tier_depth);
let mut iter = self.range(start_prefix..end_prefix); let mut iter = self.range(start_prefix..end_prefix);
iter.next().filter(|_| iter.next().is_none()) iter.next().filter(|_| iter.next().is_none())
}; };
@ -346,12 +350,8 @@ fn cmp_digests(d1: &RpoDigest, d2: &RpoDigest) -> Ordering {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{RpoDigest, ValueStore};
use crate::{
merkle::{tiered_smt::values::StoreEntry, NodeIndex},
Felt, ONE, WORD_SIZE, ZERO,
};
use super::{LeafNodeIndex, RpoDigest, StoreEntry, ValueStore};
use crate::{Felt, ONE, WORD_SIZE, ZERO};
#[test] #[test]
fn test_insert() { fn test_insert() {
@ -569,17 +569,17 @@ mod tests {
store.insert(key_b, value_b); store.insert(key_b, value_b);
// check sibling node for `a` // check sibling node for `a`
let index = NodeIndex::make(32, 0b_10101010_10101010_00011111_11111110);
let parent_index = NodeIndex::make(16, 0b_10101010_10101010);
let index = LeafNodeIndex::make(32, 0b_10101010_10101010_00011111_11111110);
let parent_index = LeafNodeIndex::make(16, 0b_10101010_10101010);
assert_eq!(store.get_lone_sibling(index), Some((&key_a, &value_a, parent_index))); assert_eq!(store.get_lone_sibling(index), Some((&key_a, &value_a, parent_index)));
// check sibling node for `b` // check sibling node for `b`
let index = NodeIndex::make(32, 0b_11111111_11111111_00011111_11111111);
let parent_index = NodeIndex::make(16, 0b_11111111_11111111);
let index = LeafNodeIndex::make(32, 0b_11111111_11111111_00011111_11111111);
let parent_index = LeafNodeIndex::make(16, 0b_11111111_11111111);
assert_eq!(store.get_lone_sibling(index), Some((&key_b, &value_b, parent_index))); assert_eq!(store.get_lone_sibling(index), Some((&key_b, &value_b, parent_index)));
// check some other sibling for some other index // check some other sibling for some other index
let index = NodeIndex::make(32, 0b_11101010_10101010);
let index = LeafNodeIndex::make(32, 0b_11101010_10101010);
assert_eq!(store.get_lone_sibling(index), None); assert_eq!(store.get_lone_sibling(index), None);
} }
} }

Loading…
Cancel
Save