feat: add support for hashmaps in Smt and SimpleSmt (#363)

This commit is contained in:
polydez
2025-01-02 23:23:12 +05:00
committed by GitHub
parent e4373e54c9
commit 7ee6d7fb93
13 changed files with 171 additions and 84 deletions

View File

@@ -1,5 +1,11 @@
use alloc::string::String;
use core::{cmp::Ordering, fmt::Display, ops::Deref, slice};
use core::{
cmp::Ordering,
fmt::Display,
hash::{Hash, Hasher},
ops::Deref,
slice,
};
use thiserror::Error;
@@ -55,6 +61,12 @@ impl RpoDigest {
}
}
impl Hash for RpoDigest {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write(&self.as_bytes());
}
}
impl Digest for RpoDigest {
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
let mut result = [0; DIGEST_BYTES];

View File

@@ -3,6 +3,7 @@ use super::RpoDigest;
/// Representation of a node with two children used for iterating over containers.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(test, derive(PartialOrd, Ord))]
pub struct InnerNodeInfo {
pub value: RpoDigest,
pub left: RpoDigest,

View File

@@ -1,12 +1,8 @@
use alloc::{
collections::{BTreeMap, BTreeSet},
string::ToString,
vec::Vec,
};
use alloc::{collections::BTreeSet, string::ToString, vec::Vec};
use super::{
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath,
MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, MerkleError,
MerklePath, MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
};
mod error;
@@ -30,6 +26,8 @@ pub const SMT_DEPTH: u8 = 64;
// SMT
// ================================================================================================
type Leaves = super::Leaves<SmtLeaf>;
/// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented
/// by 4 field elements.
///
@@ -43,8 +41,8 @@ pub const SMT_DEPTH: u8 = 64;
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Smt {
root: RpoDigest,
leaves: BTreeMap<u64, SmtLeaf>,
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
inner_nodes: InnerNodes,
leaves: Leaves,
}
impl Smt {
@@ -64,8 +62,8 @@ impl Smt {
Self {
root,
leaves: BTreeMap::new(),
inner_nodes: BTreeMap::new(),
inner_nodes: Default::default(),
leaves: Default::default(),
}
}
@@ -148,11 +146,7 @@ impl Smt {
/// # Panics
/// With debug assertions on, this function panics if `root` does not match the root node in
/// `inner_nodes`.
pub fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, SmtLeaf>,
root: RpoDigest,
) -> Self {
pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: RpoDigest) -> Self {
// Our particular implementation of `from_raw_parts()` never returns `Err`.
<Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
}
@@ -339,8 +333,8 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, SmtLeaf>,
inner_nodes: InnerNodes,
leaves: Leaves,
root: RpoDigest,
) -> Result<Self, MerkleError> {
if cfg!(debug_assertions) {

View File

@@ -1,9 +1,9 @@
use alloc::{collections::BTreeMap, vec::Vec};
use alloc::vec::Vec;
use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
use crate::{
merkle::{
smt::{NodeMutation, SparseMerkleTree},
smt::{NodeMutation, SparseMerkleTree, UnorderedMap},
EmptySubtreeRoots, MerkleStore, MutationSet,
},
utils::{Deserializable, Serializable},
@@ -420,7 +420,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), root_empty, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_1, EMPTY_WORD)]),
UnorderedMap::from_iter([(key_1, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);
assert_eq!(
@@ -440,7 +440,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]),
UnorderedMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);
@@ -454,7 +454,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_3, value_3)]),
UnorderedMap::from_iter([(key_3, value_3)]),
"reverse mutations pairs did not match"
);
@@ -474,7 +474,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]),
UnorderedMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]),
"reverse mutations pairs did not match"
);
@@ -603,21 +603,21 @@ fn test_smt_get_value() {
/// Tests that `entries()` works as expected
#[test]
fn test_smt_entries() {
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]);
let key_2: RpoDigest = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]);
let key_1 = RpoDigest::from([ONE, ONE, ONE, ONE]);
let key_2 = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let entries = [(key_1, value_1), (key_2, value_2)];
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
let smt = Smt::with_entries(entries).unwrap();
let mut entries = smt.entries();
let mut expected = Vec::from_iter(entries);
expected.sort_by_key(|(k, _)| *k);
let mut actual: Vec<_> = smt.entries().cloned().collect();
actual.sort_by_key(|(k, _)| *k);
// Note: for simplicity, we assume the order `(k1,v1), (k2,v2)`. If a new implementation
// switches the order, it is OK to modify the order here as well.
assert_eq!(&(key_1, value_1), entries.next().unwrap());
assert_eq!(&(key_2, value_2), entries.next().unwrap());
assert!(entries.next().is_none());
assert_eq!(actual, expected);
}
/// Tests that `EMPTY_ROOT` constant generated in the `Smt` equals to the root of the empty tree of

