From ebf71c2dc7c1981bf1524aa11566c946136b586a Mon Sep 17 00:00:00 2001 From: Andrey Khmuro Date: Fri, 2 Jun 2023 21:57:33 +0300 Subject: [PATCH] refactor: optimize code, remove not momentarily necessary functions --- src/merkle/mod.rs | 5 + src/merkle/partial_mt/mod.rs | 197 +++++++++++++-------------------- src/merkle/partial_mt/tests.rs | 145 +++++++----------------- 3 files changed, 118 insertions(+), 229 deletions(-) diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 6f7d2ab..c3cbdf6 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -101,3 +101,8 @@ impl std::error::Error for MerkleError {} const fn int_to_node(value: u64) -> Word { [Felt::new(value), ZERO, ZERO, ZERO] } + +#[cfg(test)] +const fn int_to_digest(value: u64) -> RpoDigest { + RpoDigest::new([Felt::new(value), ZERO, ZERO, ZERO]) +} diff --git a/src/merkle/partial_mt/mod.rs b/src/merkle/partial_mt/mod.rs index 8b793b6..be46802 100644 --- a/src/merkle/partial_mt/mod.rs +++ b/src/merkle/partial_mt/mod.rs @@ -1,15 +1,25 @@ use super::{ - BTreeMap, BTreeSet, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, + BTreeMap, BTreeSet, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, ValuePath, Vec, Word, EMPTY_WORD, }; #[cfg(test)] mod tests; +// CONSTANTS +// ================================================================================================ + +/// Index of the root node. +const ROOT_INDEX: NodeIndex = NodeIndex::root(); + +/// An RpoDigest consisting of 4 ZERO elements. +const EMPTY_DIGEST: RpoDigest = RpoDigest::new(EMPTY_WORD); + // PARTIAL MERKLE TREE // ================================================================================================ -/// A partial Merkle tree with NodeIndex keys and 4-element RpoDigest leaf values. +/// A partial Merkle tree with NodeIndex keys and 4-element RpoDigest leaf values. Partial Merkle +/// Tree allows to create Merkle Tree by providing Merkle paths of different lengths. /// /// The root of the tree is recomputed on each new leaf update. pub struct PartialMerkleTree { @@ -28,17 +38,12 @@ impl PartialMerkleTree { // CONSTANTS // -------------------------------------------------------------------------------------------- - /// An RpoDigest consisting of 4 ZERO elements. - pub const EMPTY_DIGEST: RpoDigest = RpoDigest::new(EMPTY_WORD); - /// Minimum supported depth. pub const MIN_DEPTH: u8 = 1; /// Maximum supported depth. pub const MAX_DEPTH: u8 = 64; - pub const ROOT_INDEX: NodeIndex = NodeIndex::new_unchecked(0, 0); - // CONSTRUCTORS // -------------------------------------------------------------------------------------------- @@ -56,7 +61,7 @@ impl PartialMerkleTree { /// Analogous to [Self::add_path]. pub fn with_paths(paths: I) -> Result where - I: IntoIterator, + I: IntoIterator, { // create an empty tree let tree = PartialMerkleTree::new(); @@ -71,8 +76,8 @@ impl PartialMerkleTree { // -------------------------------------------------------------------------------------------- /// Returns the root of this Merkle tree. - pub fn root(&self) -> Word { - *self.nodes.get(&Self::ROOT_INDEX).cloned().unwrap_or(Self::EMPTY_DIGEST) + pub fn root(&self) -> RpoDigest { + self.nodes.get(&ROOT_INDEX).cloned().unwrap_or(EMPTY_DIGEST) } /// Returns the depth of this Merkle tree. @@ -101,38 +106,22 @@ impl PartialMerkleTree { } node_index.move_up() } - // we don't have an error for this case, maybe it makes sense to create a new error, something like - // NoLeafForIndex("There is no leaf for provided index"). But it will be used almost never. - Err(MerkleError::NodeNotInSet(node_index)) - } - - /// Returns a value of the leaf at the specified NodeIndex. - /// - /// # Errors - /// Returns an error if the NodeIndex is not contained in the leaves set. - pub fn get_leaf(&self, index: NodeIndex) -> Result { - if !self.leaves.contains(&index) { - // This error not really suitable in this situation, should I create a new error? - Err(MerkleError::InvalidIndex { - depth: index.depth(), - value: index.value(), - }) - } else { - self.nodes - .get(&index) - .ok_or(MerkleError::NodeNotInSet(index)) - .map(|hash| **hash) - } + Ok(0_u8) } - /// Returns a map of paths from every leaf to the root. - pub fn paths(&self) -> Result, MerkleError> { - let mut paths = BTreeMap::new(); - for leaf_index in self.leaves.iter() { - let index = *leaf_index; - paths.insert(leaf_index, self.get_path(index)?); - } - Ok(paths) + /// Returns a vector of paths from every leaf to the root. + pub fn paths(&self) -> Vec<(NodeIndex, ValuePath)> { + let mut paths = Vec::new(); + self.leaves.iter().for_each(|leaf| { + paths.push(( + *leaf, + ValuePath { + value: *self.get_node(*leaf).expect("Failed to get leaf node"), + path: self.get_path(*leaf).expect("Failed to get path"), + }, + )); + }); + paths } /// Returns a Merkle path from the node at the specified index to the root. @@ -157,11 +146,11 @@ impl PartialMerkleTree { let mut path = Vec::new(); for _ in 0..index.depth() { - let sibling_index = Self::get_sibling_index(&index)?; + let sibling_index = index.sibling(); index.move_up(); - let sibling_hash = - self.nodes.get(&sibling_index).cloned().unwrap_or(Self::EMPTY_DIGEST); - path.push(Word::from(sibling_hash)); + let sibling = + self.nodes.get(&sibling_index).cloned().expect("Sibling node not in the map"); + path.push(Word::from(sibling)); } Ok(MerklePath::new(path)) } @@ -170,28 +159,18 @@ impl PartialMerkleTree { // -------------------------------------------------------------------------------------------- /// Returns an iterator over the leaves of this [PartialMerkleTree]. - pub fn leaves(&self) -> impl Iterator { - self.nodes - .iter() - .filter(|(index, _)| self.leaves.contains(index)) - .map(|(index, hash)| (*index, &(**hash))) - } - - /// Returns an iterator over the inner nodes of this Merkle tree. - pub fn inner_nodes(&self) -> impl Iterator + '_ { - let inner_nodes = self.nodes.iter().filter(|(index, _)| !self.leaves.contains(index)); - inner_nodes.map(|(index, digest)| { - let left_index = NodeIndex::new(index.depth() + 1, index.value() * 2) - .expect("Failure to get left child index"); - let right_index = NodeIndex::new(index.depth() + 1, index.value() * 2 + 1) - .expect("Failure to get right child index"); - let left_hash = self.nodes.get(&left_index).cloned().unwrap_or(Self::EMPTY_DIGEST); - let right_hash = self.nodes.get(&right_index).cloned().unwrap_or(Self::EMPTY_DIGEST); - InnerNodeInfo { - value: **digest, - left: *left_hash, - right: *right_hash, - } + pub fn leaves(&self) -> impl Iterator + '_ { + self.leaves.iter().map(|leaf| { + ( + *leaf, + self.get_node(*leaf).unwrap_or_else(|_| { + panic!( + "Leaf with node index ({}, {}) is not in the nodes map", + leaf.depth(), + leaf.value() + ) + }), + ) }) } @@ -208,55 +187,60 @@ impl PartialMerkleTree { /// different root). pub fn add_path( &mut self, - index_value: NodeIndex, - value: Word, - mut path: MerklePath, + index_value: u64, + value: RpoDigest, + path: MerklePath, ) -> Result<(), MerkleError> { + let index_value = NodeIndex::new(path.len() as u8, index_value)?; + Self::check_depth(index_value.depth())?; self.update_depth(index_value.depth()); - // add node index to the leaves set + // add provided node and its sibling to the leaves set self.leaves.insert(index_value); - let sibling_node_index = Self::get_sibling_index(&index_value)?; + let sibling_node_index = index_value.sibling(); self.leaves.insert(sibling_node_index); - // add first two nodes to the nodes map - self.nodes.insert(index_value, value.into()); + // add provided node and its sibling to the nodes map + self.nodes.insert(index_value, value); self.nodes.insert(sibling_node_index, path[0].into()); - // update the current path - let parity = index_value.value() & 1; - path.insert(parity as usize, value); - // traverse to the root, updating the nodes let mut index_value = index_value; - let root = Rpo256::merge(&[path[0].into(), path[1].into()]); - let root = path.iter().skip(2).copied().fold(root, |root, hash| { + let node = Rpo256::merge(&index_value.build_node(value, path[0].into())); + let root = path.iter().skip(1).copied().fold(node, |node, hash| { index_value.move_up(); // insert calculated node to the nodes map - self.nodes.insert(index_value, root); + self.nodes.insert(index_value, node); - let sibling_node = Self::get_sibling_index_unchecked(&index_value); - // assume for now that all path nodes are leaves and add them to the leaves set - self.leaves.insert(sibling_node); + let sibling_node = index_value.sibling(); + // node became a leaf only if it is a new node (it wasn't in nodes map) + if !self.nodes.contains_key(&sibling_node) { + self.leaves.insert(sibling_node); + } + + // node stops being a leaf if the path contains a node which is a child of this leaf + let mut parent = index_value; + parent.move_up(); + if self.leaves.contains(&parent) { + self.leaves.remove(&parent); + } // insert node from Merkle path to the nodes map self.nodes.insert(sibling_node, hash.into()); - Rpo256::merge(&index_value.build_node(root, hash.into())) + Rpo256::merge(&index_value.build_node(node, hash.into())) }); - let old_root = self.nodes.get(&Self::ROOT_INDEX).cloned().unwrap_or(Self::EMPTY_DIGEST); - // if the path set is empty (the root is all ZEROs), set the root to the root of the added // path; otherwise, the root of the added path must be identical to the current root - if old_root == Self::EMPTY_DIGEST { - self.nodes.insert(Self::ROOT_INDEX, root); - } else if old_root != root { - return Err(MerkleError::ConflictingRoots([*old_root, *root].to_vec())); + if self.root() == EMPTY_DIGEST { + self.nodes.insert(ROOT_INDEX, root); + } else if self.root() != root { + return Err(MerkleError::ConflictingRoots([*self.root(), *root].to_vec())); } - self.update_leaves()?; + // self.update_leaves()?; Ok(()) } @@ -277,7 +261,7 @@ impl PartialMerkleTree { self.leaves.insert(node_index); // add node value to the nodes Map - let old_value = self.nodes.insert(node_index, value).unwrap_or(Self::EMPTY_DIGEST); + let old_value = self.nodes.insert(node_index, value).unwrap_or(EMPTY_DIGEST); // if the old value and new value are the same, there is nothing to update if value == old_value { @@ -333,33 +317,4 @@ impl PartialMerkleTree { } Ok(()) } - - fn get_sibling_index(node_index: &NodeIndex) -> Result { - if node_index.is_value_odd() { - NodeIndex::new(node_index.depth(), node_index.value() - 1) - } else { - NodeIndex::new(node_index.depth(), node_index.value() + 1) - } - } - - fn get_sibling_index_unchecked(node_index: &NodeIndex) -> NodeIndex { - if node_index.is_value_odd() { - NodeIndex::new_unchecked(node_index.depth(), node_index.value() - 1) - } else { - NodeIndex::new_unchecked(node_index.depth(), node_index.value() + 1) - } - } - - // Removes from the leaves set indexes of nodes which have descendants. - fn update_leaves(&mut self) -> Result<(), MerkleError> { - for leaf_node in self.leaves.clone().iter() { - let left_child = NodeIndex::new(leaf_node.depth() + 1, leaf_node.value() * 2)?; - let right_child = NodeIndex::new(leaf_node.depth() + 1, leaf_node.value() * 2 + 1)?; - if self.nodes.contains_key(&left_child) || self.nodes.contains_key(&right_child) { - self.leaves.remove(leaf_node); - } - } - - Ok(()) - } } diff --git a/src/merkle/partial_mt/tests.rs b/src/merkle/partial_mt/tests.rs index 2a57ec8..1efbb6a 100644 --- a/src/merkle/partial_mt/tests.rs +++ b/src/merkle/partial_mt/tests.rs @@ -1,22 +1,14 @@ use crate::hash::rpo::RpoDigest; use super::{ - super::{int_to_node, NodeIndex}, - InnerNodeInfo, MerkleError, PartialMerkleTree, Rpo256, Vec, Word, + super::{int_to_digest, int_to_node, NodeIndex}, + PartialMerkleTree, Rpo256, }; // TEST DATA // ================================================================================================ -const ROOT_NODE: NodeIndex = NodeIndex::new_unchecked(0, 0); - -const NODE10: NodeIndex = NodeIndex::new_unchecked(1, 0); -const NODE11: NodeIndex = NodeIndex::new_unchecked(1, 1); - -const NODE20: NodeIndex = NodeIndex::new_unchecked(2, 0); -const NODE21: NodeIndex = NodeIndex::new_unchecked(2, 1); const NODE22: NodeIndex = NodeIndex::new_unchecked(2, 2); -const NODE23: NodeIndex = NodeIndex::new_unchecked(2, 3); const NODE32: NodeIndex = NodeIndex::new_unchecked(3, 2); const NODE33: NodeIndex = NodeIndex::new_unchecked(3, 3); @@ -29,29 +21,29 @@ const NODE33: NodeIndex = NodeIndex::new_unchecked(3, 3); #[test] fn get_root() { - let leaf0 = int_to_node(0); - let leaf1 = int_to_node(1); - let leaf2 = int_to_node(2); - let leaf3 = int_to_node(3); + let leaf0 = int_to_digest(0); + let leaf1 = int_to_digest(1); + let leaf2 = int_to_digest(2); + let leaf3 = int_to_digest(3); let parent0 = calculate_parent_hash(leaf0, 0, leaf1); let parent1 = calculate_parent_hash(leaf2, 2, leaf3); let root_exp = calculate_parent_hash(parent0, 0, parent1); - let set = super::PartialMerkleTree::with_paths([(NODE20, leaf0, vec![leaf1, parent1].into())]) - .unwrap(); + let set = + super::PartialMerkleTree::with_paths([(0, leaf0, vec![*leaf1, *parent1].into())]).unwrap(); assert_eq!(set.root(), root_exp); } #[test] fn add_and_get_paths() { - let value32 = int_to_node(32).into(); - let value33 = int_to_node(33).into(); - let value20 = int_to_node(20).into(); - let value22 = int_to_node(22).into(); - let value23 = int_to_node(23).into(); + let value32 = int_to_digest(32); + let value33 = int_to_digest(33); + let value20 = int_to_digest(20); + let value22 = int_to_digest(22); + let value23 = int_to_digest(23); let value21 = Rpo256::merge(&[value32, value33]); let value10 = Rpo256::merge(&[value20, value21]); @@ -62,8 +54,8 @@ fn add_and_get_paths() { let path_22 = vec![*value23, *value10]; let pmt = PartialMerkleTree::with_paths([ - (NODE33, *value33, path_33.clone().into()), - (NODE22, *value22, path_22.clone().into()), + (3, value33, path_33.clone().into()), + (2, value22, path_22.clone().into()), ]) .unwrap(); let stored_path_33 = pmt.get_path(NODE33).unwrap(); @@ -76,20 +68,20 @@ fn add_and_get_paths() { #[test] fn get_node() { let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)]; - let hash_6 = int_to_node(6); + let hash_6 = int_to_digest(6); let index = NodeIndex::make(3, 6); - let pmt = PartialMerkleTree::with_paths([(index, hash_6, path_6.into())]).unwrap(); + let pmt = PartialMerkleTree::with_paths([(index.value(), hash_6, path_6.into())]).unwrap(); - assert_eq!(int_to_node(6u64), *pmt.get_node(index).unwrap()); + assert_eq!(int_to_digest(6u64), pmt.get_node(index).unwrap()); } #[test] fn update_leaf() { - let value32 = int_to_node(32).into(); - let value33 = int_to_node(33).into(); - let value20 = int_to_node(20).into(); - let value22 = int_to_node(22).into(); - let value23 = int_to_node(23).into(); + let value32 = int_to_digest(32); + let value33 = int_to_digest(33); + let value20 = int_to_digest(20); + let value22 = int_to_digest(22); + let value23 = int_to_digest(23); let value21 = Rpo256::merge(&[value32, value33]); let value10 = Rpo256::merge(&[value20, value21]); @@ -99,13 +91,11 @@ fn update_leaf() { let path_22 = vec![*value23, *value10]; - let mut pmt = PartialMerkleTree::with_paths([ - (NODE33, *value33, path_33.into()), - (NODE22, *value22, path_22.into()), - ]) - .unwrap(); + let mut pmt = + PartialMerkleTree::with_paths([(3, value33, path_33.into()), (2, value22, path_22.into())]) + .unwrap(); - let new_value32 = int_to_node(132).into(); + let new_value32 = int_to_digest(132); let new_value21 = Rpo256::merge(&[new_value32, value33]); let new_value10 = Rpo256::merge(&[value20, new_value21]); let expected_root = Rpo256::merge(&[new_value10, value11]); @@ -116,83 +106,22 @@ fn update_leaf() { let new_root = pmt.root(); - assert_eq!(new_root, *expected_root); -} - -#[test] -fn test_inner_node_iterator() -> Result<(), MerkleError> { - let value32 = int_to_node(32).into(); - let value33 = int_to_node(33).into(); - let value20 = int_to_node(20).into(); - let value22 = int_to_node(22).into(); - let value23 = int_to_node(23).into(); - - let value21 = Rpo256::merge(&[value32, value33]); - let value10 = Rpo256::merge(&[value20, value21]); - let value11 = Rpo256::merge(&[value22, value23]); - let root = Rpo256::merge(&[value10, value11]); - - let path_33 = vec![*value32, *value20, *value11]; - - let path_22 = vec![*value23, *value10]; - - let pmt = PartialMerkleTree::with_paths([ - (NODE33, *value33, path_33.into()), - (NODE22, *value22, path_22.into()), - ]) - .unwrap(); - - assert_eq!(root, pmt.get_node(ROOT_NODE).unwrap()); - assert_eq!(value10, pmt.get_node(NODE10).unwrap()); - assert_eq!(value11, pmt.get_node(NODE11).unwrap()); - assert_eq!(value20, pmt.get_node(NODE20).unwrap()); - assert_eq!(value21, pmt.get_node(NODE21).unwrap()); - assert_eq!(value22, pmt.get_node(NODE22).unwrap()); - assert_eq!(value23, pmt.get_node(NODE23).unwrap()); - assert_eq!(value32, pmt.get_node(NODE32).unwrap()); - assert_eq!(value33, pmt.get_node(NODE33).unwrap()); - - let nodes: Vec = pmt.inner_nodes().collect(); - let expected = vec![ - InnerNodeInfo { - value: *root, - left: *value10, - right: *value11, - }, - InnerNodeInfo { - value: *value10, - left: *value20, - right: *value21, - }, - InnerNodeInfo { - value: *value11, - left: *value22, - right: *value23, - }, - InnerNodeInfo { - value: *value21, - left: *value32, - right: *value33, - }, - ]; - assert_eq!(nodes, expected); - - Ok(()) + assert_eq!(new_root, expected_root); } #[test] fn check_leaf_depth() { - let value32: RpoDigest = int_to_node(32).into(); - let value33: RpoDigest = int_to_node(33).into(); - let value20: RpoDigest = int_to_node(20).into(); - let value22 = int_to_node(22).into(); - let value23 = int_to_node(23).into(); + let value32 = int_to_digest(32); + let value33 = int_to_digest(33); + let value20 = int_to_digest(20); + let value22 = int_to_digest(22); + let value23 = int_to_digest(23); let value11 = Rpo256::merge(&[value22, value23]); let path_33 = vec![*value32, *value20, *value11]; - let pmt = PartialMerkleTree::with_paths([(NODE33, *value33, path_33.into())]).unwrap(); + let pmt = PartialMerkleTree::with_paths([(3, value33, path_33.into())]).unwrap(); assert_eq!(pmt.get_leaf_depth(0).unwrap(), 2); assert_eq!(pmt.get_leaf_depth(1).unwrap(), 2); @@ -211,11 +140,11 @@ fn check_leaf_depth() { /// - node — current node /// - node_pos — position of the current node /// - sibling — neighboring vertex in the tree -fn calculate_parent_hash(node: Word, node_pos: u64, sibling: Word) -> Word { +fn calculate_parent_hash(node: RpoDigest, node_pos: u64, sibling: RpoDigest) -> RpoDigest { let parity = node_pos & 1; if parity == 0 { - Rpo256::merge(&[node.into(), sibling.into()]).into() + Rpo256::merge(&[node, sibling]) } else { - Rpo256::merge(&[sibling.into(), node.into()]).into() + Rpo256::merge(&[sibling, node]) } }