feat: reverse mutations generation, mutations serialization (#355)

* feat: revert mutations generation, mutations serialization
* tests: check both `apply_mutations` and `apply_mutations_with_reversion`
* feat: add `num_leaves` method for `Smt`
* refactor: improve ad-hoc benchmarks
* chore: update crate version to v0.13.1
This commit is contained in:
polydez
2024-12-27 07:16:38 +05:00
committed by GitHub
parent 1444bbc0f2
commit 589839fef1
12 changed files with 533 additions and 150 deletions

View File

@@ -4,8 +4,9 @@ use clap::Parser;
use miden_crypto::{
hash::rpo::{Rpo256, RpoDigest},
merkle::{MerkleError, Smt},
Felt, Word, ONE,
Felt, Word, EMPTY_WORD, ONE,
};
use rand::{prelude::IteratorRandom, thread_rng, Rng};
use rand_utils::rand_value;
#[derive(Parser, Debug)]
@@ -13,7 +14,7 @@ use rand_utils::rand_value;
pub struct BenchmarkCmd {
/// Size of the tree
#[clap(short = 's', long = "size")]
size: u64,
size: usize,
}
fn main() {
@@ -29,101 +30,153 @@ pub fn benchmark_smt() {
let mut entries = Vec::new();
for i in 0..tree_size {
let key = rand_value::<RpoDigest>();
let value = [ONE, ONE, ONE, Felt::new(i)];
let value = [ONE, ONE, ONE, Felt::new(i as u64)];
entries.push((key, value));
}
let mut tree = construction(entries, tree_size).unwrap();
insertion(&mut tree, tree_size).unwrap();
batched_insertion(&mut tree, tree_size).unwrap();
proof_generation(&mut tree, tree_size).unwrap();
let mut tree = construction(entries.clone(), tree_size).unwrap();
insertion(&mut tree).unwrap();
batched_insertion(&mut tree).unwrap();
batched_update(&mut tree, entries).unwrap();
proof_generation(&mut tree).unwrap();
}
/// Runs the construction benchmark for [`Smt`], returning the constructed tree.
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result<Smt, MerkleError> {
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: usize) -> Result<Smt, MerkleError> {
println!("Running a construction benchmark:");
let now = Instant::now();
let tree = Smt::with_entries(entries)?;
let elapsed = now.elapsed();
println!(
"Constructed a SMT with {} key-value pairs in {:.3} seconds",
size,
elapsed.as_secs_f32(),
);
let elapsed = now.elapsed().as_secs_f32();
println!("Constructed a SMT with {size} key-value pairs in {elapsed:.1} seconds");
println!("Number of leaf nodes: {}\n", tree.leaves().count());
Ok(tree)
}
/// Runs the insertion benchmark for the [`Smt`].
pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
pub fn insertion(tree: &mut Smt) -> Result<(), MerkleError> {
const NUM_INSERTIONS: usize = 1_000;
println!("Running an insertion benchmark:");
let size = tree.num_leaves();
let mut insertion_times = Vec::new();
for i in 0..20 {
for i in 0..NUM_INSERTIONS {
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let test_value = [ONE, ONE, ONE, Felt::new(size + i)];
let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
let now = Instant::now();
tree.insert(test_key, test_value);
let elapsed = now.elapsed();
insertion_times.push(elapsed.as_secs_f32());
insertion_times.push(elapsed.as_micros());
}
println!(
"An average insertion time measured by 20 inserts into a SMT with {} key-value pairs is {:.3} milliseconds\n",
size,
// calculate the average by dividing by 20 and convert to milliseconds by multiplying by
// 1000. As a result, we can only multiply by 50
insertion_times.iter().sum::<f32>() * 50f32,
"An average insertion time measured by {NUM_INSERTIONS} inserts into an SMT with {size} leaves is {:.0} μs\n",
// calculate the average
insertion_times.iter().sum::<u128>() as f64 / (NUM_INSERTIONS as f64),
);
Ok(())
}
pub fn batched_insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> {
const NUM_INSERTIONS: usize = 1_000;
println!("Running a batched insertion benchmark:");
let new_pairs: Vec<(RpoDigest, Word)> = (0..1000)
let size = tree.num_leaves();
let new_pairs: Vec<(RpoDigest, Word)> = (0..NUM_INSERTIONS)
.map(|i| {
let key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let value = [ONE, ONE, ONE, Felt::new(size + i)];
let value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
(key, value)
})
.collect();
let now = Instant::now();
let mutations = tree.compute_mutations(new_pairs);
let compute_elapsed = now.elapsed();
let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
let now = Instant::now();
tree.apply_mutations(mutations).unwrap();
let apply_elapsed = now.elapsed();
tree.apply_mutations(mutations)?;
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
println!(
"An average batch computation time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds",
size,
compute_elapsed.as_secs_f32() * 1000f32,
// Dividing by the number of iterations, 1000, and then multiplying by 1000 to get
// milliseconds, cancels out.
compute_elapsed.as_secs_f32(),
"An average insert-batch computation time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
compute_elapsed,
compute_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs
);
println!(
"An average batch application time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds",
size,
apply_elapsed.as_secs_f32() * 1000f32,
// Dividing by the number of iterations, 1000, and then multiplying by 1000 to get
// milliseconds, cancels out.
apply_elapsed.as_secs_f32(),
"An average insert-batch application time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
apply_elapsed,
apply_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs
);
println!(
"An average batch insertion time measured by a 1k-batch into an SMT with {} key-value pairs totals to {:.3} milliseconds",
size,
(compute_elapsed + apply_elapsed).as_secs_f32() * 1000f32,
"An average batch insertion time measured by a 1k-batch into an SMT with {size} leaves totals to {:.1} ms",
(compute_elapsed + apply_elapsed),
);
println!();
Ok(())
}
pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result<(), MerkleError> {
const NUM_UPDATES: usize = 1_000;
const REMOVAL_PROBABILITY: f64 = 0.2;
println!("Running a batched update benchmark:");
let size = tree.num_leaves();
let mut rng = thread_rng();
let new_pairs =
entries
.into_iter()
.choose_multiple(&mut rng, NUM_UPDATES)
.into_iter()
.map(|(key, _)| {
let value = if rng.gen_bool(REMOVAL_PROBABILITY) {
EMPTY_WORD
} else {
[ONE, ONE, ONE, Felt::new(rng.gen())]
};
(key, value)
});
assert_eq!(new_pairs.len(), NUM_UPDATES);
let now = Instant::now();
let mutations = tree.compute_mutations(new_pairs);
let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
let now = Instant::now();
tree.apply_mutations(mutations)?;
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
println!(
"An average update-batch computation time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
compute_elapsed,
compute_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs
);
println!(
"An average update-batch application time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
apply_elapsed,
apply_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs
);
println!(
"An average batch update time measured by a 1k-batch into an SMT with {size} leaves totals to {:.1} ms",
(compute_elapsed + apply_elapsed),
);
println!();
@@ -132,28 +185,29 @@ pub fn batched_insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
}
/// Runs the proof generation benchmark for the [`Smt`].
pub fn proof_generation(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
pub fn proof_generation(tree: &mut Smt) -> Result<(), MerkleError> {
const NUM_PROOFS: usize = 100;
println!("Running a proof generation benchmark:");
let mut insertion_times = Vec::new();
for i in 0..20 {
let size = tree.num_leaves();
for i in 0..NUM_PROOFS {
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let test_value = [ONE, ONE, ONE, Felt::new(size + i)];
let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
tree.insert(test_key, test_value);
let now = Instant::now();
let _proof = tree.open(&test_key);
let elapsed = now.elapsed();
insertion_times.push(elapsed.as_secs_f32());
insertion_times.push(now.elapsed().as_micros());
}
println!(
"An average proving time measured by 20 value proofs in a SMT with {} key-value pairs in {:.3} microseconds",
size,
// calculate the average by dividing by 20 and convert to microseconds by multiplying by
// 1000000. As a result, we can only multiply by 50000
insertion_times.iter().sum::<f32>() * 50000f32,
"An average proving time measured by {NUM_PROOFS} value proofs in an SMT with {size} leaves in {:.0} μs",
// calculate the average
insertion_times.iter().sum::<u128>() as f64 / (NUM_PROOFS as f64),
);
Ok(())

View File

@@ -128,7 +128,7 @@ impl NodeIndex {
self.value
}
/// Returns true if the current instance points to a right sibling node.
/// Returns `true` if the current instance points to a right sibling node.
pub const fn is_value_odd(&self) -> bool {
(self.value & 1) == 1
}

View File

@@ -303,7 +303,7 @@ impl PartialMmr {
if leaf_pos + 1 == self.forest
&& path.depth() == 0
&& self.peaks.last().map_or(false, |v| *v == leaf)
&& self.peaks.last().is_some_and(|v| *v == leaf)
{
self.track_latest = true;
return Ok(());

View File

@@ -70,7 +70,7 @@ impl SmtLeaf {
Self::Single((key, value))
}
/// Returns a new single leaf with the specified entry. The leaf index is derived from the
/// Returns a new multiple leaf with the specified entries. The leaf index is derived from the
/// entries' keys.
///
/// # Errors

View File

@@ -114,6 +114,11 @@ impl Smt {
<Self as SparseMerkleTree<SMT_DEPTH>>::root(self)
}
/// Returns the number of non-empty leaves in this tree.
pub fn num_leaves(&self) -> usize {
self.leaves.len()
}
/// Returns the leaf to which `key` maps
pub fn get_leaf(&self, key: &RpoDigest) -> SmtLeaf {
<Self as SparseMerkleTree<SMT_DEPTH>>::get_leaf(self, key)
@@ -200,7 +205,7 @@ impl Smt {
<Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
}
/// Apply the prospective mutations computed with [`Smt::compute_mutations()`] to this tree.
/// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
@@ -214,6 +219,23 @@ impl Smt {
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations(self, mutations)
}
/// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree
/// and returns the reverse mutation set.
///
/// Applying the reverse mutation sets to the updated tree will revert the changes.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
pub fn apply_mutations_with_reversion(
&mut self,
mutations: MutationSet<SMT_DEPTH, RpoDigest, Word>,
) -> Result<MutationSet<SMT_DEPTH, RpoDigest, Word>, MerkleError> {
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations_with_reversion(self, mutations)
}
// HELPERS
// --------------------------------------------------------------------------------------------
@@ -275,12 +297,12 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()))
}
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
self.inner_nodes.insert(index, inner_node);
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> {
self.inner_nodes.insert(index, inner_node)
}
fn remove_inner_node(&mut self, index: NodeIndex) {
let _ = self.inner_nodes.remove(&index);
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> {
self.inner_nodes.remove(&index)
}
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> {

View File

@@ -1,12 +1,14 @@
use alloc::vec::Vec;
use alloc::{collections::BTreeMap, vec::Vec};
use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
use crate::{
merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleStore},
merkle::{
smt::{NodeMutation, SparseMerkleTree},
EmptySubtreeRoots, MerkleStore, MutationSet,
},
utils::{Deserializable, Serializable},
Word, ONE, WORD_SIZE,
};
// SMT
// --------------------------------------------------------------------------------------------
@@ -412,21 +414,49 @@ fn test_prospective_insertion() {
let mutations = smt.compute_mutations(vec![(key_1, value_1)]);
assert_eq!(mutations.root(), root_1, "prospective root 1 did not match actual root 1");
smt.apply_mutations(mutations).unwrap();
let revert = apply_mutations(&mut smt, mutations);
assert_eq!(smt.root(), root_1, "mutations before and after apply did not match");
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), root_empty, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_1, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);
assert_eq!(
revert.node_mutations,
smt.inner_nodes.keys().map(|key| (*key, NodeMutation::Removal)).collect(),
"reverse mutations inner nodes did not match"
);
let mutations = smt.compute_mutations(vec![(key_2, value_2)]);
assert_eq!(mutations.root(), root_2, "prospective root 2 did not match actual root 2");
let mutations =
smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_2, value_2), (key_3, value_3)]);
assert_eq!(mutations.root(), root_3, "mutations before and after apply did not match");
smt.apply_mutations(mutations).unwrap();
let old_root = smt.root();
let revert = apply_mutations(&mut smt, mutations);
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);
// Edge case: multiple values at the same key, where a later pair restores the original value.
let mutations = smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_3, value_3)]);
assert_eq!(mutations.root(), root_3);
smt.apply_mutations(mutations).unwrap();
let old_root = smt.root();
let revert = apply_mutations(&mut smt, mutations);
assert_eq!(smt.root(), root_3);
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_3, value_3)]),
"reverse mutations pairs did not match"
);
// Test batch updates, and that the order doesn't matter.
let pairs =
@@ -437,8 +467,16 @@ fn test_prospective_insertion() {
root_empty,
"prospective root for batch removal did not match actual root",
);
smt.apply_mutations(mutations).unwrap();
let old_root = smt.root();
let revert = apply_mutations(&mut smt, mutations);
assert_eq!(smt.root(), root_empty, "mutations before and after apply did not match");
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]),
"reverse mutations pairs did not match"
);
let pairs = vec![(key_3, value_3), (key_1, value_1), (key_2, value_2)];
let mutations = smt.compute_mutations(pairs);
@@ -447,6 +485,72 @@ fn test_prospective_insertion() {
assert_eq!(smt.root(), root_3);
}
#[test]
fn test_mutations_revert() {
let mut smt = Smt::default();
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(1)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(2)]);
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(3)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3 = [3_u32.into(); WORD_SIZE];
smt.insert(key_1, value_1);
smt.insert(key_2, value_2);
let mutations =
smt.compute_mutations(vec![(key_1, EMPTY_WORD), (key_2, value_1), (key_3, value_3)]);
let original = smt.clone();
let revert = smt.apply_mutations_with_reversion(mutations).unwrap();
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), original.root(), "reverse mutations new root did not match");
smt.apply_mutations(revert).unwrap();
assert_eq!(smt, original, "SMT with applied revert mutations did not match original SMT");
}
#[test]
fn test_mutation_set_serialization() {
let mut smt = Smt::default();
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(1)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(2)]);
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(3)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3 = [3_u32.into(); WORD_SIZE];
smt.insert(key_1, value_1);
smt.insert(key_2, value_2);
let mutations =
smt.compute_mutations(vec![(key_1, EMPTY_WORD), (key_2, value_1), (key_3, value_3)]);
let serialized = mutations.to_bytes();
let deserialized =
MutationSet::<SMT_DEPTH, RpoDigest, Word>::read_from_bytes(&serialized).unwrap();
assert_eq!(deserialized, mutations, "deserialized mutations did not match original");
let revert = smt.apply_mutations_with_reversion(mutations).unwrap();
let serialized = revert.to_bytes();
let deserialized =
MutationSet::<SMT_DEPTH, RpoDigest, Word>::read_from_bytes(&serialized).unwrap();
assert_eq!(deserialized, revert, "deserialized mutations did not match original");
}
/// Tests that 2 key-value pairs stored in the same leaf have the same path
#[test]
fn test_smt_path_to_keys_in_same_leaf_are_equal() {
@@ -602,3 +706,19 @@ fn build_multiple_leaf_node(kv_pairs: &[(RpoDigest, Word)]) -> RpoDigest {
Rpo256::hash_elements(&elements)
}
/// Applies mutations with and without reversion to the given SMT, comparing resulting SMTs,
/// returning mutation set for reversion.
fn apply_mutations(
smt: &mut Smt,
mutation_set: MutationSet<SMT_DEPTH, RpoDigest, Word>,
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
let mut smt2 = smt.clone();
let reversion = smt.apply_mutations_with_reversion(mutation_set.clone()).unwrap();
smt2.apply_mutations(mutation_set).unwrap();
assert_eq!(&smt2, smt);
reversion
}

View File

@@ -1,5 +1,7 @@
use alloc::{collections::BTreeMap, vec::Vec};
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
use crate::{
hash::rpo::{Rpo256, RpoDigest},
@@ -135,7 +137,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
// If a subtree is empty, then can remove the inner node, since it's equal to the
// default value
self.remove_inner_node(index)
self.remove_inner_node(index);
} else {
self.insert_inner_node(index, InnerNode { left, right });
}
@@ -241,7 +243,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
}
}
/// Apply the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
/// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
/// this tree.
///
/// # Errors
@@ -275,8 +277,12 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
for (index, mutation) in node_mutations {
match mutation {
Removal => self.remove_inner_node(index),
Addition(node) => self.insert_inner_node(index, node),
Removal => {
self.remove_inner_node(index);
},
Addition(node) => {
self.insert_inner_node(index, node);
},
}
}
@@ -289,6 +295,76 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
Ok(())
}
/// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
/// this tree and returns the reverse mutation set. Applying the reverse mutation sets to the
/// updated tree will revert the changes.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
fn apply_mutations_with_reversion(
&mut self,
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
) -> Result<MutationSet<DEPTH, Self::Key, Self::Value>, MerkleError>
where
Self: Sized,
{
use NodeMutation::*;
let MutationSet {
old_root,
node_mutations,
new_pairs,
new_root,
} = mutations;
// Guard against accidentally trying to apply mutations that were computed against a
// different tree, including a stale version of this tree.
if old_root != self.root() {
return Err(MerkleError::ConflictingRoots {
expected_root: self.root(),
actual_root: old_root,
});
}
let mut reverse_mutations = BTreeMap::new();
for (index, mutation) in node_mutations {
match mutation {
Removal => {
if let Some(node) = self.remove_inner_node(index) {
reverse_mutations.insert(index, Addition(node));
}
},
Addition(node) => {
if let Some(old_node) = self.insert_inner_node(index, node) {
reverse_mutations.insert(index, Addition(old_node));
} else {
reverse_mutations.insert(index, Removal);
}
},
}
}
let mut reverse_pairs = BTreeMap::new();
for (key, value) in new_pairs {
if let Some(old_value) = self.insert_value(key.clone(), value) {
reverse_pairs.insert(key, old_value);
} else {
reverse_pairs.insert(key, Self::EMPTY_VALUE);
}
}
self.set_root(new_root);
Ok(MutationSet {
old_root: new_root,
node_mutations: reverse_mutations,
new_pairs: reverse_pairs,
new_root: old_root,
})
}
// REQUIRED METHODS
// ---------------------------------------------------------------------------------------------
@@ -302,10 +378,10 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
fn get_inner_node(&self, index: NodeIndex) -> InnerNode;
/// Inserts an inner node at the given index
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode);
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode>;
/// Removes an inner node at the given index
fn remove_inner_node(&mut self, index: NodeIndex);
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode>;
/// Inserts a leaf node, and returns the value at the key if already exists
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>;
@@ -352,7 +428,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub(crate) struct InnerNode {
pub struct InnerNode {
pub left: RpoDigest,
pub right: RpoDigest,
}
@@ -423,7 +499,7 @@ impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
/// [`MutationSet`] stores this type in relation to a [`NodeIndex`] to keep track of what changes
/// need to occur at which node indices.
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum NodeMutation {
pub enum NodeMutation {
/// Corresponds to [`SparseMerkleTree::remove_inner_node()`].
Removal,
/// Corresponds to [`SparseMerkleTree::insert_inner_node()`].
@@ -456,9 +532,94 @@ pub struct MutationSet<const DEPTH: u8, K, V> {
}
impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
/// Queries the root that was calculated during `SparseMerkleTree::compute_mutations()`. See
/// Returns the SMT root that was calculated during `SparseMerkleTree::compute_mutations()`. See
/// that method for more information.
pub fn root(&self) -> RpoDigest {
self.new_root
}
/// Returns the SMT root before the mutations were applied.
pub fn old_root(&self) -> RpoDigest {
self.old_root
}
/// Returns the set of inner nodes that need to be removed or added.
pub fn node_mutations(&self) -> &BTreeMap<NodeIndex, NodeMutation> {
&self.node_mutations
}
/// Returns the set of top-level key-value pairs that need to be added, updated or deleted
/// (i.e. set to `EMPTY_WORD`).
pub fn new_pairs(&self) -> &BTreeMap<K, V> {
&self.new_pairs
}
}
// SERIALIZATION
// ================================================================================================
impl Serializable for InnerNode {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.left.write_into(target);
self.right.write_into(target);
}
}
impl Deserializable for InnerNode {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let left = source.read()?;
let right = source.read()?;
Ok(Self { left, right })
}
}
impl Serializable for NodeMutation {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
match self {
NodeMutation::Removal => target.write_bool(false),
NodeMutation::Addition(inner_node) => {
target.write_bool(true);
inner_node.write_into(target);
},
}
}
}
impl Deserializable for NodeMutation {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
if source.read_bool()? {
let inner_node = source.read()?;
return Ok(NodeMutation::Addition(inner_node));
}
Ok(NodeMutation::Removal)
}
}
impl<const DEPTH: u8, K: Serializable, V: Serializable> Serializable for MutationSet<DEPTH, K, V> {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.old_root);
target.write(self.new_root);
self.node_mutations.write_into(target);
self.new_pairs.write_into(target);
}
}
impl<const DEPTH: u8, K: Deserializable + Ord, V: Deserializable> Deserializable
for MutationSet<DEPTH, K, V>
{
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let old_root = source.read()?;
let new_root = source.read()?;
let node_mutations = source.read()?;
let new_pairs = source.read()?;
Ok(Self {
old_root,
node_mutations,
new_pairs,
new_root,
})
}
}

