From ae807a47aea08b271dc465d5cfa522861122c896 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Wed, 11 Sep 2024 17:49:57 -0600 Subject: [PATCH] feat: implement transactional Smt insertion (#327) * feat(smt): impl constructing leaves that don't yet exist This commit implements 'prospective leaf construction' -- computing sparse Merkle tree leaves for a key-value insertion without actually performing that insertion. For SimpleSmt, this is trivial, since the leaf type is simply the value being inserted. For the full Smt, the new leaf payload depends on the existing payload in that leaf. Since almost all leaves are very small, we can just clone the leaf and modify a copy. This will allow us to perform more general prospective changes on Merkle trees. * feat(smt): export get_value() in the trait * feat(smt): implement generic prospective insertions This commit adds two methods to SparseMerkleTree: compute_mutations() and apply_mutations(), which respectively create and consume a new MutationSet type. This type represents as set of changes to a SparseMerkleTree that haven't happened yet, and can be queried on to ensure a set of insertions result in the correct tree root before finalizing and committing the mutation. This is a direct step towards issue 222, and will directly enable removing Merkle tree clones in miden-node InnerState::apply_block(). As part of this change, SparseMerkleTree now requires its Key to be Ord and its Leaf to be Clone (both bounds which were already met by existing implementations). The Ord bound could instead be changed to Eq + Hash, if MutationSet were changed to use a HashMap instead of a BTreeMap. * chore(smt): refactor empty node construction to helper function --- CHANGELOG.md | 1 + src/main.rs | 49 ++++++++ src/merkle/empty_roots.rs | 13 ++- src/merkle/mod.rs | 4 +- src/merkle/smt/full/leaf.rs | 4 +- src/merkle/smt/full/mod.rs | 90 +++++++++++++-- src/merkle/smt/full/tests.rs | 191 +++++++++++++++++++++++++++++- src/merkle/smt/mod.rs | 218 ++++++++++++++++++++++++++++++++++- src/merkle/smt/simple/mod.rs | 68 +++++++++-- 9 files changed, 610 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 76c65a7..a3b6c21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ - [BREAKING]: renamed `Mmr::open()` into `Mmr::open_at()` and `Mmr::peaks()` into `Mmr::peaks_at()` (#234). - Added `Mmr::open()` and `Mmr::peaks()` which rely on `Mmr::open_at()` and `Mmr::peaks()` respectively (#234). - Standardised CI and Makefile across Miden repos (#323). +- Added `Smt::compute_mutations()` and `Smt::apply_mutations()` for validation-checked insertions (#327). ## 0.10.0 (2024-08-06) diff --git a/src/main.rs b/src/main.rs index ee6e86c..776ccc2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -35,6 +35,7 @@ pub fn benchmark_smt() { let mut tree = construction(entries, tree_size).unwrap(); insertion(&mut tree, tree_size).unwrap(); + batched_insertion(&mut tree, tree_size).unwrap(); proof_generation(&mut tree, tree_size).unwrap(); } @@ -82,6 +83,54 @@ pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> { Ok(()) } +pub fn batched_insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> { + println!("Running a batched insertion benchmark:"); + + let new_pairs: Vec<(RpoDigest, Word)> = (0..1000) + .map(|i| { + let key = Rpo256::hash(&rand_value::().to_be_bytes()); + let value = [ONE, ONE, ONE, Felt::new(size + i)]; + (key, value) + }) + .collect(); + + let now = Instant::now(); + let mutations = tree.compute_mutations(new_pairs); + let compute_elapsed = now.elapsed(); + + let now = Instant::now(); + tree.apply_mutations(mutations).unwrap(); + let apply_elapsed = now.elapsed(); + + println!( + "An average batch computation time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds", + size, + compute_elapsed.as_secs_f32() * 1000f32, + // Dividing by the number of iterations, 1000, and then multiplying by 1000 to get + // milliseconds, cancels out. + compute_elapsed.as_secs_f32(), + ); + + println!( + "An average batch application time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds", + size, + apply_elapsed.as_secs_f32() * 1000f32, + // Dividing by the number of iterations, 1000, and then multiplying by 1000 to get + // milliseconds, cancels out. + apply_elapsed.as_secs_f32(), + ); + + println!( + "An average batch insertion time measured by a 1k-batch into an SMT with {} key-value pairs totals to {:.3} milliseconds", + size, + (compute_elapsed + apply_elapsed).as_secs_f32() * 1000f32, + ); + + println!(); + + Ok(()) +} + /// Runs the proof generation benchmark for the [`Smt`]. pub fn proof_generation(tree: &mut Smt, size: u64) -> Result<(), MerkleError> { println!("Running a proof generation benchmark:"); diff --git a/src/merkle/empty_roots.rs b/src/merkle/empty_roots.rs index 30dd41a..1f54a7a 100644 --- a/src/merkle/empty_roots.rs +++ b/src/merkle/empty_roots.rs @@ -1,6 +1,6 @@ use core::slice; -use super::{Felt, RpoDigest, EMPTY_WORD}; +use super::{smt::InnerNode, Felt, RpoDigest, EMPTY_WORD}; // EMPTY NODES SUBTREES // ================================================================================================ @@ -25,6 +25,17 @@ impl EmptySubtreeRoots { let pos = 255 - tree_depth + node_depth; &EMPTY_SUBTREES[pos as usize] } + + /// Returns a sparse Merkle tree [`InnerNode`] with two empty children. + /// + /// # Note + /// `node_depth` is the depth of the **parent** to have empty children. That is, `node_depth` + /// and the depth of the returned [`InnerNode`] are the same, and thus the empty hashes are for + /// subtrees of depth `node_depth + 1`. + pub(crate) const fn get_inner_node(tree_depth: u8, node_depth: u8) -> InnerNode { + let &child = Self::entry(tree_depth, node_depth + 1); + InnerNode { left: child, right: child } + } } const EMPTY_SUBTREES: [RpoDigest; 256] = [ diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 8954d4d..a562aa5 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -22,8 +22,8 @@ pub use path::{MerklePath, RootPath, ValuePath}; mod smt; pub use smt::{ - LeafIndex, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH, - SMT_MAX_DEPTH, SMT_MIN_DEPTH, + LeafIndex, MutationSet, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, + SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH, }; mod mmr; diff --git a/src/merkle/smt/full/leaf.rs b/src/merkle/smt/full/leaf.rs index 095a4fb..585fc40 100644 --- a/src/merkle/smt/full/leaf.rs +++ b/src/merkle/smt/full/leaf.rs @@ -350,7 +350,7 @@ impl Deserializable for SmtLeaf { // ================================================================================================ /// Converts a key-value tuple to an iterator of `Felt`s -fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator { +pub(crate) fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator { let key_elements = key.into_iter(); let value_elements = value.into_iter(); @@ -359,7 +359,7 @@ fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator /// Compares two keys, compared element-by-element using their integer representations starting with /// the most significant element. -fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering { +pub(crate) fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering { for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() { let v1 = v1.as_int(); let v2 = v2.as_int(); diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 5cd510b..9c64002 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -6,7 +6,7 @@ use alloc::{ use super::{ EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath, - NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, + MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, }; mod error; @@ -121,12 +121,7 @@ impl Smt { /// Returns the value associated with `key` pub fn get_value(&self, key: &RpoDigest) -> Word { - let leaf_pos = LeafIndex::::from(*key).value(); - - match self.leaves.get(&leaf_pos) { - Some(leaf) => leaf.get_value(key).unwrap_or_default(), - None => EMPTY_WORD, - } + >::get_value(self, key) } /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle @@ -172,6 +167,47 @@ impl Smt { >::insert(self, key, value) } + /// Computes what changes are necessary to insert the specified key-value pairs into this Merkle + /// tree, allowing for validation before applying those changes. + /// + /// This method returns a [`MutationSet`], which contains all the information for inserting + /// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can + /// be queried with [`MutationSet::root()`]. Once a mutation set is returned, + /// [`Smt::apply_mutations()`] can be called in order to commit these changes to the Merkle + /// tree, or [`drop()`] to discard them. + /// + /// # Example + /// ``` + /// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word}; + /// # use miden_crypto::merkle::{Smt, EmptySubtreeRoots, SMT_DEPTH}; + /// let mut smt = Smt::new(); + /// let pair = (RpoDigest::default(), Word::default()); + /// let mutations = smt.compute_mutations(vec![pair]); + /// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0)); + /// smt.apply_mutations(mutations); + /// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0)); + /// ``` + pub fn compute_mutations( + &self, + kv_pairs: impl IntoIterator, + ) -> MutationSet { + >::compute_mutations(self, kv_pairs) + } + + /// Apply the prospective mutations computed with [`Smt::compute_mutations()`] to this tree. + /// + /// # Errors + /// If `mutations` was computed on a tree with a different root than this one, returns + /// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash + /// the `mutations` were computed against, and the second item is the actual current root of + /// this tree. + pub fn apply_mutations( + &mut self, + mutations: MutationSet, + ) -> Result<(), MerkleError> { + >::apply_mutations(self, mutations) + } + // HELPERS // -------------------------------------------------------------------------------------------- @@ -226,11 +262,10 @@ impl SparseMerkleTree for Smt { } fn get_inner_node(&self, index: NodeIndex) -> InnerNode { - self.inner_nodes.get(&index).cloned().unwrap_or_else(|| { - let node = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth() + 1); - - InnerNode { left: *node, right: *node } - }) + self.inner_nodes + .get(&index) + .cloned() + .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth())) } fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) { @@ -250,6 +285,15 @@ impl SparseMerkleTree for Smt { } } + fn get_value(&self, key: &Self::Key) -> Self::Value { + let leaf_pos = LeafIndex::::from(*key).value(); + + match self.leaves.get(&leaf_pos) { + Some(leaf) => leaf.get_value(key).unwrap_or_default(), + None => EMPTY_WORD, + } + } + fn get_leaf(&self, key: &RpoDigest) -> Self::Leaf { let leaf_pos = LeafIndex::::from(*key).value(); @@ -263,6 +307,28 @@ impl SparseMerkleTree for Smt { leaf.hash() } + fn construct_prospective_leaf( + &self, + mut existing_leaf: SmtLeaf, + key: &RpoDigest, + value: &Word, + ) -> SmtLeaf { + debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key)); + + match existing_leaf { + SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value), + _ => { + if *value != EMPTY_WORD { + existing_leaf.insert(*key, *value); + } else { + existing_leaf.remove(*key); + } + + existing_leaf + }, + } + } + fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex { let most_significant_felt = key[3]; LeafIndex::new_max_depth(most_significant_felt.as_int()) diff --git a/src/merkle/smt/full/tests.rs b/src/merkle/smt/full/tests.rs index 66cd203..1613c8f 100644 --- a/src/merkle/smt/full/tests.rs +++ b/src/merkle/smt/full/tests.rs @@ -2,7 +2,7 @@ use alloc::vec::Vec; use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH}; use crate::{ - merkle::{EmptySubtreeRoots, MerkleStore}, + merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleStore}, utils::{Deserializable, Serializable}, Word, ONE, WORD_SIZE, }; @@ -258,6 +258,195 @@ fn test_smt_removal() { } } +/// This tests that we can correctly calculate prospective leaves -- that is, we can construct +/// correct [`SmtLeaf`] values for a theoretical insertion on a Merkle tree without mutating or +/// cloning the tree. +#[test] +fn test_prospective_hash() { + let mut smt = Smt::default(); + + let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; + + let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]); + let key_2: RpoDigest = + RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]); + // Sort key_3 before key_1, to test non-append insertion. + let key_3: RpoDigest = + RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]); + + let value_1 = [ONE; WORD_SIZE]; + let value_2 = [2_u32.into(); WORD_SIZE]; + let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE]; + + // insert key-value 1 + { + let prospective = + smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &value_1).hash(); + smt.insert(key_1, value_1); + + let leaf = smt.get_leaf(&key_1); + assert_eq!( + prospective, + leaf.hash(), + "prospective hash for leaf {leaf:?} did not match actual hash", + ); + } + + // insert key-value 2 + { + let prospective = + smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &value_2).hash(); + smt.insert(key_2, value_2); + + let leaf = smt.get_leaf(&key_2); + assert_eq!( + prospective, + leaf.hash(), + "prospective hash for leaf {leaf:?} did not match actual hash", + ); + } + + // insert key-value 3 + { + let prospective = + smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &value_3).hash(); + smt.insert(key_3, value_3); + + let leaf = smt.get_leaf(&key_3); + assert_eq!( + prospective, + leaf.hash(), + "prospective hash for leaf {leaf:?} did not match actual hash", + ); + } + + // remove key 3 + { + let old_leaf = smt.get_leaf(&key_3); + let old_value_3 = smt.insert(key_3, EMPTY_WORD); + assert_eq!(old_value_3, value_3); + let prospective_leaf = + smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &old_value_3); + + assert_eq!( + old_leaf.hash(), + prospective_leaf.hash(), + "removing and prospectively re-adding a leaf didn't yield the original leaf:\ + \n original leaf: {old_leaf:?}\ + \n prospective leaf: {prospective_leaf:?}", + ); + } + + // remove key 2 + { + let old_leaf = smt.get_leaf(&key_2); + let old_value_2 = smt.insert(key_2, EMPTY_WORD); + assert_eq!(old_value_2, value_2); + let prospective_leaf = + smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &old_value_2); + + assert_eq!( + old_leaf.hash(), + prospective_leaf.hash(), + "removing and prospectively re-adding a leaf didn't yield the original leaf:\ + \n original leaf: {old_leaf:?}\ + \n prospective leaf: {prospective_leaf:?}", + ); + } + + // remove key 1 + { + let old_leaf = smt.get_leaf(&key_1); + let old_value_1 = smt.insert(key_1, EMPTY_WORD); + assert_eq!(old_value_1, value_1); + let prospective_leaf = + smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &old_value_1); + assert_eq!( + old_leaf.hash(), + prospective_leaf.hash(), + "removing and prospectively re-adding a leaf didn't yield the original leaf:\ + \n original leaf: {old_leaf:?}\ + \n prospective leaf: {prospective_leaf:?}", + ); + } +} + +/// This tests that we can perform prospective changes correctly. +#[test] +fn test_prospective_insertion() { + let mut smt = Smt::default(); + + let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; + + let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]); + let key_2: RpoDigest = + RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]); + // Sort key_3 before key_1, to test non-append insertion. + let key_3: RpoDigest = + RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]); + + let value_1 = [ONE; WORD_SIZE]; + let value_2 = [2_u32.into(); WORD_SIZE]; + let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE]; + + let root_empty = smt.root(); + + let root_1 = { + smt.insert(key_1, value_1); + smt.root() + }; + + let root_2 = { + smt.insert(key_2, value_2); + smt.root() + }; + + let root_3 = { + smt.insert(key_3, value_3); + smt.root() + }; + + // Test incremental updates. + + let mut smt = Smt::default(); + + let mutations = smt.compute_mutations(vec![(key_1, value_1)]); + assert_eq!(mutations.root(), root_1, "prospective root 1 did not match actual root 1"); + smt.apply_mutations(mutations).unwrap(); + assert_eq!(smt.root(), root_1, "mutations before and after apply did not match"); + + let mutations = smt.compute_mutations(vec![(key_2, value_2)]); + assert_eq!(mutations.root(), root_2, "prospective root 2 did not match actual root 2"); + let mutations = + smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_2, value_2), (key_3, value_3)]); + assert_eq!(mutations.root(), root_3, "mutations before and after apply did not match"); + smt.apply_mutations(mutations).unwrap(); + + // Edge case: multiple values at the same key, where a later pair restores the original value. + let mutations = smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_3, value_3)]); + assert_eq!(mutations.root(), root_3); + smt.apply_mutations(mutations).unwrap(); + assert_eq!(smt.root(), root_3); + + // Test batch updates, and that the order doesn't matter. + let pairs = + vec![(key_3, value_2), (key_2, EMPTY_WORD), (key_1, EMPTY_WORD), (key_3, EMPTY_WORD)]; + let mutations = smt.compute_mutations(pairs); + assert_eq!( + mutations.root(), + root_empty, + "prospective root for batch removal did not match actual root", + ); + smt.apply_mutations(mutations).unwrap(); + assert_eq!(smt.root(), root_empty, "mutations before and after apply did not match"); + + let pairs = vec![(key_3, value_3), (key_1, value_1), (key_2, value_2)]; + let mutations = smt.compute_mutations(pairs); + assert_eq!(mutations.root(), root_3); + smt.apply_mutations(mutations).unwrap(); + assert_eq!(smt.root(), root_3); +} + /// Tests that 2 key-value pairs stored in the same leaf have the same path #[test] fn test_smt_path_to_keys_in_same_leaf_are_equal() { diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index d7d42da..0b7ceb9 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -1,4 +1,4 @@ -use alloc::vec::Vec; +use alloc::{collections::BTreeMap, vec::Vec}; use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex}; use crate::{ @@ -45,11 +45,11 @@ pub const SMT_MAX_DEPTH: u8 = 64; /// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs. pub(crate) trait SparseMerkleTree { /// The type for a key - type Key: Clone; + type Key: Clone + Ord; /// The type for a value type Value: Clone + PartialEq; /// The type for a leaf - type Leaf; + type Leaf: Clone; /// The type for an opening (i.e. a "proof") of a leaf type Opening; @@ -140,6 +140,149 @@ pub(crate) trait SparseMerkleTree { self.set_root(node_hash); } + /// Computes what changes are necessary to insert the specified key-value pairs into this Merkle + /// tree, allowing for validation before applying those changes. + /// + /// This method returns a [`MutationSet`], which contains all the information for inserting + /// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can + /// be queried with [`MutationSet::root()`]. Once a mutation set is returned, + /// [`SparseMerkleTree::apply_mutations()`] can be called in order to commit these changes to + /// the Merkle tree, or [`drop()`] to discard them. + fn compute_mutations( + &self, + kv_pairs: impl IntoIterator, + ) -> MutationSet { + use NodeMutation::*; + + let mut new_root = self.root(); + let mut new_pairs: BTreeMap = Default::default(); + let mut node_mutations: BTreeMap = Default::default(); + + for (key, value) in kv_pairs { + // If the old value and the new value are the same, there is nothing to update. + // For the unusual case that kv_pairs has multiple values at the same key, we'll have + // to check the key-value pairs we've already seen to get the "effective" old value. + let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key)); + if value == old_value { + continue; + } + + let leaf_index = Self::key_to_leaf_index(&key); + let mut node_index = NodeIndex::from(leaf_index); + + // We need the current leaf's hash to calculate the new leaf, but in the rare case that + // `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also + // part of the "current leaf". + let old_leaf = { + let pairs_at_index = new_pairs + .iter() + .filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index); + + pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| { + // Most of the time `pairs_at_index` should only contain a single entry (or + // none at all), as multi-leaves should be really rare. + let existing_leaf = acc.clone(); + self.construct_prospective_leaf(existing_leaf, k, v) + }) + }; + + let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value); + + let mut new_child_hash = Self::hash_leaf(&new_leaf); + + for node_depth in (0..node_index.depth()).rev() { + // Whether the node we're replacing is the right child or the left child. + let is_right = node_index.is_value_odd(); + node_index.move_up(); + + let old_node = node_mutations + .get(&node_index) + .map(|mutation| match mutation { + Addition(node) => node.clone(), + Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth), + }) + .unwrap_or_else(|| self.get_inner_node(node_index)); + + let new_node = if is_right { + InnerNode { + left: old_node.left, + right: new_child_hash, + } + } else { + InnerNode { + left: new_child_hash, + right: old_node.right, + } + }; + + // The next iteration will operate on this new node's hash. + new_child_hash = new_node.hash(); + + let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth); + let is_removal = new_child_hash == equivalent_empty_hash; + let new_entry = if is_removal { Removal } else { Addition(new_node) }; + node_mutations.insert(node_index, new_entry); + } + + // Once we're at depth 0, the last node we made is the new root. + new_root = new_child_hash; + // And then we're done with this pair; on to the next one. + new_pairs.insert(key, value); + } + + MutationSet { + old_root: self.root(), + new_root, + node_mutations, + new_pairs, + } + } + + /// Apply the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to + /// this tree. + /// + /// # Errors + /// If `mutations` was computed on a tree with a different root than this one, returns + /// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash + /// the `mutations` were computed against, and the second item is the actual current root of + /// this tree. + fn apply_mutations( + &mut self, + mutations: MutationSet, + ) -> Result<(), MerkleError> + where + Self: Sized, + { + use NodeMutation::*; + let MutationSet { + old_root, + node_mutations, + new_pairs, + new_root, + } = mutations; + + // Guard against accidentally trying to apply mutations that were computed against a + // different tree, including a stale version of this tree. + if old_root != self.root() { + return Err(MerkleError::ConflictingRoots(vec![old_root, self.root()])); + } + + for (index, mutation) in node_mutations { + match mutation { + Removal => self.remove_inner_node(index), + Addition(node) => self.insert_inner_node(index, node), + } + } + + for (key, value) in new_pairs { + self.insert_value(key, value); + } + + self.set_root(new_root); + + Ok(()) + } + // REQUIRED METHODS // --------------------------------------------------------------------------------------------- @@ -161,12 +304,34 @@ pub(crate) trait SparseMerkleTree { /// Inserts a leaf node, and returns the value at the key if already exists fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option; + /// Returns the value at the specified key. Recall that by definition, any key that hasn't been + /// updated is associated with [`Self::EMPTY_VALUE`]. + fn get_value(&self, key: &Self::Key) -> Self::Value; + /// Returns the leaf at the specified index. fn get_leaf(&self, key: &Self::Key) -> Self::Leaf; /// Returns the hash of a leaf fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest; + /// Returns what a leaf would look like if a key-value pair were inserted into the tree, without + /// mutating the tree itself. The existing leaf can be empty. + /// + /// To get a prospective leaf based on the current state of the tree, use `self.get_leaf(key)` + /// as the argument for `existing_leaf`. The return value from this function can be chained back + /// into this function as the first argument to continue making prospective changes. + /// + /// # Invariants + /// Because this method is for a prospective key-value insertion into a specific leaf, + /// `existing_leaf` must have the same leaf index as `key` (as determined by + /// [`SparseMerkleTree::key_to_leaf_index()`]), or the result will be meaningless. + fn construct_prospective_leaf( + &self, + existing_leaf: Self::Leaf, + key: &Self::Key, + value: &Self::Value, + ) -> Self::Leaf; + /// Maps a key to a leaf index fn key_to_leaf_index(key: &Self::Key) -> LeafIndex; @@ -244,3 +409,50 @@ impl TryFrom for LeafIndex { Self::new(node_index.value()) } } + +// MUTATIONS +// ================================================================================================ + +/// A change to an inner node of a [`SparseMerkleTree`] that hasn't yet been applied. +/// [`MutationSet`] stores this type in relation to a [`NodeIndex`] to keep track of what changes +/// need to occur at which node indices. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum NodeMutation { + /// Corresponds to [`SparseMerkleTree::remove_inner_node()`]. + Removal, + /// Corresponds to [`SparseMerkleTree::insert_inner_node()`]. + Addition(InnerNode), +} + +/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by +/// `SparseMerkleTree::compute_mutations()`, and that can be applied with +/// `SparseMerkleTree::apply_mutations()`. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct MutationSet { + /// The root of the Merkle tree this MutationSet is for, recorded at the time + /// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying + /// mutations to the wrong tree or applying stale mutations to a tree that has since changed. + old_root: RpoDigest, + /// The set of nodes that need to be removed or added. The "effective" node at an index is the + /// Merkle tree's existing node at that index, with the [`NodeMutation`] in this map at that + /// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a + /// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`] + /// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call. + node_mutations: BTreeMap, + /// The set of top-level key-value pairs we're prospectively adding to the tree, including + /// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling + /// back to the existing value in the Merkle tree. Each entry corresponds to a + /// [`SparseMerkleTree::insert_value()`] call. + new_pairs: BTreeMap, + /// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with + /// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call. + new_root: RpoDigest, +} + +impl MutationSet { + /// Queries the root that was calculated during `SparseMerkleTree::compute_mutations()`. See + /// that method for more information. + pub fn root(&self) -> RpoDigest { + self.new_root + } +} diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index f1ff0dc..1744430 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -2,8 +2,8 @@ use alloc::collections::{BTreeMap, BTreeSet}; use super::{ super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, - MerklePath, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, SMT_MAX_DEPTH, - SMT_MIN_DEPTH, + MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, + SMT_MAX_DEPTH, SMT_MIN_DEPTH, }; #[cfg(test)] @@ -188,6 +188,48 @@ impl SimpleSmt { >::insert(self, key, value) } + /// Computes what changes are necessary to insert the specified key-value pairs into this + /// Merkle tree, allowing for validation before applying those changes. + /// + /// This method returns a [`MutationSet`], which contains all the information for inserting + /// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can + /// be queried with [`MutationSet::root()`]. Once a mutation set is returned, + /// [`SimpleSmt::apply_mutations()`] can be called in order to commit these changes to the + /// Merkle tree, or [`drop()`] to discard them. + + /// # Example + /// ``` + /// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word}; + /// # use miden_crypto::merkle::{LeafIndex, SimpleSmt, EmptySubtreeRoots, SMT_DEPTH}; + /// let mut smt: SimpleSmt<3> = SimpleSmt::new().unwrap(); + /// let pair = (LeafIndex::default(), Word::default()); + /// let mutations = smt.compute_mutations(vec![pair]); + /// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(3, 0)); + /// smt.apply_mutations(mutations); + /// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(3, 0)); + /// ``` + pub fn compute_mutations( + &self, + kv_pairs: impl IntoIterator, Word)>, + ) -> MutationSet, Word> { + >::compute_mutations(self, kv_pairs) + } + + /// Apply the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this + /// tree. + /// + /// # Errors + /// If `mutations` was computed on a tree with a different root than this one, returns + /// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the + /// root hash the `mutations` were computed against, and the second item is the actual + /// current root of this tree. + pub fn apply_mutations( + &mut self, + mutations: MutationSet, Word>, + ) -> Result<(), MerkleError> { + >::apply_mutations(self, mutations) + } + /// Inserts a subtree at the specified index. The depth at which the subtree is inserted is /// computed as `DEPTH - SUBTREE_DEPTH`. /// @@ -266,11 +308,10 @@ impl SparseMerkleTree for SimpleSmt { } fn get_inner_node(&self, index: NodeIndex) -> InnerNode { - self.inner_nodes.get(&index).cloned().unwrap_or_else(|| { - let node = EmptySubtreeRoots::entry(DEPTH, index.depth() + 1); - - InnerNode { left: *node, right: *node } - }) + self.inner_nodes + .get(&index) + .cloned() + .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(DEPTH, index.depth())) } fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) { @@ -289,6 +330,10 @@ impl SparseMerkleTree for SimpleSmt { } } + fn get_value(&self, key: &LeafIndex) -> Word { + self.get_leaf(key) + } + fn get_leaf(&self, key: &LeafIndex) -> Word { let leaf_pos = key.value(); match self.leaves.get(&leaf_pos) { @@ -302,6 +347,15 @@ impl SparseMerkleTree for SimpleSmt { leaf.into() } + fn construct_prospective_leaf( + &self, + _existing_leaf: Word, + _key: &LeafIndex, + value: &Word, + ) -> Word { + *value + } + fn key_to_leaf_index(key: &LeafIndex) -> LeafIndex { *key }