diff --git a/src/merkle/delta.rs b/src/merkle/delta.rs new file mode 100644 index 0000000..71b822a --- /dev/null +++ b/src/merkle/delta.rs @@ -0,0 +1,153 @@ +use super::{ + BTreeMap, KvMap, MerkleError, MerkleStore, NodeIndex, RpoDigest, StoreNode, Vec, Word, +}; +use crate::utils::collections::Diff; + +#[cfg(test)] +use super::{empty_roots::EMPTY_WORD, Felt, SimpleSmt}; + +// MERKLE STORE DELTA +// ================================================================================================ + +/// [MerkleStoreDelta] stores a vector of ([RpoDigest], [MerkleTreeDelta]) tuples where the +/// [RpoDigest] represents the root of the Merkle tree and [MerkleTreeDelta] represents the +/// differences between the initial and final Merkle tree states. +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>); + +// MERKLE TREE DELTA +// ================================================================================================ + +/// [MerkleDelta] stores the differences between the initial and final Merkle tree states. +/// +/// The differences are represented as follows: +/// - depth: the depth of the merkle tree. +/// - cleared_slots: indexes of slots where values were set to [ZERO; 4]. +/// - updated_slots: index-value pairs of slots where values were set to non [ZERO; 4] values. +#[cfg(not(test))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MerkleTreeDelta { + depth: u8, + cleared_slots: Vec, + updated_slots: Vec<(u64, Word)>, +} + +impl MerkleTreeDelta { + // CONSTRUCTOR + // -------------------------------------------------------------------------------------------- + pub fn new(depth: u8) -> Self { + Self { + depth, + cleared_slots: Vec::new(), + updated_slots: Vec::new(), + } + } + + // ACCESSORS + // -------------------------------------------------------------------------------------------- + /// Returns the depth of the Merkle tree the [MerkleDelta] is associated with. + pub fn depth(&self) -> u8 { + self.depth + } + + /// Returns the indexes of slots where values were set to [ZERO; 4]. + pub fn cleared_slots(&self) -> &[u64] { + &self.cleared_slots + } + + /// Returns the index-value pairs of slots where values were set to non [ZERO; 4] values. + pub fn updated_slots(&self) -> &[(u64, Word)] { + &self.updated_slots + } + + // MODIFIERS + // -------------------------------------------------------------------------------------------- + /// Adds a slot index to the list of cleared slots. + pub fn add_cleared_slot(&mut self, index: u64) { + self.cleared_slots.push(index); + } + + /// Adds a slot index and a value to the list of updated slots. + pub fn add_updated_slot(&mut self, index: u64, value: Word) { + self.updated_slots.push((index, value)); + } +} + +/// Extracts a [MerkleDelta] object by comparing the leaves of two Merkle trees specifies by +/// their roots and depth. +pub fn merkle_tree_delta>( + tree_root_1: RpoDigest, + tree_root_2: RpoDigest, + depth: u8, + merkle_store: &MerkleStore, +) -> Result { + if tree_root_1 == tree_root_2 { + return Ok(MerkleTreeDelta::new(depth)); + } + + let tree_1_leaves: BTreeMap = + merkle_store.non_empty_leaves(tree_root_1, depth).collect(); + let tree_2_leaves: BTreeMap = + merkle_store.non_empty_leaves(tree_root_2, depth).collect(); + let diff = tree_1_leaves.diff(&tree_2_leaves); + + // TODO: Refactor this diff implementation to prevent allocation of both BTree and Vec. + Ok(MerkleTreeDelta { + depth, + cleared_slots: diff.removed.into_iter().map(|index| index.value()).collect(), + updated_slots: diff + .updated + .into_iter() + .map(|(index, leaf)| (index.value(), *leaf)) + .collect(), + }) +} + +// INTERNALS +// -------------------------------------------------------------------------------------------- +#[cfg(test)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MerkleTreeDelta { + pub depth: u8, + pub cleared_slots: Vec, + pub updated_slots: Vec<(u64, Word)>, +} + +// MERKLE DELTA +// ================================================================================================ +#[test] +fn test_compute_merkle_delta() { + let entries = vec![ + (10, [Felt::new(0), Felt::new(1), Felt::new(2), Felt::new(3)]), + (15, [Felt::new(4), Felt::new(5), Felt::new(6), Felt::new(7)]), + (20, [Felt::new(8), Felt::new(9), Felt::new(10), Felt::new(11)]), + (31, [Felt::new(12), Felt::new(13), Felt::new(14), Felt::new(15)]), + ]; + let simple_smt = SimpleSmt::with_leaves(30, entries.clone()).unwrap(); + let mut store: MerkleStore = (&simple_smt).into(); + let root = simple_smt.root(); + + // add a new node + let new_value = [Felt::new(16), Felt::new(17), Felt::new(18), Felt::new(19)]; + let new_index = NodeIndex::new(simple_smt.depth(), 32).unwrap(); + let root = store.set_node(root, new_index, new_value.into()).unwrap().root; + + // update an existing node + let update_value = [Felt::new(20), Felt::new(21), Felt::new(22), Felt::new(23)]; + let update_idx = NodeIndex::new(simple_smt.depth(), entries[0].0).unwrap(); + let root = store.set_node(root, update_idx, update_value.into()).unwrap().root; + + // remove a node + let remove_idx = NodeIndex::new(simple_smt.depth(), entries[1].0).unwrap(); + let root = store.set_node(root, remove_idx, EMPTY_WORD.into()).unwrap().root; + + let merkle_delta = + merkle_tree_delta(simple_smt.root(), root, simple_smt.depth(), &store).unwrap(); + let expected_merkle_delta = MerkleTreeDelta { + depth: simple_smt.depth(), + cleared_slots: vec![remove_idx.value()], + updated_slots: vec![(update_idx.value(), update_value), (new_index.value(), new_value)], + }; + + assert_eq!(merkle_delta, expected_merkle_delta); +} diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 3e1c9d9..c49c004 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -1,6 +1,6 @@ use super::{ hash::rpo::{Rpo256, RpoDigest}, - utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, Vec}, + utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, TryApplyDiff, Vec}, Felt, StarkField, Word, WORD_SIZE, ZERO, }; use core::fmt; @@ -11,6 +11,9 @@ use core::fmt; mod empty_roots; pub use empty_roots::EmptySubtreeRoots; +mod delta; +pub use delta::{merkle_tree_delta, MerkleStoreDelta, MerkleTreeDelta}; + mod index; pub use index::NodeIndex; diff --git a/src/merkle/simple_smt/mod.rs b/src/merkle/simple_smt/mod.rs index c8da302..542ab51 100644 --- a/src/merkle/simple_smt/mod.rs +++ b/src/merkle/simple_smt/mod.rs @@ -1,6 +1,6 @@ use super::{ - BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, - Rpo256, RpoDigest, Vec, Word, + BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, MerkleTreeDelta, + NodeIndex, Rpo256, RpoDigest, StoreNode, TryApplyDiff, Vec, Word, }; #[cfg(test)] @@ -275,3 +275,29 @@ impl BranchNode { Rpo256::merge(&[self.left, self.right]) } } + +// TRY APPLY DIFF +// ================================================================================================ +impl TryApplyDiff for SimpleSmt { + type Error = MerkleError; + type DiffType = MerkleTreeDelta; + + fn try_apply(&mut self, diff: MerkleTreeDelta) -> Result<(), MerkleError> { + if diff.depth() != self.depth() { + return Err(MerkleError::InvalidDepth { + expected: self.depth(), + provided: diff.depth(), + }); + } + + for slot in diff.cleared_slots() { + self.update_leaf(*slot, Self::EMPTY_VALUE)?; + } + + for (slot, value) in diff.updated_slots() { + self.update_leaf(*slot, *value)?; + } + + Ok(()) + } +} diff --git a/src/merkle/store/mod.rs b/src/merkle/store/mod.rs index f77d558..f250485 100644 --- a/src/merkle/store/mod.rs +++ b/src/merkle/store/mod.rs @@ -1,12 +1,9 @@ use super::{ - mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, KvMap, MerkleError, MerklePath, - MerklePathSet, MerkleTree, NodeIndex, RecordingMap, RootPath, Rpo256, RpoDigest, SimpleSmt, - TieredSmt, ValuePath, Vec, -}; -use crate::utils::{ - collections::{ApplyDiff, Diff, KvMapDiff}, - ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, + empty_roots::EMPTY_WORD, mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, KvMap, + MerkleError, MerklePath, MerklePathSet, MerkleStoreDelta, MerkleTree, NodeIndex, RecordingMap, + RootPath, Rpo256, RpoDigest, SimpleSmt, TieredSmt, TryApplyDiff, ValuePath, Vec, }; +use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use core::borrow::Borrow; #[cfg(test)] @@ -280,6 +277,37 @@ impl> MerkleStore { }) } + /// Iterator over the non-empty leaves of the Merkle tree associated with the specified `root` + /// and `max_depth`. + pub fn non_empty_leaves( + &self, + root: RpoDigest, + max_depth: u8, + ) -> impl Iterator + '_ { + let empty_roots = EmptySubtreeRoots::empty_hashes(max_depth); + let mut stack = Vec::new(); + stack.push((NodeIndex::new_unchecked(0, 0), root)); + + core::iter::from_fn(move || { + while let Some((index, node_hash)) = stack.pop() { + if index.depth() == max_depth { + return Some((index, node_hash)); + } + + if let Some(node) = self.nodes.get(&node_hash) { + if !empty_roots.contains(&node.left) { + stack.push((index.left_child(), node.left)); + } + if !empty_roots.contains(&node.right) { + stack.push((index.right_child(), node.right)); + } + } + } + + None + }) + } + // STATE MUTATORS // -------------------------------------------------------------------------------------------- @@ -462,7 +490,6 @@ impl> FromIterator<(RpoDigest, StoreNode)> for Me // ITERATORS // ================================================================================================ - impl> Extend for MerkleStore { fn extend>(&mut self, iter: I) { self.nodes.extend(iter.into_iter().map(|info| { @@ -479,19 +506,34 @@ impl> Extend for MerkleStore { // DiffT & ApplyDiffT TRAIT IMPLEMENTATION // ================================================================================================ -impl> Diff for MerkleStore { - type DiffType = KvMapDiff; - - fn diff(&self, other: &Self) -> Self::DiffType { - self.nodes.diff(&other.nodes) - } -} - -impl> ApplyDiff for MerkleStore { - type DiffType = KvMapDiff; +impl> TryApplyDiff for MerkleStore { + type Error = MerkleError; + type DiffType = MerkleStoreDelta; + + fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), MerkleError> { + for (root, delta) in diff.0 { + let mut root = root; + for cleared_slot in delta.cleared_slots() { + root = self + .set_node( + root, + NodeIndex::new(delta.depth(), *cleared_slot)?, + EMPTY_WORD.into(), + )? + .root; + } + for (updated_slot, updated_value) in delta.updated_slots() { + root = self + .set_node( + root, + NodeIndex::new(delta.depth(), *updated_slot)?, + (*updated_value).into(), + )? + .root; + } + } - fn apply(&mut self, diff: Self::DiffType) { - self.nodes.apply(diff); + Ok(()) } } diff --git a/src/merkle/store/tests.rs b/src/merkle/store/tests.rs index c6f346f..5e5bce7 100644 --- a/src/merkle/store/tests.rs +++ b/src/merkle/store/tests.rs @@ -847,7 +847,7 @@ fn test_recorder() { // construct the proof let rec_map = recorder.into_inner(); - let proof = rec_map.into_proof(); + let (_, proof) = rec_map.finalize(); let merkle_store: MerkleStore = proof.into(); // make sure the proof contains all nodes from both trees diff --git a/src/utils/diff.rs b/src/utils/diff.rs index 48c80b6..97fc32f 100644 --- a/src/utils/diff.rs +++ b/src/utils/diff.rs @@ -1,16 +1,31 @@ /// A trait for computing the difference between two objects. pub trait Diff { + /// The type that describes the difference between two objects. type DiffType; - /// Returns a `Self::DiffType` object that represents the difference between this object and + /// Returns a [Self::DiffType] object that represents the difference between this object and /// other. fn diff(&self, other: &Self) -> Self::DiffType; } /// A trait for applying the difference between two objects. pub trait ApplyDiff { + /// The type that describes the difference between two objects. type DiffType; - /// Applies the provided changes described by [DiffType] to the object implementing this trait. + /// Applies the provided changes described by [Self::DiffType] to the object implementing this trait. fn apply(&mut self, diff: Self::DiffType); } + +/// A trait for applying the difference between two objects with the possibility of failure. +pub trait TryApplyDiff { + /// The type that describes the difference between two objects. + type DiffType; + + /// An error type that can be returned if the changes cannot be applied. + type Error; + + /// Applies the provided changes described by [Self::DiffType] to the object implementing this trait. + /// Returns an error if the changes cannot be applied. + fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), Self::Error>; +} diff --git a/src/utils/kv_map.rs b/src/utils/kv_map.rs index 063a0a0..3c92b56 100644 --- a/src/utils/kv_map.rs +++ b/src/utils/kv_map.rs @@ -97,10 +97,12 @@ impl RecordingMap { // FINALIZER // -------------------------------------------------------------------------------------------- - /// Consumes the [RecordingMap] and returns a [BTreeMap] containing the key-value pairs from - /// the initial data set that were read during recording. - pub fn into_proof(self) -> BTreeMap { - self.trace.take() + /// Consumes the [RecordingMap] and returns a ([BTreeMap], [BTreeMap]) tuple. The first + /// element of the tuple is a map that represents the state of the map at the time `.finalize()` + /// is called. The second element contains the key-value pairs from the initial data set that + /// were read during recording. + pub fn finalize(self) -> (BTreeMap, BTreeMap) { + (self.data, self.trace.take()) } // TEST HELPERS @@ -217,8 +219,8 @@ impl IntoIterator for RecordingMap { /// - `removed` - a set of keys that were removed from the second map compared to the first map. #[derive(Debug, Clone)] pub struct KvMapDiff { - updated: BTreeMap, - removed: BTreeSet, + pub updated: BTreeMap, + pub removed: BTreeSet, } impl KvMapDiff { @@ -296,7 +298,7 @@ mod tests { } // convert the map into a proof - let proof = map.into_proof(); + let (_, proof) = map.finalize(); // check that the proof contains the expected values for (key, value) in ITEMS.iter() { @@ -319,7 +321,7 @@ mod tests { } // convert the map into a proof - let proof = map.into_proof(); + let (_, proof) = map.finalize(); // check that the proof contains the expected values for (key, _) in ITEMS.iter() { @@ -383,7 +385,7 @@ mod tests { // Note: The length reported by the proof will be different to the length originally // reported by the map. - let proof = map.into_proof(); + let (_, proof) = map.finalize(); // length of the proof should be equal to get_items + 1. The extra item is the original // value at key = 4u64 @@ -458,7 +460,7 @@ mod tests { assert_eq!(map.updates_len(), 2); // convert the map into a proof - let proof = map.into_proof(); + let (_, proof) = map.finalize(); // check that the proof contains the expected values for (key, value) in ITEMS.iter() {