View File

@@ -221,7 +221,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
<Self as SparseMerkleTree<DEPTH>>::compute_mutations(self, kv_pairs)
}
/// Apply the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this
/// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this
/// tree.
///
/// # Errors
@@ -236,6 +236,23 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
<Self as SparseMerkleTree<DEPTH>>::apply_mutations(self, mutations)
}
/// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to
/// this tree and returns the reverse mutation set.
///
/// Applying the reverse mutation sets to the updated tree will revert the changes.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
/// root hash the `mutations` were computed against, and the second item is the actual
/// current root of this tree.
pub fn apply_mutations_with_reversion(
&mut self,
mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
) -> Result<MutationSet<DEPTH, LeafIndex<DEPTH>, Word>, MerkleError> {
<Self as SparseMerkleTree<DEPTH>>::apply_mutations_with_reversion(self, mutations)
}
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
/// computed as `DEPTH - SUBTREE_DEPTH`.
///
@@ -321,12 +338,12 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(DEPTH, index.depth()))
}
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
self.inner_nodes.insert(index, inner_node);
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> {
self.inner_nodes.insert(index, inner_node)
}
fn remove_inner_node(&mut self, index: NodeIndex) {
let _ = self.inner_nodes.remove(&index);
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> {
self.inner_nodes.remove(&index)
}
fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> {