use super::{Felt, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word}; use crate::{utils::uninit_vector, FieldElement}; use core::slice; use winter_math::log2; // MERKLE TREE // ================================================================================================ /// A fully-balanced binary Merkle tree (i.e., a tree where the number of leaves is a power of two). #[derive(Debug, Clone, PartialEq, Eq)] pub struct MerkleTree { nodes: Vec, } impl MerkleTree { // CONSTRUCTOR // -------------------------------------------------------------------------------------------- /// Returns a Merkle tree instantiated from the provided leaves. /// /// # Errors /// Returns an error if the number of leaves is smaller than two or is not a power of two. pub fn new(leaves: Vec) -> Result { let n = leaves.len(); if n <= 1 { return Err(MerkleError::DepthTooSmall(n as u8)); } else if !n.is_power_of_two() { return Err(MerkleError::NumLeavesNotPowerOfTwo(n)); } // create un-initialized vector to hold all tree nodes let mut nodes = unsafe { uninit_vector(2 * n) }; nodes[0] = [Felt::ZERO; 4]; // copy leaves into the second part of the nodes vector nodes[n..].copy_from_slice(&leaves); // re-interpret nodes as an array of two nodes fused together // Safety: `nodes` will never move here as it is not bound to an external lifetime (i.e. // `self`). let ptr = nodes.as_ptr() as *const [RpoDigest; 2]; let pairs = unsafe { slice::from_raw_parts(ptr, n) }; // calculate all internal tree nodes for i in (1..n).rev() { nodes[i] = Rpo256::merge(&pairs[i]).into(); } Ok(Self { nodes }) } // PUBLIC ACCESSORS // -------------------------------------------------------------------------------------------- /// Returns the root of this Merkle tree. pub fn root(&self) -> Word { self.nodes[1] } /// Returns the depth of this Merkle tree. /// /// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc. pub fn depth(&self) -> u8 { log2(self.nodes.len() / 2) as u8 } /// Returns a node at the specified depth and index value. /// /// # Errors /// Returns an error if: /// * The specified depth is greater than the depth of the tree. /// * The specified index not valid for the specified depth. pub fn get_node(&self, index: NodeIndex) -> Result { if index.is_root() { return Err(MerkleError::DepthTooSmall(index.depth())); } else if index.depth() > self.depth() { return Err(MerkleError::DepthTooBig(index.depth())); } else if !index.is_valid() { return Err(MerkleError::InvalidIndex(index)); } let pos = index.to_scalar_index() as usize; Ok(self.nodes[pos]) } /// Returns a Merkle path to the node at the specified depth and index value. The node itself /// is not included in the path. /// /// # Errors /// Returns an error if: /// * The specified depth is greater than the depth of the tree. /// * The specified value not valid for the specified depth. pub fn get_path(&self, mut index: NodeIndex) -> Result { if index.is_root() { return Err(MerkleError::DepthTooSmall(index.depth())); } else if index.depth() > self.depth() { return Err(MerkleError::DepthTooBig(index.depth())); } else if !index.is_valid() { return Err(MerkleError::InvalidIndex(index)); } // TODO should we create a helper in `NodeIndex` that will encapsulate traversal to root so // we always use inlined `for` instead of `while`? the reason to use `for` is because its // easier for the compiler to vectorize. let mut path = Vec::with_capacity(index.depth() as usize); for _ in 0..index.depth() { let sibling = index.sibling().to_scalar_index() as usize; path.push(self.nodes[sibling]); index.move_up(); } Ok(path.into()) } /// Replaces the leaf at the specified index with the provided value. /// /// # Errors /// Returns an error if the specified index value is not a valid leaf value for this tree. pub fn update_leaf<'a>(&'a mut self, index_value: u64, value: Word) -> Result<(), MerkleError> { let depth = self.depth(); let mut index = NodeIndex::new(depth, index_value); if !index.is_valid() { return Err(MerkleError::InvalidIndex(index)); } // we don't need to copy the pairs into a new address as we are logically guaranteed to not // overlap write instructions. however, it's important to bind the lifetime of pairs to // `self.nodes` so the compiler will never move one without moving the other. debug_assert_eq!(self.nodes.len() & 1, 0); let n = self.nodes.len() / 2; // Safety: the length of nodes is guaranteed to contain pairs of words; hence, pairs of // digests. we explicitly bind the lifetime here so we add an extra layer of guarantee that // `self.nodes` will be moved only if `pairs` is moved as well. also, the algorithm is // logically guaranteed to not overlap write positions as the write index is always half // the index from which we read the digest input. let ptr = self.nodes.as_ptr() as *const [RpoDigest; 2]; let pairs: &'a [[RpoDigest; 2]] = unsafe { slice::from_raw_parts(ptr, n) }; // update the current node let pos = index.to_scalar_index() as usize; self.nodes[pos] = value; // traverse to the root, updating each node with the merged values of its parents for _ in 0..index.depth() { index.move_up(); let pos = index.to_scalar_index() as usize; let value = Rpo256::merge(&pairs[pos]).into(); self.nodes[pos] = value; } Ok(()) } } // TESTS // ================================================================================================ #[cfg(test)] mod tests { use super::*; use crate::merkle::int_to_node; use core::mem::size_of; use proptest::prelude::*; const LEAVES4: [Word; 4] = [ int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4), ]; const LEAVES8: [Word; 8] = [ int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4), int_to_node(5), int_to_node(6), int_to_node(7), int_to_node(8), ]; #[test] fn build_merkle_tree() { let tree = super::MerkleTree::new(LEAVES4.to_vec()).unwrap(); assert_eq!(8, tree.nodes.len()); // leaves were copied correctly for (a, b) in tree.nodes.iter().skip(4).zip(LEAVES4.iter()) { assert_eq!(a, b); } let (root, node2, node3) = compute_internal_nodes(); assert_eq!(root, tree.nodes[1]); assert_eq!(node2, tree.nodes[2]); assert_eq!(node3, tree.nodes[3]); assert_eq!(root, tree.root()); } #[test] fn get_leaf() { let tree = super::MerkleTree::new(LEAVES4.to_vec()).unwrap(); // check depth 2 assert_eq!(LEAVES4[0], tree.get_node(NodeIndex::new(2, 0)).unwrap()); assert_eq!(LEAVES4[1], tree.get_node(NodeIndex::new(2, 1)).unwrap()); assert_eq!(LEAVES4[2], tree.get_node(NodeIndex::new(2, 2)).unwrap()); assert_eq!(LEAVES4[3], tree.get_node(NodeIndex::new(2, 3)).unwrap()); // check depth 1 let (_, node2, node3) = compute_internal_nodes(); assert_eq!(node2, tree.get_node(NodeIndex::new(1, 0)).unwrap()); assert_eq!(node3, tree.get_node(NodeIndex::new(1, 1)).unwrap()); } #[test] fn get_path() { let tree = super::MerkleTree::new(LEAVES4.to_vec()).unwrap(); let (_, node2, node3) = compute_internal_nodes(); // check depth 2 assert_eq!( vec![LEAVES4[1], node3], *tree.get_path(NodeIndex::new(2, 0)).unwrap() ); assert_eq!( vec![LEAVES4[0], node3], *tree.get_path(NodeIndex::new(2, 1)).unwrap() ); assert_eq!( vec![LEAVES4[3], node2], *tree.get_path(NodeIndex::new(2, 2)).unwrap() ); assert_eq!( vec![LEAVES4[2], node2], *tree.get_path(NodeIndex::new(2, 3)).unwrap() ); // check depth 1 assert_eq!(vec![node3], *tree.get_path(NodeIndex::new(1, 0)).unwrap()); assert_eq!(vec![node2], *tree.get_path(NodeIndex::new(1, 1)).unwrap()); } #[test] fn update_leaf() { let mut tree = super::MerkleTree::new(LEAVES8.to_vec()).unwrap(); // update one leaf let value = 3; let new_node = int_to_node(9); let mut expected_leaves = LEAVES8.to_vec(); expected_leaves[value as usize] = new_node; let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap(); tree.update_leaf(value, new_node).unwrap(); assert_eq!(expected_tree.nodes, tree.nodes); // update another leaf let value = 6; let new_node = int_to_node(10); expected_leaves[value as usize] = new_node; let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap(); tree.update_leaf(value, new_node).unwrap(); assert_eq!(expected_tree.nodes, tree.nodes); } proptest! { #[test] fn arbitrary_word_can_be_represented_as_digest( a in prop::num::u64::ANY, b in prop::num::u64::ANY, c in prop::num::u64::ANY, d in prop::num::u64::ANY, ) { // this test will assert the memory equivalence between word and digest. // it is used to safeguard the `[MerkleTee::update_leaf]` implementation // that assumes this equivalence. // build a word and copy it to another address as digest let word = [Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]; let digest = RpoDigest::from(word); // assert the addresses are different let word_ptr = (&word).as_ptr() as *const u8; let digest_ptr = (&digest).as_ptr() as *const u8; assert_ne!(word_ptr, digest_ptr); // compare the bytes representation let word_bytes = unsafe { slice::from_raw_parts(word_ptr, size_of::()) }; let digest_bytes = unsafe { slice::from_raw_parts(digest_ptr, size_of::()) }; assert_eq!(word_bytes, digest_bytes); } } // HELPER FUNCTIONS // -------------------------------------------------------------------------------------------- fn compute_internal_nodes() -> (Word, Word, Word) { let node2 = Rpo256::hash_elements(&[LEAVES4[0], LEAVES4[1]].concat()); let node3 = Rpo256::hash_elements(&[LEAVES4[2], LEAVES4[3]].concat()); let root = Rpo256::merge(&[node2, node3]); (root.into(), node2.into(), node3.into()) } }