View File

@@ -1,5 +1,5 @@
use alloc::{collections::BTreeMap, vec::Vec};
use core::mem;
use core::{hash::Hash, mem};
use num::Integer;
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
@@ -28,6 +28,15 @@ pub const SMT_MAX_DEPTH: u8 = 64;
// SPARSE MERKLE TREE
// ================================================================================================
/// A map whose keys are not guarantied to be ordered.
#[cfg(feature = "smt_hashmaps")]
type UnorderedMap<K, V> = hashbrown::HashMap<K, V>;
#[cfg(not(feature = "smt_hashmaps"))]
type UnorderedMap<K, V> = alloc::collections::BTreeMap<K, V>;
type InnerNodes = UnorderedMap<NodeIndex, InnerNode>;
type Leaves<T> = UnorderedMap<u64, T>;
type NodeMutations = UnorderedMap<NodeIndex, NodeMutation>;
/// An abstract description of a sparse Merkle tree.
///
/// A sparse Merkle tree is a key-value map which also supports proving that a given value is indeed
@@ -49,7 +58,7 @@ pub const SMT_MAX_DEPTH: u8 = 64;
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// The type for a key
type Key: Clone + Ord;
type Key: Clone + Ord + Eq + Hash;
/// The type for a value
type Value: Clone + PartialEq;
/// The type for a leaf
@@ -173,8 +182,8 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
use NodeMutation::*;
let mut new_root = self.root();
let mut new_pairs: BTreeMap<Self::Key, Self::Value> = Default::default();
let mut node_mutations: BTreeMap<NodeIndex, NodeMutation> = Default::default();
let mut new_pairs: UnorderedMap<Self::Key, Self::Value> = Default::default();
let mut node_mutations: NodeMutations = Default::default();
for (key, value) in kv_pairs {
// If the old value and the new value are the same, there is nothing to update.
@@ -341,7 +350,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
});
}
let mut reverse_mutations = BTreeMap::new();
let mut reverse_mutations = NodeMutations::new();
for (index, mutation) in node_mutations {
match mutation {
Removal => {
@@ -359,7 +368,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
}
}
let mut reverse_pairs = BTreeMap::new();
let mut reverse_pairs = UnorderedMap::new();
for (key, value) in new_pairs {
if let Some(old_value) = self.insert_value(key.clone(), value) {
reverse_pairs.insert(key, old_value);
@@ -384,8 +393,8 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// Construct this type from already computed leaves and nodes. The caller ensures passed
/// arguments are correct and consistent with each other.
fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, Self::Leaf>,
inner_nodes: InnerNodes,
leaves: Leaves<Self::Leaf>,
root: RpoDigest,
) -> Result<Self, MerkleError>
where
@@ -516,7 +525,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
#[cfg(feature = "concurrent")]
fn build_subtrees(
mut entries: Vec<(Self::Key, Self::Value)>,
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<u64, Self::Leaf>) {
) -> (InnerNodes, Leaves<Self::Leaf>) {
entries.sort_by_key(|item| {
let index = Self::key_to_leaf_index(&item.0);
index.value()
@@ -531,10 +540,10 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
#[cfg(feature = "concurrent")]
fn build_subtrees_from_sorted_entries(
entries: Vec<(Self::Key, Self::Value)>,
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<u64, Self::Leaf>) {
) -> (InnerNodes, Leaves<Self::Leaf>) {
use rayon::prelude::*;
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let mut accumulated_nodes: InnerNodes = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
@@ -651,8 +660,8 @@ pub enum NodeMutation {
/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by
/// `SparseMerkleTree::compute_mutations()`, and that can be applied with
/// `SparseMerkleTree::apply_mutations()`.
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct MutationSet<const DEPTH: u8, K, V> {
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct MutationSet<const DEPTH: u8, K: Eq + Hash, V> {
/// The root of the Merkle tree this MutationSet is for, recorded at the time
/// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying
/// mutations to the wrong tree or applying stale mutations to a tree that has since changed.
@@ -662,18 +671,18 @@ pub struct MutationSet<const DEPTH: u8, K, V> {
/// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a
/// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`]
/// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call.
node_mutations: BTreeMap<NodeIndex, NodeMutation>,
node_mutations: NodeMutations,
/// The set of top-level key-value pairs we're prospectively adding to the tree, including
/// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling
/// back to the existing value in the Merkle tree. Each entry corresponds to a
/// [`SparseMerkleTree::insert_value()`] call.
new_pairs: BTreeMap<K, V>,
new_pairs: UnorderedMap<K, V>,
/// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with
/// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call.
new_root: RpoDigest,
}
impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
impl<const DEPTH: u8, K: Eq + Hash, V> MutationSet<DEPTH, K, V> {
/// Returns the SMT root that was calculated during `SparseMerkleTree::compute_mutations()`. See
/// that method for more information.
pub fn root(&self) -> RpoDigest {
@@ -686,13 +695,13 @@ impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
}
/// Returns the set of inner nodes that need to be removed or added.
pub fn node_mutations(&self) -> &BTreeMap<NodeIndex, NodeMutation> {
pub fn node_mutations(&self) -> &NodeMutations {
&self.node_mutations
}
/// Returns the set of top-level key-value pairs that need to be added, updated or deleted
/// (i.e. set to `EMPTY_WORD`).
pub fn new_pairs(&self) -> &BTreeMap<K, V> {
pub fn new_pairs(&self) -> &UnorderedMap<K, V> {
&self.new_pairs
}
}
@@ -702,8 +711,8 @@ impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
impl Serializable for InnerNode {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.left.write_into(target);
self.right.write_into(target);
target.write(self.left);
target.write(self.right);
}
}
@@ -739,23 +748,57 @@ impl Deserializable for NodeMutation {
}
}
impl<const DEPTH: u8, K: Serializable, V: Serializable> Serializable for MutationSet<DEPTH, K, V> {
impl<const DEPTH: u8, K: Serializable + Eq + Hash, V: Serializable> Serializable
for MutationSet<DEPTH, K, V>
{
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.old_root);
target.write(self.new_root);
self.node_mutations.write_into(target);
self.new_pairs.write_into(target);
let inner_removals: Vec<_> = self
.node_mutations
.iter()
.filter(|(_, value)| matches!(value, NodeMutation::Removal))
.map(|(key, _)| key)
.collect();
let inner_additions: Vec<_> = self
.node_mutations
.iter()
.filter_map(|(key, value)| match value {
NodeMutation::Addition(node) => Some((key, node)),
_ => None,
})
.collect();
target.write(inner_removals);
target.write(inner_additions);
target.write_usize(self.new_pairs.len());
target.write_many(&self.new_pairs);
}
}
impl<const DEPTH: u8, K: Deserializable + Ord, V: Deserializable> Deserializable
impl<const DEPTH: u8, K: Deserializable + Ord + Eq + Hash, V: Deserializable> Deserializable
for MutationSet<DEPTH, K, V>
{
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let old_root = source.read()?;
let new_root = source.read()?;
let node_mutations = source.read()?;
let new_pairs = source.read()?;
let inner_removals: Vec<NodeIndex> = source.read()?;
let inner_additions: Vec<(NodeIndex, InnerNode)> = source.read()?;
let node_mutations = NodeMutations::from_iter(
inner_removals.into_iter().map(|index| (index, NodeMutation::Removal)).chain(
inner_additions
.into_iter()
.map(|(index, node)| (index, NodeMutation::Addition(node))),
),
);
let num_new_pairs = source.read_usize()?;
let new_pairs = source.read_many(num_new_pairs)?;
let new_pairs = UnorderedMap::from_iter(new_pairs);
Ok(Self {
old_root,
@@ -768,6 +811,7 @@ impl<const DEPTH: u8, K: Deserializable + Ord, V: Deserializable> Deserializable
// SUBTREES
// ================================================================================================
/// A subtree is of depth 8.
const SUBTREE_DEPTH: u8 = 8;
@@ -787,10 +831,10 @@ pub struct SubtreeLeaf {
}
/// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`].
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone)]
pub(crate) struct PairComputations<K, L> {
/// Literal leaves to be added to the sparse Merkle tree's internal mapping.
pub nodes: BTreeMap<K, L>,
pub nodes: UnorderedMap<K, L>,
/// "Conceptual" leaves that will be used for computations.
pub leaves: Vec<Vec<SubtreeLeaf>>,
}
@@ -818,7 +862,7 @@ impl<'s> SubtreeLeavesIter<'s> {
Self { leaves: leaves.drain(..).peekable() }
}
}
impl core::iter::Iterator for SubtreeLeavesIter<'_> {
impl Iterator for SubtreeLeavesIter<'_> {
type Item = Vec<SubtreeLeaf>;
/// Each `next()` collects an entire subtree.

View File

@@ -1,11 +1,8 @@
use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};
use alloc::{collections::BTreeSet, vec::Vec};
use super::{
super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError,
MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex,
MerkleError, MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
SMT_MAX_DEPTH, SMT_MIN_DEPTH,
};
@@ -15,6 +12,8 @@ mod tests;
// SPARSE MERKLE TREE
// ================================================================================================
type Leaves = super::Leaves<Word>;
/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
///
/// The root of the tree is recomputed on each new leaf update.
@@ -22,8 +21,8 @@ mod tests;
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct SimpleSmt<const DEPTH: u8> {
root: RpoDigest,
leaves: BTreeMap<u64, Word>,
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
inner_nodes: InnerNodes,
leaves: Leaves,
}
impl<const DEPTH: u8> SimpleSmt<DEPTH> {
@@ -54,8 +53,8 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
Ok(Self {
root,
leaves: BTreeMap::new(),
inner_nodes: BTreeMap::new(),
inner_nodes: Default::default(),
leaves: Default::default(),
})
}
@@ -108,11 +107,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
/// # Panics
/// With debug assertions on, this function panics if `root` does not match the root node in
/// `inner_nodes`.
pub fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, Word>,
root: RpoDigest,
) -> Self {
pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: RpoDigest) -> Self {
// Our particular implementation of `from_raw_parts()` never returns `Err`.
<Self as SparseMerkleTree<DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
}
@@ -344,8 +339,8 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0);
fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, Word>,
inner_nodes: InnerNodes,
leaves: Leaves,
root: RpoDigest,
) -> Result<Self, MerkleError> {
if cfg!(debug_assertions) {

View File

@@ -141,12 +141,15 @@ fn test_inner_node_iterator() -> Result<(), MerkleError> {
let l2n2 = tree.get_node(NodeIndex::make(2, 2))?;
let l2n3 = tree.get_node(NodeIndex::make(2, 3))?;
let nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
let expected = vec![
let mut nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
let mut expected = [
InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
];
nodes.sort();
expected.sort();
assert_eq!(nodes, expected);
Ok(())