Browse Source

feat: introduce TryApplyDiff and refactor RecordingMap finalizer

al-gkr-basic-workflow
frisitano 1 year ago
parent
commit
da2d08714d
7 changed files with 277 additions and 36 deletions
  1. +153
    -0
      src/merkle/delta.rs
  2. +4
    -1
      src/merkle/mod.rs
  3. +28
    -2
      src/merkle/simple_smt/mod.rs
  4. +62
    -20
      src/merkle/store/mod.rs
  5. +1
    -1
      src/merkle/store/tests.rs
  6. +17
    -2
      src/utils/diff.rs
  7. +12
    -10
      src/utils/kv_map.rs

+ 153
- 0
src/merkle/delta.rs

@ -0,0 +1,153 @@
use super::{
BTreeMap, KvMap, MerkleError, MerkleStore, NodeIndex, RpoDigest, StoreNode, Vec, Word,
};
use crate::utils::collections::Diff;
#[cfg(test)]
use super::{empty_roots::EMPTY_WORD, Felt, SimpleSmt};
// MERKLE STORE DELTA
// ================================================================================================
/// [MerkleStoreDelta] stores a vector of ([RpoDigest], [MerkleTreeDelta]) tuples where the
/// [RpoDigest] represents the root of the Merkle tree and [MerkleTreeDelta] represents the
/// differences between the initial and final Merkle tree states.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>);
// MERKLE TREE DELTA
// ================================================================================================
/// [MerkleDelta] stores the differences between the initial and final Merkle tree states.
///
/// The differences are represented as follows:
/// - depth: the depth of the merkle tree.
/// - cleared_slots: indexes of slots where values were set to [ZERO; 4].
/// - updated_slots: index-value pairs of slots where values were set to non [ZERO; 4] values.
#[cfg(not(test))]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MerkleTreeDelta {
depth: u8,
cleared_slots: Vec<u64>,
updated_slots: Vec<(u64, Word)>,
}
impl MerkleTreeDelta {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
pub fn new(depth: u8) -> Self {
Self {
depth,
cleared_slots: Vec::new(),
updated_slots: Vec::new(),
}
}
// ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the depth of the Merkle tree the [MerkleDelta] is associated with.
pub fn depth(&self) -> u8 {
self.depth
}
/// Returns the indexes of slots where values were set to [ZERO; 4].
pub fn cleared_slots(&self) -> &[u64] {
&self.cleared_slots
}
/// Returns the index-value pairs of slots where values were set to non [ZERO; 4] values.
pub fn updated_slots(&self) -> &[(u64, Word)] {
&self.updated_slots
}
// MODIFIERS
// --------------------------------------------------------------------------------------------
/// Adds a slot index to the list of cleared slots.
pub fn add_cleared_slot(&mut self, index: u64) {
self.cleared_slots.push(index);
}
/// Adds a slot index and a value to the list of updated slots.
pub fn add_updated_slot(&mut self, index: u64, value: Word) {
self.updated_slots.push((index, value));
}
}
/// Extracts a [MerkleDelta] object by comparing the leaves of two Merkle trees specifies by
/// their roots and depth.
pub fn merkle_tree_delta<T: KvMap<RpoDigest, StoreNode>>(
tree_root_1: RpoDigest,
tree_root_2: RpoDigest,
depth: u8,
merkle_store: &MerkleStore<T>,
) -> Result<MerkleTreeDelta, MerkleError> {
if tree_root_1 == tree_root_2 {
return Ok(MerkleTreeDelta::new(depth));
}
let tree_1_leaves: BTreeMap<NodeIndex, RpoDigest> =
merkle_store.non_empty_leaves(tree_root_1, depth).collect();
let tree_2_leaves: BTreeMap<NodeIndex, RpoDigest> =
merkle_store.non_empty_leaves(tree_root_2, depth).collect();
let diff = tree_1_leaves.diff(&tree_2_leaves);
// TODO: Refactor this diff implementation to prevent allocation of both BTree and Vec.
Ok(MerkleTreeDelta {
depth,
cleared_slots: diff.removed.into_iter().map(|index| index.value()).collect(),
updated_slots: diff
.updated
.into_iter()
.map(|(index, leaf)| (index.value(), *leaf))
.collect(),
})
}
// INTERNALS
// --------------------------------------------------------------------------------------------
#[cfg(test)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MerkleTreeDelta {
pub depth: u8,
pub cleared_slots: Vec<u64>,
pub updated_slots: Vec<(u64, Word)>,
}
// MERKLE DELTA
// ================================================================================================
#[test]
fn test_compute_merkle_delta() {
let entries = vec![
(10, [Felt::new(0), Felt::new(1), Felt::new(2), Felt::new(3)]),
(15, [Felt::new(4), Felt::new(5), Felt::new(6), Felt::new(7)]),
(20, [Felt::new(8), Felt::new(9), Felt::new(10), Felt::new(11)]),
(31, [Felt::new(12), Felt::new(13), Felt::new(14), Felt::new(15)]),
];
let simple_smt = SimpleSmt::with_leaves(30, entries.clone()).unwrap();
let mut store: MerkleStore = (&simple_smt).into();
let root = simple_smt.root();
// add a new node
let new_value = [Felt::new(16), Felt::new(17), Felt::new(18), Felt::new(19)];
let new_index = NodeIndex::new(simple_smt.depth(), 32).unwrap();
let root = store.set_node(root, new_index, new_value.into()).unwrap().root;
// update an existing node
let update_value = [Felt::new(20), Felt::new(21), Felt::new(22), Felt::new(23)];
let update_idx = NodeIndex::new(simple_smt.depth(), entries[0].0).unwrap();
let root = store.set_node(root, update_idx, update_value.into()).unwrap().root;
// remove a node
let remove_idx = NodeIndex::new(simple_smt.depth(), entries[1].0).unwrap();
let root = store.set_node(root, remove_idx, EMPTY_WORD.into()).unwrap().root;
let merkle_delta =
merkle_tree_delta(simple_smt.root(), root, simple_smt.depth(), &store).unwrap();
let expected_merkle_delta = MerkleTreeDelta {
depth: simple_smt.depth(),
cleared_slots: vec![remove_idx.value()],
updated_slots: vec![(update_idx.value(), update_value), (new_index.value(), new_value)],
};
assert_eq!(merkle_delta, expected_merkle_delta);
}

