Browse Source

feat: add support for hashmaps in `Smt` and `SimpleSmt` (#363)

next
polydez 4 months ago
committed by GitHub
parent
commit
7ee6d7fb93
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
13 changed files with 171 additions and 84 deletions
  1. +1
    -1
      .github/workflows/test.yml
  2. +1
    -0
      CHANGELOG.md
  3. +31
    -0
      Cargo.lock
  4. +2
    -0
      Cargo.toml
  5. +4
    -1
      Makefile
  6. +1
    -0
      README.md
  7. +13
    -1
      src/hash/rescue/rpo/digest.rs
  8. +1
    -0
      src/merkle/node.rs
  9. +12
    -18
      src/merkle/smt/full/mod.rs
  10. +15
    -15
      src/merkle/smt/full/tests.rs
  11. +73
    -29
      src/merkle/smt/mod.rs
  12. +12
    -17
      src/merkle/smt/simple/mod.rs
  13. +5
    -2
      src/merkle/smt/simple/tests.rs

+ 1
- 1
.github/workflows/test.yml

@ -17,7 +17,7 @@ jobs:
matrix:
toolchain: [stable, nightly]
os: [ubuntu]
args: [default, no-std]
args: [default, smt-hashmaps, no-std]
timeout-minutes: 30
steps:
- uses: actions/checkout@main

+ 1
- 0
CHANGELOG.md

@ -7,6 +7,7 @@
- Fixed a bug in the implementation of `draw_integers` for `RpoRandomCoin` (#343).
- [BREAKING] Refactor error messages and use `thiserror` to derive errors (#344).
- [BREAKING] Updated Winterfell dependency to v0.11 (#346).
- Added support for hashmaps in `Smt` and `SimpleSmt` which gives up to 10x boost in some operations (#363).
## 0.12.0 (2024-10-30)

+ 31
- 0
Cargo.lock

@ -11,6 +11,12 @@ dependencies = [
"memchr",
]
[[package]]
name = "allocator-api2"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
[[package]]
name = "anes"
version = "0.1.6"
@ -349,6 +355,12 @@ version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
[[package]]
name = "equivalent"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "errno"
version = "0.3.10"
@ -371,6 +383,12 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foldhash"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2"
[[package]]
name = "generic-array"
version = "0.14.7"
@ -410,6 +428,18 @@ dependencies = [
"crunchy",
]
[[package]]
name = "hashbrown"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash",
"serde",
]
[[package]]
name = "heck"
version = "0.5.0"
@ -535,6 +565,7 @@ dependencies = [
"criterion",
"getrandom",
"glob",
"hashbrown",
"hex",
"num",
"num-complex",

+ 2
- 0
Cargo.toml

@ -48,6 +48,7 @@ harness = false
concurrent = ["dep:rayon"]
default = ["std", "concurrent"]
executable = ["dep:clap", "dep:rand-utils", "std"]
smt_hashmaps = ["dep:hashbrown"]
internal = []
serde = ["dep:serde", "serde?/alloc", "winter-math/serde"]
std = [
@ -63,6 +64,7 @@ std = [
[dependencies]
blake3 = { version = "1.5", default-features = false }
clap = { version = "4.5", optional = true, features = ["derive"] }
hashbrown = { version = "0.15", optional = true, features = ["serde"] }
num = { version = "0.4", default-features = false, features = ["alloc", "libm"] }
num-complex = { version = "0.4", default-features = false }
rand = { version = "0.8", default-features = false }

+ 4
- 1
Makefile

@ -46,6 +46,9 @@ doc: ## Generate and check documentation
test-default: ## Run tests with default features
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --all-features
.PHONY: test-smt-hashmaps
test-smt-hashmaps: ## Run tests with `smt_hashmaps` feature enabled
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --features smt_hashmaps
.PHONY: test-no-std
test-no-std: ## Run tests with `no-default-features` (std)
@ -53,7 +56,7 @@ test-no-std: ## Run tests with `no-default-features` (std)
.PHONY: test
test: test-default test-no-std ## Run all tests
test: test-default test-smt-hashmaps test-no-std ## Run all tests
# --- checking ------------------------------------------------------------------------------------

+ 1
- 0
README.md

@ -63,6 +63,7 @@ This crate can be compiled with the following features:
- `concurrent`- enabled by default; enables multi-threaded implementation of `Smt::with_entries()` which significantly improves performance on multi-core CPUs.
- `std` - enabled by default and relies on the Rust standard library.
- `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly.
- `smt_hashmaps` - uses hashbrown hashmaps in SMT implementation which significantly improves performance of SMT updating. Keys ordering in SMT iterators is not guarantied when this feature is enabled.
All of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/) to support heap-allocated collections.

+ 13
- 1
src/hash/rescue/rpo/digest.rs

@ -1,5 +1,11 @@
use alloc::string::String;
use core::{cmp::Ordering, fmt::Display, ops::Deref, slice};
use core::{
cmp::Ordering,
fmt::Display,
hash::{Hash, Hasher},
ops::Deref,
slice,
};
use thiserror::Error;
@ -55,6 +61,12 @@ impl RpoDigest {
}
}
impl Hash for RpoDigest {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write(&self.as_bytes());
}
}
impl Digest for RpoDigest {
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
let mut result = [0; DIGEST_BYTES];

+ 1
- 0
src/merkle/node.rs

@ -3,6 +3,7 @@ use super::RpoDigest;
/// Representation of a node with two children used for iterating over containers.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(test, derive(PartialOrd, Ord))]
pub struct InnerNodeInfo {
pub value: RpoDigest,
pub left: RpoDigest,

+ 12
- 18
src/merkle/smt/full/mod.rs

@ -1,12 +1,8 @@
use alloc::{
collections::{BTreeMap, BTreeSet},
string::ToString,
vec::Vec,
};
use alloc::{collections::BTreeSet, string::ToString, vec::Vec};
use super::{
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath,
MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, MerkleError,
MerklePath, MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
};
mod error;
@ -30,6 +26,8 @@ pub const SMT_DEPTH: u8 = 64;
// SMT
// ================================================================================================
type Leaves = super::Leaves<SmtLeaf>;
/// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented
/// by 4 field elements.
///
@ -43,8 +41,8 @@ pub const SMT_DEPTH: u8 = 64;
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Smt {
root: RpoDigest,
leaves: BTreeMap<u64, SmtLeaf>,
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
inner_nodes: InnerNodes,
leaves: Leaves,
}
impl Smt {
@ -64,8 +62,8 @@ impl Smt {
Self {
root,
leaves: BTreeMap::new(),
inner_nodes: BTreeMap::new(),
inner_nodes: Default::default(),
leaves: Default::default(),
}
}
@ -148,11 +146,7 @@ impl Smt {
/// # Panics
/// With debug assertions on, this function panics if `root` does not match the root node in
/// `inner_nodes`.
pub fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, SmtLeaf>,
root: RpoDigest,
) -> Self {
pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: RpoDigest) -> Self {
// Our particular implementation of `from_raw_parts()` never returns `Err`.
<Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
}
@ -339,8 +333,8 @@ impl SparseMerkleTree for Smt {
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, SmtLeaf>,
inner_nodes: InnerNodes,
leaves: Leaves,
root: RpoDigest,
) -> Result<Self, MerkleError> {
if cfg!(debug_assertions) {

+ 15
- 15
src/merkle/smt/full/tests.rs

@ -1,9 +1,9 @@
use alloc::{collections::BTreeMap, vec::Vec};
use alloc::vec::Vec;
use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
use crate::{
merkle::{
smt::{NodeMutation, SparseMerkleTree},
smt::{NodeMutation, SparseMerkleTree, UnorderedMap},
EmptySubtreeRoots, MerkleStore, MutationSet,
},
utils::{Deserializable, Serializable},
@ -420,7 +420,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), root_empty, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_1, EMPTY_WORD)]),
UnorderedMap::from_iter([(key_1, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);
assert_eq!(
@ -440,7 +440,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]),
UnorderedMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);
@ -454,7 +454,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_3, value_3)]),
UnorderedMap::from_iter([(key_3, value_3)]),
"reverse mutations pairs did not match"
);
@ -474,7 +474,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]),
UnorderedMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]),
"reverse mutations pairs did not match"
);
@ -603,21 +603,21 @@ fn test_smt_get_value() {
/// Tests that `entries()` works as expected
#[test]
fn test_smt_entries() {
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]);
let key_2: RpoDigest = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]);
let key_1 = RpoDigest::from([ONE, ONE, ONE, ONE]);
let key_2 = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let entries = [(key_1, value_1), (key_2, value_2)];
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
let smt = Smt::with_entries(entries).unwrap();
let mut entries = smt.entries();
let mut expected = Vec::from_iter(entries);
expected.sort_by_key(|(k, _)| *k);
let mut actual: Vec<_> = smt.entries().cloned().collect();
actual.sort_by_key(|(k, _)| *k);
// Note: for simplicity, we assume the order `(k1,v1), (k2,v2)`. If a new implementation
// switches the order, it is OK to modify the order here as well.
assert_eq!(&(key_1, value_1), entries.next().unwrap());
assert_eq!(&(key_2, value_2), entries.next().unwrap());
assert!(entries.next().is_none());
assert_eq!(actual, expected);
}
/// Tests that `EMPTY_ROOT` constant generated in the `Smt` equals to the root of the empty tree of

+ 73
- 29
src/merkle/smt/mod.rs

@ -1,5 +1,5 @@
use alloc::{collections::BTreeMap, vec::Vec};
use core::mem;
use core::{hash::Hash, mem};
use num::Integer;
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
@ -28,6 +28,15 @@ pub const SMT_MAX_DEPTH: u8 = 64;
// SPARSE MERKLE TREE
// ================================================================================================
/// A map whose keys are not guarantied to be ordered.
#[cfg(feature = "smt_hashmaps")]
type UnorderedMap<K, V> = hashbrown::HashMap<K, V>;
#[cfg(not(feature = "smt_hashmaps"))]
type UnorderedMap<K, V> = alloc::collections::BTreeMap<K, V>;
type InnerNodes = UnorderedMap<NodeIndex, InnerNode>;
type Leaves<T> = UnorderedMap<u64, T>;
type NodeMutations = UnorderedMap<NodeIndex, NodeMutation>;
/// An abstract description of a sparse Merkle tree.
///
/// A sparse Merkle tree is a key-value map which also supports proving that a given value is indeed
@ -49,7 +58,7 @@ pub const SMT_MAX_DEPTH: u8 = 64;
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// The type for a key
type Key: Clone + Ord;
type Key: Clone + Ord + Eq + Hash;
/// The type for a value
type Value: Clone + PartialEq;
/// The type for a leaf
@ -173,8 +182,8 @@ pub(crate) trait SparseMerkleTree {
use NodeMutation::*;
let mut new_root = self.root();
let mut new_pairs: BTreeMap<Self::Key, Self::Value> = Default::default();
let mut node_mutations: BTreeMap<NodeIndex, NodeMutation> = Default::default();
let mut new_pairs: UnorderedMap<Self::Key, Self::Value> = Default::default();
let mut node_mutations: NodeMutations = Default::default();
for (key, value) in kv_pairs {
// If the old value and the new value are the same, there is nothing to update.
@ -341,7 +350,7 @@ pub(crate) trait SparseMerkleTree {
});
}
let mut reverse_mutations = BTreeMap::new();
let mut reverse_mutations = NodeMutations::new();
for (index, mutation) in node_mutations {
match mutation {
Removal => {
@ -359,7 +368,7 @@ pub(crate) trait SparseMerkleTree {
}
}
let mut reverse_pairs = BTreeMap::new();
let mut reverse_pairs = UnorderedMap::new();
for (key, value) in new_pairs {
if let Some(old_value) = self.insert_value(key.clone(), value) {
reverse_pairs.insert(key, old_value);
@ -384,8 +393,8 @@ pub(crate) trait SparseMerkleTree {
/// Construct this type from already computed leaves and nodes. The caller ensures passed
/// arguments are correct and consistent with each other.
fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, Self::Leaf>,
inner_nodes: InnerNodes,
leaves: Leaves<Self::Leaf>,
root: RpoDigest,
) -> Result<Self, MerkleError>
where
@ -516,7 +525,7 @@ pub(crate) trait SparseMerkleTree {
#[cfg(feature = "concurrent")]
fn build_subtrees(
mut entries: Vec<(Self::Key, Self::Value)>,
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<u64, Self::Leaf>) {
) -> (InnerNodes, Leaves<Self::Leaf>) {
entries.sort_by_key(|item| {
let index = Self::key_to_leaf_index(&item.0);
index.value()
@ -531,10 +540,10 @@ pub(crate) trait SparseMerkleTree {
#[cfg(feature = "concurrent")]
fn build_subtrees_from_sorted_entries(
entries: Vec<(Self::Key, Self::Value)>,
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<u64, Self::Leaf>) {
) -> (InnerNodes, Leaves<Self::Leaf>) {
use rayon::prelude::*;
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let mut accumulated_nodes: InnerNodes = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
@ -651,8 +660,8 @@ pub enum NodeMutation {
/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by
/// `SparseMerkleTree::compute_mutations()`, and that can be applied with
/// `SparseMerkleTree::apply_mutations()`.
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct MutationSet<const DEPTH: u8, K, V> {
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct MutationSet<const DEPTH: u8, K: Eq + Hash, V> {
/// The root of the Merkle tree this MutationSet is for, recorded at the time
/// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying
/// mutations to the wrong tree or applying stale mutations to a tree that has since changed.
@ -662,18 +671,18 @@ pub struct MutationSet {
/// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a
/// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`]
/// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call.
node_mutations: BTreeMap<NodeIndex, NodeMutation>,
node_mutations: NodeMutations,
/// The set of top-level key-value pairs we're prospectively adding to the tree, including
/// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling
/// back to the existing value in the Merkle tree. Each entry corresponds to a
/// [`SparseMerkleTree::insert_value()`] call.
new_pairs: BTreeMap<K, V>,
new_pairs: UnorderedMap<K, V>,
/// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with
/// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call.
new_root: RpoDigest,
}
impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
impl<const DEPTH: u8, K: Eq + Hash, V> MutationSet<DEPTH, K, V> {
/// Returns the SMT root that was calculated during `SparseMerkleTree::compute_mutations()`. See
/// that method for more information.
pub fn root(&self) -> RpoDigest {
@ -686,13 +695,13 @@ impl MutationSet {
}
/// Returns the set of inner nodes that need to be removed or added.
pub fn node_mutations(&self) -> &BTreeMap<NodeIndex, NodeMutation> {
pub fn node_mutations(&self) -> &NodeMutations {
&self.node_mutations
}
/// Returns the set of top-level key-value pairs that need to be added, updated or deleted
/// (i.e. set to `EMPTY_WORD`).
pub fn new_pairs(&self) -> &BTreeMap<K, V> {
pub fn new_pairs(&self) -> &UnorderedMap<K, V> {
&self.new_pairs
}
}
@ -702,8 +711,8 @@ impl MutationSet {
impl Serializable for InnerNode {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.left.write_into(target);
self.right.write_into(target);
target.write(self.left);
target.write(self.right);
}
}
@ -739,23 +748,57 @@ impl Deserializable for NodeMutation {
}
}
impl<const DEPTH: u8, K: Serializable, V: Serializable> Serializable for MutationSet<DEPTH, K, V> {
impl<const DEPTH: u8, K: Serializable + Eq + Hash, V: Serializable> Serializable
for MutationSet<DEPTH, K, V>
{
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.old_root);
target.write(self.new_root);
self.node_mutations.write_into(target);
self.new_pairs.write_into(target);
let inner_removals: Vec<_> = self
.node_mutations
.iter()
.filter(|(_, value)| matches!(value, NodeMutation::Removal))
.map(|(key, _)| key)
.collect();
let inner_additions: Vec<_> = self
.node_mutations
.iter()
.filter_map(|(key, value)| match value {
NodeMutation::Addition(node) => Some((key, node)),
_ => None,
})
.collect();
target.write(inner_removals);
target.write(inner_additions);
target.write_usize(self.new_pairs.len());
target.write_many(&self.new_pairs);
}
}
impl<const DEPTH: u8, K: Deserializable + Ord, V: Deserializable> Deserializable
impl<const DEPTH: u8, K: Deserializable + Ord + Eq + Hash, V: Deserializable> Deserializable
for MutationSet<DEPTH, K, V>
{
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let old_root = source.read()?;
let new_root = source.read()?;
let node_mutations = source.read()?;
let new_pairs = source.read()?;
let inner_removals: Vec<NodeIndex> = source.read()?;
let inner_additions: Vec<(NodeIndex, InnerNode)> = source.read()?;
let node_mutations = NodeMutations::from_iter(
inner_removals.into_iter().map(|index| (index, NodeMutation::Removal)).chain(
inner_additions
.into_iter()
.map(|(index, node)| (index, NodeMutation::Addition(node))),
),
);
let num_new_pairs = source.read_usize()?;
let new_pairs = source.read_many(num_new_pairs)?;
let new_pairs = UnorderedMap::from_iter(new_pairs);
Ok(Self {
old_root,
@ -768,6 +811,7 @@ impl Deserializable
// SUBTREES
// ================================================================================================
/// A subtree is of depth 8.
const SUBTREE_DEPTH: u8 = 8;
@ -787,10 +831,10 @@ pub struct SubtreeLeaf {
}
/// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`].
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone)]
pub(crate) struct PairComputations<K, L> {
/// Literal leaves to be added to the sparse Merkle tree's internal mapping.
pub nodes: BTreeMap<K, L>,
pub nodes: UnorderedMap<K, L>,
/// "Conceptual" leaves that will be used for computations.
pub leaves: Vec<Vec<SubtreeLeaf>>,
}
@ -818,7 +862,7 @@ impl<'s> SubtreeLeavesIter<'s> {
Self { leaves: leaves.drain(..).peekable() }
}
}
impl core::iter::Iterator for SubtreeLeavesIter<'_> {
impl Iterator for SubtreeLeavesIter<'_> {
type Item = Vec<SubtreeLeaf>;
/// Each `next()` collects an entire subtree.

+ 12
- 17
src/merkle/smt/simple/mod.rs

@ -1,11 +1,8 @@
use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};
use alloc::{collections::BTreeSet, vec::Vec};
use super::{
super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError,
MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex,
MerkleError, MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
SMT_MAX_DEPTH, SMT_MIN_DEPTH,
};
@ -15,6 +12,8 @@ mod tests;
// SPARSE MERKLE TREE
// ================================================================================================
type Leaves = super::Leaves<Word>;
/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
///
/// The root of the tree is recomputed on each new leaf update.
@ -22,8 +21,8 @@ mod tests;
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct SimpleSmt<const DEPTH: u8> {
root: RpoDigest,
leaves: BTreeMap<u64, Word>,
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
inner_nodes: InnerNodes,
leaves: Leaves,
}
impl<const DEPTH: u8> SimpleSmt<DEPTH> {
@ -54,8 +53,8 @@ impl SimpleSmt {
Ok(Self {
root,
leaves: BTreeMap::new(),
inner_nodes: BTreeMap::new(),
inner_nodes: Default::default(),
leaves: Default::default(),
})
}
@ -108,11 +107,7 @@ impl SimpleSmt {
/// # Panics
/// With debug assertions on, this function panics if `root` does not match the root node in
/// `inner_nodes`.
pub fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, Word>,
root: RpoDigest,
) -> Self {
pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: RpoDigest) -> Self {
// Our particular implementation of `from_raw_parts()` never returns `Err`.
<Self as SparseMerkleTree<DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
}
@ -344,8 +339,8 @@ impl SparseMerkleTree for SimpleSmt {
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0);
fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, Word>,
inner_nodes: InnerNodes,
leaves: Leaves,
root: RpoDigest,
) -> Result<Self, MerkleError> {
if cfg!(debug_assertions) {

+ 5
- 2
src/merkle/smt/simple/tests.rs

@ -141,12 +141,15 @@ fn test_inner_node_iterator() -> Result<(), MerkleError> {
let l2n2 = tree.get_node(NodeIndex::make(2, 2))?;
let l2n3 = tree.get_node(NodeIndex::make(2, 3))?;
let nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
let expected = vec![
let mut nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
let mut expected = [
InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
];
nodes.sort();
expected.sort();
assert_eq!(nodes, expected);
Ok(())

Loading…
Cancel
Save