diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index c3cbdf6..6f7d2ab 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -101,8 +101,3 @@ impl std::error::Error for MerkleError {} const fn int_to_node(value: u64) -> Word { [Felt::new(value), ZERO, ZERO, ZERO] } - -#[cfg(test)] -const fn int_to_digest(value: u64) -> RpoDigest { - RpoDigest::new([Felt::new(value), ZERO, ZERO, ZERO]) -} diff --git a/src/merkle/partial_mt/mod.rs b/src/merkle/partial_mt/mod.rs index 26eb8a3..7115557 100644 --- a/src/merkle/partial_mt/mod.rs +++ b/src/merkle/partial_mt/mod.rs @@ -22,6 +22,7 @@ const EMPTY_DIGEST: RpoDigest = RpoDigest::new(EMPTY_WORD); /// Tree allows to create Merkle Tree by providing Merkle paths of different lengths. /// /// The root of the tree is recomputed on each new leaf update. +#[derive(Debug, Clone, PartialEq, Eq)] pub struct PartialMerkleTree { max_depth: u8, nodes: BTreeMap, @@ -112,12 +113,12 @@ impl PartialMerkleTree { /// Returns a vector of paths from every leaf to the root. pub fn paths(&self) -> Vec<(NodeIndex, ValuePath)> { let mut paths = Vec::new(); - self.leaves.iter().for_each(|leaf| { + self.leaves.iter().for_each(|&leaf| { paths.push(( - *leaf, + leaf, ValuePath { - value: *self.get_node(*leaf).expect("Failed to get leaf node"), - path: self.get_path(*leaf).expect("Failed to get path"), + value: *self.get_node(leaf).expect("Failed to get leaf node"), + path: self.get_path(leaf).expect("Failed to get path"), }, )); }); @@ -160,10 +161,10 @@ impl PartialMerkleTree { /// Returns an iterator over the leaves of this [PartialMerkleTree]. pub fn leaves(&self) -> impl Iterator + '_ { - self.leaves.iter().map(|leaf| { + self.leaves.iter().map(|&leaf| { ( - *leaf, - self.get_node(*leaf).unwrap_or_else(|_| { + leaf, + self.get_node(leaf).unwrap_or_else(|_| { panic!( "Leaf with node index ({}, {}) is not in the nodes map", leaf.depth(), @@ -214,19 +215,25 @@ impl PartialMerkleTree { self.nodes.insert(index_value, node); // if the calculated node was a leaf, remove it from leaves set. - if self.leaves.contains(&index_value) { - self.leaves.remove(&index_value); - } + self.leaves.remove(&index_value); let sibling_node = index_value.sibling(); - // node became a leaf only if it is a new node (it wasn't in nodes map) - if !self.nodes.contains_key(&sibling_node) { + + // Insert node from Merkle path to the nodes map. This sibling node becomes a leaf only + // if it is a new node (it wasn't in nodes map). + // Node can be in 3 states: internal node, leaf of the tree and not a node at all. + // - Internal node can only stay in this state -- addition of a new path can't make it + // a leaf or remove it from the tree. + // - Leaf node can stay in the same state (remain a leaf) or can become an internal + // node. In the first case we don't need to do anything, and the second case is handled + // in the line 219. + // - New node can be a calculated node or a "sibling" node from a Merkle Path: + // --- Calculated node, obviously, never can be a leaf. + // --- Sibling node can be only a leaf, because otherwise it is not a new node. + if self.nodes.insert(sibling_node, hash.into()).is_none() { self.leaves.insert(sibling_node); } - // insert node from Merkle path to the nodes map - self.nodes.insert(sibling_node, hash.into()); - Rpo256::merge(&index_value.build_node(node, hash.into())) }); @@ -238,8 +245,6 @@ impl PartialMerkleTree { return Err(MerkleError::ConflictingRoots([*self.root(), *root].to_vec())); } - // self.update_leaves()?; - Ok(()) } @@ -250,7 +255,7 @@ impl PartialMerkleTree { &mut self, node_index: NodeIndex, value: RpoDigest, - ) -> Result { + ) -> Result, MerkleError> { // check correctness of the depth and update it Self::check_depth(node_index.depth())?; self.update_depth(node_index.depth()); @@ -259,38 +264,19 @@ impl PartialMerkleTree { self.leaves.insert(node_index); // add node value to the nodes Map - let old_value = self.nodes.insert(node_index, value).unwrap_or(EMPTY_DIGEST); + let old_value = self.nodes.insert(node_index, value); // if the old value and new value are the same, there is nothing to update - if value == old_value { - return Ok(value); + if old_value.is_some() && value == old_value.unwrap() { + return Ok(old_value); } let mut node_index = node_index; let mut value = value; for _ in 0..node_index.depth() { - let is_right = node_index.is_value_odd(); - let (left, right) = if is_right { - let left_index = NodeIndex::new(node_index.depth(), node_index.value() - 1)?; - ( - self.nodes - .get(&left_index) - .cloned() - .ok_or(MerkleError::NodeNotInSet(left_index))?, - value, - ) - } else { - let right_index = NodeIndex::new(node_index.depth(), node_index.value() + 1)?; - ( - value, - self.nodes - .get(&right_index) - .cloned() - .ok_or(MerkleError::NodeNotInSet(right_index))?, - ) - }; + let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist"); + value = Rpo256::merge(&node_index.build_node(value, *sibling)); node_index.move_up(); - value = Rpo256::merge(&[left, right]); self.nodes.insert(node_index, value); } diff --git a/src/merkle/partial_mt/tests.rs b/src/merkle/partial_mt/tests.rs index 35d41ef..c612f66 100644 --- a/src/merkle/partial_mt/tests.rs +++ b/src/merkle/partial_mt/tests.rs @@ -1,18 +1,29 @@ -use crate::hash::rpo::RpoDigest; - use super::{ - super::{int_to_digest, int_to_node, NodeIndex}, - PartialMerkleTree, Rpo256, + super::{int_to_node, MerkleStore, MerkleTree, NodeIndex, PartialMerkleTree}, + Word, }; // TEST DATA // ================================================================================================ +const NODE10: NodeIndex = NodeIndex::new_unchecked(1, 0); + const NODE22: NodeIndex = NodeIndex::new_unchecked(2, 2); const NODE32: NodeIndex = NodeIndex::new_unchecked(3, 2); const NODE33: NodeIndex = NodeIndex::new_unchecked(3, 3); +const VALUES8: [Word; 8] = [ + int_to_node(1), + int_to_node(2), + int_to_node(3), + int_to_node(4), + int_to_node(5), + int_to_node(6), + int_to_node(7), + int_to_node(8), +]; + // TESTS // ================================================================================================ @@ -21,107 +32,92 @@ const NODE33: NodeIndex = NodeIndex::new_unchecked(3, 3); #[test] fn get_root() { - let leaf0 = int_to_digest(0); - let leaf1 = int_to_digest(1); - let leaf2 = int_to_digest(2); - let leaf3 = int_to_digest(3); + let mt = MerkleTree::new(VALUES8.to_vec()).unwrap(); + let expected_root = mt.root(); - let parent0 = calculate_parent_hash(leaf0, 0, leaf1); - let parent1 = calculate_parent_hash(leaf2, 2, leaf3); + let mut store = MerkleStore::new(); + let ms = MerkleStore::extend(&mut store, mt.inner_nodes()); - let root_exp = calculate_parent_hash(parent0, 0, parent1); + let path33 = ms.get_path(expected_root, NODE33).unwrap(); - let set = - super::PartialMerkleTree::with_paths([(0, leaf0, vec![*leaf1, *parent1].into())]).unwrap(); + let pmt = PartialMerkleTree::with_paths([(3_u64, path33.value.into(), path33.path)]).unwrap(); - assert_eq!(set.root(), root_exp); + assert_eq!(pmt.root(), expected_root.into()); } #[test] fn add_and_get_paths() { - let value32 = int_to_digest(32); - let value33 = int_to_digest(33); - let value20 = int_to_digest(20); - let value22 = int_to_digest(22); - let value23 = int_to_digest(23); - - let value21 = Rpo256::merge(&[value32, value33]); - let value10 = Rpo256::merge(&[value20, value21]); - let value11 = Rpo256::merge(&[value22, value23]); + let mt = MerkleTree::new(VALUES8.to_vec()).unwrap(); + let expected_root = mt.root(); - let path_33 = vec![*value32, *value20, *value11]; + let mut store = MerkleStore::new(); + let ms = MerkleStore::extend(&mut store, mt.inner_nodes()); - let path_22 = vec![*value23, *value10]; + let expected_path33 = ms.get_path(expected_root, NODE33).unwrap(); + let expected_path22 = ms.get_path(expected_root, NODE22).unwrap(); let pmt = PartialMerkleTree::with_paths([ - (3, value33, path_33.clone().into()), - (2, value22, path_22.clone().into()), + (3_u64, expected_path33.value.into(), expected_path33.path.clone()), + (2, expected_path22.value.into(), expected_path22.path.clone()), ]) .unwrap(); - let stored_path_33 = pmt.get_path(NODE33).unwrap(); - let stored_path_22 = pmt.get_path(NODE22).unwrap(); - assert_eq!(path_33, *stored_path_33); - assert_eq!(path_22, *stored_path_22); + let path33 = pmt.get_path(NODE33).unwrap(); + let path22 = pmt.get_path(NODE22).unwrap(); + + assert_eq!(expected_path33.path, path33); + assert_eq!(expected_path22.path, path22); } #[test] fn get_node() { - let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)]; - let hash_6 = int_to_digest(6); - let index = NodeIndex::make(3, 6); - let pmt = PartialMerkleTree::with_paths([(index.value(), hash_6, path_6.into())]).unwrap(); + let mt = MerkleTree::new(VALUES8.to_vec()).unwrap(); + let expected_root = mt.root(); + + let mut store = MerkleStore::new(); + let ms = MerkleStore::extend(&mut store, mt.inner_nodes()); + + let path33 = ms.get_path(expected_root, NODE33).unwrap(); - assert_eq!(int_to_digest(6u64), pmt.get_node(index).unwrap()); + let pmt = PartialMerkleTree::with_paths([(3_u64, path33.value.into(), path33.path)]).unwrap(); + + assert_eq!(ms.get_node(expected_root, NODE32).unwrap(), *pmt.get_node(NODE32).unwrap()); + assert_eq!(ms.get_node(expected_root, NODE10).unwrap(), *pmt.get_node(NODE10).unwrap()); } #[test] fn update_leaf() { - let value32 = int_to_digest(32); - let value33 = int_to_digest(33); - let value20 = int_to_digest(20); - let value22 = int_to_digest(22); - let value23 = int_to_digest(23); - - let value21 = Rpo256::merge(&[value32, value33]); - let value10 = Rpo256::merge(&[value20, value21]); - let value11 = Rpo256::merge(&[value22, value23]); + let mut mt = MerkleTree::new(VALUES8.to_vec()).unwrap(); + let root = mt.root(); - let path_33 = vec![*value32, *value20, *value11]; - - let path_22 = vec![*value23, *value10]; + let mut store = MerkleStore::new(); + let ms = MerkleStore::extend(&mut store, mt.inner_nodes()); + let path33 = ms.get_path(root, NODE33).unwrap(); let mut pmt = - PartialMerkleTree::with_paths([(3, value33, path_33.into()), (2, value22, path_22.into())]) - .unwrap(); - - let new_value32 = int_to_digest(132); - let new_value21 = Rpo256::merge(&[new_value32, value33]); - let new_value10 = Rpo256::merge(&[value20, new_value21]); - let expected_root = Rpo256::merge(&[new_value10, value11]); + PartialMerkleTree::with_paths([(3_u64, path33.value.into(), path33.path)]).unwrap(); - let old_leaf = pmt.update_leaf(NODE32, new_value32).unwrap(); + let new_value32 = int_to_node(132); + mt.update_leaf(2_u64, new_value32).unwrap(); + let expected_root = mt.root(); - assert_eq!(value32, old_leaf); + pmt.update_leaf(NODE32, new_value32.into()).unwrap(); + let actual_root = pmt.root(); - let new_root = pmt.root(); - - assert_eq!(new_root, expected_root); + assert_eq!(expected_root, *actual_root); } #[test] fn check_leaf_depth() { - let value32 = int_to_digest(32); - let value33 = int_to_digest(33); - let value20 = int_to_digest(20); - let value22 = int_to_digest(22); - let value23 = int_to_digest(23); + let mt = MerkleTree::new(VALUES8.to_vec()).unwrap(); + let expected_root = mt.root(); - let value11 = Rpo256::merge(&[value22, value23]); + let mut store = MerkleStore::new(); + let ms = MerkleStore::extend(&mut store, mt.inner_nodes()); - let path_33 = vec![*value32, *value20, *value11]; + let path33 = ms.get_path(expected_root, NODE33).unwrap(); - let pmt = PartialMerkleTree::with_paths([(3, value33, path_33.into())]).unwrap(); + let pmt = PartialMerkleTree::with_paths([(3_u64, path33.value.into(), path33.path)]).unwrap(); assert_eq!(pmt.get_leaf_depth(0).unwrap(), 2); assert_eq!(pmt.get_leaf_depth(1).unwrap(), 2); @@ -131,23 +127,8 @@ fn check_leaf_depth() { assert_eq!(pmt.get_leaf_depth(5).unwrap(), 1); assert_eq!(pmt.get_leaf_depth(6).unwrap(), 1); assert_eq!(pmt.get_leaf_depth(7).unwrap(), 1); + assert!(pmt.get_leaf_depth(8).is_err()); } // TODO: add test for add_path function and check correctness of leaf determination (requires // inner_nodes iter) - -// HELPER FUNCTIONS -// -------------------------------------------------------------------------------------------- - -/// Calculates the hash of the parent node by two sibling ones -/// - node — current node -/// - node_pos — position of the current node -/// - sibling — neighboring vertex in the tree -fn calculate_parent_hash(node: RpoDigest, node_pos: u64, sibling: RpoDigest) -> RpoDigest { - let parity = node_pos & 1; - if parity == 0 { - Rpo256::merge(&[node, sibling]) - } else { - Rpo256::merge(&[sibling, node]) - } -}