From 08aec4443ccfb2dfab4dc193046ccec6551de687 Mon Sep 17 00:00:00 2001 From: Andrey Khmuro Date: Thu, 6 Jul 2023 00:19:03 +0300 Subject: [PATCH] Enhancement of the Partial Merkle Tree (#163) feat: implement additional functionality for the PartialMerkleTree --- src/merkle/index.rs | 17 ++++ src/merkle/partial_mt/mod.rs | 176 +++++++++++++++++++++++++++++---- src/merkle/partial_mt/tests.rs | 172 +++++++++++++++++++++++++++++--- src/merkle/store/mod.rs | 8 +- src/utils/mod.rs | 4 +- 5 files changed, 342 insertions(+), 35 deletions(-) diff --git a/src/merkle/index.rs b/src/merkle/index.rs index f17216f..3a79ac0 100644 --- a/src/merkle/index.rs +++ b/src/merkle/index.rs @@ -1,4 +1,5 @@ use super::{Felt, MerkleError, RpoDigest, StarkField}; +use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use core::fmt::Display; // NODE INDEX @@ -161,6 +162,22 @@ impl Display for NodeIndex { } } +impl Serializable for NodeIndex { + fn write_into(&self, target: &mut W) { + target.write_u8(self.depth); + target.write_u64(self.value); + } +} + +impl Deserializable for NodeIndex { + fn read_from(source: &mut R) -> Result { + let depth = source.read_u8()?; + let value = source.read_u64()?; + NodeIndex::new(depth, value) + .map_err(|_| DeserializationError::InvalidValue("Invalid index".into())) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/merkle/partial_mt/mod.rs b/src/merkle/partial_mt/mod.rs index 3558c9f..ef87516 100644 --- a/src/merkle/partial_mt/mod.rs +++ b/src/merkle/partial_mt/mod.rs @@ -1,7 +1,11 @@ use super::{ - BTreeMap, BTreeSet, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, ValuePath, Vec, ZERO, + BTreeMap, BTreeSet, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, + ValuePath, Vec, Word, ZERO, +}; +use crate::utils::{ + format, string::String, vec, word_to_hex, ByteReader, ByteWriter, Deserializable, + DeserializationError, Serializable, }; -use crate::utils::{format, string::String, word_to_hex}; use core::fmt; #[cfg(test)] @@ -74,6 +78,92 @@ impl PartialMerkleTree { }) } + /// Returns a new [PartialMerkleTree] instantiated with leaves map as specified by the provided + /// entries. + /// + /// # Errors + /// Returns an error if: + /// - If the depth is 0 or is greater than 64. + /// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}. + /// - The provided entries contain an insufficient set of nodes. + pub fn with_leaves(entries: R) -> Result + where + R: IntoIterator, + I: Iterator + ExactSizeIterator, + { + let mut layers: BTreeMap> = BTreeMap::new(); + let mut leaves = BTreeSet::new(); + let mut nodes = BTreeMap::new(); + + // add data to the leaves and nodes maps and also fill layers map, where the key is the + // depth of the node and value is its index. + for (node_index, hash) in entries.into_iter() { + leaves.insert(node_index); + nodes.insert(node_index, hash); + layers + .entry(node_index.depth()) + .and_modify(|layer_vec| layer_vec.push(node_index.value())) + .or_insert(vec![node_index.value()]); + } + + // check if the number of leaves can be accommodated by the tree's depth; we use a min + // depth of 63 because we consider passing in a vector of size 2^64 infeasible. + let max = (1_u64 << 63) as usize; + if layers.len() > max { + return Err(MerkleError::InvalidNumEntries(max, layers.len())); + } + + // Get maximum depth + let max_depth = *layers.keys().next_back().unwrap_or(&0); + + // fill layers without nodes with empty vector + for depth in 0..max_depth { + layers.entry(depth).or_insert(vec![]); + } + + let mut layer_iter = layers.into_values().rev(); + let mut parent_layer = layer_iter.next().unwrap(); + let mut current_layer; + + for depth in (1..max_depth + 1).rev() { + // set current_layer = parent_layer and parent_layer = layer_iter.next() + current_layer = layer_iter.next().unwrap(); + core::mem::swap(&mut current_layer, &mut parent_layer); + + for index_value in current_layer { + // get the parent node index + let parent_node = NodeIndex::new(depth - 1, index_value / 2)?; + + // Check if the parent hash was already calculated. In about half of the cases, we + // don't need to do anything. + if !parent_layer.contains(&parent_node.value()) { + // create current node index + let index = NodeIndex::new(depth, index_value)?; + + // get hash of the current node + let node = nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index))?; + // get hash of the sibling node + let sibling = nodes + .get(&index.sibling()) + .ok_or(MerkleError::NodeNotInSet(index.sibling()))?; + // get parent hash + let parent = Rpo256::merge(&index.build_node(*node, *sibling)); + + // add index value of the calculated node to the parents layer + parent_layer.push(parent_node.value()); + // add index and hash to the nodes map + nodes.insert(parent_node, parent); + } + } + } + + Ok(PartialMerkleTree { + max_depth, + nodes, + leaves, + }) + } + // PUBLIC ACCESSORS // -------------------------------------------------------------------------------------------- @@ -101,7 +191,7 @@ impl PartialMerkleTree { } /// Returns a vector of paths from every leaf to the root. - pub fn paths(&self) -> Vec<(NodeIndex, ValuePath)> { + pub fn to_paths(&self) -> Vec<(NodeIndex, ValuePath)> { let mut paths = Vec::new(); self.leaves.iter().for_each(|&leaf| { paths.push(( @@ -160,6 +250,22 @@ impl PartialMerkleTree { }) } + /// Returns an iterator over the inner nodes of this Merkle tree. + pub fn inner_nodes(&self) -> impl Iterator + '_ { + let inner_nodes = self.nodes.iter().filter(|(index, _)| !self.leaves.contains(index)); + inner_nodes.map(|(index, digest)| { + let left_hash = + self.nodes.get(&index.left_child()).expect("Failed to get left child hash"); + let right_hash = + self.nodes.get(&index.right_child()).expect("Failed to get right child hash"); + InnerNodeInfo { + value: *digest, + left: *left_hash, + right: *right_hash, + } + }) + } + // STATE MUTATORS // -------------------------------------------------------------------------------------------- @@ -235,37 +341,37 @@ impl PartialMerkleTree { /// Updates value of the leaf at the specified index returning the old leaf value. /// + /// By default the specified index is assumed to belong to the deepest layer. If the considered + /// node does not belong to the tree, the first node on the way to the root will be changed. + /// /// This also recomputes all hashes between the leaf and the root, updating the root itself. /// /// # Errors /// Returns an error if: - /// - The depth of the specified node_index is greater than 64 or smaller than 1. - /// - The specified node index is not corresponding to the leaf. - pub fn update_leaf( - &mut self, - node_index: NodeIndex, - value: RpoDigest, - ) -> Result { - // check correctness of the depth and update it - Self::check_depth(node_index.depth())?; - self.update_depth(node_index.depth()); + /// - The specified index is greater than the maximum number of nodes on the deepest layer. + pub fn update_leaf(&mut self, index: u64, value: Word) -> Result { + let mut node_index = NodeIndex::new(self.max_depth(), index)?; - // insert NodeIndex to the leaves Set - self.leaves.insert(node_index); + // proceed to the leaf + for _ in 0..node_index.depth() { + if !self.leaves.contains(&node_index) { + node_index.move_up(); + } + } // add node value to the nodes Map let old_value = self .nodes - .insert(node_index, value) + .insert(node_index, value.into()) .ok_or(MerkleError::NodeNotInSet(node_index))?; // if the old value and new value are the same, there is nothing to update - if value == old_value { + if value == *old_value { return Ok(old_value); } let mut node_index = node_index; - let mut value = value; + let mut value = value.into(); for _ in 0..node_index.depth() { let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist"); value = Rpo256::merge(&node_index.build_node(value, *sibling)); @@ -327,3 +433,37 @@ impl PartialMerkleTree { Ok(()) } } + +// SERIALIZATION +// ================================================================================================ + +impl Serializable for PartialMerkleTree { + fn write_into(&self, target: &mut W) { + // write leaf nodes + target.write_u64(self.leaves.len() as u64); + for leaf_index in self.leaves.iter() { + leaf_index.write_into(target); + self.get_node(*leaf_index).expect("Leaf hash not found").write_into(target); + } + } +} + +impl Deserializable for PartialMerkleTree { + fn read_from(source: &mut R) -> Result { + let leaves_len = source.read_u64()? as usize; + let mut leaf_nodes = Vec::with_capacity(leaves_len); + + // add leaf nodes to the vector + for _ in 0..leaves_len { + let index = NodeIndex::read_from(source)?; + let hash = RpoDigest::read_from(source)?; + leaf_nodes.push((index, hash)); + } + + let pmt = PartialMerkleTree::with_leaves(leaf_nodes).map_err(|_| { + DeserializationError::InvalidValue("Invalid data for PartialMerkleTree creation".into()) + })?; + + Ok(pmt) + } +} diff --git a/src/merkle/partial_mt/tests.rs b/src/merkle/partial_mt/tests.rs index ed5281f..4e580d2 100644 --- a/src/merkle/partial_mt/tests.rs +++ b/src/merkle/partial_mt/tests.rs @@ -1,9 +1,9 @@ use super::{ super::{ - digests_to_words, int_to_node, DefaultMerkleStore as MerkleStore, MerkleTree, NodeIndex, - PartialMerkleTree, + digests_to_words, int_to_node, BTreeMap, DefaultMerkleStore as MerkleStore, MerkleTree, + NodeIndex, PartialMerkleTree, }, - RpoDigest, ValuePath, Vec, + Deserializable, InnerNodeInfo, RpoDigest, Serializable, ValuePath, Vec, }; // TEST DATA @@ -13,6 +13,7 @@ const NODE10: NodeIndex = NodeIndex::new_unchecked(1, 0); const NODE11: NodeIndex = NodeIndex::new_unchecked(1, 1); const NODE20: NodeIndex = NodeIndex::new_unchecked(2, 0); +const NODE21: NodeIndex = NodeIndex::new_unchecked(2, 1); const NODE22: NodeIndex = NodeIndex::new_unchecked(2, 2); const NODE23: NodeIndex = NodeIndex::new_unchecked(2, 3); @@ -50,6 +51,43 @@ const VALUES8: [RpoDigest; 8] = [ // NodeIndex(3, 5) will be labeled as `35`. Leaves of the tree are shown as nodes with parenthesis // (33). +/// Checks that creation of the PMT with `with_leaves()` constructor is working correctly. +#[test] +fn with_leaves() { + let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap(); + let expected_root = mt.root(); + + let leaf_nodes_vec = vec![ + (NODE20, mt.get_node(NODE20).unwrap()), + (NODE32, mt.get_node(NODE32).unwrap()), + (NODE33, mt.get_node(NODE33).unwrap()), + (NODE22, mt.get_node(NODE22).unwrap()), + (NODE23, mt.get_node(NODE23).unwrap()), + ]; + + let leaf_nodes: BTreeMap = leaf_nodes_vec.into_iter().collect(); + + let pmt = PartialMerkleTree::with_leaves(leaf_nodes).unwrap(); + + assert_eq!(expected_root, pmt.root()) +} + +/// Checks that `with_leaves()` function returns an error when using incomplete set of nodes. +#[test] +fn err_with_leaves() { + // NODE22 is missing + let leaf_nodes_vec = vec![ + (NODE20, int_to_node(20)), + (NODE32, int_to_node(32)), + (NODE33, int_to_node(33)), + (NODE23, int_to_node(23)), + ]; + + let leaf_nodes: BTreeMap = leaf_nodes_vec.into_iter().collect(); + + assert!(PartialMerkleTree::with_leaves(leaf_nodes).is_err()); +} + /// Checks that root returned by `root()` function is equal to the expected one. #[test] fn get_root() { @@ -61,7 +99,7 @@ fn get_root() { let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap(); - assert_eq!(pmt.root(), expected_root); + assert_eq!(expected_root, pmt.root()); } /// This test checks correctness of the `add_path()` and `get_path()` functions. First it creates a @@ -121,7 +159,7 @@ fn update_leaf() { let new_value32 = int_to_node(132); let expected_root = ms.set_node(root, NODE32, new_value32).unwrap().root; - pmt.update_leaf(NODE32, new_value32).unwrap(); + pmt.update_leaf(2, *new_value32).unwrap(); let actual_root = pmt.root(); assert_eq!(expected_root, actual_root); @@ -129,7 +167,15 @@ fn update_leaf() { let new_value20 = int_to_node(120); let expected_root = ms.set_node(expected_root, NODE20, new_value20).unwrap().root; - pmt.update_leaf(NODE20, new_value20).unwrap(); + pmt.update_leaf(0, *new_value20).unwrap(); + let actual_root = pmt.root(); + + assert_eq!(expected_root, actual_root); + + let new_value11 = int_to_node(111); + let expected_root = ms.set_node(expected_root, NODE11, new_value11).unwrap().root; + + pmt.update_leaf(6, *new_value11).unwrap(); let actual_root = pmt.root(); assert_eq!(expected_root, actual_root); @@ -177,7 +223,7 @@ fn get_paths() { }) .collect(); - let actual_paths = pmt.paths(); + let actual_paths = pmt.to_paths(); assert_eq!(expected_paths, actual_paths); } @@ -247,6 +293,113 @@ fn leaves() { assert!(expected_leaves.eq(pmt.leaves())); } +/// Checks that nodes of the PMT returned by `inner_nodes()` function are equal to the expected ones. +#[test] +fn test_inner_node_iterator() { + let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap(); + let expected_root = mt.root(); + + let ms = MerkleStore::from(&mt); + + let path33 = ms.get_path(expected_root, NODE33).unwrap(); + let path22 = ms.get_path(expected_root, NODE22).unwrap(); + + let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap(); + + // get actual inner nodes + let actual: Vec = pmt.inner_nodes().collect(); + + let expected_n00 = mt.root(); + let expected_n10 = mt.get_node(NODE10).unwrap(); + let expected_n11 = mt.get_node(NODE11).unwrap(); + let expected_n20 = mt.get_node(NODE20).unwrap(); + let expected_n21 = mt.get_node(NODE21).unwrap(); + let expected_n32 = mt.get_node(NODE32).unwrap(); + let expected_n33 = mt.get_node(NODE33).unwrap(); + + // create vector of the expected inner nodes + let mut expected = vec![ + InnerNodeInfo { + value: expected_n00, + left: expected_n10, + right: expected_n11, + }, + InnerNodeInfo { + value: expected_n10, + left: expected_n20, + right: expected_n21, + }, + InnerNodeInfo { + value: expected_n21, + left: expected_n32, + right: expected_n33, + }, + ]; + + assert_eq!(actual, expected); + + // add another path to the Partial Merkle Tree + pmt.add_path(2, path22.value, path22.path).unwrap(); + + // get new actual inner nodes + let actual: Vec = pmt.inner_nodes().collect(); + + let expected_n22 = mt.get_node(NODE22).unwrap(); + let expected_n23 = mt.get_node(NODE23).unwrap(); + + let info_11 = InnerNodeInfo { + value: expected_n11, + left: expected_n22, + right: expected_n23, + }; + + // add new inner node to the existing vertor + expected.insert(2, info_11); + + assert_eq!(actual, expected); +} + +/// Checks that serialization and deserialization implementations for the PMT are working +/// correctly. +#[test] +fn serialization() { + let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap(); + let expected_root = mt.root(); + + let ms = MerkleStore::from(&mt); + + let path33 = ms.get_path(expected_root, NODE33).unwrap(); + let path22 = ms.get_path(expected_root, NODE22).unwrap(); + + let pmt = PartialMerkleTree::with_paths([ + (3, path33.value, path33.path), + (2, path22.value, path22.path), + ]) + .unwrap(); + + let serialized_pmt = pmt.to_bytes(); + let deserialized_pmt = PartialMerkleTree::read_from_bytes(&serialized_pmt).unwrap(); + + assert_eq!(deserialized_pmt, pmt); +} + +/// Checks that deserialization fails with incorrect data. +#[test] +fn err_deserialization() { + let mut tree_bytes: Vec = vec![5]; + tree_bytes.append(&mut NODE20.to_bytes()); + tree_bytes.append(&mut int_to_node(20).to_bytes()); + + tree_bytes.append(&mut NODE21.to_bytes()); + tree_bytes.append(&mut int_to_node(21).to_bytes()); + + // node with depth 1 could have index 0 or 1, but it has 2 + tree_bytes.append(&mut vec![1, 2]); + tree_bytes.append(&mut int_to_node(11).to_bytes()); + + assert!(PartialMerkleTree::read_from_bytes(&tree_bytes).is_err()); +} + /// Checks that addition of the path with different root will cause an error. #[test] fn err_add_path() { @@ -306,8 +459,5 @@ fn err_update_leaf() { let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap(); - assert!(pmt.update_leaf(NODE22, int_to_node(22)).is_err()); - assert!(pmt.update_leaf(NODE23, int_to_node(23)).is_err()); - assert!(pmt.update_leaf(NODE30, int_to_node(30)).is_err()); - assert!(pmt.update_leaf(NODE31, int_to_node(31)).is_err()); + assert!(pmt.update_leaf(8, *int_to_node(38)).is_err()); } diff --git a/src/merkle/store/mod.rs b/src/merkle/store/mod.rs index fdba5ed..d78be59 100644 --- a/src/merkle/store/mod.rs +++ b/src/merkle/store/mod.rs @@ -438,21 +438,21 @@ impl> From<&TieredSmt> for MerkleStore { impl> From for MerkleStore { fn from(values: T) -> Self { - let nodes = values.into_iter().chain(empty_hashes().into_iter()).collect(); + let nodes = values.into_iter().chain(empty_hashes()).collect(); Self { nodes } } } impl> FromIterator for MerkleStore { fn from_iter>(iter: I) -> Self { - let nodes = combine_nodes_with_empty_hashes(iter.into_iter()).collect(); + let nodes = combine_nodes_with_empty_hashes(iter).collect(); Self { nodes } } } impl> FromIterator<(RpoDigest, StoreNode)> for MerkleStore { fn from_iter>(iter: I) -> Self { - let nodes = iter.into_iter().chain(empty_hashes().into_iter()).collect(); + let nodes = iter.into_iter().chain(empty_hashes()).collect(); Self { nodes } } } @@ -553,5 +553,5 @@ fn combine_nodes_with_empty_hashes( }, ) }) - .chain(empty_hashes().into_iter()) + .chain(empty_hashes()) } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 7804420..8059d26 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -2,10 +2,10 @@ use super::{utils::string::String, Word}; use core::fmt::{self, Write}; #[cfg(not(feature = "std"))] -pub use alloc::format; +pub use alloc::{format, vec}; #[cfg(feature = "std")] -pub use std::format; +pub use std::{format, vec}; mod kv_map;