feat: add merkle node index

This commit introduces a wrapper structure to encapsulate the merkle
tree traversal.

related issue: #36
This commit is contained in:
Victor Lopez
2023-02-11 12:50:52 +01:00
parent 0c242d2c51
commit 0799b1bb9d
9 changed files with 531 additions and 320 deletions

View File

@@ -38,3 +38,32 @@ pub const ZERO: Felt = Felt::ZERO;
/// Field element representing ONE in the Miden base filed.
pub const ONE: Felt = Felt::ONE;
// TESTS
// ================================================================================================
#[test]
#[should_panic]
fn debug_assert_is_checked() {
// enforce the release checks to always have `RUSTFLAGS="-C debug-assertions".
//
// some upstream tests are performed with `debug_assert`, and we want to assert its correctness
// downstream.
//
// for reference, check
// https://github.com/0xPolygonMiden/miden-vm/issues/433
debug_assert!(false);
}
#[test]
#[should_panic]
#[allow(arithmetic_overflow)]
fn overflow_panics_for_test() {
// overflows might be disabled if tests are performed in release mode. these are critical,
// mandatory checks as overflows might be attack vectors.
//
// to enable overflow checks in release mode, ensure `RUSTFLAGS="-C overflow-checks"`
let a = 1_u64;
let b = 64;
assert_ne!(a << b, 0);
}

114
src/merkle/index.rs Normal file
View File

