diff --git a/README.md b/README.md index f50fbc1..36f7bd5 100644 --- a/README.md +++ b/README.md @@ -25,5 +25,21 @@ Both of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/ To compile with `no_std`, disable default features via `--no-default-features` flag. +## Testing + +You can use cargo defaults to test the library: + +```shell +cargo test +``` + +However, some of the functions are heavy and might take a while for the tests to complete. In order to test in release mode, we have to replicate the same test conditions of the development mode so all debug assertions can be verified. + +We do that by enabling some special [flags](https://doc.rust-lang.org/cargo/reference/profiles.html) for the compilation. + +```shell +RUSTFLAGS="-C debug-assertions -C overflow-checks -C debuginfo=2" cargo test --release +``` + ## License This project is [MIT licensed](./LICENSE). diff --git a/src/lib.rs b/src/lib.rs index 8cf4e3c..a68c2bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,3 +38,32 @@ pub const ZERO: Felt = Felt::ZERO; /// Field element representing ONE in the Miden base filed. pub const ONE: Felt = Felt::ONE; + +// TESTS +// ================================================================================================ + +#[test] +#[should_panic] +fn debug_assert_is_checked() { + // enforce the release checks to always have `RUSTFLAGS="-C debug-assertions". + // + // some upstream tests are performed with `debug_assert`, and we want to assert its correctness + // downstream. + // + // for reference, check + // https://github.com/0xPolygonMiden/miden-vm/issues/433 + debug_assert!(false); +} + +#[test] +#[should_panic] +#[allow(arithmetic_overflow)] +fn overflow_panics_for_test() { + // overflows might be disabled if tests are performed in release mode. these are critical, + // mandatory checks as overflows might be attack vectors. + // + // to enable overflow checks in release mode, ensure `RUSTFLAGS="-C overflow-checks"` + let a = 1_u64; + let b = 64; + assert_ne!(a << b, 0); +} diff --git a/src/merkle/index.rs b/src/merkle/index.rs new file mode 100644 index 0000000..270bbbe --- /dev/null +++ b/src/merkle/index.rs @@ -0,0 +1,114 @@ +use super::RpoDigest; + +// NODE INDEX +// ================================================================================================ + +/// A Merkle tree address to an arbitrary node. +#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub struct NodeIndex { + depth: u8, + value: u64, +} + +impl NodeIndex { + // CONSTRUCTORS + // -------------------------------------------------------------------------------------------- + + /// Creates a new node index. + pub const fn new(depth: u8, value: u64) -> Self { + Self { depth, value } + } + + /// Creates a new node index pointing to the root of the tree. + pub const fn root() -> Self { + Self { depth: 0, value: 0 } + } + + /// Mutates the instance and returns it, replacing the depth. + pub const fn with_depth(mut self, depth: u8) -> Self { + self.depth = depth; + self + } + + /// Computes the value of the sibling of the current node. + pub fn sibling(mut self) -> Self { + self.value ^= 1; + self + } + + // PROVIDERS + // -------------------------------------------------------------------------------------------- + + /// Builds a node to be used as input of a hash function when computing a Merkle path. + /// + /// Will evaluate the parity of the current instance to define the result. + pub const fn build_node(&self, slf: RpoDigest, sibling: RpoDigest) -> [RpoDigest; 2] { + if self.is_value_odd() { + [sibling, slf] + } else { + [slf, sibling] + } + } + + /// Returns the scalar representation of the depth/value pair. + /// + /// It is computed as `2^depth + value`. + pub const fn to_scalar_index(&self) -> u64 { + (1 << self.depth as u64) + self.value + } + + /// Returns the depth of the current instance. + pub const fn depth(&self) -> u8 { + self.depth + } + + /// Returns the value of the current depth. + pub const fn value(&self) -> u64 { + self.value + } + + /// Returns true if the current value fits the current depth for a binary tree. + pub const fn is_valid(&self) -> bool { + self.value < (1 << self.depth as u64) + } + + /// Returns true if the current instance points to a right sibling node. + pub const fn is_value_odd(&self) -> bool { + (self.value & 1) == 1 + } + + /// Returns `true` if the depth is `0`. + pub const fn is_root(&self) -> bool { + self.depth == 0 + } + + // STATE MUTATORS + // -------------------------------------------------------------------------------------------- + + /// Traverse one level towards the root, decrementing the depth by `1`. + pub fn move_up(&mut self) -> &mut Self { + self.depth = self.depth.saturating_sub(1); + self.value >>= 1; + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + + proptest! { + #[test] + fn arbitrary_index_wont_panic_on_move_up( + depth in prop::num::u8::ANY, + value in prop::num::u64::ANY, + count in prop::num::u8::ANY, + ) { + let mut index = NodeIndex::new(depth, value); + for _ in 0..count { + index.move_up(); + } + } + } +} diff --git a/src/merkle/merkle_tree.rs b/src/merkle/merkle_tree.rs index 7adbccb..01c5072 100644 --- a/src/merkle/merkle_tree.rs +++ b/src/merkle/merkle_tree.rs @@ -1,4 +1,4 @@ -use super::{Felt, MerkleError, MerklePath, Rpo256, RpoDigest, Vec, Word}; +use super::{Felt, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word}; use crate::{utils::uninit_vector, FieldElement}; use core::slice; use winter_math::log2; @@ -22,7 +22,7 @@ impl MerkleTree { pub fn new(leaves: Vec) -> Result { let n = leaves.len(); if n <= 1 { - return Err(MerkleError::DepthTooSmall(n as u32)); + return Err(MerkleError::DepthTooSmall(n as u8)); } else if !n.is_power_of_two() { return Err(MerkleError::NumLeavesNotPowerOfTwo(n)); } @@ -35,12 +35,14 @@ impl MerkleTree { nodes[n..].copy_from_slice(&leaves); // re-interpret nodes as an array of two nodes fused together - let two_nodes = - unsafe { slice::from_raw_parts(nodes.as_ptr() as *const [RpoDigest; 2], n) }; + // Safety: `nodes` will never move here as it is not bound to an external lifetime (i.e. + // `self`). + let ptr = nodes.as_ptr() as *const [RpoDigest; 2]; + let pairs = unsafe { slice::from_raw_parts(ptr, n) }; // calculate all internal tree nodes for i in (1..n).rev() { - nodes[i] = Rpo256::merge(&two_nodes[i]).into(); + nodes[i] = Rpo256::merge(&pairs[i]).into(); } Ok(Self { nodes }) @@ -57,53 +59,53 @@ impl MerkleTree { /// Returns the depth of this Merkle tree. /// /// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc. - pub fn depth(&self) -> u32 { - log2(self.nodes.len() / 2) + pub fn depth(&self) -> u8 { + log2(self.nodes.len() / 2) as u8 } - /// Returns a node at the specified depth and index. + /// Returns a node at the specified depth and index value. /// /// # Errors /// Returns an error if: /// * The specified depth is greater than the depth of the tree. /// * The specified index not valid for the specified depth. - pub fn get_node(&self, depth: u32, index: u64) -> Result { - if depth == 0 { - return Err(MerkleError::DepthTooSmall(depth)); - } else if depth > self.depth() { - return Err(MerkleError::DepthTooBig(depth)); - } - if index >= 2u64.pow(depth) { - return Err(MerkleError::InvalidIndex(depth, index)); + pub fn get_node(&self, index: NodeIndex) -> Result { + if index.is_root() { + return Err(MerkleError::DepthTooSmall(index.depth())); + } else if index.depth() > self.depth() { + return Err(MerkleError::DepthTooBig(index.depth())); + } else if !index.is_valid() { + return Err(MerkleError::InvalidIndex(index)); } - let pos = 2_usize.pow(depth) + (index as usize); + let pos = index.to_scalar_index() as usize; Ok(self.nodes[pos]) } - /// Returns a Merkle path to the node at the specified depth and index. The node itself is - /// not included in the path. + /// Returns a Merkle path to the node at the specified depth and index value. The node itself + /// is not included in the path. /// /// # Errors /// Returns an error if: /// * The specified depth is greater than the depth of the tree. - /// * The specified index not valid for the specified depth. - pub fn get_path(&self, depth: u32, index: u64) -> Result { - if depth == 0 { - return Err(MerkleError::DepthTooSmall(depth)); - } else if depth > self.depth() { - return Err(MerkleError::DepthTooBig(depth)); + /// * The specified value not valid for the specified depth. + pub fn get_path(&self, mut index: NodeIndex) -> Result { + if index.is_root() { + return Err(MerkleError::DepthTooSmall(index.depth())); + } else if index.depth() > self.depth() { + return Err(MerkleError::DepthTooBig(index.depth())); + } else if !index.is_valid() { + return Err(MerkleError::InvalidIndex(index)); } - if index >= 2u64.pow(depth) { - return Err(MerkleError::InvalidIndex(depth, index)); - } - - let mut path = Vec::with_capacity(depth as usize); - let mut pos = 2_usize.pow(depth) + (index as usize); - while pos > 1 { - path.push(self.nodes[pos ^ 1]); - pos >>= 1; + // TODO should we create a helper in `NodeIndex` that will encapsulate traversal to root so + // we always use inlined `for` instead of `while`? the reason to use `for` is because its + // easier for the compiler to vectorize. + let mut path = Vec::with_capacity(index.depth() as usize); + for _ in 0..index.depth() { + let sibling = index.sibling().to_scalar_index() as usize; + path.push(self.nodes[sibling]); + index.move_up(); } Ok(path.into()) @@ -112,23 +114,38 @@ impl MerkleTree { /// Replaces the leaf at the specified index with the provided value. /// /// # Errors - /// Returns an error if the specified index is not a valid leaf index for this tree. - pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<(), MerkleError> { + /// Returns an error if the specified index value is not a valid leaf value for this tree. + pub fn update_leaf<'a>(&'a mut self, index_value: u64, value: Word) -> Result<(), MerkleError> { let depth = self.depth(); - if index >= 2u64.pow(depth) { - return Err(MerkleError::InvalidIndex(depth, index)); + let mut index = NodeIndex::new(depth, index_value); + if !index.is_valid() { + return Err(MerkleError::InvalidIndex(index)); } - let mut index = 2usize.pow(depth) + index as usize; - self.nodes[index] = value; - + // we don't need to copy the pairs into a new address as we are logically guaranteed to not + // overlap write instructions. however, it's important to bind the lifetime of pairs to + // `self.nodes` so the compiler will never move one without moving the other. + debug_assert_eq!(self.nodes.len() & 1, 0); let n = self.nodes.len() / 2; - let two_nodes = - unsafe { slice::from_raw_parts(self.nodes.as_ptr() as *const [RpoDigest; 2], n) }; - for _ in 0..depth { - index /= 2; - self.nodes[index] = Rpo256::merge(&two_nodes[index]).into(); + // Safety: the length of nodes is guaranteed to contain pairs of words; hence, pairs of + // digests. we explicitly bind the lifetime here so we add an extra layer of guarantee that + // `self.nodes` will be moved only if `pairs` is moved as well. also, the algorithm is + // logically guaranteed to not overlap write positions as the write index is always half + // the index from which we read the digest input. + let ptr = self.nodes.as_ptr() as *const [RpoDigest; 2]; + let pairs: &'a [[RpoDigest; 2]] = unsafe { slice::from_raw_parts(ptr, n) }; + + // update the current node + let pos = index.to_scalar_index() as usize; + self.nodes[pos] = value; + + // traverse to the root, updating each node with the merged values of its parents + for _ in 0..index.depth() { + index.move_up(); + let pos = index.to_scalar_index() as usize; + let value = Rpo256::merge(&pairs[pos]).into(); + self.nodes[pos] = value; } Ok(()) @@ -140,10 +157,10 @@ impl MerkleTree { #[cfg(test)] mod tests { - use super::{ - super::{int_to_node, Rpo256}, - Word, - }; + use super::*; + use crate::merkle::int_to_node; + use core::mem::size_of; + use proptest::prelude::*; const LEAVES4: [Word; 4] = [ int_to_node(1), @@ -187,16 +204,16 @@ mod tests { let tree = super::MerkleTree::new(LEAVES4.to_vec()).unwrap(); // check depth 2 - assert_eq!(LEAVES4[0], tree.get_node(2, 0).unwrap()); - assert_eq!(LEAVES4[1], tree.get_node(2, 1).unwrap()); - assert_eq!(LEAVES4[2], tree.get_node(2, 2).unwrap()); - assert_eq!(LEAVES4[3], tree.get_node(2, 3).unwrap()); + assert_eq!(LEAVES4[0], tree.get_node(NodeIndex::new(2, 0)).unwrap()); + assert_eq!(LEAVES4[1], tree.get_node(NodeIndex::new(2, 1)).unwrap()); + assert_eq!(LEAVES4[2], tree.get_node(NodeIndex::new(2, 2)).unwrap()); + assert_eq!(LEAVES4[3], tree.get_node(NodeIndex::new(2, 3)).unwrap()); // check depth 1 let (_, node2, node3) = compute_internal_nodes(); - assert_eq!(node2, tree.get_node(1, 0).unwrap()); - assert_eq!(node3, tree.get_node(1, 1).unwrap()); + assert_eq!(node2, tree.get_node(NodeIndex::new(1, 0)).unwrap()); + assert_eq!(node3, tree.get_node(NodeIndex::new(1, 1)).unwrap()); } #[test] @@ -206,14 +223,26 @@ mod tests { let (_, node2, node3) = compute_internal_nodes(); // check depth 2 - assert_eq!(vec![LEAVES4[1], node3], *tree.get_path(2, 0).unwrap()); - assert_eq!(vec![LEAVES4[0], node3], *tree.get_path(2, 1).unwrap()); - assert_eq!(vec![LEAVES4[3], node2], *tree.get_path(2, 2).unwrap()); - assert_eq!(vec![LEAVES4[2], node2], *tree.get_path(2, 3).unwrap()); + assert_eq!( + vec![LEAVES4[1], node3], + *tree.get_path(NodeIndex::new(2, 0)).unwrap() + ); + assert_eq!( + vec![LEAVES4[0], node3], + *tree.get_path(NodeIndex::new(2, 1)).unwrap() + ); + assert_eq!( + vec![LEAVES4[3], node2], + *tree.get_path(NodeIndex::new(2, 2)).unwrap() + ); + assert_eq!( + vec![LEAVES4[2], node2], + *tree.get_path(NodeIndex::new(2, 3)).unwrap() + ); // check depth 1 - assert_eq!(vec![node3], *tree.get_path(1, 0).unwrap()); - assert_eq!(vec![node2], *tree.get_path(1, 1).unwrap()); + assert_eq!(vec![node3], *tree.get_path(NodeIndex::new(1, 0)).unwrap()); + assert_eq!(vec![node2], *tree.get_path(NodeIndex::new(1, 1)).unwrap()); } #[test] @@ -221,25 +250,53 @@ mod tests { let mut tree = super::MerkleTree::new(LEAVES8.to_vec()).unwrap(); // update one leaf - let index = 3; + let value = 3; let new_node = int_to_node(9); let mut expected_leaves = LEAVES8.to_vec(); - expected_leaves[index as usize] = new_node; + expected_leaves[value as usize] = new_node; let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap(); - tree.update_leaf(index, new_node).unwrap(); + tree.update_leaf(value, new_node).unwrap(); assert_eq!(expected_tree.nodes, tree.nodes); // update another leaf - let index = 6; + let value = 6; let new_node = int_to_node(10); - expected_leaves[index as usize] = new_node; + expected_leaves[value as usize] = new_node; let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap(); - tree.update_leaf(index, new_node).unwrap(); + tree.update_leaf(value, new_node).unwrap(); assert_eq!(expected_tree.nodes, tree.nodes); } + proptest! { + #[test] + fn arbitrary_word_can_be_represented_as_digest( + a in prop::num::u64::ANY, + b in prop::num::u64::ANY, + c in prop::num::u64::ANY, + d in prop::num::u64::ANY, + ) { + // this test will assert the memory equivalence between word and digest. + // it is used to safeguard the `[MerkleTee::update_leaf]` implementation + // that assumes this equivalence. + + // build a word and copy it to another address as digest + let word = [Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]; + let digest = RpoDigest::from(word); + + // assert the addresses are different + let word_ptr = (&word).as_ptr() as *const u8; + let digest_ptr = (&digest).as_ptr() as *const u8; + assert_ne!(word_ptr, digest_ptr); + + // compare the bytes representation + let word_bytes = unsafe { slice::from_raw_parts(word_ptr, size_of::()) }; + let digest_bytes = unsafe { slice::from_raw_parts(digest_ptr, size_of::()) }; + assert_eq!(word_bytes, digest_bytes); + } + } + // HELPER FUNCTIONS // -------------------------------------------------------------------------------------------- diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 014d4c7..04550f7 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -5,6 +5,9 @@ use super::{ }; use core::fmt; +mod index; +pub use index::NodeIndex; + mod merkle_tree; pub use merkle_tree::MerkleTree; @@ -22,11 +25,11 @@ pub use simple_smt::SimpleSmt; #[derive(Clone, Debug)] pub enum MerkleError { - DepthTooSmall(u32), - DepthTooBig(u32), + DepthTooSmall(u8), + DepthTooBig(u8), NumLeavesNotPowerOfTwo(usize), - InvalidIndex(u32, u64), - InvalidDepth(u32, u32), + InvalidIndex(NodeIndex), + InvalidDepth { expected: u8, provided: u8 }, InvalidPath(MerklePath), InvalidEntriesCount(usize, usize), NodeNotInSet(u64), @@ -41,11 +44,11 @@ impl fmt::Display for MerkleError { NumLeavesNotPowerOfTwo(leaves) => { write!(f, "the leaves count {leaves} is not a power of 2") } - InvalidIndex(depth, index) => write!( + InvalidIndex(index) => write!( f, - "the leaf index {index} is not valid for the depth {depth}" + "the index value {} is not valid for the depth {}", index.value(), index.depth() ), - InvalidDepth(expected, provided) => write!( + InvalidDepth { expected, provided } => write!( f, "the provided depth {provided} is not valid for {expected}" ), diff --git a/src/merkle/path.rs b/src/merkle/path.rs index 2b81bbe..d7edd5d 100644 --- a/src/merkle/path.rs +++ b/src/merkle/path.rs @@ -1,4 +1,4 @@ -use super::{vec, Rpo256, Vec, Word}; +use super::{vec, NodeIndex, Rpo256, Vec, Word}; use core::ops::{Deref, DerefMut}; // MERKLE PATH @@ -23,17 +23,12 @@ impl MerklePath { // -------------------------------------------------------------------------------------------- /// Computes the merkle root for this opening. - pub fn compute_root(&self, mut index: u64, node: Word) -> Word { + pub fn compute_root(&self, index_value: u64, node: Word) -> Word { + let mut index = NodeIndex::new(self.depth(), index_value); self.nodes.iter().copied().fold(node, |node, sibling| { - // build the input node, considering the parity of the current index. - let is_right_sibling = (index & 1) == 1; - let input = if is_right_sibling { - [sibling.into(), node.into()] - } else { - [node.into(), sibling.into()] - }; // compute the node and move to the next iteration. - index >>= 1; + let input = index.build_node(node.into(), sibling.into()); + index.move_up(); Rpo256::merge(&input).into() }) } diff --git a/src/merkle/path_set.rs b/src/merkle/path_set.rs index 216dff6..935929b 100644 --- a/src/merkle/path_set.rs +++ b/src/merkle/path_set.rs @@ -1,4 +1,4 @@ -use super::{BTreeMap, MerkleError, MerklePath, Rpo256, Vec, Word, ZERO}; +use super::{BTreeMap, MerkleError, MerklePath, NodeIndex, Rpo256, Vec, Word, ZERO}; // MERKLE PATH SET // ================================================================================================ @@ -7,7 +7,7 @@ use super::{BTreeMap, MerkleError, MerklePath, Rpo256, Vec, Word, ZERO}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct MerklePathSet { root: Word, - total_depth: u32, + total_depth: u8, paths: BTreeMap, } @@ -16,7 +16,7 @@ impl MerklePathSet { // -------------------------------------------------------------------------------------------- /// Returns an empty MerklePathSet. - pub fn new(depth: u32) -> Result { + pub fn new(depth: u8) -> Result { let root = [ZERO; 4]; let paths = BTreeMap::new(); @@ -38,7 +38,7 @@ impl MerklePathSet { /// Returns the depth of the Merkle tree implied by the paths stored in this set. /// /// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc. - pub const fn depth(&self) -> u32 { + pub const fn depth(&self) -> u8 { self.total_depth } @@ -48,27 +48,26 @@ impl MerklePathSet { /// Returns an error if: /// * The specified index not valid for the depth of structure. /// * Requested node does not exist in the set. - pub fn get_node(&self, depth: u32, index: u64) -> Result { - if index >= 2u64.pow(self.total_depth) { - return Err(MerkleError::InvalidIndex(self.total_depth, index)); + pub fn get_node(&self, index: NodeIndex) -> Result { + if !index.with_depth(self.total_depth).is_valid() { + return Err(MerkleError::InvalidIndex( + index.with_depth(self.total_depth), + )); } - if depth != self.total_depth { - return Err(MerkleError::InvalidDepth(self.total_depth, depth)); + if index.depth() != self.total_depth { + return Err(MerkleError::InvalidDepth { + expected: self.total_depth, + provided: index.depth(), + }); } - let pos = 2u64.pow(depth) + index; - let index = pos / 2; - - match self.paths.get(&index) { - None => Err(MerkleError::NodeNotInSet(index)), - Some(path) => { - if Self::is_even(pos) { - Ok(path[0]) - } else { - Ok(path[1]) - } - } - } + let index_value = index.to_scalar_index(); + let parity = index_value & 1; + let index_value = index_value / 2; + self.paths + .get(&index_value) + .ok_or(MerkleError::NodeNotInSet(index_value)) + .map(|path| path[parity as usize]) } /// Returns a Merkle path to the node at the specified index. The node itself is @@ -78,30 +77,27 @@ impl MerklePathSet { /// Returns an error if: /// * The specified index not valid for the depth of structure. /// * Node of the requested path does not exist in the set. - pub fn get_path(&self, depth: u32, index: u64) -> Result { - if index >= 2u64.pow(self.total_depth) { - return Err(MerkleError::InvalidIndex(self.total_depth, index)); + pub fn get_path(&self, index: NodeIndex) -> Result { + if !index.with_depth(self.total_depth).is_valid() { + return Err(MerkleError::InvalidIndex(index)); } - if depth != self.total_depth { - return Err(MerkleError::InvalidDepth(self.total_depth, depth)); + if index.depth() != self.total_depth { + return Err(MerkleError::InvalidDepth { + expected: self.total_depth, + provided: index.depth(), + }); } - let pos = 2u64.pow(depth) + index; - let index = pos / 2; - - match self.paths.get(&index) { - None => Err(MerkleError::NodeNotInSet(index)), - Some(path) => { - let mut local_path = path.clone(); - if Self::is_even(pos) { - local_path.remove(0); - Ok(local_path) - } else { - local_path.remove(1); - Ok(local_path) - } - } - } + let index_value = index.to_scalar_index(); + let index = index_value / 2; + let parity = index_value & 1; + let mut path = self + .paths + .get(&index) + .cloned() + .ok_or(MerkleError::NodeNotInSet(index))?; + path.remove(parity as usize); + Ok(path) } // STATE MUTATORS @@ -118,36 +114,41 @@ impl MerklePathSet { /// different root). pub fn add_path( &mut self, - index: u64, + index_value: u64, value: Word, - path: MerklePath, + mut path: MerklePath, ) -> Result<(), MerkleError> { - let depth = (path.len() + 1) as u32; - if depth != self.total_depth { - return Err(MerkleError::InvalidDepth(self.total_depth, depth)); + let depth = (path.len() + 1) as u8; + let mut index = NodeIndex::new(depth, index_value); + if index.depth() != self.total_depth { + return Err(MerkleError::InvalidDepth { + expected: self.total_depth, + provided: index.depth(), + }); } - // Actual number of node in tree - let pos = 2u64.pow(self.total_depth) + index; - - // Index of the leaf path in map. Paths of neighboring leaves are stored in one key-value pair - let half_pos = pos / 2; + // update the current path + let index_value = index.to_scalar_index(); + let upper_index_value = index_value / 2; + let parity = index_value & 1; + path.insert(parity as usize, value); - let mut extended_path = path; - if Self::is_even(pos) { - extended_path.insert(0, value); - } else { - extended_path.insert(1, value); - } + // traverse to the root, updating the nodes + let root: Word = Rpo256::merge(&[path[0].into(), path[1].into()]).into(); + let root = path.iter().skip(2).copied().fold(root, |root, hash| { + index.move_up(); + Rpo256::merge(&index.build_node(root.into(), hash.into())).into() + }); - let root_of_current_path = Self::compute_path_root(&extended_path, depth, index); + // TODO review and document this logic if self.root == [ZERO; 4] { - self.root = root_of_current_path; - } else if self.root != root_of_current_path { - return Err(MerkleError::InvalidPath(extended_path)); + self.root = root; + } else if self.root != root { + return Err(MerkleError::InvalidPath(path)); } - self.paths.insert(half_pos, extended_path); + // finish updating the path + self.paths.insert(upper_index_value, path); Ok(()) } @@ -156,29 +157,44 @@ impl MerklePathSet { /// # Errors /// Returns an error if: /// * Requested node does not exist in the set. - pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<(), MerkleError> { + pub fn update_leaf(&mut self, base_index_value: u64, value: Word) -> Result<(), MerkleError> { let depth = self.depth(); - if index >= 2u64.pow(depth) { - return Err(MerkleError::InvalidIndex(depth, index)); + let mut index = NodeIndex::new(depth, base_index_value); + if !index.is_valid() { + return Err(MerkleError::InvalidIndex(index)); } - let pos = 2u64.pow(depth) + index; - let path = match self.paths.get_mut(&(pos / 2)) { - None => return Err(MerkleError::NodeNotInSet(index)), + let path = match self + .paths + .get_mut(&index.clone().move_up().to_scalar_index()) + { Some(path) => path, + None => return Err(MerkleError::NodeNotInSet(base_index_value)), }; // Fill old_hashes vector ----------------------------------------------------------------- - let (old_hashes, _) = Self::compute_path_trace(path, depth, index); + let mut current_index = index; + let mut old_hashes = Vec::with_capacity(path.len().saturating_sub(2)); + let mut root: Word = Rpo256::merge(&[path[0].into(), path[1].into()]).into(); + for hash in path.iter().skip(2).copied() { + old_hashes.push(root); + current_index.move_up(); + let input = current_index.build_node(hash.into(), root.into()); + root = Rpo256::merge(&input).into(); + } // Fill new_hashes vector ----------------------------------------------------------------- - if Self::is_even(pos) { - path[0] = value; - } else { - path[1] = value; + path[index.is_value_odd() as usize] = value; + + let mut new_hashes = Vec::with_capacity(path.len().saturating_sub(2)); + let mut new_root: Word = Rpo256::merge(&[path[0].into(), path[1].into()]).into(); + for path_hash in path.iter().skip(2).copied() { + new_hashes.push(new_root); + index.move_up(); + let input = current_index.build_node(path_hash.into(), new_root.into()); + new_root = Rpo256::merge(&input).into(); } - let (new_hashes, new_root) = Self::compute_path_trace(path, depth, index); self.root = new_root; // update paths --------------------------------------------------------------------------- @@ -193,59 +209,6 @@ impl MerklePathSet { Ok(()) } - - // HELPER FUNCTIONS - // -------------------------------------------------------------------------------------------- - - const fn is_even(pos: u64) -> bool { - pos & 1 == 0 - } - - /// Returns hash of the root - fn compute_path_root(path: &[Word], depth: u32, index: u64) -> Word { - let mut pos = 2u64.pow(depth) + index; - - // hash that is obtained after calculating the current hash and path hash - let mut comp_hash = Rpo256::merge(&[path[0].into(), path[1].into()]).into(); - - for path_hash in path.iter().skip(2) { - pos /= 2; - comp_hash = Self::calculate_parent_hash(comp_hash, pos, *path_hash); - } - - comp_hash - } - - /// Calculates the hash of the parent node by two sibling ones - /// - node — current node - /// - node_pos — position of the current node - /// - sibling — neighboring vertex in the tree - fn calculate_parent_hash(node: Word, node_pos: u64, sibling: Word) -> Word { - if Self::is_even(node_pos) { - Rpo256::merge(&[node.into(), sibling.into()]).into() - } else { - Rpo256::merge(&[sibling.into(), node.into()]).into() - } - } - - /// Returns vector of hashes from current to the root - fn compute_path_trace(path: &[Word], depth: u32, index: u64) -> (MerklePath, Word) { - let mut pos = 2u64.pow(depth) + index; - - let mut computed_hashes = Vec::::new(); - - let mut comp_hash = Rpo256::merge(&[path[0].into(), path[1].into()]).into(); - - if path.len() != 2 { - for path_hash in path.iter().skip(2) { - computed_hashes.push(comp_hash); - pos /= 2; - comp_hash = Self::calculate_parent_hash(comp_hash, pos, *path_hash); - } - } - - (computed_hashes.into(), comp_hash) - } } // TESTS @@ -263,10 +226,10 @@ mod tests { let leaf2 = int_to_node(2); let leaf3 = int_to_node(3); - let parent0 = MerklePathSet::calculate_parent_hash(leaf0, 0, leaf1); - let parent1 = MerklePathSet::calculate_parent_hash(leaf2, 2, leaf3); + let parent0 = calculate_parent_hash(leaf0, 0, leaf1); + let parent1 = calculate_parent_hash(leaf2, 2, leaf3); - let root_exp = MerklePathSet::calculate_parent_hash(parent0, 0, parent1); + let root_exp = calculate_parent_hash(parent0, 0, parent1); let mut set = super::MerklePathSet::new(3).unwrap(); @@ -279,29 +242,32 @@ mod tests { fn add_and_get_path() { let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)]; let hash_6 = int_to_node(6); - let index = 6u64; - let depth = 4u32; + let index = 6_u64; + let depth = 4_u8; let mut set = super::MerklePathSet::new(depth).unwrap(); set.add_path(index, hash_6, path_6.clone().into()).unwrap(); - let stored_path_6 = set.get_path(depth, index).unwrap(); + let stored_path_6 = set.get_path(NodeIndex::new(depth, index)).unwrap(); assert_eq!(path_6, *stored_path_6); - assert!(set.get_path(depth, 15u64).is_err()) + assert!(set.get_path(NodeIndex::new(depth, 15_u64)).is_err()) } #[test] fn get_node() { let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)]; let hash_6 = int_to_node(6); - let index = 6u64; - let depth = 4u32; - let mut set = super::MerklePathSet::new(depth).unwrap(); + let index = 6_u64; + let depth = 4_u8; + let mut set = MerklePathSet::new(depth).unwrap(); set.add_path(index, hash_6, path_6.into()).unwrap(); - assert_eq!(int_to_node(6u64), set.get_node(depth, index).unwrap()); - assert!(set.get_node(depth, 15u64).is_err()); + assert_eq!( + int_to_node(6u64), + set.get_node(NodeIndex::new(depth, index)).unwrap() + ); + assert!(set.get_node(NodeIndex::new(depth, 15_u64)).is_err()); } #[test] @@ -310,8 +276,8 @@ mod tests { let hash_5 = int_to_node(5); let hash_6 = int_to_node(6); let hash_7 = int_to_node(7); - let hash_45 = MerklePathSet::calculate_parent_hash(hash_4, 12u64, hash_5); - let hash_67 = MerklePathSet::calculate_parent_hash(hash_6, 14u64, hash_7); + let hash_45 = calculate_parent_hash(hash_4, 12u64, hash_5); + let hash_67 = calculate_parent_hash(hash_6, 14u64, hash_7); let hash_0123 = int_to_node(123); @@ -319,11 +285,11 @@ mod tests { let path_5 = vec![hash_4, hash_67, hash_0123]; let path_4 = vec![hash_5, hash_67, hash_0123]; - let index_6 = 6u64; - let index_5 = 5u64; - let index_4 = 4u64; - let depth = 4u32; - let mut set = super::MerklePathSet::new(depth).unwrap(); + let index_6 = 6_u64; + let index_5 = 5_u64; + let index_4 = 4_u64; + let depth = 4_u8; + let mut set = MerklePathSet::new(depth).unwrap(); set.add_path(index_6, hash_6, path_6.into()).unwrap(); set.add_path(index_5, hash_5, path_5.into()).unwrap(); @@ -333,15 +299,34 @@ mod tests { let new_hash_5 = int_to_node(55); set.update_leaf(index_6, new_hash_6).unwrap(); - let new_path_4 = set.get_path(depth, index_4).unwrap(); - let new_hash_67 = MerklePathSet::calculate_parent_hash(new_hash_6, 14u64, hash_7); + let new_path_4 = set.get_path(NodeIndex::new(depth, index_4)).unwrap(); + let new_hash_67 = calculate_parent_hash(new_hash_6, 14_u64, hash_7); assert_eq!(new_hash_67, new_path_4[1]); set.update_leaf(index_5, new_hash_5).unwrap(); - let new_path_4 = set.get_path(depth, index_4).unwrap(); - let new_path_6 = set.get_path(depth, index_6).unwrap(); - let new_hash_45 = MerklePathSet::calculate_parent_hash(new_hash_5, 13u64, hash_4); + let new_path_4 = set.get_path(NodeIndex::new(depth, index_4)).unwrap(); + let new_path_6 = set.get_path(NodeIndex::new(depth, index_6)).unwrap(); + let new_hash_45 = calculate_parent_hash(new_hash_5, 13_u64, hash_4); assert_eq!(new_hash_45, new_path_6[1]); assert_eq!(new_hash_5, new_path_4[0]); } + + // HELPER FUNCTIONS + // -------------------------------------------------------------------------------------------- + + const fn is_even(pos: u64) -> bool { + pos & 1 == 0 + } + + /// Calculates the hash of the parent node by two sibling ones + /// - node — current node + /// - node_pos — position of the current node + /// - sibling — neighboring vertex in the tree + fn calculate_parent_hash(node: Word, node_pos: u64, sibling: Word) -> Word { + if is_even(node_pos) { + Rpo256::merge(&[node.into(), sibling.into()]).into() + } else { + Rpo256::merge(&[sibling.into(), node.into()]).into() + } + } } diff --git a/src/merkle/simple_smt/mod.rs b/src/merkle/simple_smt/mod.rs index 07453d7..28623c8 100644 --- a/src/merkle/simple_smt/mod.rs +++ b/src/merkle/simple_smt/mod.rs @@ -1,4 +1,4 @@ -use super::{BTreeMap, MerkleError, MerklePath, Rpo256, RpoDigest, Vec, Word}; +use super::{BTreeMap, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word}; #[cfg(test)] mod tests; @@ -12,7 +12,7 @@ mod tests; #[derive(Debug, Clone, PartialEq, Eq)] pub struct SimpleSmt { root: Word, - depth: u32, + depth: u8, store: Store, } @@ -21,10 +21,10 @@ impl SimpleSmt { // -------------------------------------------------------------------------------------------- /// Minimum supported depth. - pub const MIN_DEPTH: u32 = 1; + pub const MIN_DEPTH: u8 = 1; /// Maximum supported depth. - pub const MAX_DEPTH: u32 = 63; + pub const MAX_DEPTH: u8 = 63; // CONSTRUCTORS // -------------------------------------------------------------------------------------------- @@ -37,7 +37,7 @@ impl SimpleSmt { /// /// The function will fail if the provided entries count exceed the maximum tree capacity, that /// is `2^{depth}`. - pub fn new(entries: R, depth: u32) -> Result + pub fn new(entries: R, depth: u8) -> Result where R: IntoIterator, I: Iterator + ExactSizeIterator, @@ -67,7 +67,7 @@ impl SimpleSmt { } /// Returns the depth of this Merkle tree. - pub const fn depth(&self) -> u32 { + pub const fn depth(&self) -> u8 { self.depth } @@ -82,15 +82,15 @@ impl SimpleSmt { /// Returns an error if: /// * The specified depth is greater than the depth of the tree. /// * The specified key does not exist - pub fn get_node(&self, depth: u32, key: u64) -> Result { - if depth == 0 { - Err(MerkleError::DepthTooSmall(depth)) - } else if depth > self.depth() { - Err(MerkleError::DepthTooBig(depth)) - } else if depth == self.depth() { - self.store.get_leaf_node(key) + pub fn get_node(&self, index: &NodeIndex) -> Result { + if index.is_root() { + Err(MerkleError::DepthTooSmall(index.depth())) + } else if index.depth() > self.depth() { + Err(MerkleError::DepthTooBig(index.depth())) + } else if index.depth() == self.depth() { + self.store.get_leaf_node(index.value()) } else { - let branch_node = self.store.get_branch_node(key, depth)?; + let branch_node = self.store.get_branch_node(index)?; Ok(Rpo256::merge(&[branch_node.left, branch_node.right]).into()) } } @@ -102,27 +102,23 @@ impl SimpleSmt { /// Returns an error if: /// * The specified key does not exist as a branch or leaf node /// * The specified depth is greater than the depth of the tree. - pub fn get_path(&self, depth: u32, key: u64) -> Result { - if depth == 0 { - return Err(MerkleError::DepthTooSmall(depth)); - } else if depth > self.depth() { - return Err(MerkleError::DepthTooBig(depth)); - } else if depth == self.depth() && !self.store.check_leaf_node_exists(key) { - return Err(MerkleError::InvalidIndex(self.depth(), key)); + pub fn get_path(&self, mut index: NodeIndex) -> Result { + if index.is_root() { + return Err(MerkleError::DepthTooSmall(index.depth())); + } else if index.depth() > self.depth() { + return Err(MerkleError::DepthTooBig(index.depth())); + } else if index.depth() == self.depth() && !self.store.check_leaf_node_exists(index.value()) + { + return Err(MerkleError::InvalidIndex(index.with_depth(self.depth()))); } - let mut path = Vec::with_capacity(depth as usize); - let mut curr_key = key; - for n in (0..depth).rev() { - let parent_key = curr_key >> 1; - let parent_node = self.store.get_branch_node(parent_key, n)?; - let sibling_node = if curr_key & 1 == 1 { - parent_node.left - } else { - parent_node.right - }; - path.push(sibling_node.into()); - curr_key >>= 1; + let mut path = Vec::with_capacity(index.depth() as usize); + for _ in 0..index.depth() { + let is_right = index.is_value_odd(); + index.move_up(); + let BranchNode { left, right } = self.store.get_branch_node(&index)?; + let value = if is_right { left } else { right }; + path.push(*value); } Ok(path.into()) } @@ -134,7 +130,7 @@ impl SimpleSmt { /// Returns an error if: /// * The specified key does not exist as a leaf node. pub fn get_leaf_path(&self, key: u64) -> Result { - self.get_path(self.depth(), key) + self.get_path(NodeIndex::new(self.depth(), key)) } /// Replaces the leaf located at the specified key, and recomputes hashes by walking up the tree @@ -143,7 +139,7 @@ impl SimpleSmt { /// Returns an error if the specified key is not a valid leaf index for this tree. pub fn update_leaf(&mut self, key: u64, value: Word) -> Result<(), MerkleError> { if !self.store.check_leaf_node_exists(key) { - return Err(MerkleError::InvalidIndex(self.depth(), key)); + return Err(MerkleError::InvalidIndex(NodeIndex::new(self.depth(), key))); } self.insert_leaf(key, value)?; @@ -154,27 +150,25 @@ impl SimpleSmt { pub fn insert_leaf(&mut self, key: u64, value: Word) -> Result<(), MerkleError> { self.store.insert_leaf_node(key, value); - let depth = self.depth(); - let mut curr_key = key; - let mut curr_node: RpoDigest = value.into(); - for n in (0..depth).rev() { - let parent_key = curr_key >> 1; - let parent_node = self + // TODO consider using a map `index |-> word` instead of `index |-> (word, word)` + let mut index = NodeIndex::new(self.depth(), key); + let mut value = RpoDigest::from(value); + for _ in 0..index.depth() { + let is_right = index.is_value_odd(); + index.move_up(); + let BranchNode { left, right } = self .store - .get_branch_node(parent_key, n) - .unwrap_or_else(|_| self.store.get_empty_node((n + 1) as usize)); - let (left, right) = if curr_key & 1 == 1 { - (parent_node.left, curr_node) + .get_branch_node(&index) + .unwrap_or_else(|_| self.store.get_empty_node(index.depth() as usize + 1)); + let (left, right) = if is_right { + (left, value) } else { - (curr_node, parent_node.right) + (value, right) }; - - self.store.insert_branch_node(parent_key, n, left, right); - curr_key = parent_key; - curr_node = Rpo256::merge(&[left, right]); + self.store.insert_branch_node(index, left, right); + value = Rpo256::merge(&[left, right]); } - self.root = curr_node.into(); - + self.root = value.into(); Ok(()) } } @@ -188,10 +182,10 @@ impl SimpleSmt { /// with the root hash of an empty tree, and ending with the zero value of a leaf node. #[derive(Debug, Clone, PartialEq, Eq)] struct Store { - branches: BTreeMap<(u64, u32), BranchNode>, + branches: BTreeMap, leaves: BTreeMap, empty_hashes: Vec, - depth: u32, + depth: u8, } #[derive(Debug, Default, Clone, PartialEq, Eq)] @@ -201,7 +195,7 @@ struct BranchNode { } impl Store { - fn new(depth: u32) -> (Self, Word) { + fn new(depth: u8) -> (Self, Word) { let branches = BTreeMap::new(); let leaves = BTreeMap::new(); @@ -244,23 +238,23 @@ impl Store { self.leaves .get(&key) .cloned() - .ok_or(MerkleError::InvalidIndex(self.depth, key)) + .ok_or(MerkleError::InvalidIndex(NodeIndex::new(self.depth, key))) } fn insert_leaf_node(&mut self, key: u64, node: Word) { self.leaves.insert(key, node); } - fn get_branch_node(&self, key: u64, depth: u32) -> Result { + fn get_branch_node(&self, index: &NodeIndex) -> Result { self.branches - .get(&(key, depth)) + .get(index) .cloned() - .ok_or(MerkleError::InvalidIndex(depth, key)) + .ok_or(MerkleError::InvalidIndex(*index)) } - fn insert_branch_node(&mut self, key: u64, depth: u32, left: RpoDigest, right: RpoDigest) { - let node = BranchNode { left, right }; - self.branches.insert((key, depth), node); + fn insert_branch_node(&mut self, index: NodeIndex, left: RpoDigest, right: RpoDigest) { + let branch = BranchNode { left, right }; + self.branches.insert(index, branch); } fn leaves_count(&self) -> usize { diff --git a/src/merkle/simple_smt/tests.rs b/src/merkle/simple_smt/tests.rs index 595d021..2096fd1 100644 --- a/src/merkle/simple_smt/tests.rs +++ b/src/merkle/simple_smt/tests.rs @@ -1,6 +1,6 @@ use super::{ super::{MerkleTree, RpoDigest, SimpleSmt}, - Rpo256, Vec, Word, + NodeIndex, Rpo256, Vec, Word, }; use crate::{Felt, FieldElement}; use core::iter; @@ -62,7 +62,10 @@ fn build_sparse_tree() { .expect("Failed to insert leaf"); let mt2 = MerkleTree::new(values.clone()).unwrap(); assert_eq!(mt2.root(), smt.root()); - assert_eq!(mt2.get_path(3, 6).unwrap(), smt.get_path(3, 6).unwrap()); + assert_eq!( + mt2.get_path(NodeIndex::new(3, 6)).unwrap(), + smt.get_path(NodeIndex::new(3, 6)).unwrap() + ); // insert second value at distinct leaf branch let key = 2; @@ -72,7 +75,10 @@ fn build_sparse_tree() { .expect("Failed to insert leaf"); let mt3 = MerkleTree::new(values).unwrap(); assert_eq!(mt3.root(), smt.root()); - assert_eq!(mt3.get_path(3, 2).unwrap(), smt.get_path(3, 2).unwrap()); + assert_eq!( + mt3.get_path(NodeIndex::new(3, 2)).unwrap(), + smt.get_path(NodeIndex::new(3, 2)).unwrap() + ); } #[test] @@ -81,8 +87,8 @@ fn build_full_tree() { let (root, node2, node3) = compute_internal_nodes(); assert_eq!(root, tree.root()); - assert_eq!(node2, tree.get_node(1, 0).unwrap()); - assert_eq!(node3, tree.get_node(1, 1).unwrap()); + assert_eq!(node2, tree.get_node(&NodeIndex::new(1, 0)).unwrap()); + assert_eq!(node3, tree.get_node(&NodeIndex::new(1, 1)).unwrap()); } #[test] @@ -90,10 +96,10 @@ fn get_values() { let tree = SimpleSmt::new(KEYS4.into_iter().zip(VALUES4.into_iter()), 2).unwrap(); // check depth 2 - assert_eq!(VALUES4[0], tree.get_node(2, 0).unwrap()); - assert_eq!(VALUES4[1], tree.get_node(2, 1).unwrap()); - assert_eq!(VALUES4[2], tree.get_node(2, 2).unwrap()); - assert_eq!(VALUES4[3], tree.get_node(2, 3).unwrap()); + assert_eq!(VALUES4[0], tree.get_node(&NodeIndex::new(2, 0)).unwrap()); + assert_eq!(VALUES4[1], tree.get_node(&NodeIndex::new(2, 1)).unwrap()); + assert_eq!(VALUES4[2], tree.get_node(&NodeIndex::new(2, 2)).unwrap()); + assert_eq!(VALUES4[3], tree.get_node(&NodeIndex::new(2, 3)).unwrap()); } #[test] @@ -103,14 +109,26 @@ fn get_path() { let (_, node2, node3) = compute_internal_nodes(); // check depth 2 - assert_eq!(vec![VALUES4[1], node3], *tree.get_path(2, 0).unwrap()); - assert_eq!(vec![VALUES4[0], node3], *tree.get_path(2, 1).unwrap()); - assert_eq!(vec![VALUES4[3], node2], *tree.get_path(2, 2).unwrap()); - assert_eq!(vec![VALUES4[2], node2], *tree.get_path(2, 3).unwrap()); + assert_eq!( + vec![VALUES4[1], node3], + *tree.get_path(NodeIndex::new(2, 0)).unwrap() + ); + assert_eq!( + vec![VALUES4[0], node3], + *tree.get_path(NodeIndex::new(2, 1)).unwrap() + ); + assert_eq!( + vec![VALUES4[3], node2], + *tree.get_path(NodeIndex::new(2, 2)).unwrap() + ); + assert_eq!( + vec![VALUES4[2], node2], + *tree.get_path(NodeIndex::new(2, 3)).unwrap() + ); // check depth 1 - assert_eq!(vec![node3], *tree.get_path(1, 0).unwrap()); - assert_eq!(vec![node2], *tree.get_path(1, 1).unwrap()); + assert_eq!(vec![node3], *tree.get_path(NodeIndex::new(1, 0)).unwrap()); + assert_eq!(vec![node2], *tree.get_path(NodeIndex::new(1, 1)).unwrap()); } #[test] @@ -175,7 +193,7 @@ fn small_tree_opening_is_consistent() { assert_eq!(tree.root(), Word::from(k)); - let cases: Vec<(u32, u64, Vec)> = vec![ + let cases: Vec<(u8, u64, Vec)> = vec![ (3, 0, vec![b, f, j]), (3, 1, vec![a, f, j]), (3, 4, vec![z, h, i]), @@ -189,7 +207,7 @@ fn small_tree_opening_is_consistent() { ]; for (depth, key, path) in cases { - let opening = tree.get_path(depth, key).unwrap(); + let opening = tree.get_path(NodeIndex::new(depth, key)).unwrap(); assert_eq!(path, *opening); } @@ -213,7 +231,7 @@ proptest! { // traverse to root, fetching all paths for d in 1..depth { let k = key >> (depth - d); - tree.get_path(d, k).unwrap(); + tree.get_path(NodeIndex::new(d, k)).unwrap(); } }