mirror of
https://github.com/arnaucube/miden-crypto.git
synced 2026-01-12 00:51:29 +01:00
feat: add merkle node index
This commit introduces a wrapper structure to encapsulate the merkle tree traversal. related issue: #36
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
use super::{Felt, MerkleError, MerklePath, Rpo256, RpoDigest, Vec, Word};
|
||||
use super::{Felt, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word};
|
||||
use crate::{utils::uninit_vector, FieldElement};
|
||||
use core::slice;
|
||||
use winter_math::log2;
|
||||
@@ -22,7 +22,7 @@ impl MerkleTree {
|
||||
pub fn new(leaves: Vec<Word>) -> Result<Self, MerkleError> {
|
||||
let n = leaves.len();
|
||||
if n <= 1 {
|
||||
return Err(MerkleError::DepthTooSmall(n as u32));
|
||||
return Err(MerkleError::DepthTooSmall(n as u8));
|
||||
} else if !n.is_power_of_two() {
|
||||
return Err(MerkleError::NumLeavesNotPowerOfTwo(n));
|
||||
}
|
||||
@@ -35,12 +35,14 @@ impl MerkleTree {
|
||||
nodes[n..].copy_from_slice(&leaves);
|
||||
|
||||
// re-interpret nodes as an array of two nodes fused together
|
||||
let two_nodes =
|
||||
unsafe { slice::from_raw_parts(nodes.as_ptr() as *const [RpoDigest; 2], n) };
|
||||
// Safety: `nodes` will never move here as it is not bound to an external lifetime (i.e.
|
||||
// `self`).
|
||||
let ptr = nodes.as_ptr() as *const [RpoDigest; 2];
|
||||
let pairs = unsafe { slice::from_raw_parts(ptr, n) };
|
||||
|
||||
// calculate all internal tree nodes
|
||||
for i in (1..n).rev() {
|
||||
nodes[i] = Rpo256::merge(&two_nodes[i]).into();
|
||||
nodes[i] = Rpo256::merge(&pairs[i]).into();
|
||||
}
|
||||
|
||||
Ok(Self { nodes })
|
||||
@@ -57,53 +59,53 @@ impl MerkleTree {
|
||||
/// Returns the depth of this Merkle tree.
|
||||
///
|
||||
/// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc.
|
||||
pub fn depth(&self) -> u32 {
|
||||
log2(self.nodes.len() / 2)
|
||||
pub fn depth(&self) -> u8 {
|
||||
log2(self.nodes.len() / 2) as u8
|
||||
}
|
||||
|
||||
/// Returns a node at the specified depth and index.
|
||||
/// Returns a node at the specified depth and index value.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// * The specified depth is greater than the depth of the tree.
|
||||
/// * The specified index not valid for the specified depth.
|
||||
pub fn get_node(&self, depth: u32, index: u64) -> Result<Word, MerkleError> {
|
||||
if depth == 0 {
|
||||
return Err(MerkleError::DepthTooSmall(depth));
|
||||
} else if depth > self.depth() {
|
||||
return Err(MerkleError::DepthTooBig(depth));
|
||||
}
|
||||
if index >= 2u64.pow(depth) {
|
||||
return Err(MerkleError::InvalidIndex(depth, index));
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
} else if index.depth() > self.depth() {
|
||||
return Err(MerkleError::DepthTooBig(index.depth()));
|
||||
} else if !index.is_valid() {
|
||||
return Err(MerkleError::InvalidIndex(index));
|
||||
}
|
||||
|
||||
let pos = 2_usize.pow(depth) + (index as usize);
|
||||
let pos = index.to_scalar_index() as usize;
|
||||
Ok(self.nodes[pos])
|
||||
}
|
||||
|
||||
/// Returns a Merkle path to the node at the specified depth and index. The node itself is
|
||||
/// not included in the path.
|
||||
/// Returns a Merkle path to the node at the specified depth and index value. The node itself
|
||||
/// is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// * The specified depth is greater than the depth of the tree.
|
||||
/// * The specified index not valid for the specified depth.
|
||||
pub fn get_path(&self, depth: u32, index: u64) -> Result<MerklePath, MerkleError> {
|
||||
if depth == 0 {
|
||||
return Err(MerkleError::DepthTooSmall(depth));
|
||||
} else if depth > self.depth() {
|
||||
return Err(MerkleError::DepthTooBig(depth));
|
||||
}
|
||||
if index >= 2u64.pow(depth) {
|
||||
return Err(MerkleError::InvalidIndex(depth, index));
|
||||
/// * The specified value not valid for the specified depth.
|
||||
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
} else if index.depth() > self.depth() {
|
||||
return Err(MerkleError::DepthTooBig(index.depth()));
|
||||
} else if !index.is_valid() {
|
||||
return Err(MerkleError::InvalidIndex(index));
|
||||
}
|
||||
|
||||
let mut path = Vec::with_capacity(depth as usize);
|
||||
let mut pos = 2_usize.pow(depth) + (index as usize);
|
||||
|
||||
while pos > 1 {
|
||||
path.push(self.nodes[pos ^ 1]);
|
||||
pos >>= 1;
|
||||
// TODO should we create a helper in `NodeIndex` that will encapsulate traversal to root so
|
||||
// we always use inlined `for` instead of `while`? the reason to use `for` is because its
|
||||
// easier for the compiler to vectorize.
|
||||
let mut path = Vec::with_capacity(index.depth() as usize);
|
||||
for _ in 0..index.depth() {
|
||||
let sibling = index.sibling().to_scalar_index() as usize;
|
||||
path.push(self.nodes[sibling]);
|
||||
index.move_up();
|
||||
}
|
||||
|
||||
Ok(path.into())
|
||||
@@ -112,23 +114,38 @@ impl MerkleTree {
|
||||
/// Replaces the leaf at the specified index with the provided value.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified index is not a valid leaf index for this tree.
|
||||
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<(), MerkleError> {
|
||||
/// Returns an error if the specified index value is not a valid leaf value for this tree.
|
||||
pub fn update_leaf<'a>(&'a mut self, index_value: u64, value: Word) -> Result<(), MerkleError> {
|
||||
let depth = self.depth();
|
||||
if index >= 2u64.pow(depth) {
|
||||
return Err(MerkleError::InvalidIndex(depth, index));
|
||||
let mut index = NodeIndex::new(depth, index_value);
|
||||
if !index.is_valid() {
|
||||
return Err(MerkleError::InvalidIndex(index));
|
||||
}
|
||||
|
||||
let mut index = 2usize.pow(depth) + index as usize;
|
||||
self.nodes[index] = value;
|
||||
|
||||
// we don't need to copy the pairs into a new address as we are logically guaranteed to not
|
||||
// overlap write instructions. however, it's important to bind the lifetime of pairs to
|
||||
// `self.nodes` so the compiler will never move one without moving the other.
|
||||
debug_assert_eq!(self.nodes.len() & 1, 0);
|
||||
let n = self.nodes.len() / 2;
|
||||
let two_nodes =
|
||||
unsafe { slice::from_raw_parts(self.nodes.as_ptr() as *const [RpoDigest; 2], n) };
|
||||
|
||||
for _ in 0..depth {
|
||||
index /= 2;
|
||||
self.nodes[index] = Rpo256::merge(&two_nodes[index]).into();
|
||||
// Safety: the length of nodes is guaranteed to contain pairs of words; hence, pairs of
|
||||
// digests. we explicitly bind the lifetime here so we add an extra layer of guarantee that
|
||||
// `self.nodes` will be moved only if `pairs` is moved as well. also, the algorithm is
|
||||
// logically guaranteed to not overlap write positions as the write index is always half
|
||||
// the index from which we read the digest input.
|
||||
let ptr = self.nodes.as_ptr() as *const [RpoDigest; 2];
|
||||
let pairs: &'a [[RpoDigest; 2]] = unsafe { slice::from_raw_parts(ptr, n) };
|
||||
|
||||
// update the current node
|
||||
let pos = index.to_scalar_index() as usize;
|
||||
self.nodes[pos] = value;
|
||||
|
||||
// traverse to the root, updating each node with the merged values of its parents
|
||||
for _ in 0..index.depth() {
|
||||
index.move_up();
|
||||
let pos = index.to_scalar_index() as usize;
|
||||
let value = Rpo256::merge(&pairs[pos]).into();
|
||||
self.nodes[pos] = value;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -140,10 +157,10 @@ impl MerkleTree {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
super::{int_to_node, Rpo256},
|
||||
Word,
|
||||
};
|
||||
use super::*;
|
||||
use crate::merkle::int_to_node;
|
||||
use core::mem::size_of;
|
||||
use proptest::prelude::*;
|
||||
|
||||
const LEAVES4: [Word; 4] = [
|
||||
int_to_node(1),
|
||||
@@ -187,16 +204,16 @@ mod tests {
|
||||
let tree = super::MerkleTree::new(LEAVES4.to_vec()).unwrap();
|
||||
|
||||
// check depth 2
|
||||
assert_eq!(LEAVES4[0], tree.get_node(2, 0).unwrap());
|
||||
assert_eq!(LEAVES4[1], tree.get_node(2, 1).unwrap());
|
||||
assert_eq!(LEAVES4[2], tree.get_node(2, 2).unwrap());
|
||||
assert_eq!(LEAVES4[3], tree.get_node(2, 3).unwrap());
|
||||
assert_eq!(LEAVES4[0], tree.get_node(NodeIndex::new(2, 0)).unwrap());
|
||||
assert_eq!(LEAVES4[1], tree.get_node(NodeIndex::new(2, 1)).unwrap());
|
||||
assert_eq!(LEAVES4[2], tree.get_node(NodeIndex::new(2, 2)).unwrap());
|
||||
assert_eq!(LEAVES4[3], tree.get_node(NodeIndex::new(2, 3)).unwrap());
|
||||
|
||||
// check depth 1
|
||||
let (_, node2, node3) = compute_internal_nodes();
|
||||
|
||||
assert_eq!(node2, tree.get_node(1, 0).unwrap());
|
||||
assert_eq!(node3, tree.get_node(1, 1).unwrap());
|
||||
assert_eq!(node2, tree.get_node(NodeIndex::new(1, 0)).unwrap());
|
||||
assert_eq!(node3, tree.get_node(NodeIndex::new(1, 1)).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -206,14 +223,26 @@ mod tests {
|
||||
let (_, node2, node3) = compute_internal_nodes();
|
||||
|
||||
// check depth 2
|
||||
assert_eq!(vec![LEAVES4[1], node3], *tree.get_path(2, 0).unwrap());
|
||||
assert_eq!(vec![LEAVES4[0], node3], *tree.get_path(2, 1).unwrap());
|
||||
assert_eq!(vec![LEAVES4[3], node2], *tree.get_path(2, 2).unwrap());
|
||||
assert_eq!(vec![LEAVES4[2], node2], *tree.get_path(2, 3).unwrap());
|
||||
assert_eq!(
|
||||
vec![LEAVES4[1], node3],
|
||||
*tree.get_path(NodeIndex::new(2, 0)).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
vec![LEAVES4[0], node3],
|
||||
*tree.get_path(NodeIndex::new(2, 1)).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
vec![LEAVES4[3], node2],
|
||||
*tree.get_path(NodeIndex::new(2, 2)).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
vec![LEAVES4[2], node2],
|
||||
*tree.get_path(NodeIndex::new(2, 3)).unwrap()
|
||||
);
|
||||
|
||||
// check depth 1
|
||||
assert_eq!(vec![node3], *tree.get_path(1, 0).unwrap());
|
||||
assert_eq!(vec![node2], *tree.get_path(1, 1).unwrap());
|
||||
assert_eq!(vec![node3], *tree.get_path(NodeIndex::new(1, 0)).unwrap());
|
||||
assert_eq!(vec![node2], *tree.get_path(NodeIndex::new(1, 1)).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -221,25 +250,53 @@ mod tests {
|
||||
let mut tree = super::MerkleTree::new(LEAVES8.to_vec()).unwrap();
|
||||
|
||||
// update one leaf
|
||||
let index = 3;
|
||||
let value = 3;
|
||||
let new_node = int_to_node(9);
|
||||
let mut expected_leaves = LEAVES8.to_vec();
|
||||
expected_leaves[index as usize] = new_node;
|
||||
expected_leaves[value as usize] = new_node;
|
||||
let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
|
||||
|
||||
tree.update_leaf(index, new_node).unwrap();
|
||||
tree.update_leaf(value, new_node).unwrap();
|
||||
assert_eq!(expected_tree.nodes, tree.nodes);
|
||||
|
||||
// update another leaf
|
||||
let index = 6;
|
||||
let value = 6;
|
||||
let new_node = int_to_node(10);
|
||||
expected_leaves[index as usize] = new_node;
|
||||
expected_leaves[value as usize] = new_node;
|
||||
let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
|
||||
|
||||
tree.update_leaf(index, new_node).unwrap();
|
||||
tree.update_leaf(value, new_node).unwrap();
|
||||
assert_eq!(expected_tree.nodes, tree.nodes);
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn arbitrary_word_can_be_represented_as_digest(
|
||||
a in prop::num::u64::ANY,
|
||||
b in prop::num::u64::ANY,
|
||||
c in prop::num::u64::ANY,
|
||||
d in prop::num::u64::ANY,
|
||||
) {
|
||||
// this test will assert the memory equivalence between word and digest.
|
||||
// it is used to safeguard the `[MerkleTee::update_leaf]` implementation
|
||||
// that assumes this equivalence.
|
||||
|
||||
// build a word and copy it to another address as digest
|
||||
let word = [Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)];
|
||||
let digest = RpoDigest::from(word);
|
||||
|
||||
// assert the addresses are different
|
||||
let word_ptr = (&word).as_ptr() as *const u8;
|
||||
let digest_ptr = (&digest).as_ptr() as *const u8;
|
||||
assert_ne!(word_ptr, digest_ptr);
|
||||
|
||||
// compare the bytes representation
|
||||
let word_bytes = unsafe { slice::from_raw_parts(word_ptr, size_of::<Word>()) };
|
||||
let digest_bytes = unsafe { slice::from_raw_parts(digest_ptr, size_of::<RpoDigest>()) };
|
||||
assert_eq!(word_bytes, digest_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
Reference in New Issue
Block a user