+ 4
- 1
src/merkle/mod.rs

@ -1,6 +1,6 @@
use super::{
hash::rpo::{Rpo256, RpoDigest},
utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, Vec},
utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, TryApplyDiff, Vec},
Felt, StarkField, Word, WORD_SIZE, ZERO,
};
use core::fmt;
@ -11,6 +11,9 @@ use core::fmt;
mod empty_roots;
pub use empty_roots::EmptySubtreeRoots;
mod delta;
pub use delta::{merkle_tree_delta, MerkleStoreDelta, MerkleTreeDelta};
mod index;
pub use index::NodeIndex;

+ 28
- 2
src/merkle/simple_smt/mod.rs

@ -1,6 +1,6 @@
use super::{
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex,
Rpo256, RpoDigest, Vec, Word,
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, MerkleTreeDelta,
NodeIndex, Rpo256, RpoDigest, StoreNode, TryApplyDiff, Vec, Word,
};
#[cfg(test)]
@ -275,3 +275,29 @@ impl BranchNode {
Rpo256::merge(&[self.left, self.right])
}
}
// TRY APPLY DIFF
// ================================================================================================
impl TryApplyDiff<RpoDigest, StoreNode> for SimpleSmt {
type Error = MerkleError;
type DiffType = MerkleTreeDelta;
fn try_apply(&mut self, diff: MerkleTreeDelta) -> Result<(), MerkleError> {
if diff.depth() != self.depth() {
return Err(MerkleError::InvalidDepth {
expected: self.depth(),
provided: diff.depth(),
});
}
for slot in diff.cleared_slots() {
self.update_leaf(*slot, Self::EMPTY_VALUE)?;
}
for (slot, value) in diff.updated_slots() {
self.update_leaf(*slot, *value)?;
}
Ok(())
}
}

