diff --git a/CHANGELOG.md b/CHANGELOG.md index 76c65a7..a3b6c21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ - [BREAKING]: renamed `Mmr::open()` into `Mmr::open_at()` and `Mmr::peaks()` into `Mmr::peaks_at()` (#234). - Added `Mmr::open()` and `Mmr::peaks()` which rely on `Mmr::open_at()` and `Mmr::peaks()` respectively (#234). - Standardised CI and Makefile across Miden repos (#323). +- Added `Smt::compute_mutations()` and `Smt::apply_mutations()` for validation-checked insertions (#327). ## 0.10.0 (2024-08-06) diff --git a/src/main.rs b/src/main.rs index ee6e86c..776ccc2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -35,6 +35,7 @@ pub fn benchmark_smt() { let mut tree = construction(entries, tree_size).unwrap(); insertion(&mut tree, tree_size).unwrap(); + batched_insertion(&mut tree, tree_size).unwrap(); proof_generation(&mut tree, tree_size).unwrap(); } @@ -82,6 +83,54 @@ pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> { Ok(()) } +pub fn batched_insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> { + println!("Running a batched insertion benchmark:"); + + let new_pairs: Vec<(RpoDigest, Word)> = (0..1000) + .map(|i| { + let key = Rpo256::hash(&rand_value::().to_be_bytes()); + let value = [ONE, ONE, ONE, Felt::new(size + i)]; + (key, value) + }) + .collect(); + + let now = Instant::now(); + let mutations = tree.compute_mutations(new_pairs); + let compute_elapsed = now.elapsed(); + + let now = Instant::now(); + tree.apply_mutations(mutations).unwrap(); + let apply_elapsed = now.elapsed(); + + println!( + "An average batch computation time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds", + size, + compute_elapsed.as_secs_f32() * 1000f32, + // Dividing by the number of iterations, 1000, and then multiplying by 1000 to get + // milliseconds, cancels out. + compute_elapsed.as_secs_f32(), + ); + + println!( + "An average batch application time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds", + size, + apply_elapsed.as_secs_f32() * 1000f32, + // Dividing by the number of iterations, 1000, and then multiplying by 1000 to get + // milliseconds, cancels out. + apply_elapsed.as_secs_f32(), + ); + + println!( + "An average batch insertion time measured by a 1k-batch into an SMT with {} key-value pairs totals to {:.3} milliseconds", + size, + (compute_elapsed + apply_elapsed).as_secs_f32() * 1000f32, + ); + + println!(); + + Ok(()) +} + /// Runs the proof generation benchmark for the [`Smt`]. pub fn proof_generation(tree: &mut Smt, size: u64) -> Result<(), MerkleError> { println!("Running a proof generation benchmark:"); diff --git a/src/merkle/empty_roots.rs b/src/merkle/empty_roots.rs index 30dd41a..1f54a7a 100644 --- a/src/merkle/empty_roots.rs +++ b/src/merkle/empty_roots.rs @@ -1,6 +1,6 @@ use core::slice; -use super::{Felt, RpoDigest, EMPTY_WORD}; +use super::{smt::InnerNode, Felt, RpoDigest, EMPTY_WORD}; // EMPTY NODES SUBTREES // ================================================================================================ @@ -25,6 +25,17 @@ impl EmptySubtreeRoots { let pos = 255 - tree_depth + node_depth; &EMPTY_SUBTREES[pos as usize] } + + /// Returns a sparse Merkle tree [`InnerNode`] with two empty children. + /// + /// # Note + /// `node_depth` is the depth of the **parent** to have empty children. That is, `node_depth` + /// and the depth of the returned [`InnerNode`] are the same, and thus the empty hashes are for + /// subtrees of depth `node_depth + 1`. + pub(crate) const fn get_inner_node(tree_depth: u8, node_depth: u8) -> InnerNode { + let &child = Self::entry(tree_depth, node_depth + 1); + InnerNode { left: child, right: child } + } } const EMPTY_SUBTREES: [RpoDigest; 256] = [ diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 8954d4d..a562aa5 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -22,8 +22,8 @@ pub use path::{MerklePath, RootPath, ValuePath}; mod smt; pub use smt::{ - LeafIndex, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH, - SMT_MAX_DEPTH, SMT_MIN_DEPTH, + LeafIndex, MutationSet, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, + SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH, }; mod mmr; diff --git a/src/merkle/smt/full/leaf.rs b/src/merkle/smt/full/leaf.rs index 095a4fb..585fc40 100644 --- a/src/merkle/smt/full/leaf.rs +++ b/src/merkle/smt/full/leaf.rs @@ -350,7 +350,7 @@ impl Deserializable for SmtLeaf { // ================================================================================================ /// Converts a key-value tuple to an iterator of `Felt`s -fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator { +pub(crate) fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator { let key_elements = key.into_iter(); let value_elements = value.into_iter(); @@ -359,7 +359,7 @@ fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator /// Compares two keys, compared element-by-element using their integer representations starting with /// the most significant element. -fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering { +pub(crate) fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering { for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() { let v1 = v1.as_int(); let v2 = v2.as_int(); diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 5cd510b..9c64002 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -6,7 +6,7 @@ use alloc::{ use super::{ EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath, - NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, + MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, }; mod error; @@ -121,12 +121,7 @@ impl Smt { /// Returns the value associated with `key` pub fn get_value(&self, key: &RpoDigest) -> Word { - let leaf_pos = LeafIndex::::from(*key).value(); - - match self.leaves.get(&leaf_pos) { - Some(leaf) => leaf.get_value(key).unwrap_or_default(), - None => EMPTY_WORD, - } + >::get_value(self, key) } /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle @@ -172,6 +167,47 @@ impl Smt { >::insert(self, key, value) } + /// Computes what changes are necessary to insert the specified key-value pairs into this Merkle + /// tree, allowing for validation before applying those changes. + /// + /// This method returns a [`MutationSet`], which contains all the information for inserting + /// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can + /// be queried with [`MutationSet::root()`]. Once a mutation set is returned, + /// [`Smt::apply_mutations()`] can be called in order to commit these changes to the Merkle + /// tree, or [`drop()`] to discard them. + /// + /// # Example + /// ``` + /// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word}; + /// # use miden_crypto::merkle::{Smt, EmptySubtreeRoots, SMT_DEPTH}; + /// let mut smt = Smt::new(); + /// let pair = (RpoDigest::default(), Word::default()); + /// let mutations = smt.compute_mutations(vec![pair]); + /// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0)); + /// smt.apply_mutations(mutations); + /// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0)); + /// ``` + pub fn compute_mutations( + &self, + kv_pairs: impl IntoIterator, + ) -> MutationSet { + >::compute_mutations(self, kv_pairs) + } + + /// Apply the prospective mutations computed with [`Smt::compute_mutations()`] to this tree. + /// + /// # Errors + /// If `mutations` was computed on a tree with a different root than this one, returns + /// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash + /// the `mutations` were computed against, and the second item is the actual current root of + /// this tree. + pub fn apply_mutations( + &mut self, + mutations: MutationSet, + ) -> Result<(), MerkleError> { + >::apply_mutations(self, mutations) + } + // HELPERS // -------------------------------------------------------------------------------------------- @@ -226,11 +262,10 @@ impl SparseMerkleTree for Smt { } fn get_inner_node(&self, index: NodeIndex) -> InnerNode { - self.inner_nodes.get(&index).cloned().unwrap_or_else(|| { - let node = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth() + 1); - - InnerNode { left: *node, right: *node } - }) + self.inner_nodes + .get(&index) + .cloned() + .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth())) } fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) { @@ -250,6 +285,15 @@ impl SparseMerkleTree for Smt { } } + fn get_value(&self, key: &Self::Key) -> Self::Value { + let leaf_pos = LeafIndex::::from(*key).value(); + + match self.leaves.get(&leaf_pos) { + Some(leaf) => leaf.get_value(key).unwrap_or_default(), + None => EMPTY_WORD, + } + } + fn get_leaf(&self, key: &RpoDigest) -> Self::Leaf { let leaf_pos = LeafIndex::::from(*key).value(); @@ -263,6 +307,28 @@ impl SparseMerkleTree for Smt { leaf.hash() } + fn construct_prospective_leaf( + &self, + mut existing_leaf: SmtLeaf, + key: &RpoDigest, + value: &Word, + ) -> SmtLeaf { + debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key)); + + match existing_leaf { + SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value), + _ => { + if *value != EMPTY_WORD { + existing_leaf.insert(*key, *value); + } else { + existing_leaf.remove(*key); + } + + existing_leaf + }, + } + } + fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex { let most_significant_felt = key[3]; LeafIndex::new_max_depth(most_significant_felt.as_int()) diff --git a/src/merkle/smt/full/tests.rs b/src/merkle/smt/full/tests.rs index 66cd203..1613c8f 100644 --- a/src/merkle/smt/full/tests.rs +++ b/src/merkle/smt/full/tests.rs @@ -2,7 +2,7 @@ use alloc::vec::Vec; use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH}; use crate::{ - merkle::{EmptySubtreeRoots, MerkleStore}, + merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleStore}, utils::{Deserializable, Serializable}, Word, ONE, WORD_SIZE, }; @@ -258,6 +258,195 @@ fn test_smt_removal() { } } +/// This tests that we can correctly calculate prospective leaves -- that is, we can construct +/// correct [`SmtLeaf`] values for a theoretical insertion on a Merkle tree without mutating or +/// cloning the tree. +#[test] +fn test_prospective_hash() { + let mut smt = Smt::default(); + + let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; + + let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]); + let key_2: RpoDigest = + RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]); + // Sort key_3 before key_1, to test non-append insertion. + let key_3: RpoDigest = + RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]); + + let value_1 = [ONE; WORD_SIZE]; + let value_2 = [2_u32.into(); WORD_SIZE]; + let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE]; + + // insert key-value 1 + { + let prospective = + smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &value_1).hash(); + smt.insert(key_1, value_1); + + let leaf = smt.get_leaf(&key_1); + assert_eq!( + prospective, + leaf.hash(), + "prospective hash for leaf {leaf:?} did not match actual hash", + ); + } + + // insert key-value 2 + { + let prospective = + smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &value_2).hash(); + smt.insert(key_2, value_2); + + let leaf = smt.get_leaf(&key_2); + assert_eq!( + prospective, + leaf.hash(), + "prospective hash for leaf {leaf:?} did not match actual hash", + ); + } + + // insert key-value 3 + { + let prospective = + smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &value_3).hash(); + smt.insert(key_3, value_3); + + let leaf = smt.get_leaf(&key_3); + assert_eq!( + prospective, + leaf.hash(), + "prospective hash for leaf {leaf:?} did not match actual hash", + ); + } + + // remove key 3 + { + let old_leaf = smt.get_leaf(&key_3); + let old_value_3 = smt.insert(key_3, EMPTY_WORD); + assert_eq!(old_value_3, value_3); + let prospective_leaf = + smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &old_value_3); + + assert_eq!( + old_leaf.hash(), + prospective_leaf.hash(), + "removing and prospectively re-adding a leaf didn't yield the original leaf:\ + \n original leaf: {old_leaf:?}\ + \n prospective leaf: {prospective_leaf:?}", + ); + } + + // remove key 2 + { + let old_leaf = smt.get_leaf(&key_2); + let old_value_2 = smt.insert(key_2, EMPTY_WORD); + assert_eq!(old_value_2, value_2); + let prospective_leaf = + smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &old_value_2); + + assert_eq!( + old_leaf.hash(), + prospective_leaf.hash(), + "removing and prospectively re-adding a leaf didn't yield the original leaf:\ + \n original leaf: {old_leaf:?}\ + \n prospective leaf: {prospective_leaf:?}", + ); + } + + // remove key 1 + { + let old_leaf = smt.get_leaf(&key_1); + let old_value_1 = smt.insert(key_1, EMPTY_WORD); + assert_eq!(old_value_1, value_1); + let prospective_leaf = + smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &old_value_1); + assert_eq!( + old_leaf.hash(), + prospective_leaf.hash(), + "removing and prospectively re-adding a leaf didn't yield the original leaf:\ + \n original leaf: {old_leaf:?}\ + \n prospective leaf: {prospective_leaf:?}", + ); + } +} + +/// This tests that we can perform prospective changes correctly. +#[test] +fn test_prospective_insertion() { + let mut smt = Smt::default(); + + let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; + + let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]); + let key_2: RpoDigest = + RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]); + // Sort key_3 before key_1, to test non-append insertion. + let key_3: RpoDigest = + RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]); + + let value_1 = [ONE; WORD_SIZE]; + let value_2 = [2_u32.into(); WORD_SIZE]; + let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE]; + + let root_empty = smt.root(); + + let root_1 = { + smt.insert(key_1, value_1); + smt.root() + }; + + let root_2 = { + smt.insert(key_2, value_2); + smt.root() + }; + + let root_3 = { + smt.insert(key_3, value_3); + smt.root() + }; + + // Test incremental updates. + + let mut smt = Smt::default(); + + let mutations = smt.compute_mutations(vec![(key_1, value_1)]); + assert_eq!(mutations.root(), root_1, "prospective root 1 did not match actual root 1"); + smt.apply_mutations(mutations).unwrap(); + assert_eq!(smt.root(), root_1, "mutations before and after apply did not match"); + + let mutations = smt.compute_mutations(vec![(key_2, value_2)]); + assert_eq!(mutations.root(), root_2, "prospective root 2 did not match actual root 2"); + let mutations = + smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_2, value_2), (key_3, value_3)]); + assert_eq!(mutations.root(), root_3, "mutations before and after apply did not match"); + smt.apply_mutations(mutations).unwrap(); + + // Edge case: multiple values at the same key, where a later pair restores the original value. + let mutations = smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_3, value_3)]); + assert_eq!(mutations.root(), root_3); + smt.apply_mutations(mutations).unwrap(); + assert_eq!(smt.root(), root_3); + + // Test batch updates, and that the order doesn't matter. + let pairs = + vec![(key_3, value_2), (key_2, EMPTY_WORD), (key_1, EMPTY_WORD), (key_3, EMPTY_WORD)]; + let mutations = smt.compute_mutations(pairs); + assert_eq!( + mutations.root(), + root_empty, + "prospective root for batch removal did not match actual root", + ); + smt.apply_mutations(mutations).unwrap(); + assert_eq!(smt.root(), root_empty, "mutations before and after apply did not match"); + + let pairs = vec![(key_3, value_3), (key_1, value_1), (key_2, value_2)]; + let mutations = smt.compute_mutations(pairs); + assert_eq!(mutations.root(), root_3); + smt.apply_mutations(mutations).unwrap(); + assert_eq!(smt.root(), root_3); +} + /// Tests that 2 key-value pairs stored in the same leaf have the same path #[test] fn test_smt_path_to_keys_in_same_leaf_are_equal() { diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index d7d42da..0b7ceb9 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -1,4 +1,4 @@ -use alloc::vec::Vec; +use alloc::{collections::BTreeMap, vec::Vec}; use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex}; use crate::{ @@ -45,11 +45,11 @@ pub const SMT_MAX_DEPTH: u8 = 64; /// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs. pub(crate) trait SparseMerkleTree { /// The type for a key - type Key: Clone; + type Key: Clone + Ord; /// The type for a value type Value: Clone + PartialEq; /// The type for a leaf - type Leaf; + type Leaf: Clone; /// The type for an opening (i.e. a "proof") of a leaf type Opening; @@ -140,6 +140,149 @@ pub(crate) trait SparseMerkleTree { self.set_root(node_hash); } + /// Computes what changes are necessary to insert the specified key-value pairs into this Merkle + /// tree, allowing for validation before applying those changes. + /// + /// This method returns a [`MutationSet`], which contains all the information for inserting + /// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can + /// be queried with [`MutationSet::root()`]. Once a mutation set is returned, + /// [`SparseMerkleTree::apply_mutations()`] can be called in order to commit these changes to + /// the Merkle tree, or [`drop()`] to discard them. + fn compute_mutations( + &self, + kv_pairs: impl IntoIterator, + ) -> MutationSet { + use NodeMutation::*; + + let mut new_root = self.root(); + let mut new_pairs: BTreeMap = Default::default(); + let mut node_mutations: BTreeMap = Default::default(); + + for (key, value) in kv_pairs { + // If the old value and the new value are the same, there is nothing to update. + // For the unusual case that kv_pairs has multiple values at the same key, we'll have + // to check the key-value pairs we've already seen to get the "effective" old value. + let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key)); + if value == old_value { + continue; + } + + let leaf_index = Self::key_to_leaf_index(&key); + let mut node_index = NodeIndex::from(leaf_index); + + // We need the current leaf's hash to calculate the new leaf, but in the rare case that + // `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also + // part of the "current leaf". + let old_leaf = { + let pairs_at_index = new_pairs + .iter() + .filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index); + + pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| { + // Most of the time `pairs_at_index` should only contain a single entry (or + // none at all), as multi-leaves should be really rare. + let existing_leaf = acc.clone(); + self.construct_prospective_leaf(existing_leaf, k, v) + }) + }; + + let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value); + + let mut new_child_hash = Self::hash_leaf(&new_leaf); + + for node_depth in (0..node_index.depth()).rev() { + // Whether the node we're replacing is the right child or the left child. + let is_right = node_index.is_value_odd(); + node_index.move_up(); + + let old_node = node_mutations + .get(&node_index) + .map(|mutation| match mutation { + Addition(node) => node.clone(), + Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth), + }) + .unwrap_or_else(|| self.get_inner_node(node_index)); + + let new_node = if is_right { + InnerNode { + left: old_node.left, + right: new_child_hash, + } + } else { + InnerNode { + left: new_child_hash, + right: old_node.right, + } + }; + + // The next iteration will operate on this new node's hash. + new_child_hash = new_node.hash(); + + let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth); + let is_removal = new_child_hash == equivalent_empty_hash; + let new_entry = if is_removal { Removal } else { Addition(new_node) }; + node_mutations.insert(node_index, new_entry); + } + + // Once we're at depth 0, the last node we made is the new root. + new_root = new_child_hash; + // And then we're done with this pair; on to the next one. + new_pairs.insert(key, value); + } + + MutationSet { + old_root: self.root(), + new_root, + node_mutations, + new_pairs, + } + } + + /// Apply the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to + /// this tree. + /// + /// # Errors + /// If `mutations` was computed on a tree with a different root than this one, returns + /// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash + /// the `mutations` were computed against, and the second item is the actual current root of + /// this tree. + fn apply_mutations( + &mut self, + mutations: MutationSet, + ) -> Result<(), MerkleError> + where + Self: Sized, + { + use NodeMutation::*; + let MutationSet { + old_root, + node_mutations, + new_pairs, + new_root, + } = mutations; + + // Guard against accidentally trying to apply mutations that were computed against a + // different tree, including a stale version of this tree. + if old_root != self.root() { + return Err(MerkleError::ConflictingRoots(vec![old_root, self.root()])); + } + + for (index, mutation) in node_mutations { + match mutation { + Removal => self.remove_inner_node(index), + Addition(node) => self.insert_inner_node(index, node), + } + } + + for (key, value) in new_pairs { + self.insert_value(key, value); + } + + self.set_root(new_root); + + Ok(()) + } + // REQUIRED METHODS // --------------------------------------------------------------------------------------------- @@ -161,12 +304,34 @@ pub(crate) trait SparseMerkleTree { /// Inserts a leaf node, and returns the value at the key if already exists fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option; + /// Returns the value at the specified key. Recall that by definition, any key that hasn't been + /// updated is associated with [`Self::EMPTY_VALUE`]. + fn get_value(&self, key: &Self::Key) -> Self::Value; + /// Returns the leaf at the specified index. fn get_leaf(&self, key: &Self::Key) -> Self::Leaf; /// Returns the hash of a leaf fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest; + /// Returns what a leaf would look like if a key-value pair were inserted into the tree, without + /// mutating the tree itself. The existing leaf can be empty. + /// + /// To get a prospective leaf based on the current state of the tree, use `self.get_leaf(key)` + /// as the argument for `existing_leaf`. The return value from this function can be chained back + /// into this function as the first argument to continue making prospective changes. + /// + /// # Invariants + /// Because this method is for a prospective key-value insertion into a specific leaf, + /// `existing_leaf` must have the same leaf index as `key` (as determined by + /// [`SparseMerkleTree::key_to_leaf_index()`]), or the result will be meaningless. + fn construct_prospective_leaf( + &self, + existing_leaf: Self::Leaf, + key: &Self::Key, + value: &Self::Value, + ) -> Self::Leaf; + /// Maps a key to a leaf index fn key_to_leaf_index(key: &Self::Key) -> LeafIndex; @@ -244,3 +409,50 @@ impl TryFrom for LeafIndex { Self::new(node_index.value()) } } + +// MUTATIONS +// ================================================================================================ + +/// A change to an inner node of a [`SparseMerkleTree`] that hasn't yet been applied. +/// [`MutationSet`] stores this type in relation to a [`NodeIndex`] to keep track of what changes +/// need to occur at which node indices. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum NodeMutation { + /// Corresponds to [`SparseMerkleTree::remove_inner_node()`]. + Removal, + /// Corresponds to [`SparseMerkleTree::insert_inner_node()`]. + Addition(InnerNode), +} + +/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by +/// `SparseMerkleTree::compute_mutations()`, and that can be applied with +/// `SparseMerkleTree::apply_mutations()`. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct MutationSet { + /// The root of the Merkle tree this MutationSet is for, recorded at the time + /// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying + /// mutations to the wrong tree or applying stale mutations to a tree that has since changed. + old_root: RpoDigest, + /// The set of nodes that need to be removed or added. The "effective" node at an index is the + /// Merkle tree's existing node at that index, with the [`NodeMutation`] in this map at that + /// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a + /// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`] + /// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call. + node_mutations: BTreeMap, + /// The set of top-level key-value pairs we're prospectively adding to the tree, including + /// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling + /// back to the existing value in the Merkle tree. Each entry corresponds to a + /// [`SparseMerkleTree::insert_value()`] call. + new_pairs: BTreeMap, + /// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with + /// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call. + new_root: RpoDigest, +} + +impl MutationSet { + /// Queries the root that was calculated during `SparseMerkleTree::compute_mutations()`. See + /// that method for more information. + pub fn root(&self) -> RpoDigest { + self.new_root + } +} diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index f1ff0dc..1744430 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -2,8 +2,8 @@ use alloc::collections::{BTreeMap, BTreeSet}; use super::{ super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, - MerklePath, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, SMT_MAX_DEPTH, - SMT_MIN_DEPTH, + MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, + SMT_MAX_DEPTH, SMT_MIN_DEPTH, }; #[cfg(test)] @@ -188,6 +188,48 @@ impl SimpleSmt { >::insert(self, key, value) } + /// Computes what changes are necessary to insert the specified key-value pairs into this + /// Merkle tree, allowing for validation before applying those changes. + /// + /// This method returns a [`MutationSet`], which contains all the information for inserting + /// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can + /// be queried with [`MutationSet::root()`]. Once a mutation set is returned, + /// [`SimpleSmt::apply_mutations()`] can be called in order to commit these changes to the + /// Merkle tree, or [`drop()`] to discard them. + + /// # Example + /// ``` + /// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word}; + /// # use miden_crypto::merkle::{LeafIndex, SimpleSmt, EmptySubtreeRoots, SMT_DEPTH}; + /// let mut smt: SimpleSmt<3> = SimpleSmt::new().unwrap(); + /// let pair = (LeafIndex::default(), Word::default()); + /// let mutations = smt.compute_mutations(vec![pair]); + /// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(3, 0)); + /// smt.apply_mutations(mutations); + /// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(3, 0)); + /// ``` + pub fn compute_mutations( + &self, + kv_pairs: impl IntoIterator, Word)>, + ) -> MutationSet, Word> { + >::compute_mutations(self, kv_pairs) + } + + /// Apply the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this + /// tree. + /// + /// # Errors + /// If `mutations` was computed on a tree with a different root than this one, returns + /// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the + /// root hash the `mutations` were computed against, and the second item is the actual + /// current root of this tree. + pub fn apply_mutations( + &mut self, + mutations: MutationSet, Word>, + ) -> Result<(), MerkleError> { + >::apply_mutations(self, mutations) + } + /// Inserts a subtree at the specified index. The depth at which the subtree is inserted is /// computed as `DEPTH - SUBTREE_DEPTH`. /// @@ -266,11 +308,10 @@ impl SparseMerkleTree for SimpleSmt { } fn get_inner_node(&self, index: NodeIndex) -> InnerNode { - self.inner_nodes.get(&index).cloned().unwrap_or_else(|| { - let node = EmptySubtreeRoots::entry(DEPTH, index.depth() + 1); - - InnerNode { left: *node, right: *node } - }) + self.inner_nodes + .get(&index) + .cloned() + .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(DEPTH, index.depth())) } fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) { @@ -289,6 +330,10 @@ impl SparseMerkleTree for SimpleSmt { } } + fn get_value(&self, key: &LeafIndex) -> Word { + self.get_leaf(key) + } + fn get_leaf(&self, key: &LeafIndex) -> Word { let leaf_pos = key.value(); match self.leaves.get(&leaf_pos) { @@ -302,6 +347,15 @@ impl SparseMerkleTree for SimpleSmt { leaf.into() } + fn construct_prospective_leaf( + &self, + _existing_leaf: Word, + _key: &LeafIndex, + value: &Word, + ) -> Word { + *value + } + fn key_to_leaf_index(key: &LeafIndex) -> LeafIndex { *key }