diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 631b960..6f7d2ab 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -39,6 +39,9 @@ pub use store::MerkleStore; mod node; pub use node::InnerNodeInfo; +mod partial_mt; +pub use partial_mt::PartialMerkleTree; + // ERRORS // ================================================================================================ diff --git a/src/merkle/partial_mt/mod.rs b/src/merkle/partial_mt/mod.rs new file mode 100644 index 0000000..ea5fa9b --- /dev/null +++ b/src/merkle/partial_mt/mod.rs @@ -0,0 +1,270 @@ +use super::{ + BTreeMap, BTreeSet, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, + Word, EMPTY_WORD, +}; + +#[cfg(test)] +mod tests; + +// PARTIAL MERKLE TREE +// ================================================================================================ + +/// A partial Merkle tree with NodeIndex keys and 4-element RpoDigest leaf values. +/// +/// The root of the tree is recomputed on each new leaf update. +pub struct PartialMerkleTree { + root: RpoDigest, + max_depth: u8, + nodes: BTreeMap, + leaves: BTreeSet, +} + +impl Default for PartialMerkleTree { + fn default() -> Self { + Self::new() + } +} + +impl PartialMerkleTree { + // CONSTANTS + // -------------------------------------------------------------------------------------------- + + /// An RpoDigest consisting of 4 ZERO elements. + pub const EMPTY_DIGEST: RpoDigest = RpoDigest::new(EMPTY_WORD); + + /// Minimum supported depth. + pub const MIN_DEPTH: u8 = 1; + + /// Maximum supported depth. + pub const MAX_DEPTH: u8 = 64; + + // CONSTRUCTORS + // -------------------------------------------------------------------------------------------- + + /// Returns a new emply [PartialMerkleTree]. + pub fn new() -> Self { + PartialMerkleTree { + root: Self::EMPTY_DIGEST, + max_depth: 0, + nodes: BTreeMap::new(), + leaves: BTreeSet::new(), + } + } + + /// Returns a new [PartialMerkleTree] instantiated with leaves set as specified by the provided + /// entries. + /// + /// # Errors + /// Returns an error if: + /// - If the depth is 0 or is greater than 64. + /// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}. + /// - The provided entries contain multiple values for the same key. + pub fn with_leaves(entries: R) -> Result + where + R: IntoIterator, + I: Iterator + ExactSizeIterator, + { + // create an empty tree + let mut tree = PartialMerkleTree::new(); + + // check if the number of leaves can be accommodated by the tree's depth; we use a min + // depth of 63 because we consider passing in a vector of size 2^64 infeasible. + let entries = entries.into_iter(); + let max = (1_u64 << 63) as usize; + if entries.len() > max { + return Err(MerkleError::InvalidNumEntries(max, entries.len())); + } + + for (node_index, rpo_digest) in entries { + let old_value = tree.update_leaf(node_index, rpo_digest)?; + if old_value != Self::EMPTY_DIGEST { + return Err(MerkleError::DuplicateValuesForIndex(node_index.value())); + } + } + Ok(tree) + } + + // PUBLIC ACCESSORS + // -------------------------------------------------------------------------------------------- + + /// Returns the root of this Merkle tree. + pub fn root(&self) -> Word { + self.root.into() + } + + /// Returns the depth of this Merkle tree. + // TODO: maybe it's better to rename it to the `max_depth` + pub fn depth(&self) -> u8 { + self.max_depth + } + + /// Returns a node at the specified NodeIndex. + /// + /// # Errors + /// Returns an error if the specified NodeIndex is not contained in the nodes map. + pub fn get_node(&self, index: NodeIndex) -> Result { + self.nodes + .get(&index) + .ok_or(MerkleError::NodeNotInSet(index)) + .map(|hash| **hash) + } + + /// Returns a value of the leaf at the specified NodeIndex. + /// + /// # Errors + /// Returns an error if the NodeIndex is not contained in the leaves set. + pub fn get_leaf(&self, index: NodeIndex) -> Result { + if !self.leaves.contains(&index) { + // This error not really suitable in this situation, should I create a new error? + Err(MerkleError::InvalidIndex { + depth: index.depth(), + value: index.value(), + }) + } else { + self.nodes + .get(&index) + .ok_or(MerkleError::NodeNotInSet(index)) + .map(|hash| **hash) + } + } + + /// Returns a map of the all + pub fn paths(&self) -> Result, MerkleError> { + let mut paths = BTreeMap::new(); + for leaf_index in self.leaves.iter() { + let index = *leaf_index; + paths.insert(leaf_index, self.get_path(index)?); + } + Ok(paths) + } + + /// Returns a Merkle path from the node at the specified index to the root. + /// + /// The node itself is not included in the path. + /// + /// # Errors + /// Returns an error if: + /// - the specified index has depth set to 0 or the depth is greater than the depth of this + /// Merkle tree. + /// - the specified index is not contained in the nodes map. + pub fn get_path(&self, mut index: NodeIndex) -> Result { + if index.is_root() { + return Err(MerkleError::DepthTooSmall(index.depth())); + } else if index.depth() > self.depth() { + return Err(MerkleError::DepthTooBig(index.depth() as u64)); + } + + if !self.nodes.contains_key(&index) { + return Err(MerkleError::NodeNotInSet(index)); + } + + let mut path = Vec::new(); + for _ in 0..index.depth() { + let is_right = index.is_value_odd(); + let sibling_index = if is_right { + NodeIndex::new(index.depth(), index.value() - 1)? + } else { + NodeIndex::new(index.depth(), index.value() + 1)? + }; + index.move_up(); + let sibling_hash = + self.nodes.get(&sibling_index).cloned().unwrap_or(Self::EMPTY_DIGEST); + path.push(Word::from(sibling_hash)); + } + Ok(MerklePath::new(path)) + } + + // ITERATORS + // -------------------------------------------------------------------------------------------- + + /// Returns an iterator over the leaves of this [PartialMerkleTree]. + pub fn leaves(&self) -> impl Iterator { + self.nodes + .iter() + .filter(|(index, _)| self.leaves.contains(index)) + .map(|(index, hash)| (*index, &(**hash))) + } + + /// Returns an iterator over the inner nodes of this Merkle tree. + pub fn inner_nodes(&self) -> impl Iterator + '_ { + let inner_nodes = self.nodes.iter().filter(|(index, _)| !self.leaves.contains(index)); + inner_nodes.map(|(index, digest)| { + let left_index = NodeIndex::new(index.depth() + 1, index.value() * 2) + .expect("Failure to get left child index"); + let right_index = NodeIndex::new(index.depth() + 1, index.value() * 2 + 1) + .expect("Failure to get right child index"); + let left_hash = self.nodes.get(&left_index).cloned().unwrap_or(Self::EMPTY_DIGEST); + let right_hash = self.nodes.get(&right_index).cloned().unwrap_or(Self::EMPTY_DIGEST); + InnerNodeInfo { + value: **digest, + left: *left_hash, + right: *right_hash, + } + }) + } + + // STATE MUTATORS + // -------------------------------------------------------------------------------------------- + + /// Updates value of the leaf at the specified index returning the old leaf value. + /// + /// This also recomputes all hashes between the leaf and the root, updating the root itself. + pub fn update_leaf( + &mut self, + node_index: NodeIndex, + value: RpoDigest, + ) -> Result { + // check correctness of the depth and update it + Self::check_depth(node_index.depth())?; + self.update_depth(node_index.depth()); + + // insert NodeIndex to the leaves Set + self.leaves.insert(node_index); + + // add node value to the nodes Map + let old_value = self.nodes.insert(node_index, value).unwrap_or(Self::EMPTY_DIGEST); + + // if the old value and new value are the same, there is nothing to update + if value == old_value { + return Ok(value); + } + + let mut node_index = node_index; + let mut value = value; + for _ in 0..node_index.depth() { + let is_right = node_index.is_value_odd(); + let (left, right) = if is_right { + let left_index = NodeIndex::new(node_index.depth(), node_index.value() - 1)?; + (self.nodes.get(&left_index).cloned().unwrap_or(Self::EMPTY_DIGEST), value) + } else { + let right_index = NodeIndex::new(node_index.depth(), node_index.value() + 1)?; + (value, self.nodes.get(&right_index).cloned().unwrap_or(Self::EMPTY_DIGEST)) + }; + node_index.move_up(); + value = Rpo256::merge(&[left, right]); + self.nodes.insert(node_index, value); + } + + self.root = value; + Ok(old_value) + } + + // HELPER METHODS + // -------------------------------------------------------------------------------------------- + + /// Updates depth value with the maximum of current and provided depth. + fn update_depth(&mut self, new_depth: u8) { + self.max_depth = new_depth.max(self.max_depth); + } + + /// Returns an error if the depth is 0 or is greater than 64. + fn check_depth(depth: u8) -> Result<(), MerkleError> { + // validate the range of the depth. + if depth < Self::MIN_DEPTH { + return Err(MerkleError::DepthTooSmall(depth)); + } else if Self::MAX_DEPTH < depth { + return Err(MerkleError::DepthTooBig(depth as u64)); + } + Ok(()) + } +} diff --git a/src/merkle/partial_mt/tests.rs b/src/merkle/partial_mt/tests.rs new file mode 100644 index 0000000..8cf6c1f --- /dev/null +++ b/src/merkle/partial_mt/tests.rs @@ -0,0 +1,232 @@ +use super::{ + super::{int_to_node, MerkleTree, NodeIndex, RpoDigest}, + BTreeMap, InnerNodeInfo, MerkleError, PartialMerkleTree, Rpo256, Vec, Word, EMPTY_WORD, +}; + +// TEST DATA +// ================================================================================================ + +const NODE10: NodeIndex = NodeIndex::new_unchecked(1, 0); +const NODE11: NodeIndex = NodeIndex::new_unchecked(1, 1); + +const NODE20: NodeIndex = NodeIndex::new_unchecked(2, 0); +const NODE21: NodeIndex = NodeIndex::new_unchecked(2, 1); +const NODE22: NodeIndex = NodeIndex::new_unchecked(2, 2); +const NODE23: NodeIndex = NodeIndex::new_unchecked(2, 3); + +const NODE30: NodeIndex = NodeIndex::new_unchecked(3, 0); +const NODE31: NodeIndex = NodeIndex::new_unchecked(3, 1); +const NODE32: NodeIndex = NodeIndex::new_unchecked(3, 2); +const NODE34: NodeIndex = NodeIndex::new_unchecked(3, 4); +const NODE35: NodeIndex = NodeIndex::new_unchecked(3, 5); +const NODE36: NodeIndex = NodeIndex::new_unchecked(3, 6); +const NODE37: NodeIndex = NodeIndex::new_unchecked(3, 7); + +const KEYS4: [NodeIndex; 4] = [NODE20, NODE21, NODE22, NODE23]; + +const WVALUES4: [Word; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)]; +const DVALUES4: [RpoDigest; 4] = [ + RpoDigest::new(int_to_node(1)), + RpoDigest::new(int_to_node(2)), + RpoDigest::new(int_to_node(3)), + RpoDigest::new(int_to_node(4)), +]; + +const ZERO_VALUES8: [Word; 8] = [int_to_node(0); 8]; + +// TESTS +// ================================================================================================ + +#[test] +fn build_partial_tree() { + // insert single value + let mut pmt = PartialMerkleTree::new(); + + let mut values = ZERO_VALUES8.to_vec(); + let key = NODE36; + let new_node = int_to_node(7); + values[key.value() as usize] = new_node; + + let hash0 = Rpo256::merge(&[int_to_node(0).into(), int_to_node(0).into()]); + let hash00 = Rpo256::merge(&[hash0, hash0]); + + pmt.update_leaf(NODE10, hash00).expect("Failed to update leaf"); + pmt.update_leaf(NODE22, hash0).expect("Failed to update leaf"); + let old_value = pmt.update_leaf(key, new_node.into()).expect("Failed to update leaf"); + + let mt2 = MerkleTree::new(values.clone()).unwrap(); + assert_eq!(mt2.root(), pmt.root()); + assert_eq!(mt2.get_path(NODE36).unwrap(), pmt.get_path(NODE36).unwrap()); + assert_eq!(*old_value, EMPTY_WORD); + + // insert second value at distinct leaf branch + let key = NODE32; + let new_node = int_to_node(3); + values[key.value() as usize] = new_node; + pmt.update_leaf(NODE20, hash0).expect("Failed to update leaf"); + let old_value = pmt.update_leaf(key, new_node.into()).expect("Failed to update leaf"); + let mt3 = MerkleTree::new(values).unwrap(); + assert_eq!(mt3.root(), pmt.root()); + assert_eq!(mt3.get_path(NODE32).unwrap(), pmt.get_path(NODE32).unwrap()); + assert_eq!(*old_value, EMPTY_WORD); +} + +#[test] +fn test_depth2_tree() { + let tree = PartialMerkleTree::with_leaves(KEYS4.into_iter().zip(DVALUES4.into_iter())).unwrap(); + + // check internal structure + let (root, node2, node3) = compute_internal_nodes(); + assert_eq!(root, tree.root()); + assert_eq!(node2, tree.get_node(NODE10).unwrap()); + assert_eq!(node3, tree.get_node(NODE11).unwrap()); + + // check get_node() + assert_eq!(WVALUES4[0], tree.get_node(NODE20).unwrap()); + assert_eq!(WVALUES4[1], tree.get_node(NODE21).unwrap()); + assert_eq!(WVALUES4[2], tree.get_node(NODE22).unwrap()); + assert_eq!(WVALUES4[3], tree.get_node(NODE23).unwrap()); + + // check get_path(): depth 2 + assert_eq!(vec![WVALUES4[1], node3], *tree.get_path(NODE20).unwrap()); + assert_eq!(vec![WVALUES4[0], node3], *tree.get_path(NODE21).unwrap()); + assert_eq!(vec![WVALUES4[3], node2], *tree.get_path(NODE22).unwrap()); + assert_eq!(vec![WVALUES4[2], node2], *tree.get_path(NODE23).unwrap()); + + // check get_path(): depth 1 + assert_eq!(vec![node3], *tree.get_path(NODE10).unwrap()); + assert_eq!(vec![node2], *tree.get_path(NODE11).unwrap()); +} + +#[test] +fn test_inner_node_iterator() -> Result<(), MerkleError> { + let tree = PartialMerkleTree::with_leaves(KEYS4.into_iter().zip(DVALUES4.into_iter())).unwrap(); + + // check depth 2 + assert_eq!(WVALUES4[0], tree.get_node(NODE20).unwrap()); + assert_eq!(WVALUES4[1], tree.get_node(NODE21).unwrap()); + assert_eq!(WVALUES4[2], tree.get_node(NODE22).unwrap()); + assert_eq!(WVALUES4[3], tree.get_node(NODE23).unwrap()); + + // get parent nodes + let root = tree.root(); + let l1n0 = tree.get_node(NODE10)?; + let l1n1 = tree.get_node(NODE11)?; + let l2n0 = tree.get_node(NODE20)?; + let l2n1 = tree.get_node(NODE21)?; + let l2n2 = tree.get_node(NODE22)?; + let l2n3 = tree.get_node(NODE23)?; + + let nodes: Vec = tree.inner_nodes().collect(); + let expected = vec![ + InnerNodeInfo { + value: root, + left: l1n0, + right: l1n1, + }, + InnerNodeInfo { + value: l1n0, + left: l2n0, + right: l2n1, + }, + InnerNodeInfo { + value: l1n1, + left: l2n2, + right: l2n3, + }, + ]; + assert_eq!(nodes, expected); + + Ok(()) +} + +#[test] +fn small_tree_opening_is_consistent() { + // ____k____ + // / \ + // _i_ _j_ + // / \ / \ + // e f g h + // / \ / \ / \ / \ + // a b 0 0 c 0 0 d + + let z = Word::from(RpoDigest::default()); + + let a = Word::from(Rpo256::merge(&[z.into(); 2])); + let b = Word::from(Rpo256::merge(&[a.into(); 2])); + let c = Word::from(Rpo256::merge(&[b.into(); 2])); + let d = Word::from(Rpo256::merge(&[c.into(); 2])); + + let e = Word::from(Rpo256::merge(&[a.into(), b.into()])); + let f = Word::from(Rpo256::merge(&[z.into(), z.into()])); + let g = Word::from(Rpo256::merge(&[c.into(), z.into()])); + let h = Word::from(Rpo256::merge(&[z.into(), d.into()])); + + let i = Word::from(Rpo256::merge(&[e.into(), f.into()])); + let j = Word::from(Rpo256::merge(&[g.into(), h.into()])); + + let k = Word::from(Rpo256::merge(&[i.into(), j.into()])); + + // let depth = 3; + // let entries = vec![(0, a), (1, b), (4, c), (7, d)]; + // let tree = SimpleSmt::with_leaves(depth, entries).unwrap(); + let entries = BTreeMap::from([ + (NODE30, a.into()), + (NODE31, b.into()), + (NODE34, c.into()), + (NODE37, d.into()), + (NODE21, f.into()), + ]); + + let tree = PartialMerkleTree::with_leaves(entries).unwrap(); + + assert_eq!(tree.root(), k); + + let cases: Vec<(NodeIndex, Vec)> = vec![ + (NODE30, vec![b, f, j]), + (NODE31, vec![a, f, j]), + (NODE34, vec![z, h, i]), + (NODE37, vec![z, g, i]), + (NODE20, vec![f, j]), + (NODE21, vec![e, j]), + (NODE22, vec![h, i]), + (NODE23, vec![g, i]), + (NODE10, vec![j]), + (NODE11, vec![i]), + ]; + + for (index, path) in cases { + let opening = tree.get_path(index).unwrap(); + + assert_eq!(path, *opening); + } +} + +#[test] +fn fail_on_duplicates() { + let entries = [ + (NODE31, int_to_node(1).into()), + (NODE35, int_to_node(2).into()), + (NODE31, int_to_node(3).into()), + ]; + let smt = PartialMerkleTree::with_leaves(entries); + assert!(smt.is_err()); +} + +#[test] +fn with_no_duplicates_empty_node() { + let entries = [(NODE31, int_to_node(0).into()), (NODE35, int_to_node(2).into())]; + let smt = PartialMerkleTree::with_leaves(entries); + assert!(smt.is_ok()); +} + +// HELPER FUNCTIONS +// -------------------------------------------------------------------------------------------- + +fn compute_internal_nodes() -> (Word, Word, Word) { + let node2 = Rpo256::hash_elements(&[WVALUES4[0], WVALUES4[1]].concat()); + let node3 = Rpo256::hash_elements(&[WVALUES4[2], WVALUES4[3]].concat()); + let root = Rpo256::merge(&[node2, node3]); + + (root.into(), node2.into(), node3.into()) +}