+ 62
- 20
src/merkle/store/mod.rs

@ -1,12 +1,9 @@
use super::{
mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, KvMap, MerkleError, MerklePath,
MerklePathSet, MerkleTree, NodeIndex, RecordingMap, RootPath, Rpo256, RpoDigest, SimpleSmt,
TieredSmt, ValuePath, Vec,
};
use crate::utils::{
collections::{ApplyDiff, Diff, KvMapDiff},
ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
empty_roots::EMPTY_WORD, mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, KvMap,
MerkleError, MerklePath, MerklePathSet, MerkleStoreDelta, MerkleTree, NodeIndex, RecordingMap,
RootPath, Rpo256, RpoDigest, SimpleSmt, TieredSmt, TryApplyDiff, ValuePath, Vec,
};
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use core::borrow::Borrow;
#[cfg(test)]
@ -280,6 +277,37 @@ impl> MerkleStore {
})
}
/// Iterator over the non-empty leaves of the Merkle tree associated with the specified `root`
/// and `max_depth`.
pub fn non_empty_leaves(
&self,
root: RpoDigest,
max_depth: u8,
) -> impl Iterator<Item = (NodeIndex, RpoDigest)> + '_ {
let empty_roots = EmptySubtreeRoots::empty_hashes(max_depth);
let mut stack = Vec::new();
stack.push((NodeIndex::new_unchecked(0, 0), root));
core::iter::from_fn(move || {
while let Some((index, node_hash)) = stack.pop() {
if index.depth() == max_depth {
return Some((index, node_hash));
}
if let Some(node) = self.nodes.get(&node_hash) {
if !empty_roots.contains(&node.left) {
stack.push((index.left_child(), node.left));
}
if !empty_roots.contains(&node.right) {
stack.push((index.right_child(), node.right));
}
}
}
None
})
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
@ -462,7 +490,6 @@ impl> FromIterator<(RpoDigest, StoreNode)> for Me
// ITERATORS
// ================================================================================================
impl<T: KvMap<RpoDigest, StoreNode>> Extend<InnerNodeInfo> for MerkleStore<T> {
fn extend<I: IntoIterator<Item = InnerNodeInfo>>(&mut self, iter: I) {
self.nodes.extend(iter.into_iter().map(|info| {
@ -479,19 +506,34 @@ impl> Extend for MerkleStore {
// DiffT & ApplyDiffT TRAIT IMPLEMENTATION
// ================================================================================================
impl<T: KvMap<RpoDigest, StoreNode>> Diff<RpoDigest, StoreNode> for MerkleStore<T> {
type DiffType = KvMapDiff<RpoDigest, StoreNode>;
fn diff(&self, other: &Self) -> Self::DiffType {
self.nodes.diff(&other.nodes)
}
}
impl<T: KvMap<RpoDigest, StoreNode>> ApplyDiff<RpoDigest, StoreNode> for MerkleStore<T> {
type DiffType = KvMapDiff<RpoDigest, StoreNode>;
impl<T: KvMap<RpoDigest, StoreNode>> TryApplyDiff<RpoDigest, StoreNode> for MerkleStore<T> {
type Error = MerkleError;
type DiffType = MerkleStoreDelta;
fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), MerkleError> {
for (root, delta) in diff.0 {
let mut root = root;
for cleared_slot in delta.cleared_slots() {
root = self
.set_node(
root,
NodeIndex::new(delta.depth(), *cleared_slot)?,
EMPTY_WORD.into(),
)?
.root;
}
for (updated_slot, updated_value) in delta.updated_slots() {
root = self
.set_node(
root,
NodeIndex::new(delta.depth(), *updated_slot)?,
(*updated_value).into(),
)?
.root;
}
}
fn apply(&mut self, diff: Self::DiffType) {
self.nodes.apply(diff);
Ok(())
}
}

+ 1
- 1
src/merkle/store/tests.rs

@ -847,7 +847,7 @@ fn test_recorder() {
// construct the proof
let rec_map = recorder.into_inner();
let proof = rec_map.into_proof();
let (_, proof) = rec_map.finalize();
let merkle_store: MerkleStore = proof.into();
// make sure the proof contains all nodes from both trees

+ 17
- 2
src/utils/diff.rs

@ -1,16 +1,31 @@
/// A trait for computing the difference between two objects.
pub trait Diff<K: Ord + Clone, V: Clone> {
/// The type that describes the difference between two objects.
type DiffType;
/// Returns a `Self::DiffType` object that represents the difference between this object and
/// Returns a [Self::DiffType] object that represents the difference between this object and
/// other.
fn diff(&self, other: &Self) -> Self::DiffType;
}
/// A trait for applying the difference between two objects.
pub trait ApplyDiff<K: Ord + Clone, V: Clone> {
/// The type that describes the difference between two objects.
type DiffType;
/// Applies the provided changes described by [DiffType] to the object implementing this trait.
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
fn apply(&mut self, diff: Self::DiffType);
}
/// A trait for applying the difference between two objects with the possibility of failure.
pub trait TryApplyDiff<K: Ord + Clone, V: Clone> {
/// The type that describes the difference between two objects.
type DiffType;
/// An error type that can be returned if the changes cannot be applied.
type Error;
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
/// Returns an error if the changes cannot be applied.
fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), Self::Error>;
}

+ 12
- 10
src/utils/kv_map.rs

@ -97,10 +97,12 @@ impl RecordingMap {
// FINALIZER
// --------------------------------------------------------------------------------------------
/// Consumes the [RecordingMap] and returns a [BTreeMap] containing the key-value pairs from
/// the initial data set that were read during recording.
pub fn into_proof(self) -> BTreeMap<K, V> {
self.trace.take()
/// Consumes the [RecordingMap] and returns a ([BTreeMap], [BTreeMap]) tuple. The first
/// element of the tuple is a map that represents the state of the map at the time `.finalize()`
/// is called. The second element contains the key-value pairs from the initial data set that
/// were read during recording.
pub fn finalize(self) -> (BTreeMap<K, V>, BTreeMap<K, V>) {
(self.data, self.trace.take())
}
// TEST HELPERS
@ -217,8 +219,8 @@ impl IntoIterator for RecordingMap {
/// - `removed` - a set of keys that were removed from the second map compared to the first map.
#[derive(Debug, Clone)]
pub struct KvMapDiff<K, V> {
updated: BTreeMap<K, V>,
removed: BTreeSet<K>,
pub updated: BTreeMap<K, V>,
pub removed: BTreeSet<K>,
}
impl<K, V> KvMapDiff<K, V> {
@ -296,7 +298,7 @@ mod tests {
}
// convert the map into a proof
let proof = map.into_proof();
let (_, proof) = map.finalize();
// check that the proof contains the expected values
for (key, value) in ITEMS.iter() {
@ -319,7 +321,7 @@ mod tests {
}
// convert the map into a proof
let proof = map.into_proof();
let (_, proof) = map.finalize();
// check that the proof contains the expected values
for (key, _) in ITEMS.iter() {
@ -383,7 +385,7 @@ mod tests {
// Note: The length reported by the proof will be different to the length originally
// reported by the map.
let proof = map.into_proof();
let (_, proof) = map.finalize();
// length of the proof should be equal to get_items + 1. The extra item is the original
// value at key = 4u64
@ -458,7 +460,7 @@ mod tests {
assert_eq!(map.updates_len(), 2);
// convert the map into a proof
let proof = map.into_proof();
let (_, proof) = map.finalize();
// check that the proof contains the expected values
for (key, value) in ITEMS.iter() {

Loading…
Cancel
Save