diff --git a/src/merkle/simple_smt/mod.rs b/src/merkle/simple_smt/mod.rs index 9699ea9..85fec0e 100644 --- a/src/merkle/simple_smt/mod.rs +++ b/src/merkle/simple_smt/mod.rs @@ -1,5 +1,6 @@ use super::{ - BTreeMap, EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word, + BTreeMap, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, + RpoDigest, Vec, Word, }; #[cfg(test)] @@ -15,14 +16,20 @@ pub struct SimpleSmt { depth: u8, root: Word, leaves: BTreeMap, - pub(crate) branches: BTreeMap, + branches: BTreeMap, empty_hashes: Vec, } #[derive(Debug, Default, Clone, PartialEq, Eq)] -pub(crate) struct BranchNode { - pub(crate) left: RpoDigest, - pub(crate) right: RpoDigest, +struct BranchNode { + left: RpoDigest, + right: RpoDigest, +} + +impl BranchNode { + fn parent(&self) -> RpoDigest { + Rpo256::merge(&[self.left, self.right]) + } } impl SimpleSmt { @@ -171,6 +178,15 @@ impl SimpleSmt { self.get_path(NodeIndex::new(self.depth(), key)) } + /// Iterator over the inner nodes of the [SimpleSmt]. + pub fn inner_nodes(&self) -> impl Iterator + '_ { + self.branches.values().map(|e| InnerNodeInfo { + value: e.parent().into(), + left: e.left.into(), + right: e.right.into(), + }) + } + // STATE MUTATORS // -------------------------------------------------------------------------------------------- diff --git a/src/merkle/simple_smt/tests.rs b/src/merkle/simple_smt/tests.rs index 6abd343..f7b8635 100644 --- a/src/merkle/simple_smt/tests.rs +++ b/src/merkle/simple_smt/tests.rs @@ -1,5 +1,5 @@ use super::{ - super::{int_to_node, MerkleTree, RpoDigest, SimpleSmt}, + super::{int_to_node, InnerNodeInfo, MerkleError, MerkleTree, RpoDigest, SimpleSmt}, NodeIndex, Rpo256, Vec, Word, }; use proptest::prelude::*; @@ -138,6 +138,51 @@ fn get_path() { assert_eq!(vec![node2], *tree.get_path(NodeIndex::new(1, 1)).unwrap()); } +#[test] +fn test_parent_node_iterator() -> Result<(), MerkleError> { + let tree = SimpleSmt::new(2) + .unwrap() + .with_leaves(KEYS4.into_iter().zip(VALUES4.into_iter())) + .unwrap(); + + // check depth 2 + assert_eq!(VALUES4[0], tree.get_node(&NodeIndex::new(2, 0)).unwrap()); + assert_eq!(VALUES4[1], tree.get_node(&NodeIndex::new(2, 1)).unwrap()); + assert_eq!(VALUES4[2], tree.get_node(&NodeIndex::new(2, 2)).unwrap()); + assert_eq!(VALUES4[3], tree.get_node(&NodeIndex::new(2, 3)).unwrap()); + + // get parent nodes + let root = tree.root(); + let l1n0 = tree.get_node(&NodeIndex::new(1, 0))?; + let l1n1 = tree.get_node(&NodeIndex::new(1, 1))?; + let l2n0 = tree.get_node(&NodeIndex::new(2, 0))?; + let l2n1 = tree.get_node(&NodeIndex::new(2, 1))?; + let l2n2 = tree.get_node(&NodeIndex::new(2, 2))?; + let l2n3 = tree.get_node(&NodeIndex::new(2, 3))?; + + let nodes: Vec = tree.inner_nodes().collect(); + let expected = vec![ + InnerNodeInfo { + value: root.into(), + left: l1n0.into(), + right: l1n1.into(), + }, + InnerNodeInfo { + value: l1n0.into(), + left: l2n0.into(), + right: l2n1.into(), + }, + InnerNodeInfo { + value: l1n1.into(), + left: l2n2.into(), + right: l2n3.into(), + }, + ]; + assert_eq!(nodes, expected); + + Ok(()) +} + #[test] fn update_leaf() { let mut tree = SimpleSmt::new(3) diff --git a/src/merkle/store/mod.rs b/src/merkle/store/mod.rs index 557942c..7d47912 100644 --- a/src/merkle/store/mod.rs +++ b/src/merkle/store/mod.rs @@ -281,13 +281,12 @@ impl MerkleStore { I: Iterator + ExactSizeIterator, { let smt = SimpleSmt::new(depth)?.with_leaves(entries)?; - for branch in smt.branches.values() { - let parent = Rpo256::merge(&[branch.left, branch.right]); + for node in smt.inner_nodes() { self.nodes.insert( - parent, + node.value.into(), Node { - left: branch.left, - right: branch.right, + left: node.left.into(), + right: node.right.into(), }, ); }