@@ -0,0 +1,114 @@
use super::RpoDigest;
// NODE INDEX
// ================================================================================================
/// A Merkle tree address to an arbitrary node.
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub struct NodeIndex {
depth: u8,
value: u64,
}
impl NodeIndex {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Creates a new node index.
pub const fn new(depth: u8, value: u64) -> Self {
Self { depth, value }
}
/// Creates a new node index pointing to the root of the tree.
pub const fn root() -> Self {
Self { depth: 0, value: 0 }
}
/// Mutates the instance and returns it, replacing the depth.
pub const fn with_depth(mut self, depth: u8) -> Self {
self.depth = depth;
self
}
/// Computes the value of the sibling of the current node.
pub fn sibling(mut self) -> Self {
self.value ^= 1;
self
}
// PROVIDERS
// --------------------------------------------------------------------------------------------
/// Builds a node to be used as input of a hash function when computing a Merkle path.
///
/// Will evaluate the parity of the current instance to define the result.
pub const fn build_node(&self, slf: RpoDigest, sibling: RpoDigest) -> [RpoDigest; 2] {
if self.is_value_odd() {
[sibling, slf]
} else {
[slf, sibling]
}
}
/// Returns the scalar representation of the depth/value pair.
///
/// It is computed as `2^depth + value`.
pub const fn to_scalar_index(&self) -> u64 {
(1 << self.depth as u64) + self.value
}
/// Returns the depth of the current instance.
pub const fn depth(&self) -> u8 {
self.depth
}
/// Returns the value of the current depth.
pub const fn value(&self) -> u64 {
self.value
}
/// Returns true if the current value fits the current depth for a binary tree.
pub const fn is_valid(&self) -> bool {
self.value < (1 << self.depth as u64)
}
/// Returns true if the current instance points to a right sibling node.
pub const fn is_value_odd(&self) -> bool {
(self.value & 1) == 1
}
/// Returns `true` if the depth is `0`.
pub const fn is_root(&self) -> bool {
self.depth == 0
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Traverse one level towards the root, decrementing the depth by `1`.
pub fn move_up(&mut self) -> &mut Self {
self.depth = self.depth.saturating_sub(1);
self.value >>= 1;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn arbitrary_index_wont_panic_on_move_up(
depth in prop::num::u8::ANY,
value in prop::num::u64::ANY,
count in prop::num::u8::ANY,
) {
let mut index = NodeIndex::new(depth, value);
for _ in 0..count {
index.move_up();
}
}
}
}

View File

@@ -1,4 +1,4 @@
use super::{Felt, MerkleError, MerklePath, Rpo256, RpoDigest, Vec, Word};
use super::{Felt, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word};
use crate::{utils::uninit_vector, FieldElement};
use core::slice;
use winter_math::log2;
@@ -22,7 +22,7 @@ impl MerkleTree {
pub fn new(leaves: Vec<Word>) -> Result<Self, MerkleError> {
let n = leaves.len();
if n <= 1 {
return Err(MerkleError::DepthTooSmall(n as u32));
return Err(MerkleError::DepthTooSmall(n as u8));
} else if !n.is_power_of_two() {
return Err(MerkleError::NumLeavesNotPowerOfTwo(n));
}
@@ -35,12 +35,14 @@ impl MerkleTree {
nodes[n..].copy_from_slice(&leaves);
// re-interpret nodes as an array of two nodes fused together
let two_nodes =
unsafe { slice::from_raw_parts(nodes.as_ptr() as *const [RpoDigest; 2], n) };
// Safety: `nodes` will never move here as it is not bound to an external lifetime (i.e.
// `self`).
let ptr = nodes.as_ptr() as *const [RpoDigest; 2];
let pairs = unsafe { slice::from_raw_parts(ptr, n) };
// calculate all internal tree nodes
for i in (1..n).rev() {
nodes[i] = Rpo256::merge(&two_nodes[i]).into();
nodes[i] = Rpo256::merge(&pairs[i]).into();
}
Ok(Self { nodes })
@@ -57,53 +59,53 @@ impl MerkleTree {
/// Returns the depth of this Merkle tree.
///
/// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc.
pub fn depth(&self) -> u32 {
log2(self.nodes.len() / 2)
pub fn depth(&self) -> u8 {
log2(self.nodes.len() / 2) as u8
}
/// Returns a node at the specified depth and index.
/// Returns a node at the specified depth and index value.
///
/// # Errors
/// Returns an error if:
/// * The specified depth is greater than the depth of the tree.
/// * The specified index not valid for the specified depth.
pub fn get_node(&self, depth: u32, index: u64) -> Result<Word, MerkleError> {
if depth == 0 {
return Err(MerkleError::DepthTooSmall(depth));
} else if depth > self.depth() {
return Err(MerkleError::DepthTooBig(depth));
}
if index >= 2u64.pow(depth) {
return Err(MerkleError::InvalidIndex(depth, index));
pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
if index.is_root() {
return Err(MerkleError::DepthTooSmall(index.depth()));
} else if index.depth() > self.depth() {
return Err(MerkleError::DepthTooBig(index.depth()));
} else if !index.is_valid() {
return Err(MerkleError::InvalidIndex(index));
}
let pos = 2_usize.pow(depth) + (index as usize);
let pos = index.to_scalar_index() as usize;
Ok(self.nodes[pos])
}
/// Returns a Merkle path to the node at the specified depth and index. The node itself is
/// not included in the path.
/// Returns a Merkle path to the node at the specified depth and index value. The node itself
/// is not included in the path.
///
/// # Errors
/// Returns an error if:
/// * The specified depth is greater than the depth of the tree.
/// * The specified index not valid for the specified depth.
pub fn get_path(&self, depth: u32, index: u64) -> Result<MerklePath, MerkleError> {
if depth == 0 {
return Err(MerkleError::DepthTooSmall(depth));
} else if depth > self.depth() {
return Err(MerkleError::DepthTooBig(depth));
}
if index >= 2u64.pow(depth) {
return Err(MerkleError::InvalidIndex(depth, index));
/// * The specified value not valid for the specified depth.
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
if index.is_root() {
return Err(MerkleError::DepthTooSmall(index.depth()));
} else if index.depth() > self.depth() {
return Err(MerkleError::DepthTooBig(index.depth()));
} else if !index.is_valid() {
return Err(MerkleError::InvalidIndex(index));
}
let mut path = Vec::with_capacity(depth as usize);
let mut pos = 2_usize.pow(depth) + (index as usize);
while pos > 1 {
path.push(self.nodes[pos ^ 1]);
pos >>= 1;
// TODO should we create a helper in `NodeIndex` that will encapsulate traversal to root so
// we always use inlined `for` instead of `while`? the reason to use `for` is because its
// easier for the compiler to vectorize.
let mut path = Vec::with_capacity(index.depth() as usize);
for _ in 0..index.depth() {
let sibling = index.sibling().to_scalar_index() as usize;
path.push(self.nodes[sibling]);
index.move_up();
}
Ok(path.into())
@@ -112,23 +114,38 @@ impl MerkleTree {
/// Replaces the leaf at the specified index with the provided value.
///
/// # Errors
/// Returns an error if the specified index is not a valid leaf index for this tree.
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<(), MerkleError> {
/// Returns an error if the specified index value is not a valid leaf value for this tree.
pub fn update_leaf<'a>(&'a mut self, index_value: u64, value: Word) -> Result<(), MerkleError> {
let depth = self.depth();
if index >= 2u64.pow(depth) {
return Err(MerkleError::InvalidIndex(depth, index));
let mut index = NodeIndex::new(depth, index_value);
if !index.is_valid() {
return Err(MerkleError::InvalidIndex(index));
}
let mut index = 2usize.pow(depth) + index as usize;
self.nodes[index] = value;
// we don't need to copy the pairs into a new address as we are logically guaranteed to not
// overlap write instructions. however, it's important to bind the lifetime of pairs to
// `self.nodes` so the compiler will never move one without moving the other.
debug_assert_eq!(self.nodes.len() & 1, 0);
let n = self.nodes.len() / 2;
let two_nodes =
unsafe { slice::from_raw_parts(self.nodes.as_ptr() as *const [RpoDigest; 2], n) };
for _ in 0..depth {
index /= 2;
self.nodes[index] = Rpo256::merge(&two_nodes[index]).into();
// Safety: the length of nodes is guaranteed to contain pairs of words; hence, pairs of
// digests. we explicitly bind the lifetime here so we add an extra layer of guarantee that
// `self.nodes` will be moved only if `pairs` is moved as well. also, the algorithm is
// logically guaranteed to not overlap write positions as the write index is always half
// the index from which we read the digest input.
let ptr = self.nodes.as_ptr() as *const [RpoDigest; 2];
let pairs: &'a [[RpoDigest; 2]] = unsafe { slice::from_raw_parts(ptr, n) };
// update the current node
let pos = index.to_scalar_index() as usize;
self.nodes[pos] = value;
// traverse to the root, updating each node with the merged values of its parents
for _ in 0..index.depth() {
index.move_up();
let pos = index.to_scalar_index() as usize;
let value = Rpo256::merge(&pairs[pos]).into();
self.nodes[pos] = value;
}
Ok(())
@@ -140,10 +157,10 @@ impl MerkleTree {
#[cfg(test)]
mod tests {
use super::{
super::{int_to_node, Rpo256},
Word,
};
use super::*;
use crate::merkle::int_to_node;
use core::mem::size_of;
use proptest::prelude::*;
const LEAVES4: [Word; 4] = [
int_to_node(1),
@@ -187,16 +204,16 @@ mod tests {
let tree = super::MerkleTree::new(LEAVES4.to_vec()).unwrap();
// check depth 2
assert_eq!(LEAVES4[0], tree.get_node(2, 0).unwrap());
assert_eq!(LEAVES4[1], tree.get_node(2, 1).unwrap());
assert_eq!(LEAVES4[2], tree.get_node(2, 2).unwrap());
assert_eq!(LEAVES4[3], tree.get_node(2, 3).unwrap());
assert_eq!(LEAVES4[0], tree.get_node(NodeIndex::new(2, 0)).unwrap());
assert_eq!(LEAVES4[1], tree.get_node(NodeIndex::new(2, 1)).unwrap());
assert_eq!(LEAVES4[2], tree.get_node(NodeIndex::new(2, 2)).unwrap());
assert_eq!(LEAVES4[3], tree.get_node(NodeIndex::new(2, 3)).unwrap());
// check depth 1
let (_, node2, node3) = compute_internal_nodes();
assert_eq!(node2, tree.get_node(1, 0).unwrap());
assert_eq!(node3, tree.get_node(1, 1).unwrap());
assert_eq!(node2, tree.get_node(NodeIndex::new(1, 0)).unwrap());
assert_eq!(node3, tree.get_node(NodeIndex::new(1, 1)).unwrap());
}
#[test]
@@ -206,14 +223,26 @@ mod tests {
let (_, node2, node3) = compute_internal_nodes();
// check depth 2
assert_eq!(vec![LEAVES4[1], node3], *tree.get_path(2, 0).unwrap());
assert_eq!(vec![LEAVES4[0], node3], *tree.get_path(2, 1).unwrap());
assert_eq!(vec![LEAVES4[3], node2], *tree.get_path(2, 2).unwrap());
assert_eq!(vec![LEAVES4[2], node2], *tree.get_path(2, 3).unwrap());
assert_eq!(
vec![LEAVES4[1], node3],
*tree.get_path(NodeIndex::new(2, 0)).unwrap()
);
assert_eq!(
vec![LEAVES4[0], node3],
*tree.get_path(NodeIndex::new(2, 1)).unwrap()
);
assert_eq!(
vec![LEAVES4[3], node2],
*tree.get_path(NodeIndex::new(2, 2)).unwrap()
);
assert_eq!(
vec![LEAVES4[2], node2],
*tree.get_path(NodeIndex::new(2, 3)).unwrap()
);
// check depth 1
assert_eq!(vec![node3], *tree.get_path(1, 0).unwrap());
assert_eq!(vec![node2], *tree.get_path(1, 1).unwrap());
assert_eq!(vec![node3], *tree.get_path(NodeIndex::new(1, 0)).unwrap());
assert_eq!(vec![node2], *tree.get_path(NodeIndex::new(1, 1)).unwrap());
}
#[test]
@@ -221,25 +250,53 @@ mod tests {
let mut tree = super::MerkleTree::new(LEAVES8.to_vec()).unwrap();
// update one leaf
let index = 3;
let value = 3;
let new_node = int_to_node(9);
let mut expected_leaves = LEAVES8.to_vec();
expected_leaves[index as usize] = new_node;
expected_leaves[value as usize] = new_node;
let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
tree.update_leaf(index, new_node).unwrap();
tree.update_leaf(value, new_node).unwrap();
assert_eq!(expected_tree.nodes, tree.nodes);
// update another leaf
let index = 6;
let value = 6;
let new_node = int_to_node(10);
expected_leaves[index as usize] = new_node;
expected_leaves[value as usize] = new_node;
let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
tree.update_leaf(index, new_node).unwrap();
tree.update_leaf(value, new_node).unwrap();
assert_eq!(expected_tree.nodes, tree.nodes);
}
proptest! {
#[test]
fn arbitrary_word_can_be_represented_as_digest(
a in prop::num::u64::ANY,
b in prop::num::u64::ANY,
c in prop::num::u64::ANY,
d in prop::num::u64::ANY,
) {
// this test will assert the memory equivalence between word and digest.
// it is used to safeguard the `[MerkleTee::update_leaf]` implementation
// that assumes this equivalence.
// build a word and copy it to another address as digest
let word = [Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)];
let digest = RpoDigest::from(word);
// assert the addresses are different
let word_ptr = (&word).as_ptr() as *const u8;
let digest_ptr = (&digest).as_ptr() as *const u8;
assert_ne!(word_ptr, digest_ptr);
// compare the bytes representation
let word_bytes = unsafe { slice::from_raw_parts(word_ptr, size_of::<Word>()) };
let digest_bytes = unsafe { slice::from_raw_parts(digest_ptr, size_of::<RpoDigest>()) };
assert_eq!(word_bytes, digest_bytes);
}
}
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------

View File

@@ -5,6 +5,9 @@ use super::{
};
use core::fmt;
mod index;
pub use index::NodeIndex;
mod merkle_tree;
pub use merkle_tree::MerkleTree;
@@ -22,11 +25,11 @@ pub use simple_smt::SimpleSmt;
#[derive(Clone, Debug)]
pub enum MerkleError {
DepthTooSmall(u32),
DepthTooBig(u32),
DepthTooSmall(u8),
DepthTooBig(u8),
NumLeavesNotPowerOfTwo(usize),
InvalidIndex(u32, u64),
InvalidDepth(u32, u32),
InvalidIndex(NodeIndex),
InvalidDepth { expected: u8, provided: u8 },
InvalidPath(MerklePath),
InvalidEntriesCount(usize, usize),
NodeNotInSet(u64),
@@ -41,11 +44,11 @@ impl fmt::Display for MerkleError {
NumLeavesNotPowerOfTwo(leaves) => {
write!(f, "the leaves count {leaves} is not a power of 2")
}
InvalidIndex(depth, index) => write!(
InvalidIndex(index) => write!(
f,
"the leaf index {index} is not valid for the depth {depth}"
"the index value {} is not valid for the depth {}", index.value(), index.depth()
),
InvalidDepth(expected, provided) => write!(
InvalidDepth { expected, provided } => write!(
f,
"the provided depth {provided} is not valid for {expected}"
),

View File

@@ -1,4 +1,4 @@
use super::{vec, Rpo256, Vec, Word};
use super::{vec, NodeIndex, Rpo256, Vec, Word};
use core::ops::{Deref, DerefMut};
// MERKLE PATH
@@ -23,17 +23,12 @@ impl MerklePath {
// --------------------------------------------------------------------------------------------
/// Computes the merkle root for this opening.
pub fn compute_root(&self, mut index: u64, node: Word) -> Word {
pub fn compute_root(&self, index_value: u64, node: Word) -> Word {
let mut index = NodeIndex::new(self.depth(), index_value);
self.nodes.iter().copied().fold(node, |node, sibling| {
// build the input node, considering the parity of the current index.
let is_right_sibling = (index & 1) == 1;
let input = if is_right_sibling {
[sibling.into(), node.into()]
} else {
[node.into(), sibling.into()]
};
// compute the node and move to the next iteration.
index >>= 1;
let input = index.build_node(node.into(), sibling.into());
index.move_up();
Rpo256::merge(&input).into()
})
}

View File

@@ -1,4 +1,4 @@
use super::{BTreeMap, MerkleError, MerklePath, Rpo256, Vec, Word, ZERO};
use super::{BTreeMap, MerkleError, MerklePath, NodeIndex, Rpo256, Vec, Word, ZERO};
// MERKLE PATH SET
// ================================================================================================
@@ -7,7 +7,7 @@ use super::{BTreeMap, MerkleError, MerklePath, Rpo256, Vec, Word, ZERO};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MerklePathSet {
root: Word,
total_depth: u32,
total_depth: u8,
paths: BTreeMap<u64, MerklePath>,
}
@@ -16,7 +16,7 @@ impl MerklePathSet {
// --------------------------------------------------------------------------------------------
/// Returns an empty MerklePathSet.
pub fn new(depth: u32) -> Result<Self, MerkleError> {
pub fn new(depth: u8) -> Result<Self, MerkleError> {
let root = [ZERO; 4];
let paths = BTreeMap::new();
@@ -38,7 +38,7 @@ impl MerklePathSet {
/// Returns the depth of the Merkle tree implied by the paths stored in this set.
///
/// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc.
pub const fn depth(&self) -> u32 {
pub const fn depth(&self) -> u8 {
self.total_depth
}
@@ -48,27 +48,26 @@ impl MerklePathSet {
/// Returns an error if:
/// * The specified index not valid for the depth of structure.
/// * Requested node does not exist in the set.
pub fn get_node(&self, depth: u32, index: u64) -> Result<Word, MerkleError> {
if index >= 2u64.pow(self.total_depth) {
return Err(MerkleError::InvalidIndex(self.total_depth, index));
pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
if !index.with_depth(self.total_depth).is_valid() {
return Err(MerkleError::InvalidIndex(
index.with_depth(self.total_depth),
));
}
if depth != self.total_depth {
return Err(MerkleError::InvalidDepth(self.total_depth, depth));
if index.depth() != self.total_depth {
return Err(MerkleError::InvalidDepth {
expected: self.total_depth,
provided: index.depth(),
});
}
let pos = 2u64.pow(depth) + index;
let index = pos / 2;
match self.paths.get(&index) {
None => Err(MerkleError::NodeNotInSet(index)),
Some(path) => {
if Self::is_even(pos) {
Ok(path[0])
} else {
Ok(path[1])
}
}
}
let index_value = index.to_scalar_index();
let parity = index_value & 1;
let index_value = index_value / 2;
self.paths
.get(&index_value)
.ok_or(MerkleError::NodeNotInSet(index_value))
.map(|path| path[parity as usize])
}
/// Returns a Merkle path to the node at the specified index. The node itself is
@@ -78,30 +77,27 @@ impl MerklePathSet {
/// Returns an error if:
/// * The specified index not valid for the depth of structure.
/// * Node of the requested path does not exist in the set.
pub fn get_path(&self, depth: u32, index: u64) -> Result<MerklePath, MerkleError> {
if index >= 2u64.pow(self.total_depth) {
return Err(MerkleError::InvalidIndex(self.total_depth, index));
pub fn get_path(&self, index: NodeIndex) -> Result<MerklePath, MerkleError> {
if !index.with_depth(self.total_depth).is_valid() {
return Err(MerkleError::InvalidIndex(index));
}
if depth != self.total_depth {
return Err(MerkleError::InvalidDepth(self.total_depth, depth));
if index.depth() != self.total_depth {
return Err(MerkleError::InvalidDepth {
expected: self.total_depth,
provided: index.depth(),
});
}
let pos = 2u64.pow(depth) + index;
let index = pos / 2;
match self.paths.get(&index) {
None => Err(MerkleError::NodeNotInSet(index)),
Some(path) => {
let mut local_path = path.clone();
if Self::is_even(pos) {
local_path.remove(0);
Ok(local_path)
} else {
local_path.remove(1);
Ok(local_path)
}
}
}
let index_value = index.to_scalar_index();
let index = index_value / 2;
let parity = index_value & 1;
let mut path = self
.paths
.get(&index)
.cloned()
.ok_or(MerkleError::NodeNotInSet(index))?;
path.remove(parity as usize);
Ok(path)
}
// STATE MUTATORS
@@ -118,36 +114,41 @@ impl MerklePathSet {
/// different root).
pub fn add_path(
&mut self,
index: u64,
index_value: u64,
value: Word,
path: MerklePath,
mut path: MerklePath,
) -> Result<(), MerkleError> {
let depth = (path.len() + 1) as u32;
if depth != self.total_depth {
return Err(MerkleError::InvalidDepth(self.total_depth, depth));
let depth = (path.len() + 1) as u8;
let mut index = NodeIndex::new(depth, index_value);
if index.depth() != self.total_depth {
return Err(MerkleError::InvalidDepth {
expected: self.total_depth,
provided: index.depth(),
});
}
// Actual number of node in tree
let pos = 2u64.pow(self.total_depth) + index;
// update the current path
let index_value = index.to_scalar_index();
let upper_index_value = index_value / 2;
let parity = index_value & 1;
path.insert(parity as usize, value);
// Index of the leaf path in map. Paths of neighboring leaves are stored in one key-value pair
let half_pos = pos / 2;
// traverse to the root, updating the nodes
let root: Word = Rpo256::merge(&[path[0].into(), path[1].into()]).into();
let root = path.iter().skip(2).copied().fold(root, |root, hash| {
index.move_up();
Rpo256::merge(&index.build_node(root.into(), hash.into())).into()
});
let mut extended_path = path;
if Self::is_even(pos) {
extended_path.insert(0, value);
} else {
extended_path.insert(1, value);
}
let root_of_current_path = Self::compute_path_root(&extended_path, depth, index);
// TODO review and document this logic
if self.root == [ZERO; 4] {
self.root = root_of_current_path;
} else if self.root != root_of_current_path {
return Err(MerkleError::InvalidPath(extended_path));
self.root = root;
} else if self.root != root {
return Err(MerkleError::InvalidPath(path));
}
self.paths.insert(half_pos, extended_path);
// finish updating the path
self.paths.insert(upper_index_value, path);
Ok(())
}
@@ -156,29 +157,44 @@ impl MerklePathSet {
/// # Errors
/// Returns an error if:
/// * Requested node does not exist in the set.
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<(), MerkleError> {
pub fn update_leaf(&mut self, base_index_value: u64, value: Word) -> Result<(), MerkleError> {
let depth = self.depth();
if index >= 2u64.pow(depth) {
return Err(MerkleError::InvalidIndex(depth, index));
let mut index = NodeIndex::new(depth, base_index_value);
if !index.is_valid() {
return Err(MerkleError::InvalidIndex(index));
}
let pos = 2u64.pow(depth) + index;
let path = match self.paths.get_mut(&(pos / 2)) {
None => return Err(MerkleError::NodeNotInSet(index)),
let path = match self
.paths
.get_mut(&index.clone().move_up().to_scalar_index())
{
Some(path) => path,
None => return Err(MerkleError::NodeNotInSet(base_index_value)),
};
// Fill old_hashes vector -----------------------------------------------------------------
let (old_hashes, _) = Self::compute_path_trace(path, depth, index);
// Fill new_hashes vector -----------------------------------------------------------------
if Self::is_even(pos) {
path[0] = value;
} else {
path[1] = value;
let mut current_index = index;
let mut old_hashes = Vec::with_capacity(path.len().saturating_sub(2));
let mut root: Word = Rpo256::merge(&[path[0].into(), path[1].into()]).into();
for hash in path.iter().skip(2).copied() {
old_hashes.push(root);
current_index.move_up();
let input = current_index.build_node(hash.into(), root.into());
root = Rpo256::merge(&input).into();
}
// Fill new_hashes vector -----------------------------------------------------------------
path[index.is_value_odd() as usize] = value;
let mut new_hashes = Vec::with_capacity(path.len().saturating_sub(2));
let mut new_root: Word = Rpo256::merge(&[path[0].into(), path[1].into()]).into();
for path_hash in path.iter().skip(2).copied() {
new_hashes.push(new_root);
index.move_up();
let input = current_index.build_node(path_hash.into(), new_root.into());
new_root = Rpo256::merge(&input).into();
}
let (new_hashes, new_root) = Self::compute_path_trace(path, depth, index);
self.root = new_root;
// update paths ---------------------------------------------------------------------------
@@ -193,59 +209,6 @@ impl MerklePathSet {
Ok(())
}
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
const fn is_even(pos: u64) -> bool {
pos & 1 == 0
}
/// Returns hash of the root
fn compute_path_root(path: &[Word], depth: u32, index: u64) -> Word {
let mut pos = 2u64.pow(depth) + index;
// hash that is obtained after calculating the current hash and path hash
let mut comp_hash = Rpo256::merge(&[path[0].into(), path[1].into()]).into();
for path_hash in path.iter().skip(2) {
pos /= 2;
comp_hash = Self::calculate_parent_hash(comp_hash, pos, *path_hash);
}
comp_hash
}
/// Calculates the hash of the parent node by two sibling ones
/// - node — current node
/// - node_pos — position of the current node
/// - sibling — neighboring vertex in the tree
fn calculate_parent_hash(node: Word, node_pos: u64, sibling: Word) -> Word {
if Self::is_even(node_pos) {
Rpo256::merge(&[node.into(), sibling.into()]).into()
} else {
Rpo256::merge(&[sibling.into(), node.into()]).into()
}
}
/// Returns vector of hashes from current to the root
fn compute_path_trace(path: &[Word], depth: u32, index: u64) -> (MerklePath, Word) {
let mut pos = 2u64.pow(depth) + index;
let mut computed_hashes = Vec::<Word>::new();
let mut comp_hash = Rpo256::merge(&[path[0].into(), path[1].into()]).into();
if path.len() != 2 {
for path_hash in path.iter().skip(2) {
computed_hashes.push(comp_hash);
pos /= 2;
comp_hash = Self::calculate_parent_hash(comp_hash, pos, *path_hash);
}
}
(computed_hashes.into(), comp_hash)
}
}
// TESTS
@@ -263,10 +226,10 @@ mod tests {
let leaf2 = int_to_node(2);
let leaf3 = int_to_node(3);
let parent0 = MerklePathSet::calculate_parent_hash(leaf0, 0, leaf1);
let parent1 = MerklePathSet::calculate_parent_hash(leaf2, 2, leaf3);
let parent0 = calculate_parent_hash(leaf0, 0, leaf1);
let parent1 = calculate_parent_hash(leaf2, 2, leaf3);
let root_exp = MerklePathSet::calculate_parent_hash(parent0, 0, parent1);
let root_exp = calculate_parent_hash(parent0, 0, parent1);
let mut set = super::MerklePathSet::new(3).unwrap();
@@ -279,29 +242,32 @@ mod tests {
fn add_and_get_path() {
let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)];
let hash_6 = int_to_node(6);
let index = 6u64;
let depth = 4u32;
let index = 6_u64;
let depth = 4_u8;
let mut set = super::MerklePathSet::new(depth).unwrap();
set.add_path(index, hash_6, path_6.clone().into()).unwrap();
let stored_path_6 = set.get_path(depth, index).unwrap();
let stored_path_6 = set.get_path(NodeIndex::new(depth, index)).unwrap();
assert_eq!(path_6, *stored_path_6);
assert!(set.get_path(depth, 15u64).is_err())
assert!(set.get_path(NodeIndex::new(depth, 15_u64)).is_err())
}
#[test]
fn get_node() {
let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)];
let hash_6 = int_to_node(6);
let index = 6u64;
let depth = 4u32;
let mut set = super::MerklePathSet::new(depth).unwrap();
let index = 6_u64;
let depth = 4_u8;
let mut set = MerklePathSet::new(depth).unwrap();
set.add_path(index, hash_6, path_6.into()).unwrap();
assert_eq!(int_to_node(6u64), set.get_node(depth, index).unwrap());
assert!(set.get_node(depth, 15u64).is_err());
assert_eq!(
int_to_node(6u64),
set.get_node(NodeIndex::new(depth, index)).unwrap()
);
assert!(set.get_node(NodeIndex::new(depth, 15_u64)).is_err());
}
#[test]
@@ -310,8 +276,8 @@ mod tests {
let hash_5 = int_to_node(5);
let hash_6 = int_to_node(6);
let hash_7 = int_to_node(7);
let hash_45 = MerklePathSet::calculate_parent_hash(hash_4, 12u64, hash_5);
let hash_67 = MerklePathSet::calculate_parent_hash(hash_6, 14u64, hash_7);
let hash_45 = calculate_parent_hash(hash_4, 12u64, hash_5);
let hash_67 = calculate_parent_hash(hash_6, 14u64, hash_7);
let hash_0123 = int_to_node(123);
@@ -319,11 +285,11 @@ mod tests {
let path_5 = vec![hash_4, hash_67, hash_0123];
let path_4 = vec![hash_5, hash_67, hash_0123];
let index_6 = 6u64;
let index_5 = 5u64;
let index_4 = 4u64;
let depth = 4u32;
let mut set = super::MerklePathSet::new(depth).unwrap();
let index_6 = 6_u64;
let index_5 = 5_u64;
let index_4 = 4_u64;
let depth = 4_u8;
let mut set = MerklePathSet::new(depth).unwrap();
set.add_path(index_6, hash_6, path_6.into()).unwrap();
set.add_path(index_5, hash_5, path_5.into()).unwrap();
@@ -333,15 +299,34 @@ mod tests {
let new_hash_5 = int_to_node(55);
set.update_leaf(index_6, new_hash_6).unwrap();
let new_path_4 = set.get_path(depth, index_4).unwrap();
let new_hash_67 = MerklePathSet::calculate_parent_hash(new_hash_6, 14u64, hash_7);
let new_path_4 = set.get_path(NodeIndex::new(depth, index_4)).unwrap();
let new_hash_67 = calculate_parent_hash(new_hash_6, 14_u64, hash_7);
assert_eq!(new_hash_67, new_path_4[1]);
set.update_leaf(index_5, new_hash_5).unwrap();
let new_path_4 = set.get_path(depth, index_4).unwrap();
let new_path_6 = set.get_path(depth, index_6).unwrap();
let new_hash_45 = MerklePathSet::calculate_parent_hash(new_hash_5, 13u64, hash_4);
let new_path_4 = set.get_path(NodeIndex::new(depth, index_4)).unwrap();
let new_path_6 = set.get_path(NodeIndex::new(depth, index_6)).unwrap();
let new_hash_45 = calculate_parent_hash(new_hash_5, 13_u64, hash_4);
assert_eq!(new_hash_45, new_path_6[1]);
assert_eq!(new_hash_5, new_path_4[0]);
}
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
const fn is_even(pos: u64) -> bool {
pos & 1 == 0
}
/// Calculates the hash of the parent node by two sibling ones
/// - node — current node
/// - node_pos — position of the current node
/// - sibling — neighboring vertex in the tree
fn calculate_parent_hash(node: Word, node_pos: u64, sibling: Word) -> Word {
if is_even(node_pos) {
Rpo256::merge(&[node.into(), sibling.into()]).into()
} else {
Rpo256::merge(&[sibling.into(), node.into()]).into()
}
}
}

View File

@@ -1,4 +1,4 @@
use super::{BTreeMap, MerkleError, MerklePath, Rpo256, RpoDigest, Vec, Word};
use super::{BTreeMap, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word};
#[cfg(test)]
mod tests;
@@ -12,7 +12,7 @@ mod tests;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SimpleSmt {
root: Word,
depth: u32,
depth: u8,
store: Store,
}
@@ -21,10 +21,10 @@ impl SimpleSmt {
// --------------------------------------------------------------------------------------------
/// Minimum supported depth.
pub const MIN_DEPTH: u32 = 1;
pub const MIN_DEPTH: u8 = 1;
/// Maximum supported depth.
pub const MAX_DEPTH: u32 = 63;
pub const MAX_DEPTH: u8 = 63;
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
@@ -37,7 +37,7 @@ impl SimpleSmt {
///
/// The function will fail if the provided entries count exceed the maximum tree capacity, that
/// is `2^{depth}`.
pub fn new<R, I>(entries: R, depth: u32) -> Result<Self, MerkleError>
pub fn new<R, I>(entries: R, depth: u8) -> Result<Self, MerkleError>
where
R: IntoIterator<IntoIter = I>,
I: Iterator<Item = (u64, Word)> + ExactSizeIterator,
@@ -67,7 +67,7 @@ impl SimpleSmt {
}
/// Returns the depth of this Merkle tree.
pub const fn depth(&self) -> u32 {
pub const fn depth(&self) -> u8 {
self.depth
}
@@ -82,15 +82,15 @@ impl SimpleSmt {
/// Returns an error if:
/// * The specified depth is greater than the depth of the tree.
/// * The specified key does not exist
pub fn get_node(&self, depth: u32, key: u64) -> Result<Word, MerkleError> {
if depth == 0 {
Err(MerkleError::DepthTooSmall(depth))
} else if depth > self.depth() {
Err(MerkleError::DepthTooBig(depth))
} else if depth == self.depth() {
self.store.get_leaf_node(key)
pub fn get_node(&self, index: &NodeIndex) -> Result<Word, MerkleError> {
if index.is_root() {
Err(MerkleError::DepthTooSmall(index.depth()))
} else if index.depth() > self.depth() {
Err(MerkleError::DepthTooBig(index.depth()))
} else if index.depth() == self.depth() {
self.store.get_leaf_node(index.value())
} else {
let branch_node = self.store.get_branch_node(key, depth)?;
let branch_node = self.store.get_branch_node(index)?;
Ok(Rpo256::merge(&[branch_node.left, branch_node.right]).into())
}
}
@@ -102,27 +102,23 @@ impl SimpleSmt {
/// Returns an error if:
/// * The specified key does not exist as a branch or leaf node
/// * The specified depth is greater than the depth of the tree.
pub fn get_path(&self, depth: u32, key: u64) -> Result<MerklePath, MerkleError> {
if depth == 0 {
return Err(MerkleError::DepthTooSmall(depth));
} else if depth > self.depth() {
return Err(MerkleError::DepthTooBig(depth));
} else if depth == self.depth() && !self.store.check_leaf_node_exists(key) {
return Err(MerkleError::InvalidIndex(self.depth(), key));
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
if index.is_root() {
return Err(MerkleError::DepthTooSmall(index.depth()));
} else if index.depth() > self.depth() {
return Err(MerkleError::DepthTooBig(index.depth()));
} else if index.depth() == self.depth() && !self.store.check_leaf_node_exists(index.value())
{
return Err(MerkleError::InvalidIndex(index.with_depth(self.depth())));
}
let mut path = Vec::with_capacity(depth as usize);
let mut curr_key = key;
for n in (0..depth).rev() {
let parent_key = curr_key >> 1;
let parent_node = self.store.get_branch_node(parent_key, n)?;
let sibling_node = if curr_key & 1 == 1 {
parent_node.left
} else {
parent_node.right
};
path.push(sibling_node.into());
curr_key >>= 1;
let mut path = Vec::with_capacity(index.depth() as usize);
for _ in 0..index.depth() {
let is_right = index.is_value_odd();
index.move_up();
let BranchNode { left, right } = self.store.get_branch_node(&index)?;
let value = if is_right { left } else { right };
path.push(*value);
}
Ok(path.into())
}
@@ -134,7 +130,7 @@ impl SimpleSmt {
/// Returns an error if:
/// * The specified key does not exist as a leaf node.
pub fn get_leaf_path(&self, key: u64) -> Result<MerklePath, MerkleError> {
self.get_path(self.depth(), key)
self.get_path(NodeIndex::new(self.depth(), key))
}
/// Replaces the leaf located at the specified key, and recomputes hashes by walking up the tree
@@ -143,7 +139,7 @@ impl SimpleSmt {
/// Returns an error if the specified key is not a valid leaf index for this tree.
pub fn update_leaf(&mut self, key: u64, value: Word) -> Result<(), MerkleError> {
if !self.store.check_leaf_node_exists(key) {
return Err(MerkleError::InvalidIndex(self.depth(), key));
return Err(MerkleError::InvalidIndex(NodeIndex::new(self.depth(), key)));
}
self.insert_leaf(key, value)?;
@@ -154,27 +150,25 @@ impl SimpleSmt {
pub fn insert_leaf(&mut self, key: u64, value: Word) -> Result<(), MerkleError> {
self.store.insert_leaf_node(key, value);
let depth = self.depth();
let mut curr_key = key;
let mut curr_node: RpoDigest = value.into();
for n in (0..depth).rev() {
let parent_key = curr_key >> 1;
let parent_node = self
// TODO consider using a map `index |-> word` instead of `index |-> (word, word)`
let mut index = NodeIndex::new(self.depth(), key);
let mut value = RpoDigest::from(value);
for _ in 0..index.depth() {
let is_right = index.is_value_odd();
index.move_up();
let BranchNode { left, right } = self
.store
.get_branch_node(parent_key, n)
.unwrap_or_else(|_| self.store.get_empty_node((n + 1) as usize));
let (left, right) = if curr_key & 1 == 1 {
(parent_node.left, curr_node)
.get_branch_node(&index)
.unwrap_or_else(|_| self.store.get_empty_node(index.depth() as usize + 1));
let (left, right) = if is_right {
(left, value)
} else {
(curr_node, parent_node.right)
(value, right)
};
self.store.insert_branch_node(parent_key, n, left, right);
curr_key = parent_key;
curr_node = Rpo256::merge(&[left, right]);
self.store.insert_branch_node(index, left, right);
value = Rpo256::merge(&[left, right]);
}
self.root = curr_node.into();
self.root = value.into();
Ok(())
}
}
@@ -188,10 +182,10 @@ impl SimpleSmt {
/// with the root hash of an empty tree, and ending with the zero value of a leaf node.
#[derive(Debug, Clone, PartialEq, Eq)]
struct Store {
branches: BTreeMap<(u64, u32), BranchNode>,
branches: BTreeMap<NodeIndex, BranchNode>,
leaves: BTreeMap<u64, Word>,
empty_hashes: Vec<RpoDigest>,
depth: u32,
depth: u8,
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
@@ -201,7 +195,7 @@ struct BranchNode {
}
impl Store {
fn new(depth: u32) -> (Self, Word) {
fn new(depth: u8) -> (Self, Word) {
let branches = BTreeMap::new();
let leaves = BTreeMap::new();
@@ -244,23 +238,23 @@ impl Store {
self.leaves
.get(&key)
.cloned()
.ok_or(MerkleError::InvalidIndex(self.depth, key))
.ok_or(MerkleError::InvalidIndex(NodeIndex::new(self.depth, key)))
}
fn insert_leaf_node(&mut self, key: u64, node: Word) {
self.leaves.insert(key, node);
}
fn get_branch_node(&self, key: u64, depth: u32) -> Result<BranchNode, MerkleError> {
fn get_branch_node(&self, index: &NodeIndex) -> Result<BranchNode, MerkleError> {
self.branches
.get(&(key, depth))
.get(index)
.cloned()
.ok_or(MerkleError::InvalidIndex(depth, key))
.ok_or(MerkleError::InvalidIndex(*index))
}
fn insert_branch_node(&mut self, key: u64, depth: u32, left: RpoDigest, right: RpoDigest) {
let node = BranchNode { left, right };
self.branches.insert((key, depth), node);
fn insert_branch_node(&mut self, index: NodeIndex, left: RpoDigest, right: RpoDigest) {
let branch = BranchNode { left, right };
self.branches.insert(index, branch);
}
fn leaves_count(&self) -> usize {

View File

@@ -1,6 +1,6 @@
use super::{
super::{MerkleTree, RpoDigest, SimpleSmt},
Rpo256, Vec, Word,
NodeIndex, Rpo256, Vec, Word,
};
use crate::{Felt, FieldElement};
use core::iter;
@@ -62,7 +62,10 @@ fn build_sparse_tree() {
.expect("Failed to insert leaf");
let mt2 = MerkleTree::new(values.clone()).unwrap();
assert_eq!(mt2.root(), smt.root());
assert_eq!(mt2.get_path(3, 6).unwrap(), smt.get_path(3, 6).unwrap());
assert_eq!(
mt2.get_path(NodeIndex::new(3, 6)).unwrap(),
smt.get_path(NodeIndex::new(3, 6)).unwrap()
);
// insert second value at distinct leaf branch
let key = 2;
@@ -72,7 +75,10 @@ fn build_sparse_tree() {
.expect("Failed to insert leaf");
let mt3 = MerkleTree::new(values).unwrap();
assert_eq!(mt3.root(), smt.root());
assert_eq!(mt3.get_path(3, 2).unwrap(), smt.get_path(3, 2).unwrap());
assert_eq!(
mt3.get_path(NodeIndex::new(3, 2)).unwrap(),
smt.get_path(NodeIndex::new(3, 2)).unwrap()
);
}
#[test]
@@ -81,8 +87,8 @@ fn build_full_tree() {
let (root, node2, node3) = compute_internal_nodes();
assert_eq!(root, tree.root());
assert_eq!(node2, tree.get_node(1, 0).unwrap());
assert_eq!(node3, tree.get_node(1, 1).unwrap());
assert_eq!(node2, tree.get_node(&NodeIndex::new(1, 0)).unwrap());
assert_eq!(node3, tree.get_node(&NodeIndex::new(1, 1)).unwrap());
}
#[test]
@@ -90,10 +96,10 @@ fn get_values() {
let tree = SimpleSmt::new(KEYS4.into_iter().zip(VALUES4.into_iter()), 2).unwrap();
// check depth 2
assert_eq!(VALUES4[0], tree.get_node(2, 0).unwrap());
assert_eq!(VALUES4[1], tree.get_node(2, 1).unwrap());
assert_eq!(VALUES4[2], tree.get_node(2, 2).unwrap());
assert_eq!(VALUES4[3], tree.get_node(2, 3).unwrap());
assert_eq!(VALUES4[0], tree.get_node(&NodeIndex::new(2, 0)).unwrap());
assert_eq!(VALUES4[1], tree.get_node(&NodeIndex::new(2, 1)).unwrap());
assert_eq!(VALUES4[2], tree.get_node(&NodeIndex::new(2, 2)).unwrap());
assert_eq!(VALUES4[3], tree.get_node(&NodeIndex::new(2, 3)).unwrap());
}
#[test]
@@ -103,14 +109,26 @@ fn get_path() {
let (_, node2, node3) = compute_internal_nodes();
// check depth 2
assert_eq!(vec![VALUES4[1], node3], *tree.get_path(2, 0).unwrap());
assert_eq!(vec![VALUES4[0], node3], *tree.get_path(2, 1).unwrap());
assert_eq!(vec![VALUES4[3], node2], *tree.get_path(2, 2).unwrap());
assert_eq!(vec![VALUES4[2], node2], *tree.get_path(2, 3).unwrap());
assert_eq!(
vec![VALUES4[1], node3],
*tree.get_path(NodeIndex::new(2, 0)).unwrap()
);
assert_eq!(
vec![VALUES4[0], node3],
*tree.get_path(NodeIndex::new(2, 1)).unwrap()
);
assert_eq!(
vec![VALUES4[3], node2],
*tree.get_path(NodeIndex::new(2, 2)).unwrap()
);
assert_eq!(
vec![VALUES4[2], node2],
*tree.get_path(NodeIndex::new(2, 3)).unwrap()
);
// check depth 1
assert_eq!(vec![node3], *tree.get_path(1, 0).unwrap());
assert_eq!(vec![node2], *tree.get_path(1, 1).unwrap());
assert_eq!(vec![node3], *tree.get_path(NodeIndex::new(1, 0)).unwrap());
assert_eq!(vec![node2], *tree.get_path(NodeIndex::new(1, 1)).unwrap());
}
#[test]
@@ -175,7 +193,7 @@ fn small_tree_opening_is_consistent() {
assert_eq!(tree.root(), Word::from(k));
let cases: Vec<(u32, u64, Vec<Word>)> = vec![
let cases: Vec<(u8, u64, Vec<Word>)> = vec![
(3, 0, vec![b, f, j]),
(3, 1, vec![a, f, j]),
(3, 4, vec![z, h, i]),
@@ -189,7 +207,7 @@ fn small_tree_opening_is_consistent() {
];
for (depth, key, path) in cases {
let opening = tree.get_path(depth, key).unwrap();
let opening = tree.get_path(NodeIndex::new(depth, key)).unwrap();
assert_eq!(path, *opening);
}
@@ -213,7 +231,7 @@ proptest! {
// traverse to root, fetching all paths
for d in 1..depth {
let k = key >> (depth - d);
tree.get_path(d, k).unwrap();
tree.get_path(NodeIndex::new(d, k)).unwrap();
}
}