diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index d894e9f..6c666c0 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -27,6 +27,9 @@ pub use path_set::MerklePathSet; mod simple_smt; pub use simple_smt::SimpleSmt; +mod tiered_smt; +pub use tiered_smt::TieredSmt; + mod mmr; pub use mmr::{Mmr, MmrPeaks, MmrProof}; diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs new file mode 100644 index 0000000..a8c9255 --- /dev/null +++ b/src/merkle/tiered_smt/mod.rs @@ -0,0 +1,213 @@ +use super::{ + BTreeMap, EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, StarkField, + Vec, Word, EMPTY_WORD, +}; + +#[cfg(test)] +mod tests; + +// TIERED SPARSE MERKLE TREE +// ================================================================================================ + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TieredSmt { + root: RpoDigest, + nodes: BTreeMap, + upper_leaves: BTreeMap, + bottom_leaves: BTreeMap>, + values: BTreeMap, +} + +impl TieredSmt { + // CONSTANTS + // -------------------------------------------------------------------------------------------- + + const MAX_DEPTH: u8 = 64; + + // CONSTRUCTORS + // -------------------------------------------------------------------------------------------- + + pub fn new() -> Self { + Self { + root: EmptySubtreeRoots::empty_hashes(Self::MAX_DEPTH)[0], + nodes: BTreeMap::new(), + upper_leaves: BTreeMap::new(), + bottom_leaves: BTreeMap::new(), + values: BTreeMap::new(), + } + } + + // PUBLIC ACCESSORS + // -------------------------------------------------------------------------------------------- + + pub const fn root(&self) -> RpoDigest { + self.root + } + + pub fn get_node(&self, index: NodeIndex) -> Result { + if index.is_root() { + return Err(MerkleError::DepthTooSmall(index.depth())); + } else if index.depth() > Self::MAX_DEPTH { + return Err(MerkleError::DepthTooBig(index.depth() as u64)); + } else if !self.is_node_available(index) { + todo!() + } + + Ok(self.get_branch_node(&index)) + } + + pub fn get_path(&self, mut index: NodeIndex) -> Result { + if index.is_root() { + return Err(MerkleError::DepthTooSmall(index.depth())); + } else if index.depth() > Self::MAX_DEPTH { + return Err(MerkleError::DepthTooBig(index.depth() as u64)); + } else if !self.is_node_available(index) { + todo!() + } + + let mut path = Vec::with_capacity(index.depth() as usize); + for _ in 0..index.depth() { + let node = self.get_branch_node(&index.sibling()); + path.push(node.into()); + index.move_up(); + } + Ok(path.into()) + } + + pub fn get_value(&self, key: RpoDigest) -> Result { + match self.values.get(&key) { + Some(value) => Ok(*value), + None => Ok(EMPTY_WORD), + } + } + + // STATE MUTATORS + // -------------------------------------------------------------------------------------------- + + pub fn insert(&mut self, key: RpoDigest, value: Word) -> Result { + let (mut index, leaf_key) = self.get_insert_location(&key); + + if let Some(other_key) = leaf_key { + if other_key != key { + let common_prefix_len = get_common_prefix_length(&key, &other_key); + let depth = common_prefix_len + 16; + + let other_index = key_to_index(&other_key, depth); + self.move_leaf_node(other_key, index, other_index); + + index = key_to_index(&key, depth); + } + } + + let old_value = self.values.insert(key, value).unwrap_or(EMPTY_WORD); + if value != old_value { + self.upper_leaves.insert(index, key); + let new_node = build_leaf_node(key, value, index.depth().into()); + self.root = self.update_path(index, new_node); + } + + Ok(old_value) + } + + // HELPER METHODS + // -------------------------------------------------------------------------------------------- + + fn is_node_available(&self, index: NodeIndex) -> bool { + match index.depth() { + 32 => true, + 48 => true, + _ => true, + } + } + + fn get_branch_node(&self, index: &NodeIndex) -> RpoDigest { + match self.nodes.get(index) { + Some(node) => *node, + None => EmptySubtreeRoots::empty_hashes(Self::MAX_DEPTH)[index.depth() as usize], + } + } + + fn get_insert_location(&self, key: &RpoDigest) -> (NodeIndex, Option) { + let mse = Word::from(key)[3].as_int(); + for depth in (16..64).step_by(16) { + let index = NodeIndex::new(depth, mse >> (Self::MAX_DEPTH - depth)).unwrap(); + if let Some(leaf_key) = self.upper_leaves.get(&index) { + return (index, Some(*leaf_key)); + } else if self.nodes.contains_key(&index) { + continue; + } else { + return (index, None); + } + } + + // TODO: handle bottom tier + unimplemented!() + } + + fn move_leaf_node(&mut self, key: RpoDigest, old_index: NodeIndex, new_index: NodeIndex) { + self.upper_leaves.remove(&old_index).unwrap(); + self.upper_leaves.insert(new_index, key); + let value = *self.values.get(&key).unwrap(); + let new_node = build_leaf_node(key, value, new_index.depth().into()); + self.update_path(new_index, new_node); + } + + fn update_path(&mut self, mut index: NodeIndex, mut node: RpoDigest) -> RpoDigest { + for _ in 0..index.depth() { + self.nodes.insert(index, node); + let sibling = self.get_branch_node(&index.sibling()); + node = Rpo256::merge(&index.build_node(node, sibling)); + index.move_up(); + } + node + } +} + +impl Default for TieredSmt { + fn default() -> Self { + Self::new() + } +} + +// HELPER FUNCTIONS +// ================================================================================================ + +fn get_remaining_path(key: RpoDigest, depth: u32) -> RpoDigest { + let mut key = Word::from(key); + let remaining = (key[3].as_int() << depth) >> depth; + key[3] = remaining.into(); + key.into() +} + +fn build_leaf_node(key: RpoDigest, value: Word, depth: u32) -> RpoDigest { + let remaining_path = get_remaining_path(key, depth); + Rpo256::merge_in_domain(&[remaining_path, value.into()], depth.into()) +} + +fn get_common_prefix_length(key1: &RpoDigest, key2: &RpoDigest) -> u8 { + let e1 = Word::from(key1)[3].as_int(); + let e2 = Word::from(key2)[3].as_int(); + + if e1 == e2 { + 64 + } else if e1 >> 16 == e2 >> 16 { + 48 + } else if e1 >> 32 == e2 >> 32 { + 32 + } else if e1 >> 48 == e2 >> 48 { + 16 + } else { + 0 + } +} + +fn key_to_index(key: &RpoDigest, depth: u8) -> NodeIndex { + let mse = Word::from(key)[3].as_int(); + let value = match depth { + 16 | 32 | 48 => mse >> (depth as u32), + _ => unreachable!("invalid depth: {depth}"), + }; + + // TODO: use unchecked version? + NodeIndex::new(depth, value).unwrap() +} diff --git a/src/merkle/tiered_smt/tests.rs b/src/merkle/tiered_smt/tests.rs new file mode 100644 index 0000000..121871d --- /dev/null +++ b/src/merkle/tiered_smt/tests.rs @@ -0,0 +1,179 @@ +use super::{ + super::{super::ONE, Felt, MerkleStore, WORD_SIZE}, + get_remaining_path, EmptySubtreeRoots, NodeIndex, Rpo256, RpoDigest, TieredSmt, Word, +}; + +#[test] +fn tsmt_insert_one() { + let mut smt = TieredSmt::new(); + let mut store = MerkleStore::default(); + + let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]); + let value = [ONE; WORD_SIZE]; + + // since the tree is empty, the first node will be inserted at depth 16 and the index will be + // 16 most significant bits of the key + let index = NodeIndex::make(16, raw >> 48); + let leaf_node = compute_leaf_node(key, value, 16); + let tree_root = store.set_node(smt.root().into(), index, leaf_node.into()).unwrap().root; + + smt.insert(key, value).unwrap(); + + assert_eq!(smt.root(), tree_root.into()); + + // make sure the value was inserted, and the node is at the expected index + assert_eq!(smt.get_value(key).unwrap(), value); + assert_eq!(smt.get_node(index).unwrap(), leaf_node); + + // make sure the paths we get from the store and the tree match + let expected_path = store.get_path(tree_root, index).unwrap(); + assert_eq!(smt.get_path(index).unwrap(), expected_path.path); +} + +#[test] +fn tsmt_insert_two() { + let mut smt = TieredSmt::new(); + let mut store = MerkleStore::default(); + + // --- insert the first value --------------------------------------------- + let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let val_a = [ONE; WORD_SIZE]; + smt.insert(key_a, val_a).unwrap(); + + // --- insert the second value -------------------------------------------- + // the key for this value has the same 16-bit prefix as the key for the first value, + // thus, on insertions, both values should be pushed to depth 32 tier + let raw_b = 0b_10101010_10101010_10011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let val_b = [Felt::new(2); WORD_SIZE]; + smt.insert(key_b, val_b).unwrap(); + + // --- build Merkle store with equivalent data ---------------------------- + let mut tree_root = get_init_root(); + let index_a = NodeIndex::make(32, raw_a >> 32); + let leaf_node_a = compute_leaf_node(key_a, val_a, 32); + tree_root = store.set_node(tree_root, index_a, leaf_node_a.into()).unwrap().root; + + let index_b = NodeIndex::make(32, raw_b >> 32); + let leaf_node_b = compute_leaf_node(key_b, val_b, 32); + tree_root = store.set_node(tree_root, index_b, leaf_node_b.into()).unwrap().root; + + // --- verify that data is consistent between store and tree -------------- + + assert_eq!(smt.root(), tree_root.into()); + + assert_eq!(smt.get_value(key_a).unwrap(), val_a); + assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a); + let expected_path = store.get_path(tree_root, index_a).unwrap().path; + assert_eq!(smt.get_path(index_a).unwrap(), expected_path); + + assert_eq!(smt.get_value(key_b).unwrap(), val_b); + assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b); + let expected_path = store.get_path(tree_root, index_b).unwrap().path; + assert_eq!(smt.get_path(index_b).unwrap(), expected_path); +} + +#[test] +fn tsmt_insert_three() { + let mut smt = TieredSmt::new(); + let mut store = MerkleStore::default(); + + // --- insert the first value --------------------------------------------- + let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let val_a = [ONE; WORD_SIZE]; + smt.insert(key_a, val_a).unwrap(); + + // --- insert the second value -------------------------------------------- + // the key for this value has the same 16-bit prefix as the key for the first value, + // thus, on insertions, both values should be pushed to depth 32 tier + let raw_b = 0b_10101010_10101010_10011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let val_b = [Felt::new(2); WORD_SIZE]; + smt.insert(key_b, val_b).unwrap(); + + // --- insert the third value --------------------------------------------- + // the key for this value has the same 16-bit prefix as the keys for the first two, + // values; thus, on insertions, it will be inserted into depth 32 tier, but will not + // affect locations of the other two values + let raw_c = 0b_10101010_10101010_11011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); + let val_c = [Felt::new(3); WORD_SIZE]; + smt.insert(key_c, val_c).unwrap(); + + // --- build Merkle store with equivalent data ---------------------------- + let mut tree_root = get_init_root(); + let index_a = NodeIndex::make(32, raw_a >> 32); + let leaf_node_a = compute_leaf_node(key_a, val_a, 32); + tree_root = store.set_node(tree_root, index_a, leaf_node_a.into()).unwrap().root; + + let index_b = NodeIndex::make(32, raw_b >> 32); + let leaf_node_b = compute_leaf_node(key_b, val_b, 32); + tree_root = store.set_node(tree_root, index_b, leaf_node_b.into()).unwrap().root; + + let index_c = NodeIndex::make(32, raw_c >> 32); + let leaf_node_c = compute_leaf_node(key_c, val_c, 32); + tree_root = store.set_node(tree_root, index_c, leaf_node_c.into()).unwrap().root; + + // --- verify that data is consistent between store and tree -------------- + + assert_eq!(smt.root(), tree_root.into()); + + assert_eq!(smt.get_value(key_a).unwrap(), val_a); + assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a); + let expected_path = store.get_path(tree_root, index_a).unwrap().path; + assert_eq!(smt.get_path(index_a).unwrap(), expected_path); + + assert_eq!(smt.get_value(key_b).unwrap(), val_b); + assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b); + let expected_path = store.get_path(tree_root, index_b).unwrap().path; + assert_eq!(smt.get_path(index_b).unwrap(), expected_path); + + assert_eq!(smt.get_value(key_c).unwrap(), val_c); + assert_eq!(smt.get_node(index_c).unwrap(), leaf_node_c); + let expected_path = store.get_path(tree_root, index_c).unwrap().path; + assert_eq!(smt.get_path(index_c).unwrap(), expected_path); +} + +#[test] +fn tsmt_update() { + let mut smt = TieredSmt::new(); + let mut store = MerkleStore::default(); + + // --- insert a value into the tree --------------------------------------- + let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]); + let value_a = [ONE; WORD_SIZE]; + smt.insert(key, value_a).unwrap(); + + // --- update value --------------------------------------- + let value_b = [Felt::new(2); WORD_SIZE]; + smt.insert(key, value_b).unwrap(); + + // --- verify consistency ------------------------------------------------- + let mut tree_root = get_init_root(); + let index = NodeIndex::make(16, raw >> 48); + let leaf_node = compute_leaf_node(key, value_b, 16); + tree_root = store.set_node(tree_root, index, leaf_node.into()).unwrap().root; + + assert_eq!(smt.root(), tree_root.into()); + + assert_eq!(smt.get_value(key).unwrap(), value_b); + assert_eq!(smt.get_node(index).unwrap(), leaf_node); + let expected_path = store.get_path(tree_root, index).unwrap().path; + assert_eq!(smt.get_path(index).unwrap(), expected_path); +} + +// HELPER FUNCTIONS +// ================================================================================================ + +fn get_init_root() -> Word { + EmptySubtreeRoots::empty_hashes(64)[0].into() +} + +fn compute_leaf_node(key: RpoDigest, value: Word, depth: u8) -> RpoDigest { + let remaining_path = get_remaining_path(key, depth as u32); + Rpo256::merge_in_domain(&[remaining_path, value.into()], depth.into()) +}