mirror of
https://github.com/arnaucube/miden-crypto.git
synced 2026-01-11 16:41:29 +01:00
feat: add merkle node index
This commit introduces a wrapper structure to encapsulate the merkle tree traversal. related issue: #36
This commit is contained in:
114
src/merkle/index.rs
Normal file
114
src/merkle/index.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use super::RpoDigest;
|
||||
|
||||
// NODE INDEX
|
||||
// ================================================================================================
|
||||
|
||||
/// A Merkle tree address to an arbitrary node.
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
|
||||
pub struct NodeIndex {
|
||||
depth: u8,
|
||||
value: u64,
|
||||
}
|
||||
|
||||
impl NodeIndex {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Creates a new node index.
|
||||
pub const fn new(depth: u8, value: u64) -> Self {
|
||||
Self { depth, value }
|
||||
}
|
||||
|
||||
/// Creates a new node index pointing to the root of the tree.
|
||||
pub const fn root() -> Self {
|
||||
Self { depth: 0, value: 0 }
|
||||
}
|
||||
|
||||
/// Mutates the instance and returns it, replacing the depth.
|
||||
pub const fn with_depth(mut self, depth: u8) -> Self {
|
||||
self.depth = depth;
|
||||
self
|
||||
}
|
||||
|
||||
/// Computes the value of the sibling of the current node.
|
||||
pub fn sibling(mut self) -> Self {
|
||||
self.value ^= 1;
|
||||
self
|
||||
}
|
||||
|
||||
// PROVIDERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Builds a node to be used as input of a hash function when computing a Merkle path.
|
||||
///
|
||||
/// Will evaluate the parity of the current instance to define the result.
|
||||
pub const fn build_node(&self, slf: RpoDigest, sibling: RpoDigest) -> [RpoDigest; 2] {
|
||||
if self.is_value_odd() {
|
||||
[sibling, slf]
|
||||
} else {
|
||||
[slf, sibling]
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the scalar representation of the depth/value pair.
|
||||
///
|
||||
/// It is computed as `2^depth + value`.
|
||||
pub const fn to_scalar_index(&self) -> u64 {
|
||||
(1 << self.depth as u64) + self.value
|
||||
}
|
||||
|
||||
/// Returns the depth of the current instance.
|
||||
pub const fn depth(&self) -> u8 {
|
||||
self.depth
|
||||
}
|
||||
|
||||
/// Returns the value of the current depth.
|
||||
pub const fn value(&self) -> u64 {
|
||||
self.value
|
||||
}
|
||||
|
||||
/// Returns true if the current value fits the current depth for a binary tree.
|
||||
pub const fn is_valid(&self) -> bool {
|
||||
self.value < (1 << self.depth as u64)
|
||||
}
|
||||
|
||||
/// Returns true if the current instance points to a right sibling node.
|
||||
pub const fn is_value_odd(&self) -> bool {
|
||||
(self.value & 1) == 1
|
||||
}
|
||||
|
||||
/// Returns `true` if the depth is `0`.
|
||||
pub const fn is_root(&self) -> bool {
|
||||
self.depth == 0
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Traverse one level towards the root, decrementing the depth by `1`.
|
||||
pub fn move_up(&mut self) -> &mut Self {
|
||||
self.depth = self.depth.saturating_sub(1);
|
||||
self.value >>= 1;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn arbitrary_index_wont_panic_on_move_up(
|
||||
depth in prop::num::u8::ANY,
|
||||
value in prop::num::u64::ANY,
|
||||
count in prop::num::u8::ANY,
|
||||
) {
|
||||
let mut index = NodeIndex::new(depth, value);
|
||||
for _ in 0..count {
|
||||
index.move_up();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::{Felt, MerkleError, MerklePath, Rpo256, RpoDigest, Vec, Word};
|
||||
use super::{Felt, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word};
|
||||
use crate::{utils::uninit_vector, FieldElement};
|
||||
use core::slice;
|
||||
use winter_math::log2;
|
||||
@@ -22,7 +22,7 @@ impl MerkleTree {
|
||||
pub fn new(leaves: Vec<Word>) -> Result<Self, MerkleError> {
|
||||
let n = leaves.len();
|
||||
if n <= 1 {
|
||||
return Err(MerkleError::DepthTooSmall(n as u32));
|
||||
return Err(MerkleError::DepthTooSmall(n as u8));
|
||||
} else if !n.is_power_of_two() {
|
||||
return Err(MerkleError::NumLeavesNotPowerOfTwo(n));
|
||||
}
|
||||
@@ -35,12 +35,14 @@ impl MerkleTree {
|
||||
nodes[n..].copy_from_slice(&leaves);
|
||||
|
||||
// re-interpret nodes as an array of two nodes fused together
|
||||
let two_nodes =
|
||||
unsafe { slice::from_raw_parts(nodes.as_ptr() as *const [RpoDigest; 2], n) };
|
||||
// Safety: `nodes` will never move here as it is not bound to an external lifetime (i.e.
|
||||
// `self`).
|
||||
let ptr = nodes.as_ptr() as *const [RpoDigest; 2];
|
||||
let pairs = unsafe { slice::from_raw_parts(ptr, n) };
|
||||
|
||||
// calculate all internal tree nodes
|
||||
for i in (1..n).rev() {
|
||||
nodes[i] = Rpo256::merge(&two_nodes[i]).into();
|
||||
nodes[i] = Rpo256::merge(&pairs[i]).into();
|
||||
}
|
||||
|
||||
Ok(Self { nodes })
|
||||
@@ -57,53 +59,53 @@ impl MerkleTree {
|
||||
/// Returns the depth of this Merkle tree.
|
||||
///
|
||||
/// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc.
|
||||
pub fn depth(&self) -> u32 {
|
||||
log2(self.nodes.len() / 2)
|
||||
pub fn depth(&self) -> u8 {
|
||||
log2(self.nodes.len() / 2) as u8
|
||||
}
|
||||
|
||||
/// Returns a node at the specified depth and index.
|
||||
/// Returns a node at the specified depth and index value.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// * The specified depth is greater than the depth of the tree.
|
||||
/// * The specified index not valid for the specified depth.
|
||||
pub fn get_node(&self, depth: u32, index: u64) -> Result<Word, MerkleError> {
|
||||
if depth == 0 {
|
||||
return Err(MerkleError::DepthTooSmall(depth));
|
||||
} else if depth > self.depth() {
|
||||
return Err(MerkleError::DepthTooBig(depth));
|
||||
}
|
||||
if index >= 2u64.pow(depth) {
|
||||
return Err(MerkleError::InvalidIndex(depth, index));
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
} else if index.depth() > self.depth() {
|
||||
return Err(MerkleError::DepthTooBig(index.depth()));
|
||||
} else if !index.is_valid() {
|
||||
return Err(MerkleError::InvalidIndex(index));
|
||||
}
|
||||
|
||||
let pos = 2_usize.pow(depth) + (index as usize);
|
||||
let pos = index.to_scalar_index() as usize;
|
||||
Ok(self.nodes[pos])
|
||||
}
|
||||
|
||||
/// Returns a Merkle path to the node at the specified depth and index. The node itself is
|
||||
/// not included in the path.
|
||||
/// Returns a Merkle path to the node at the specified depth and index value. The node itself
|
||||
/// is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// * The specified depth is greater than the depth of the tree.
|
||||
/// * The specified index not valid for the specified depth.
|
||||
pub fn get_path(&self, depth: u32, index: u64) -> Result<MerklePath, MerkleError> {
|
||||
if depth == 0 {
|
||||
return Err(MerkleError::DepthTooSmall(depth));
|
||||
} else if depth > self.depth() {
|
||||
return Err(MerkleError::DepthTooBig(depth));
|
||||
}
|
||||
if index >= 2u64.pow(depth) {
|
||||
return Err(MerkleError::InvalidIndex(depth, index));
|
||||
/// * The specified value not valid for the specified depth.
|
||||
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
} else if index.depth() > self.depth() {
|
||||
return Err(MerkleError::DepthTooBig(index.depth()));
|
||||
} else if !index.is_valid() {
|
||||
return Err(MerkleError::InvalidIndex(index));
|
||||
}
|
||||
|
||||
let mut path = Vec::with_capacity(depth as usize);
|
||||
let mut pos = 2_usize.pow(depth) + (index as usize);
|
||||
|
||||
while pos > 1 {
|
||||
path.push(self.nodes[pos ^ 1]);
|
||||
pos >>= 1;
|
||||
// TODO should we create a helper in `NodeIndex` that will encapsulate traversal to root so
|
||||
// we always use inlined `for` instead of `while`? the reason to use `for` is because its
|
||||
// easier for the compiler to vectorize.
|
||||
let mut path = Vec::with_capacity(index.depth() as usize);
|
||||
for _ in 0..index.depth() {
|
||||
let sibling = index.sibling().to_scalar_index() as usize;
|
||||
path.push(self.nodes[sibling]);
|
||||
index.move_up();
|
||||
}
|
||||
|
||||
Ok(path.into())
|
||||
@@ -112,23 +114,38 @@ impl MerkleTree {
|
||||
/// Replaces the leaf at the specified index with the provided value.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified index is not a valid leaf index for this tree.
|
||||
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<(), MerkleError> {
|
||||
/// Returns an error if the specified index value is not a valid leaf value for this tree.
|
||||
pub fn update_leaf<'a>(&'a mut self, index_value: u64, value: Word) -> Result<(), MerkleError> {
|
||||
let depth = self.depth();
|
||||
if index >= 2u64.pow(depth) {
|
||||
return Err(MerkleError::InvalidIndex(depth, index));
|
||||
let mut index = NodeIndex::new(depth, index_value);
|
||||
if !index.is_valid() {
|
||||
return Err(MerkleError::InvalidIndex(index));
|
||||
}
|
||||
|
||||
let mut index = 2usize.pow(depth) + index as usize;
|
||||
self.nodes[index] = value;
|
||||
|
||||
// we don't need to copy the pairs into a new address as we are logically guaranteed to not
|
||||
// overlap write instructions. however, it's important to bind the lifetime of pairs to
|
||||
// `self.nodes` so the compiler will never move one without moving the other.
|
||||
debug_assert_eq!(self.nodes.len() & 1, 0);
|
||||
let n = self.nodes.len() / 2;
|
||||
let two_nodes =
|
||||
unsafe { slice::from_raw_parts(self.nodes.as_ptr() as *const [RpoDigest; 2], n) };
|
||||
|
||||
for _ in 0..depth {
|
||||
index /= 2;
|
||||
self.nodes[index] = Rpo256::merge(&two_nodes[index]).into();
|
||||
// Safety: the length of nodes is guaranteed to contain pairs of words; hence, pairs of
|
||||
// digests. we explicitly bind the lifetime here so we add an extra layer of guarantee that
|
||||
// `self.nodes` will be moved only if `pairs` is moved as well. also, the algorithm is
|
||||
// logically guaranteed to not overlap write positions as the write index is always half
|
||||
// the index from which we read the digest input.
|
||||
let ptr = self.nodes.as_ptr() as *const [RpoDigest; 2];
|
||||
let pairs: &'a [[RpoDigest; 2]] = unsafe { slice::from_raw_parts(ptr, n) };
|
||||
|
||||
// update the current node
|
||||
let pos = index.to_scalar_index() as usize;
|
||||
self.nodes[pos] = value;
|
||||
|
||||
// traverse to the root, updating each node with the merged values of its parents
|
||||
for _ in 0..index.depth() {
|
||||
index.move_up();
|
||||
let pos = index.to_scalar_index() as usize;
|
||||
let value = Rpo256::merge(&pairs[pos]).into();
|
||||
self.nodes[pos] = value;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -140,10 +157,10 @@ impl MerkleTree {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
super::{int_to_node, Rpo256},
|
||||
Word,
|
||||
};
|
||||
use super::*;
|
||||
use crate::merkle::int_to_node;
|
||||
use core::mem::size_of;
|
||||
use proptest::prelude::*;
|
||||
|
||||
const LEAVES4: [Word; 4] = [
|
||||
int_to_node(1),
|
||||
@@ -187,16 +204,16 @@ mod tests {
|
||||
let tree = super::MerkleTree::new(LEAVES4.to_vec()).unwrap();
|
||||
|
||||
// check depth 2
|
||||
assert_eq!(LEAVES4[0], tree.get_node(2, 0).unwrap());
|
||||
assert_eq!(LEAVES4[1], tree.get_node(2, 1).unwrap());
|
||||
assert_eq!(LEAVES4[2], tree.get_node(2, 2).unwrap());
|
||||
assert_eq!(LEAVES4[3], tree.get_node(2, 3).unwrap());
|
||||
assert_eq!(LEAVES4[0], tree.get_node(NodeIndex::new(2, 0)).unwrap());
|
||||
assert_eq!(LEAVES4[1], tree.get_node(NodeIndex::new(2, 1)).unwrap());
|
||||
assert_eq!(LEAVES4[2], tree.get_node(NodeIndex::new(2, 2)).unwrap());
|
||||
assert_eq!(LEAVES4[3], tree.get_node(NodeIndex::new(2, 3)).unwrap());
|
||||
|
||||
// check depth 1
|
||||
let (_, node2, node3) = compute_internal_nodes();
|
||||
|
||||
assert_eq!(node2, tree.get_node(1, 0).unwrap());
|
||||
assert_eq!(node3, tree.get_node(1, 1).unwrap());
|
||||
assert_eq!(node2, tree.get_node(NodeIndex::new(1, 0)).unwrap());
|
||||
assert_eq!(node3, tree.get_node(NodeIndex::new(1, 1)).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -206,14 +223,26 @@ mod tests {
|
||||
let (_, node2, node3) = compute_internal_nodes();
|
||||
|
||||
// check depth 2
|
||||
assert_eq!(vec![LEAVES4[1], node3], *tree.get_path(2, 0).unwrap());
|
||||
assert_eq!(vec![LEAVES4[0], node3], *tree.get_path(2, 1).unwrap());
|
||||
assert_eq!(vec![LEAVES4[3], node2], *tree.get_path(2, 2).unwrap());
|
||||
assert_eq!(vec![LEAVES4[2], node2], *tree.get_path(2, 3).unwrap());
|
||||
assert_eq!(
|
||||
vec![LEAVES4[1], node3],
|
||||
*tree.get_path(NodeIndex::new(2, 0)).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
vec![LEAVES4[0], node3],
|
||||
*tree.get_path(NodeIndex::new(2, 1)).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
vec![LEAVES4[3], node2],
|
||||
*tree.get_path(NodeIndex::new(2, 2)).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
vec![LEAVES4[2], node2],
|
||||
*tree.get_path(NodeIndex::new(2, 3)).unwrap()
|
||||
);
|
||||
|
||||
// check depth 1
|
||||
assert_eq!(vec![node3], *tree.get_path(1, 0).unwrap());
|
||||
assert_eq!(vec![node2], *tree.get_path(1, 1).unwrap());
|
||||
assert_eq!(vec![node3], *tree.get_path(NodeIndex::new(1, 0)).unwrap());
|
||||
assert_eq!(vec![node2], *tree.get_path(NodeIndex::new(1, 1)).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -221,25 +250,53 @@ mod tests {
|
||||
let mut tree = super::MerkleTree::new(LEAVES8.to_vec()).unwrap();
|
||||
|
||||
// update one leaf
|
||||
let index = 3;
|
||||
let value = 3;
|
||||
let new_node = int_to_node(9);
|
||||
let mut expected_leaves = LEAVES8.to_vec();
|
||||
expected_leaves[index as usize] = new_node;
|
||||
expected_leaves[value as usize] = new_node;
|
||||
let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
|
||||
|
||||
tree.update_leaf(index, new_node).unwrap();
|
||||
tree.update_leaf(value, new_node).unwrap();
|
||||
assert_eq!(expected_tree.nodes, tree.nodes);
|
||||
|
||||
// update another leaf
|
||||
let index = 6;
|
||||
let value = 6;
|
||||
let new_node = int_to_node(10);
|
||||
expected_leaves[index as usize] = new_node;
|
||||
expected_leaves[value as usize] = new_node;
|
||||
let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
|
||||
|
||||
tree.update_leaf(index, new_node).unwrap();
|
||||
tree.update_leaf(value, new_node).unwrap();
|
||||
assert_eq!(expected_tree.nodes, tree.nodes);
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn arbitrary_word_can_be_represented_as_digest(
|
||||
a in prop::num::u64::ANY,
|
||||
b in prop::num::u64::ANY,
|
||||
c in prop::num::u64::ANY,
|
||||
d in prop::num::u64::ANY,
|
||||
) {
|
||||
// this test will assert the memory equivalence between word and digest.
|
||||
// it is used to safeguard the `[MerkleTee::update_leaf]` implementation
|
||||
// that assumes this equivalence.
|
||||
|
||||
// build a word and copy it to another address as digest
|
||||
let word = [Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)];
|
||||
let digest = RpoDigest::from(word);
|
||||
|
||||
// assert the addresses are different
|
||||
let word_ptr = (&word).as_ptr() as *const u8;
|
||||
let digest_ptr = (&digest).as_ptr() as *const u8;
|
||||
assert_ne!(word_ptr, digest_ptr);
|
||||
|
||||
// compare the bytes representation
|
||||
let word_bytes = unsafe { slice::from_raw_parts(word_ptr, size_of::<Word>()) };
|
||||
let digest_bytes = unsafe { slice::from_raw_parts(digest_ptr, size_of::<RpoDigest>()) };
|
||||
assert_eq!(word_bytes, digest_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -5,6 +5,9 @@ use super::{
|
||||
};
|
||||
use core::fmt;
|
||||
|
||||
mod index;
|
||||
pub use index::NodeIndex;
|
||||
|
||||
mod merkle_tree;
|
||||
pub use merkle_tree::MerkleTree;
|
||||
|
||||
@@ -22,11 +25,11 @@ pub use simple_smt::SimpleSmt;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum MerkleError {
|
||||
DepthTooSmall(u32),
|
||||
DepthTooBig(u32),
|
||||
DepthTooSmall(u8),
|
||||
DepthTooBig(u8),
|
||||
NumLeavesNotPowerOfTwo(usize),
|
||||
InvalidIndex(u32, u64),
|
||||
InvalidDepth(u32, u32),
|
||||
InvalidIndex(NodeIndex),
|
||||
InvalidDepth { expected: u8, provided: u8 },
|
||||
InvalidPath(MerklePath),
|
||||
InvalidEntriesCount(usize, usize),
|
||||
NodeNotInSet(u64),
|
||||
@@ -41,11 +44,11 @@ impl fmt::Display for MerkleError {
|
||||
NumLeavesNotPowerOfTwo(leaves) => {
|
||||
write!(f, "the leaves count {leaves} is not a power of 2")
|
||||
}
|
||||
InvalidIndex(depth, index) => write!(
|
||||
InvalidIndex(index) => write!(
|
||||
f,
|
||||
"the leaf index {index} is not valid for the depth {depth}"
|
||||
"the index value {} is not valid for the depth {}", index.value(), index.depth()
|
||||
),
|
||||
InvalidDepth(expected, provided) => write!(
|
||||
InvalidDepth { expected, provided } => write!(
|
||||
f,
|
||||
"the provided depth {provided} is not valid for {expected}"
|
||||
),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::{vec, Rpo256, Vec, Word};
|
||||
use super::{vec, NodeIndex, Rpo256, Vec, Word};
|
||||
use core::ops::{Deref, DerefMut};
|
||||
|
||||
// MERKLE PATH
|
||||
@@ -23,17 +23,12 @@ impl MerklePath {
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Computes the merkle root for this opening.
|
||||
pub fn compute_root(&self, mut index: u64, node: Word) -> Word {
|
||||
pub fn compute_root(&self, index_value: u64, node: Word) -> Word {
|
||||
let mut index = NodeIndex::new(self.depth(), index_value);
|
||||
self.nodes.iter().copied().fold(node, |node, sibling| {
|
||||
// build the input node, considering the parity of the current index.
|
||||
let is_right_sibling = (index & 1) == 1;
|
||||
let input = if is_right_sibling {
|
||||
[sibling.into(), node.into()]
|
||||
} else {
|
||||
[node.into(), sibling.into()]
|
||||
};
|
||||
// compute the node and move to the next iteration.
|
||||
index >>= 1;
|
||||
let input = index.build_node(node.into(), sibling.into());
|
||||
index.move_up();
|
||||
Rpo256::merge(&input).into()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::{BTreeMap, MerkleError, MerklePath, Rpo256, Vec, Word, ZERO};
|
||||
use super::{BTreeMap, MerkleError, MerklePath, NodeIndex, Rpo256, Vec, Word, ZERO};
|
||||
|
||||
// MERKLE PATH SET
|
||||
// ================================================================================================
|
||||
@@ -7,7 +7,7 @@ use super::{BTreeMap, MerkleError, MerklePath, Rpo256, Vec, Word, ZERO};
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct MerklePathSet {
|
||||
root: Word,
|
||||
total_depth: u32,
|
||||
total_depth: u8,
|
||||
paths: BTreeMap<u64, MerklePath>,
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ impl MerklePathSet {
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an empty MerklePathSet.
|
||||
pub fn new(depth: u32) -> Result<Self, MerkleError> {
|
||||
pub fn new(depth: u8) -> Result<Self, MerkleError> {
|
||||
let root = [ZERO; 4];
|
||||
let paths = BTreeMap::new();
|
||||
|
||||
@@ -38,7 +38,7 @@ impl MerklePathSet {
|
||||
/// Returns the depth of the Merkle tree implied by the paths stored in this set.
|
||||
///
|
||||
/// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc.
|
||||
pub const fn depth(&self) -> u32 {
|
||||
pub const fn depth(&self) -> u8 {
|
||||
self.total_depth
|
||||
}
|
||||
|
||||
@@ -48,27 +48,26 @@ impl MerklePathSet {
|
||||
/// Returns an error if:
|
||||
/// * The specified index not valid for the depth of structure.
|
||||
/// * Requested node does not exist in the set.
|
||||
pub fn get_node(&self, depth: u32, index: u64) -> Result<Word, MerkleError> {
|
||||
if index >= 2u64.pow(self.total_depth) {
|
||||
return Err(MerkleError::InvalidIndex(self.total_depth, index));
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
|
||||
if !index.with_depth(self.total_depth).is_valid() {
|
||||
return Err(MerkleError::InvalidIndex(
|
||||
index.with_depth(self.total_depth),
|
||||
));
|
||||
}
|
||||
if depth != self.total_depth {
|
||||
return Err(MerkleError::InvalidDepth(self.total_depth, depth));
|
||||
if index.depth() != self.total_depth {
|
||||
return Err(MerkleError::InvalidDepth {
|
||||
expected: self.total_depth,
|
||||
provided: index.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
let pos = 2u64.pow(depth) + index;
|
||||
let index = pos / 2;
|
||||
|
||||
match self.paths.get(&index) {
|
||||
None => Err(MerkleError::NodeNotInSet(index)),
|
||||
Some(path) => {
|
||||
if Self::is_even(pos) {
|
||||
Ok(path[0])
|
||||
} else {
|
||||
Ok(path[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
let index_value = index.to_scalar_index();
|
||||
let parity = index_value & 1;
|
||||
let index_value = index_value / 2;
|
||||
self.paths
|
||||
.get(&index_value)
|
||||
.ok_or(MerkleError::NodeNotInSet(index_value))
|
||||
.map(|path| path[parity as usize])
|
||||
}
|
||||
|
||||
/// Returns a Merkle path to the node at the specified index. The node itself is
|
||||
@@ -78,30 +77,27 @@ impl MerklePathSet {
|
||||
/// Returns an error if:
|
||||
/// * The specified index not valid for the depth of structure.
|
||||
/// * Node of the requested path does not exist in the set.
|
||||
pub fn get_path(&self, depth: u32, index: u64) -> Result<MerklePath, MerkleError> {
|
||||
if index >= 2u64.pow(self.total_depth) {
|
||||
return Err(MerkleError::InvalidIndex(self.total_depth, index));
|
||||
pub fn get_path(&self, index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
if !index.with_depth(self.total_depth).is_valid() {
|
||||
return Err(MerkleError::InvalidIndex(index));
|
||||
}
|
||||
if depth != self.total_depth {
|
||||
return Err(MerkleError::InvalidDepth(self.total_depth, depth));
|
||||
if index.depth() != self.total_depth {
|
||||
return Err(MerkleError::InvalidDepth {
|
||||
expected: self.total_depth,
|
||||
provided: index.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
let pos = 2u64.pow(depth) + index;
|
||||
let index = pos / 2;
|
||||
|
||||
match self.paths.get(&index) {
|
||||
None => Err(MerkleError::NodeNotInSet(index)),
|
||||
Some(path) => {
|
||||
let mut local_path = path.clone();
|
||||
if Self::is_even(pos) {
|
||||
local_path.remove(0);
|
||||
Ok(local_path)
|
||||
} else {
|
||||
local_path.remove(1);
|
||||
Ok(local_path)
|
||||
}
|
||||
}
|
||||
}
|
||||
let index_value = index.to_scalar_index();
|
||||
let index = index_value / 2;
|
||||
let parity = index_value & 1;
|
||||
let mut path = self
|
||||
.paths
|
||||
.get(&index)
|
||||
.cloned()
|
||||
.ok_or(MerkleError::NodeNotInSet(index))?;
|
||||
path.remove(parity as usize);
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
@@ -118,36 +114,41 @@ impl MerklePathSet {
|
||||
/// different root).
|
||||
pub fn add_path(
|
||||
&mut self,
|
||||
index: u64,
|
||||
index_value: u64,
|
||||
value: Word,
|
||||
path: MerklePath,
|
||||
mut path: MerklePath,
|
||||
) -> Result<(), MerkleError> {
|
||||
let depth = (path.len() + 1) as u32;
|
||||
if depth != self.total_depth {
|
||||
return Err(MerkleError::InvalidDepth(self.total_depth, depth));
|
||||
let depth = (path.len() + 1) as u8;
|
||||
let mut index = NodeIndex::new(depth, index_value);
|
||||
if index.depth() != self.total_depth {
|
||||
return Err(MerkleError::InvalidDepth {
|
||||
expected: self.total_depth,
|
||||
provided: index.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
// Actual number of node in tree
|
||||
let pos = 2u64.pow(self.total_depth) + index;
|
||||
// update the current path
|
||||
let index_value = index.to_scalar_index();
|
||||
let upper_index_value = index_value / 2;
|
||||
let parity = index_value & 1;
|
||||
path.insert(parity as usize, value);
|
||||
|
||||
// Index of the leaf path in map. Paths of neighboring leaves are stored in one key-value pair
|
||||
let half_pos = pos / 2;
|
||||
// traverse to the root, updating the nodes
|
||||
let root: Word = Rpo256::merge(&[path[0].into(), path[1].into()]).into();
|
||||
let root = path.iter().skip(2).copied().fold(root, |root, hash| {
|
||||
index.move_up();
|
||||
Rpo256::merge(&index.build_node(root.into(), hash.into())).into()
|
||||
});
|
||||
|
||||
let mut extended_path = path;
|
||||
if Self::is_even(pos) {
|
||||
extended_path.insert(0, value);
|
||||
} else {
|
||||
extended_path.insert(1, value);
|
||||
}
|
||||
|
||||
let root_of_current_path = Self::compute_path_root(&extended_path, depth, index);
|
||||
// TODO review and document this logic
|
||||
if self.root == [ZERO; 4] {
|
||||
self.root = root_of_current_path;
|
||||
} else if self.root != root_of_current_path {
|
||||
return Err(MerkleError::InvalidPath(extended_path));
|
||||
self.root = root;
|
||||
} else if self.root != root {
|
||||
return Err(MerkleError::InvalidPath(path));
|
||||
}
|
||||
self.paths.insert(half_pos, extended_path);
|
||||
|
||||
// finish updating the path
|
||||
self.paths.insert(upper_index_value, path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -156,29 +157,44 @@ impl MerklePathSet {
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// * Requested node does not exist in the set.
|
||||
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<(), MerkleError> {
|
||||
pub fn update_leaf(&mut self, base_index_value: u64, value: Word) -> Result<(), MerkleError> {
|
||||
let depth = self.depth();
|
||||
if index >= 2u64.pow(depth) {
|
||||
return Err(MerkleError::InvalidIndex(depth, index));
|
||||
let mut index = NodeIndex::new(depth, base_index_value);
|
||||
if !index.is_valid() {
|
||||
return Err(MerkleError::InvalidIndex(index));
|
||||
}
|
||||
let pos = 2u64.pow(depth) + index;
|
||||
|
||||
let path = match self.paths.get_mut(&(pos / 2)) {
|
||||
None => return Err(MerkleError::NodeNotInSet(index)),
|
||||
let path = match self
|
||||
.paths
|
||||
.get_mut(&index.clone().move_up().to_scalar_index())
|
||||
{
|
||||
Some(path) => path,
|
||||
None => return Err(MerkleError::NodeNotInSet(base_index_value)),
|
||||
};
|
||||
|
||||
// Fill old_hashes vector -----------------------------------------------------------------
|
||||
let (old_hashes, _) = Self::compute_path_trace(path, depth, index);
|
||||
|
||||
// Fill new_hashes vector -----------------------------------------------------------------
|
||||
if Self::is_even(pos) {
|
||||
path[0] = value;
|
||||
} else {
|
||||
path[1] = value;
|
||||
let mut current_index = index;
|
||||
let mut old_hashes = Vec::with_capacity(path.len().saturating_sub(2));
|
||||
let mut root: Word = Rpo256::merge(&[path[0].into(), path[1].into()]).into();
|
||||
for hash in path.iter().skip(2).copied() {
|
||||
old_hashes.push(root);
|
||||
current_index.move_up();
|
||||
let input = current_index.build_node(hash.into(), root.into());
|
||||
root = Rpo256::merge(&input).into();
|
||||
}
|
||||
|
||||
// Fill new_hashes vector -----------------------------------------------------------------
|
||||
path[index.is_value_odd() as usize] = value;
|
||||
|
||||
let mut new_hashes = Vec::with_capacity(path.len().saturating_sub(2));
|
||||
let mut new_root: Word = Rpo256::merge(&[path[0].into(), path[1].into()]).into();
|
||||
for path_hash in path.iter().skip(2).copied() {
|
||||
new_hashes.push(new_root);
|
||||
index.move_up();
|
||||
let input = current_index.build_node(path_hash.into(), new_root.into());
|
||||
new_root = Rpo256::merge(&input).into();
|
||||
}
|
||||
|
||||
let (new_hashes, new_root) = Self::compute_path_trace(path, depth, index);
|
||||
self.root = new_root;
|
||||
|
||||
// update paths ---------------------------------------------------------------------------
|
||||
@@ -193,59 +209,6 @@ impl MerklePathSet {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
const fn is_even(pos: u64) -> bool {
|
||||
pos & 1 == 0
|
||||
}
|
||||
|
||||
/// Returns hash of the root
|
||||
fn compute_path_root(path: &[Word], depth: u32, index: u64) -> Word {
|
||||
let mut pos = 2u64.pow(depth) + index;
|
||||
|
||||
// hash that is obtained after calculating the current hash and path hash
|
||||
let mut comp_hash = Rpo256::merge(&[path[0].into(), path[1].into()]).into();
|
||||
|
||||
for path_hash in path.iter().skip(2) {
|
||||
pos /= 2;
|
||||
comp_hash = Self::calculate_parent_hash(comp_hash, pos, *path_hash);
|
||||
}
|
||||
|
||||
comp_hash
|
||||
}
|
||||
|
||||
/// Calculates the hash of the parent node by two sibling ones
|
||||
/// - node — current node
|
||||
/// - node_pos — position of the current node
|
||||
/// - sibling — neighboring vertex in the tree
|
||||
fn calculate_parent_hash(node: Word, node_pos: u64, sibling: Word) -> Word {
|
||||
if Self::is_even(node_pos) {
|
||||
Rpo256::merge(&[node.into(), sibling.into()]).into()
|
||||
} else {
|
||||
Rpo256::merge(&[sibling.into(), node.into()]).into()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns vector of hashes from current to the root
|
||||
fn compute_path_trace(path: &[Word], depth: u32, index: u64) -> (MerklePath, Word) {
|
||||
let mut pos = 2u64.pow(depth) + index;
|
||||
|
||||
let mut computed_hashes = Vec::<Word>::new();
|
||||
|
||||
let mut comp_hash = Rpo256::merge(&[path[0].into(), path[1].into()]).into();
|
||||
|
||||
if path.len() != 2 {
|
||||
for path_hash in path.iter().skip(2) {
|
||||
computed_hashes.push(comp_hash);
|
||||
pos /= 2;
|
||||
comp_hash = Self::calculate_parent_hash(comp_hash, pos, *path_hash);
|
||||
}
|
||||
}
|
||||
|
||||
(computed_hashes.into(), comp_hash)
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
@@ -263,10 +226,10 @@ mod tests {
|
||||
let leaf2 = int_to_node(2);
|
||||
let leaf3 = int_to_node(3);
|
||||
|
||||
let parent0 = MerklePathSet::calculate_parent_hash(leaf0, 0, leaf1);
|
||||
let parent1 = MerklePathSet::calculate_parent_hash(leaf2, 2, leaf3);
|
||||
let parent0 = calculate_parent_hash(leaf0, 0, leaf1);
|
||||
let parent1 = calculate_parent_hash(leaf2, 2, leaf3);
|
||||
|
||||
let root_exp = MerklePathSet::calculate_parent_hash(parent0, 0, parent1);
|
||||
let root_exp = calculate_parent_hash(parent0, 0, parent1);
|
||||
|
||||
let mut set = super::MerklePathSet::new(3).unwrap();
|
||||
|
||||
@@ -279,29 +242,32 @@ mod tests {
|
||||
fn add_and_get_path() {
|
||||
let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)];
|
||||
let hash_6 = int_to_node(6);
|
||||
let index = 6u64;
|
||||
let depth = 4u32;
|
||||
let index = 6_u64;
|
||||
let depth = 4_u8;
|
||||
let mut set = super::MerklePathSet::new(depth).unwrap();
|
||||
|
||||
set.add_path(index, hash_6, path_6.clone().into()).unwrap();
|
||||
let stored_path_6 = set.get_path(depth, index).unwrap();
|
||||
let stored_path_6 = set.get_path(NodeIndex::new(depth, index)).unwrap();
|
||||
|
||||
assert_eq!(path_6, *stored_path_6);
|
||||
assert!(set.get_path(depth, 15u64).is_err())
|
||||
assert!(set.get_path(NodeIndex::new(depth, 15_u64)).is_err())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_node() {
|
||||
let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)];
|
||||
let hash_6 = int_to_node(6);
|
||||
let index = 6u64;
|
||||
let depth = 4u32;
|
||||
let mut set = super::MerklePathSet::new(depth).unwrap();
|
||||
let index = 6_u64;
|
||||
let depth = 4_u8;
|
||||
let mut set = MerklePathSet::new(depth).unwrap();
|
||||
|
||||
set.add_path(index, hash_6, path_6.into()).unwrap();
|
||||
|
||||
assert_eq!(int_to_node(6u64), set.get_node(depth, index).unwrap());
|
||||
assert!(set.get_node(depth, 15u64).is_err());
|
||||
assert_eq!(
|
||||
int_to_node(6u64),
|
||||
set.get_node(NodeIndex::new(depth, index)).unwrap()
|
||||
);
|
||||
assert!(set.get_node(NodeIndex::new(depth, 15_u64)).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -310,8 +276,8 @@ mod tests {
|
||||
let hash_5 = int_to_node(5);
|
||||
let hash_6 = int_to_node(6);
|
||||
let hash_7 = int_to_node(7);
|
||||
let hash_45 = MerklePathSet::calculate_parent_hash(hash_4, 12u64, hash_5);
|
||||
let hash_67 = MerklePathSet::calculate_parent_hash(hash_6, 14u64, hash_7);
|
||||
let hash_45 = calculate_parent_hash(hash_4, 12u64, hash_5);
|
||||
let hash_67 = calculate_parent_hash(hash_6, 14u64, hash_7);
|
||||
|
||||
let hash_0123 = int_to_node(123);
|
||||
|
||||
@@ -319,11 +285,11 @@ mod tests {
|
||||
let path_5 = vec![hash_4, hash_67, hash_0123];
|
||||
let path_4 = vec![hash_5, hash_67, hash_0123];
|
||||
|
||||
let index_6 = 6u64;
|
||||
let index_5 = 5u64;
|
||||
let index_4 = 4u64;
|
||||
let depth = 4u32;
|
||||
let mut set = super::MerklePathSet::new(depth).unwrap();
|
||||
let index_6 = 6_u64;
|
||||
let index_5 = 5_u64;
|
||||
let index_4 = 4_u64;
|
||||
let depth = 4_u8;
|
||||
let mut set = MerklePathSet::new(depth).unwrap();
|
||||
|
||||
set.add_path(index_6, hash_6, path_6.into()).unwrap();
|
||||
set.add_path(index_5, hash_5, path_5.into()).unwrap();
|
||||
@@ -333,15 +299,34 @@ mod tests {
|
||||
let new_hash_5 = int_to_node(55);
|
||||
|
||||
set.update_leaf(index_6, new_hash_6).unwrap();
|
||||
let new_path_4 = set.get_path(depth, index_4).unwrap();
|
||||
let new_hash_67 = MerklePathSet::calculate_parent_hash(new_hash_6, 14u64, hash_7);
|
||||
let new_path_4 = set.get_path(NodeIndex::new(depth, index_4)).unwrap();
|
||||
let new_hash_67 = calculate_parent_hash(new_hash_6, 14_u64, hash_7);
|
||||
assert_eq!(new_hash_67, new_path_4[1]);
|
||||
|
||||
set.update_leaf(index_5, new_hash_5).unwrap();
|
||||
let new_path_4 = set.get_path(depth, index_4).unwrap();
|
||||
let new_path_6 = set.get_path(depth, index_6).unwrap();
|
||||
let new_hash_45 = MerklePathSet::calculate_parent_hash(new_hash_5, 13u64, hash_4);
|
||||
let new_path_4 = set.get_path(NodeIndex::new(depth, index_4)).unwrap();
|
||||
let new_path_6 = set.get_path(NodeIndex::new(depth, index_6)).unwrap();
|
||||
let new_hash_45 = calculate_parent_hash(new_hash_5, 13_u64, hash_4);
|
||||
assert_eq!(new_hash_45, new_path_6[1]);
|
||||
assert_eq!(new_hash_5, new_path_4[0]);
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
const fn is_even(pos: u64) -> bool {
|
||||
pos & 1 == 0
|
||||
}
|
||||
|
||||
/// Calculates the hash of the parent node by two sibling ones
|
||||
/// - node — current node
|
||||
/// - node_pos — position of the current node
|
||||
/// - sibling — neighboring vertex in the tree
|
||||
fn calculate_parent_hash(node: Word, node_pos: u64, sibling: Word) -> Word {
|
||||
if is_even(node_pos) {
|
||||
Rpo256::merge(&[node.into(), sibling.into()]).into()
|
||||
} else {
|
||||
Rpo256::merge(&[sibling.into(), node.into()]).into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::{BTreeMap, MerkleError, MerklePath, Rpo256, RpoDigest, Vec, Word};
|
||||
use super::{BTreeMap, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
@@ -12,7 +12,7 @@ mod tests;
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SimpleSmt {
|
||||
root: Word,
|
||||
depth: u32,
|
||||
depth: u8,
|
||||
store: Store,
|
||||
}
|
||||
|
||||
@@ -21,10 +21,10 @@ impl SimpleSmt {
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Minimum supported depth.
|
||||
pub const MIN_DEPTH: u32 = 1;
|
||||
pub const MIN_DEPTH: u8 = 1;
|
||||
|
||||
/// Maximum supported depth.
|
||||
pub const MAX_DEPTH: u32 = 63;
|
||||
pub const MAX_DEPTH: u8 = 63;
|
||||
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
@@ -37,7 +37,7 @@ impl SimpleSmt {
|
||||
///
|
||||
/// The function will fail if the provided entries count exceed the maximum tree capacity, that
|
||||
/// is `2^{depth}`.
|
||||
pub fn new<R, I>(entries: R, depth: u32) -> Result<Self, MerkleError>
|
||||
pub fn new<R, I>(entries: R, depth: u8) -> Result<Self, MerkleError>
|
||||
where
|
||||
R: IntoIterator<IntoIter = I>,
|
||||
I: Iterator<Item = (u64, Word)> + ExactSizeIterator,
|
||||
@@ -67,7 +67,7 @@ impl SimpleSmt {
|
||||
}
|
||||
|
||||
/// Returns the depth of this Merkle tree.
|
||||
pub const fn depth(&self) -> u32 {
|
||||
pub const fn depth(&self) -> u8 {
|
||||
self.depth
|
||||
}
|
||||
|
||||
@@ -82,15 +82,15 @@ impl SimpleSmt {
|
||||
/// Returns an error if:
|
||||
/// * The specified depth is greater than the depth of the tree.
|
||||
/// * The specified key does not exist
|
||||
pub fn get_node(&self, depth: u32, key: u64) -> Result<Word, MerkleError> {
|
||||
if depth == 0 {
|
||||
Err(MerkleError::DepthTooSmall(depth))
|
||||
} else if depth > self.depth() {
|
||||
Err(MerkleError::DepthTooBig(depth))
|
||||
} else if depth == self.depth() {
|
||||
self.store.get_leaf_node(key)
|
||||
pub fn get_node(&self, index: &NodeIndex) -> Result<Word, MerkleError> {
|
||||
if index.is_root() {
|
||||
Err(MerkleError::DepthTooSmall(index.depth()))
|
||||
} else if index.depth() > self.depth() {
|
||||
Err(MerkleError::DepthTooBig(index.depth()))
|
||||
} else if index.depth() == self.depth() {
|
||||
self.store.get_leaf_node(index.value())
|
||||
} else {
|
||||
let branch_node = self.store.get_branch_node(key, depth)?;
|
||||
let branch_node = self.store.get_branch_node(index)?;
|
||||
Ok(Rpo256::merge(&[branch_node.left, branch_node.right]).into())
|
||||
}
|
||||
}
|
||||
@@ -102,27 +102,23 @@ impl SimpleSmt {
|
||||
/// Returns an error if:
|
||||
/// * The specified key does not exist as a branch or leaf node
|
||||
/// * The specified depth is greater than the depth of the tree.
|
||||
pub fn get_path(&self, depth: u32, key: u64) -> Result<MerklePath, MerkleError> {
|
||||
if depth == 0 {
|
||||
return Err(MerkleError::DepthTooSmall(depth));
|
||||
} else if depth > self.depth() {
|
||||
return Err(MerkleError::DepthTooBig(depth));
|
||||
} else if depth == self.depth() && !self.store.check_leaf_node_exists(key) {
|
||||
return Err(MerkleError::InvalidIndex(self.depth(), key));
|
||||
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
} else if index.depth() > self.depth() {
|
||||
return Err(MerkleError::DepthTooBig(index.depth()));
|
||||
} else if index.depth() == self.depth() && !self.store.check_leaf_node_exists(index.value())
|
||||
{
|
||||
return Err(MerkleError::InvalidIndex(index.with_depth(self.depth())));
|
||||
}
|
||||
|
||||
let mut path = Vec::with_capacity(depth as usize);
|
||||
let mut curr_key = key;
|
||||
for n in (0..depth).rev() {
|
||||
let parent_key = curr_key >> 1;
|
||||
let parent_node = self.store.get_branch_node(parent_key, n)?;
|
||||
let sibling_node = if curr_key & 1 == 1 {
|
||||
parent_node.left
|
||||
} else {
|
||||
parent_node.right
|
||||
};
|
||||
path.push(sibling_node.into());
|
||||
curr_key >>= 1;
|
||||
let mut path = Vec::with_capacity(index.depth() as usize);
|
||||
for _ in 0..index.depth() {
|
||||
let is_right = index.is_value_odd();
|
||||
index.move_up();
|
||||
let BranchNode { left, right } = self.store.get_branch_node(&index)?;
|
||||
let value = if is_right { left } else { right };
|
||||
path.push(*value);
|
||||
}
|
||||
Ok(path.into())
|
||||
}
|
||||
@@ -134,7 +130,7 @@ impl SimpleSmt {
|
||||
/// Returns an error if:
|
||||
/// * The specified key does not exist as a leaf node.
|
||||
pub fn get_leaf_path(&self, key: u64) -> Result<MerklePath, MerkleError> {
|
||||
self.get_path(self.depth(), key)
|
||||
self.get_path(NodeIndex::new(self.depth(), key))
|
||||
}
|
||||
|
||||
/// Replaces the leaf located at the specified key, and recomputes hashes by walking up the tree
|
||||
@@ -143,7 +139,7 @@ impl SimpleSmt {
|
||||
/// Returns an error if the specified key is not a valid leaf index for this tree.
|
||||
pub fn update_leaf(&mut self, key: u64, value: Word) -> Result<(), MerkleError> {
|
||||
if !self.store.check_leaf_node_exists(key) {
|
||||
return Err(MerkleError::InvalidIndex(self.depth(), key));
|
||||
return Err(MerkleError::InvalidIndex(NodeIndex::new(self.depth(), key)));
|
||||
}
|
||||
self.insert_leaf(key, value)?;
|
||||
|
||||
@@ -154,27 +150,25 @@ impl SimpleSmt {
|
||||
pub fn insert_leaf(&mut self, key: u64, value: Word) -> Result<(), MerkleError> {
|
||||
self.store.insert_leaf_node(key, value);
|
||||
|
||||
let depth = self.depth();
|
||||
let mut curr_key = key;
|
||||
let mut curr_node: RpoDigest = value.into();
|
||||
for n in (0..depth).rev() {
|
||||
let parent_key = curr_key >> 1;
|
||||
let parent_node = self
|
||||
// TODO consider using a map `index |-> word` instead of `index |-> (word, word)`
|
||||
let mut index = NodeIndex::new(self.depth(), key);
|
||||
let mut value = RpoDigest::from(value);
|
||||
for _ in 0..index.depth() {
|
||||
let is_right = index.is_value_odd();
|
||||
index.move_up();
|
||||
let BranchNode { left, right } = self
|
||||
.store
|
||||
.get_branch_node(parent_key, n)
|
||||
.unwrap_or_else(|_| self.store.get_empty_node((n + 1) as usize));
|
||||
let (left, right) = if curr_key & 1 == 1 {
|
||||
(parent_node.left, curr_node)
|
||||
.get_branch_node(&index)
|
||||
.unwrap_or_else(|_| self.store.get_empty_node(index.depth() as usize + 1));
|
||||
let (left, right) = if is_right {
|
||||
(left, value)
|
||||
} else {
|
||||
(curr_node, parent_node.right)
|
||||
(value, right)
|
||||
};
|
||||
|
||||
self.store.insert_branch_node(parent_key, n, left, right);
|
||||
curr_key = parent_key;
|
||||
curr_node = Rpo256::merge(&[left, right]);
|
||||
self.store.insert_branch_node(index, left, right);
|
||||
value = Rpo256::merge(&[left, right]);
|
||||
}
|
||||
self.root = curr_node.into();
|
||||
|
||||
self.root = value.into();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -188,10 +182,10 @@ impl SimpleSmt {
|
||||
/// with the root hash of an empty tree, and ending with the zero value of a leaf node.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct Store {
|
||||
branches: BTreeMap<(u64, u32), BranchNode>,
|
||||
branches: BTreeMap<NodeIndex, BranchNode>,
|
||||
leaves: BTreeMap<u64, Word>,
|
||||
empty_hashes: Vec<RpoDigest>,
|
||||
depth: u32,
|
||||
depth: u8,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
@@ -201,7 +195,7 @@ struct BranchNode {
|
||||
}
|
||||
|
||||
impl Store {
|
||||
fn new(depth: u32) -> (Self, Word) {
|
||||
fn new(depth: u8) -> (Self, Word) {
|
||||
let branches = BTreeMap::new();
|
||||
let leaves = BTreeMap::new();
|
||||
|
||||
@@ -244,23 +238,23 @@ impl Store {
|
||||
self.leaves
|
||||
.get(&key)
|
||||
.cloned()
|
||||
.ok_or(MerkleError::InvalidIndex(self.depth, key))
|
||||
.ok_or(MerkleError::InvalidIndex(NodeIndex::new(self.depth, key)))
|
||||
}
|
||||
|
||||
fn insert_leaf_node(&mut self, key: u64, node: Word) {
|
||||
self.leaves.insert(key, node);
|
||||
}
|
||||
|
||||
fn get_branch_node(&self, key: u64, depth: u32) -> Result<BranchNode, MerkleError> {
|
||||
fn get_branch_node(&self, index: &NodeIndex) -> Result<BranchNode, MerkleError> {
|
||||
self.branches
|
||||
.get(&(key, depth))
|
||||
.get(index)
|
||||
.cloned()
|
||||
.ok_or(MerkleError::InvalidIndex(depth, key))
|
||||
.ok_or(MerkleError::InvalidIndex(*index))
|
||||
}
|
||||
|
||||
fn insert_branch_node(&mut self, key: u64, depth: u32, left: RpoDigest, right: RpoDigest) {
|
||||
let node = BranchNode { left, right };
|
||||
self.branches.insert((key, depth), node);
|
||||
fn insert_branch_node(&mut self, index: NodeIndex, left: RpoDigest, right: RpoDigest) {
|
||||
let branch = BranchNode { left, right };
|
||||
self.branches.insert(index, branch);
|
||||
}
|
||||
|
||||
fn leaves_count(&self) -> usize {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::{
|
||||
super::{MerkleTree, RpoDigest, SimpleSmt},
|
||||
Rpo256, Vec, Word,
|
||||
NodeIndex, Rpo256, Vec, Word,
|
||||
};
|
||||
use crate::{Felt, FieldElement};
|
||||
use core::iter;
|
||||
@@ -62,7 +62,10 @@ fn build_sparse_tree() {
|
||||
.expect("Failed to insert leaf");
|
||||
let mt2 = MerkleTree::new(values.clone()).unwrap();
|
||||
assert_eq!(mt2.root(), smt.root());
|
||||
assert_eq!(mt2.get_path(3, 6).unwrap(), smt.get_path(3, 6).unwrap());
|
||||
assert_eq!(
|
||||
mt2.get_path(NodeIndex::new(3, 6)).unwrap(),
|
||||
smt.get_path(NodeIndex::new(3, 6)).unwrap()
|
||||
);
|
||||
|
||||
// insert second value at distinct leaf branch
|
||||
let key = 2;
|
||||
@@ -72,7 +75,10 @@ fn build_sparse_tree() {
|
||||
.expect("Failed to insert leaf");
|
||||
let mt3 = MerkleTree::new(values).unwrap();
|
||||
assert_eq!(mt3.root(), smt.root());
|
||||
assert_eq!(mt3.get_path(3, 2).unwrap(), smt.get_path(3, 2).unwrap());
|
||||
assert_eq!(
|
||||
mt3.get_path(NodeIndex::new(3, 2)).unwrap(),
|
||||
smt.get_path(NodeIndex::new(3, 2)).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -81,8 +87,8 @@ fn build_full_tree() {
|
||||
|
||||
let (root, node2, node3) = compute_internal_nodes();
|
||||
assert_eq!(root, tree.root());
|
||||
assert_eq!(node2, tree.get_node(1, 0).unwrap());
|
||||
assert_eq!(node3, tree.get_node(1, 1).unwrap());
|
||||
assert_eq!(node2, tree.get_node(&NodeIndex::new(1, 0)).unwrap());
|
||||
assert_eq!(node3, tree.get_node(&NodeIndex::new(1, 1)).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -90,10 +96,10 @@ fn get_values() {
|
||||
let tree = SimpleSmt::new(KEYS4.into_iter().zip(VALUES4.into_iter()), 2).unwrap();
|
||||
|
||||
// check depth 2
|
||||
assert_eq!(VALUES4[0], tree.get_node(2, 0).unwrap());
|
||||
assert_eq!(VALUES4[1], tree.get_node(2, 1).unwrap());
|
||||
assert_eq!(VALUES4[2], tree.get_node(2, 2).unwrap());
|
||||
assert_eq!(VALUES4[3], tree.get_node(2, 3).unwrap());
|
||||
assert_eq!(VALUES4[0], tree.get_node(&NodeIndex::new(2, 0)).unwrap());
|
||||
assert_eq!(VALUES4[1], tree.get_node(&NodeIndex::new(2, 1)).unwrap());
|
||||
assert_eq!(VALUES4[2], tree.get_node(&NodeIndex::new(2, 2)).unwrap());
|
||||
assert_eq!(VALUES4[3], tree.get_node(&NodeIndex::new(2, 3)).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -103,14 +109,26 @@ fn get_path() {
|
||||
let (_, node2, node3) = compute_internal_nodes();
|
||||
|
||||
// check depth 2
|
||||
assert_eq!(vec![VALUES4[1], node3], *tree.get_path(2, 0).unwrap());
|
||||
assert_eq!(vec![VALUES4[0], node3], *tree.get_path(2, 1).unwrap());
|
||||
assert_eq!(vec![VALUES4[3], node2], *tree.get_path(2, 2).unwrap());
|
||||
assert_eq!(vec![VALUES4[2], node2], *tree.get_path(2, 3).unwrap());
|
||||
assert_eq!(
|
||||
vec![VALUES4[1], node3],
|
||||
*tree.get_path(NodeIndex::new(2, 0)).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
vec![VALUES4[0], node3],
|
||||
*tree.get_path(NodeIndex::new(2, 1)).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
vec![VALUES4[3], node2],
|
||||
*tree.get_path(NodeIndex::new(2, 2)).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
vec![VALUES4[2], node2],
|
||||
*tree.get_path(NodeIndex::new(2, 3)).unwrap()
|
||||
);
|
||||
|
||||
// check depth 1
|
||||
assert_eq!(vec![node3], *tree.get_path(1, 0).unwrap());
|
||||
assert_eq!(vec![node2], *tree.get_path(1, 1).unwrap());
|
||||
assert_eq!(vec![node3], *tree.get_path(NodeIndex::new(1, 0)).unwrap());
|
||||
assert_eq!(vec![node2], *tree.get_path(NodeIndex::new(1, 1)).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -175,7 +193,7 @@ fn small_tree_opening_is_consistent() {
|
||||
|
||||
assert_eq!(tree.root(), Word::from(k));
|
||||
|
||||
let cases: Vec<(u32, u64, Vec<Word>)> = vec![
|
||||
let cases: Vec<(u8, u64, Vec<Word>)> = vec![
|
||||
(3, 0, vec![b, f, j]),
|
||||
(3, 1, vec![a, f, j]),
|
||||
(3, 4, vec![z, h, i]),
|
||||
@@ -189,7 +207,7 @@ fn small_tree_opening_is_consistent() {
|
||||
];
|
||||
|
||||
for (depth, key, path) in cases {
|
||||
let opening = tree.get_path(depth, key).unwrap();
|
||||
let opening = tree.get_path(NodeIndex::new(depth, key)).unwrap();
|
||||
|
||||
assert_eq!(path, *opening);
|
||||
}
|
||||
@@ -213,7 +231,7 @@ proptest! {
|
||||
// traverse to root, fetching all paths
|
||||
for d in 1..depth {
|
||||
let k = key >> (depth - d);
|
||||
tree.get_path(d, k).unwrap();
|
||||
tree.get_path(NodeIndex::new(d, k)).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user