Browse Source

Enhancement of the Partial Merkle Tree (#163)

feat: implement additional functionality for the PartialMerkleTree
al-gkr-basic-workflow
Andrey Khmuro 1 year ago
committed by GitHub
parent
commit
08aec4443c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 342 additions and 35 deletions
  1. +17
    -0
      src/merkle/index.rs
  2. +158
    -18
      src/merkle/partial_mt/mod.rs
  3. +161
    -11
      src/merkle/partial_mt/tests.rs
  4. +4
    -4
      src/merkle/store/mod.rs
  5. +2
    -2
      src/utils/mod.rs

+ 17
- 0
src/merkle/index.rs

@ -1,4 +1,5 @@
use super::{Felt, MerkleError, RpoDigest, StarkField}; use super::{Felt, MerkleError, RpoDigest, StarkField};
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use core::fmt::Display; use core::fmt::Display;
// NODE INDEX // NODE INDEX
@ -161,6 +162,22 @@ impl Display for NodeIndex {
} }
} }
impl Serializable for NodeIndex {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_u8(self.depth);
target.write_u64(self.value);
}
}
impl Deserializable for NodeIndex {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let depth = source.read_u8()?;
let value = source.read_u64()?;
NodeIndex::new(depth, value)
.map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

+ 158
- 18
src/merkle/partial_mt/mod.rs

@ -1,7 +1,11 @@
use super::{ use super::{
BTreeMap, BTreeSet, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, ValuePath, Vec, ZERO,
BTreeMap, BTreeSet, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest,
ValuePath, Vec, Word, ZERO,
};
use crate::utils::{
format, string::String, vec, word_to_hex, ByteReader, ByteWriter, Deserializable,
DeserializationError, Serializable,
}; };
use crate::utils::{format, string::String, word_to_hex};
use core::fmt; use core::fmt;
#[cfg(test)] #[cfg(test)]
@ -74,6 +78,92 @@ impl PartialMerkleTree {
}) })
} }
/// Returns a new [PartialMerkleTree] instantiated with leaves map as specified by the provided
/// entries.
///
/// # Errors
/// Returns an error if:
/// - If the depth is 0 or is greater than 64.
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
/// - The provided entries contain an insufficient set of nodes.
pub fn with_leaves<R, I>(entries: R) -> Result<Self, MerkleError>
where
R: IntoIterator<IntoIter = I>,
I: Iterator<Item = (NodeIndex, RpoDigest)> + ExactSizeIterator,
{
let mut layers: BTreeMap<u8, Vec<u64>> = BTreeMap::new();
let mut leaves = BTreeSet::new();
let mut nodes = BTreeMap::new();
// add data to the leaves and nodes maps and also fill layers map, where the key is the
// depth of the node and value is its index.
for (node_index, hash) in entries.into_iter() {
leaves.insert(node_index);
nodes.insert(node_index, hash);
layers
.entry(node_index.depth())
.and_modify(|layer_vec| layer_vec.push(node_index.value()))
.or_insert(vec![node_index.value()]);
}
// check if the number of leaves can be accommodated by the tree's depth; we use a min
// depth of 63 because we consider passing in a vector of size 2^64 infeasible.
let max = (1_u64 << 63) as usize;
if layers.len() > max {
return Err(MerkleError::InvalidNumEntries(max, layers.len()));
}
// Get maximum depth
let max_depth = *layers.keys().next_back().unwrap_or(&0);
// fill layers without nodes with empty vector
for depth in 0..max_depth {
layers.entry(depth).or_insert(vec![]);
}
let mut layer_iter = layers.into_values().rev();
let mut parent_layer = layer_iter.next().unwrap();
let mut current_layer;
for depth in (1..max_depth + 1).rev() {
// set current_layer = parent_layer and parent_layer = layer_iter.next()
current_layer = layer_iter.next().unwrap();
core::mem::swap(&mut current_layer, &mut parent_layer);
for index_value in current_layer {
// get the parent node index
let parent_node = NodeIndex::new(depth - 1, index_value / 2)?;
// Check if the parent hash was already calculated. In about half of the cases, we
// don't need to do anything.
if !parent_layer.contains(&parent_node.value()) {
// create current node index
let index = NodeIndex::new(depth, index_value)?;
// get hash of the current node
let node = nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index))?;
// get hash of the sibling node
let sibling = nodes
.get(&index.sibling())
.ok_or(MerkleError::NodeNotInSet(index.sibling()))?;
// get parent hash
let parent = Rpo256::merge(&index.build_node(*node, *sibling));
// add index value of the calculated node to the parents layer
parent_layer.push(parent_node.value());
// add index and hash to the nodes map
nodes.insert(parent_node, parent);
}
}
}
Ok(PartialMerkleTree {
max_depth,
nodes,
leaves,
})
}
// PUBLIC ACCESSORS // PUBLIC ACCESSORS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@ -101,7 +191,7 @@ impl PartialMerkleTree {
} }
/// Returns a vector of paths from every leaf to the root. /// Returns a vector of paths from every leaf to the root.
pub fn paths(&self) -> Vec<(NodeIndex, ValuePath)> {
pub fn to_paths(&self) -> Vec<(NodeIndex, ValuePath)> {
let mut paths = Vec::new(); let mut paths = Vec::new();
self.leaves.iter().for_each(|&leaf| { self.leaves.iter().for_each(|&leaf| {
paths.push(( paths.push((
@ -160,6 +250,22 @@ impl PartialMerkleTree {
}) })
} }
/// Returns an iterator over the inner nodes of this Merkle tree.
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
let inner_nodes = self.nodes.iter().filter(|(index, _)| !self.leaves.contains(index));
inner_nodes.map(|(index, digest)| {
let left_hash =
self.nodes.get(&index.left_child()).expect("Failed to get left child hash");
let right_hash =
self.nodes.get(&index.right_child()).expect("Failed to get right child hash");
InnerNodeInfo {
value: *digest,
left: *left_hash,
right: *right_hash,
}
})
}
// STATE MUTATORS // STATE MUTATORS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@ -235,37 +341,37 @@ impl PartialMerkleTree {
/// Updates value of the leaf at the specified index returning the old leaf value. /// Updates value of the leaf at the specified index returning the old leaf value.
/// ///
/// By default the specified index is assumed to belong to the deepest layer. If the considered
/// node does not belong to the tree, the first node on the way to the root will be changed.
///
/// This also recomputes all hashes between the leaf and the root, updating the root itself. /// This also recomputes all hashes between the leaf and the root, updating the root itself.
/// ///
/// # Errors /// # Errors
/// Returns an error if: /// Returns an error if:
/// - The depth of the specified node_index is greater than 64 or smaller than 1.
/// - The specified node index is not corresponding to the leaf.
pub fn update_leaf(
&mut self,
node_index: NodeIndex,
value: RpoDigest,
) -> Result<RpoDigest, MerkleError> {
// check correctness of the depth and update it
Self::check_depth(node_index.depth())?;
self.update_depth(node_index.depth());
/// - The specified index is greater than the maximum number of nodes on the deepest layer.
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<RpoDigest, MerkleError> {
let mut node_index = NodeIndex::new(self.max_depth(), index)?;
// insert NodeIndex to the leaves Set
self.leaves.insert(node_index);
// proceed to the leaf
for _ in 0..node_index.depth() {
if !self.leaves.contains(&node_index) {
node_index.move_up();
}
}
// add node value to the nodes Map // add node value to the nodes Map
let old_value = self let old_value = self
.nodes .nodes
.insert(node_index, value)
.insert(node_index, value.into())
.ok_or(MerkleError::NodeNotInSet(node_index))?; .ok_or(MerkleError::NodeNotInSet(node_index))?;
// if the old value and new value are the same, there is nothing to update // if the old value and new value are the same, there is nothing to update
if value == old_value {
if value == *old_value {
return Ok(old_value); return Ok(old_value);
} }
let mut node_index = node_index; let mut node_index = node_index;
let mut value = value;
let mut value = value.into();
for _ in 0..node_index.depth() { for _ in 0..node_index.depth() {
let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist"); let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist");
value = Rpo256::merge(&node_index.build_node(value, *sibling)); value = Rpo256::merge(&node_index.build_node(value, *sibling));
@ -327,3 +433,37 @@ impl PartialMerkleTree {
Ok(()) Ok(())
} }
} }
// SERIALIZATION
// ================================================================================================
impl Serializable for PartialMerkleTree {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
// write leaf nodes
target.write_u64(self.leaves.len() as u64);
for leaf_index in self.leaves.iter() {
leaf_index.write_into(target);
self.get_node(*leaf_index).expect("Leaf hash not found").write_into(target);
}
}
}
impl Deserializable for PartialMerkleTree {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let leaves_len = source.read_u64()? as usize;
let mut leaf_nodes = Vec::with_capacity(leaves_len);
// add leaf nodes to the vector
for _ in 0..leaves_len {
let index = NodeIndex::read_from(source)?;
let hash = RpoDigest::read_from(source)?;
leaf_nodes.push((index, hash));
}
let pmt = PartialMerkleTree::with_leaves(leaf_nodes).map_err(|_| {
DeserializationError::InvalidValue("Invalid data for PartialMerkleTree creation".into())
})?;
Ok(pmt)
}
}

+ 161
- 11
src/merkle/partial_mt/tests.rs

@ -1,9 +1,9 @@
use super::{ use super::{
super::{ super::{
digests_to_words, int_to_node, DefaultMerkleStore as MerkleStore, MerkleTree, NodeIndex,
PartialMerkleTree,
digests_to_words, int_to_node, BTreeMap, DefaultMerkleStore as MerkleStore, MerkleTree,
NodeIndex, PartialMerkleTree,
}, },
RpoDigest, ValuePath, Vec,
Deserializable, InnerNodeInfo, RpoDigest, Serializable, ValuePath, Vec,
}; };
// TEST DATA // TEST DATA
@ -13,6 +13,7 @@ const NODE10: NodeIndex = NodeIndex::new_unchecked(1, 0);
const NODE11: NodeIndex = NodeIndex::new_unchecked(1, 1); const NODE11: NodeIndex = NodeIndex::new_unchecked(1, 1);
const NODE20: NodeIndex = NodeIndex::new_unchecked(2, 0); const NODE20: NodeIndex = NodeIndex::new_unchecked(2, 0);
const NODE21: NodeIndex = NodeIndex::new_unchecked(2, 1);
const NODE22: NodeIndex = NodeIndex::new_unchecked(2, 2); const NODE22: NodeIndex = NodeIndex::new_unchecked(2, 2);
const NODE23: NodeIndex = NodeIndex::new_unchecked(2, 3); const NODE23: NodeIndex = NodeIndex::new_unchecked(2, 3);
@ -50,6 +51,43 @@ const VALUES8: [RpoDigest; 8] = [
// NodeIndex(3, 5) will be labeled as `35`. Leaves of the tree are shown as nodes with parenthesis // NodeIndex(3, 5) will be labeled as `35`. Leaves of the tree are shown as nodes with parenthesis
// (33). // (33).
/// Checks that creation of the PMT with `with_leaves()` constructor is working correctly.
#[test]
fn with_leaves() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let expected_root = mt.root();
let leaf_nodes_vec = vec![
(NODE20, mt.get_node(NODE20).unwrap()),
(NODE32, mt.get_node(NODE32).unwrap()),
(NODE33, mt.get_node(NODE33).unwrap()),
(NODE22, mt.get_node(NODE22).unwrap()),
(NODE23, mt.get_node(NODE23).unwrap()),
];
let leaf_nodes: BTreeMap<NodeIndex, RpoDigest> = leaf_nodes_vec.into_iter().collect();
let pmt = PartialMerkleTree::with_leaves(leaf_nodes).unwrap();
assert_eq!(expected_root, pmt.root())
}
/// Checks that `with_leaves()` function returns an error when using incomplete set of nodes.
#[test]
fn err_with_leaves() {
// NODE22 is missing
let leaf_nodes_vec = vec![
(NODE20, int_to_node(20)),
(NODE32, int_to_node(32)),
(NODE33, int_to_node(33)),
(NODE23, int_to_node(23)),
];
let leaf_nodes: BTreeMap<NodeIndex, RpoDigest> = leaf_nodes_vec.into_iter().collect();
assert!(PartialMerkleTree::with_leaves(leaf_nodes).is_err());
}
/// Checks that root returned by `root()` function is equal to the expected one. /// Checks that root returned by `root()` function is equal to the expected one.
#[test] #[test]
fn get_root() { fn get_root() {
@ -61,7 +99,7 @@ fn get_root() {
let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap(); let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
assert_eq!(pmt.root(), expected_root);
assert_eq!(expected_root, pmt.root());
} }
/// This test checks correctness of the `add_path()` and `get_path()` functions. First it creates a /// This test checks correctness of the `add_path()` and `get_path()` functions. First it creates a
@ -121,7 +159,7 @@ fn update_leaf() {
let new_value32 = int_to_node(132); let new_value32 = int_to_node(132);
let expected_root = ms.set_node(root, NODE32, new_value32).unwrap().root; let expected_root = ms.set_node(root, NODE32, new_value32).unwrap().root;
pmt.update_leaf(NODE32, new_value32).unwrap();
pmt.update_leaf(2, *new_value32).unwrap();
let actual_root = pmt.root(); let actual_root = pmt.root();
assert_eq!(expected_root, actual_root); assert_eq!(expected_root, actual_root);
@ -129,7 +167,15 @@ fn update_leaf() {
let new_value20 = int_to_node(120); let new_value20 = int_to_node(120);
let expected_root = ms.set_node(expected_root, NODE20, new_value20).unwrap().root; let expected_root = ms.set_node(expected_root, NODE20, new_value20).unwrap().root;
pmt.update_leaf(NODE20, new_value20).unwrap();
pmt.update_leaf(0, *new_value20).unwrap();
let actual_root = pmt.root();
assert_eq!(expected_root, actual_root);
let new_value11 = int_to_node(111);
let expected_root = ms.set_node(expected_root, NODE11, new_value11).unwrap().root;
pmt.update_leaf(6, *new_value11).unwrap();
let actual_root = pmt.root(); let actual_root = pmt.root();
assert_eq!(expected_root, actual_root); assert_eq!(expected_root, actual_root);
@ -177,7 +223,7 @@ fn get_paths() {
}) })
.collect(); .collect();
let actual_paths = pmt.paths();
let actual_paths = pmt.to_paths();
assert_eq!(expected_paths, actual_paths); assert_eq!(expected_paths, actual_paths);
} }
@ -247,6 +293,113 @@ fn leaves() {
assert!(expected_leaves.eq(pmt.leaves())); assert!(expected_leaves.eq(pmt.leaves()));
} }
/// Checks that nodes of the PMT returned by `inner_nodes()` function are equal to the expected ones.
#[test]
fn test_inner_node_iterator() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let expected_root = mt.root();
let ms = MerkleStore::from(&mt);
let path33 = ms.get_path(expected_root, NODE33).unwrap();
let path22 = ms.get_path(expected_root, NODE22).unwrap();
let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
// get actual inner nodes
let actual: Vec<InnerNodeInfo> = pmt.inner_nodes().collect();
let expected_n00 = mt.root();
let expected_n10 = mt.get_node(NODE10).unwrap();
let expected_n11 = mt.get_node(NODE11).unwrap();
let expected_n20 = mt.get_node(NODE20).unwrap();
let expected_n21 = mt.get_node(NODE21).unwrap();
let expected_n32 = mt.get_node(NODE32).unwrap();
let expected_n33 = mt.get_node(NODE33).unwrap();
// create vector of the expected inner nodes
let mut expected = vec![
InnerNodeInfo {
value: expected_n00,
left: expected_n10,
right: expected_n11,
},
InnerNodeInfo {
value: expected_n10,
left: expected_n20,
right: expected_n21,
},
InnerNodeInfo {
value: expected_n21,
left: expected_n32,
right: expected_n33,
},
];
assert_eq!(actual, expected);
// add another path to the Partial Merkle Tree
pmt.add_path(2, path22.value, path22.path).unwrap();
// get new actual inner nodes
let actual: Vec<InnerNodeInfo> = pmt.inner_nodes().collect();
let expected_n22 = mt.get_node(NODE22).unwrap();
let expected_n23 = mt.get_node(NODE23).unwrap();
let info_11 = InnerNodeInfo {
value: expected_n11,
left: expected_n22,
right: expected_n23,
};
// add new inner node to the existing vertor
expected.insert(2, info_11);
assert_eq!(actual, expected);
}
/// Checks that serialization and deserialization implementations for the PMT are working
/// correctly.
#[test]
fn serialization() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let expected_root = mt.root();
let ms = MerkleStore::from(&mt);
let path33 = ms.get_path(expected_root, NODE33).unwrap();
let path22 = ms.get_path(expected_root, NODE22).unwrap();
let pmt = PartialMerkleTree::with_paths([
(3, path33.value, path33.path),
(2, path22.value, path22.path),
])
.unwrap();
let serialized_pmt = pmt.to_bytes();
let deserialized_pmt = PartialMerkleTree::read_from_bytes(&serialized_pmt).unwrap();
assert_eq!(deserialized_pmt, pmt);
}
/// Checks that deserialization fails with incorrect data.
#[test]
fn err_deserialization() {
let mut tree_bytes: Vec<u8> = vec![5];
tree_bytes.append(&mut NODE20.to_bytes());
tree_bytes.append(&mut int_to_node(20).to_bytes());
tree_bytes.append(&mut NODE21.to_bytes());
tree_bytes.append(&mut int_to_node(21).to_bytes());
// node with depth 1 could have index 0 or 1, but it has 2
tree_bytes.append(&mut vec![1, 2]);
tree_bytes.append(&mut int_to_node(11).to_bytes());
assert!(PartialMerkleTree::read_from_bytes(&tree_bytes).is_err());
}
/// Checks that addition of the path with different root will cause an error. /// Checks that addition of the path with different root will cause an error.
#[test] #[test]
fn err_add_path() { fn err_add_path() {
@ -306,8 +459,5 @@ fn err_update_leaf() {
let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap(); let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
assert!(pmt.update_leaf(NODE22, int_to_node(22)).is_err());
assert!(pmt.update_leaf(NODE23, int_to_node(23)).is_err());
assert!(pmt.update_leaf(NODE30, int_to_node(30)).is_err());
assert!(pmt.update_leaf(NODE31, int_to_node(31)).is_err());
assert!(pmt.update_leaf(8, *int_to_node(38)).is_err());
} }

+ 4
- 4
src/merkle/store/mod.rs

@ -438,21 +438,21 @@ impl> From<&TieredSmt> for MerkleStore {
impl<T: KvMap<RpoDigest, StoreNode>> From<T> for MerkleStore<T> { impl<T: KvMap<RpoDigest, StoreNode>> From<T> for MerkleStore<T> {
fn from(values: T) -> Self { fn from(values: T) -> Self {
let nodes = values.into_iter().chain(empty_hashes().into_iter()).collect();
let nodes = values.into_iter().chain(empty_hashes()).collect();
Self { nodes } Self { nodes }
} }
} }
impl<T: KvMap<RpoDigest, StoreNode>> FromIterator<InnerNodeInfo> for MerkleStore<T> { impl<T: KvMap<RpoDigest, StoreNode>> FromIterator<InnerNodeInfo> for MerkleStore<T> {
fn from_iter<I: IntoIterator<Item = InnerNodeInfo>>(iter: I) -> Self { fn from_iter<I: IntoIterator<Item = InnerNodeInfo>>(iter: I) -> Self {
let nodes = combine_nodes_with_empty_hashes(iter.into_iter()).collect();
let nodes = combine_nodes_with_empty_hashes(iter).collect();
Self { nodes } Self { nodes }
} }
} }
impl<T: KvMap<RpoDigest, StoreNode>> FromIterator<(RpoDigest, StoreNode)> for MerkleStore<T> { impl<T: KvMap<RpoDigest, StoreNode>> FromIterator<(RpoDigest, StoreNode)> for MerkleStore<T> {
fn from_iter<I: IntoIterator<Item = (RpoDigest, StoreNode)>>(iter: I) -> Self { fn from_iter<I: IntoIterator<Item = (RpoDigest, StoreNode)>>(iter: I) -> Self {
let nodes = iter.into_iter().chain(empty_hashes().into_iter()).collect();
let nodes = iter.into_iter().chain(empty_hashes()).collect();
Self { nodes } Self { nodes }
} }
} }
@ -553,5 +553,5 @@ fn combine_nodes_with_empty_hashes(
}, },
) )
}) })
.chain(empty_hashes().into_iter())
.chain(empty_hashes())
} }

+ 2
- 2
src/utils/mod.rs

@ -2,10 +2,10 @@ use super::{utils::string::String, Word};
use core::fmt::{self, Write}; use core::fmt::{self, Write};
#[cfg(not(feature = "std"))] #[cfg(not(feature = "std"))]
pub use alloc::format;
pub use alloc::{format, vec};
#[cfg(feature = "std")] #[cfg(feature = "std")]
pub use std::format;
pub use std::{format, vec};
mod kv_map; mod kv_map;

Loading…
Cancel
Save