@ -1,5 +1,5 @@
use alloc ::{ collections ::BTreeMap , vec ::Vec } ;
use core ::mem ;
use core ::{ hash ::Hash , mem } ;
use num ::Integer ;
use winter_utils ::{ ByteReader , ByteWriter , Deserializable , DeserializationError , Serializable } ;
@ -28,6 +28,15 @@ pub const SMT_MAX_DEPTH: u8 = 64;
// SPARSE MERKLE TREE
// ================================================================================================
/// A map whose keys are not guarantied to be ordered.
#[ cfg(feature = " smt_hashmaps " ) ]
type UnorderedMap < K , V > = hashbrown ::HashMap < K , V > ;
#[ cfg(not(feature = " smt_hashmaps " )) ]
type UnorderedMap < K , V > = alloc ::collections ::BTreeMap < K , V > ;
type InnerNodes = UnorderedMap < NodeIndex , InnerNode > ;
type Leaves < T > = UnorderedMap < u64 , T > ;
type NodeMutations = UnorderedMap < NodeIndex , NodeMutation > ;
/// An abstract description of a sparse Merkle tree.
///
/// A sparse Merkle tree is a key-value map which also supports proving that a given value is indeed
@ -49,7 +58,7 @@ pub const SMT_MAX_DEPTH: u8 = 64;
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
pub ( crate ) trait SparseMerkleTree < const DEPTH : u8 > {
/// The type for a key
type Key : Clone + Ord ;
type Key : Clone + Ord + Eq + Hash ;
/// The type for a value
type Value : Clone + PartialEq ;
/// The type for a leaf
@ -173,8 +182,8 @@ pub(crate) trait SparseMerkleTree {
use NodeMutation ::* ;
let mut new_root = self . root ( ) ;
let mut new_pairs : BTree Map< Self ::Key , Self ::Value > = Default ::default ( ) ;
let mut node_mutations : BTreeMap < NodeIndex , NodeMutation > = Default ::default ( ) ;
let mut new_pairs : Unordered Map< Self ::Key , Self ::Value > = Default ::default ( ) ;
let mut node_mutations : NodeMutations = Default ::default ( ) ;
for ( key , value ) in kv_pairs {
// If the old value and the new value are the same, there is nothing to update.
@ -341,7 +350,7 @@ pub(crate) trait SparseMerkleTree {
} ) ;
}
let mut reverse_mutations = BTreeMap ::new ( ) ;
let mut reverse_mutations = NodeMutations ::new ( ) ;
for ( index , mutation ) in node_mutations {
match mutation {
Removal = > {
@ -359,7 +368,7 @@ pub(crate) trait SparseMerkleTree {
}
}
let mut reverse_pairs = BTree Map ::new ( ) ;
let mut reverse_pairs = Unordered Map ::new ( ) ;
for ( key , value ) in new_pairs {
if let Some ( old_value ) = self . insert_value ( key . clone ( ) , value ) {
reverse_pairs . insert ( key , old_value ) ;
@ -384,8 +393,8 @@ pub(crate) trait SparseMerkleTree {
/// Construct this type from already computed leaves and nodes. The caller ensures passed
/// arguments are correct and consistent with each other.
fn from_raw_parts (
inner_nodes : BTreeMap < NodeIndex , InnerNode > ,
leaves : BTreeMap < u64 , Self ::Leaf > ,
inner_nodes : InnerNodes ,
leaves : Leaves < Self ::Leaf > ,
root : RpoDigest ,
) -> Result < Self , MerkleError >
where
@ -516,7 +525,7 @@ pub(crate) trait SparseMerkleTree {
#[ cfg(feature = " concurrent " ) ]
fn build_subtrees (
mut entries : Vec < ( Self ::Key , Self ::Value ) > ,
) -> ( BTreeMap < NodeIndex , InnerNode > , BTreeMap < u64 , Self ::Leaf > ) {
) -> ( InnerNodes , Leaves < Self ::Leaf > ) {
entries . sort_by_key ( | item | {
let index = Self ::key_to_leaf_index ( & item . 0 ) ;
index . value ( )
@ -531,10 +540,10 @@ pub(crate) trait SparseMerkleTree {
#[ cfg(feature = " concurrent " ) ]
fn build_subtrees_from_sorted_entries (
entries : Vec < ( Self ::Key , Self ::Value ) > ,
) -> ( BTreeMap < NodeIndex , InnerNode > , BTreeMap < u64 , Self ::Leaf > ) {
) -> ( InnerNodes , Leaves < Self ::Leaf > ) {
use rayon ::prelude ::* ;
let mut accumulated_nodes : BTreeMap < NodeIndex , InnerNode > = Default ::default ( ) ;
let mut accumulated_nodes : InnerNodes = Default ::default ( ) ;
let PairComputations {
leaves : mut leaf_subtrees ,
@ -651,8 +660,8 @@ pub enum NodeMutation {
/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by
/// `SparseMerkleTree::compute_mutations()`, and that can be applied with
/// `SparseMerkleTree::apply_mutations()`.
#[ derive(Debug, Clone, PartialEq, Eq, Default ) ]
pub struct MutationSet < const DEPTH : u8 , K , V > {
#[ derive(Debug, Clone, Default, PartialEq, Eq) ]
pub struct MutationSet < const DEPTH : u8 , K : Eq + Hash , V > {
/// The root of the Merkle tree this MutationSet is for, recorded at the time
/// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying
/// mutations to the wrong tree or applying stale mutations to a tree that has since changed.
@ -662,18 +671,18 @@ pub struct MutationSet {
/// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a
/// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`]
/// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call.
node_mutations : BTreeMap < NodeIndex , NodeMutation > ,
node_mutations : NodeMutations ,
/// The set of top-level key-value pairs we're prospectively adding to the tree, including
/// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling
/// back to the existing value in the Merkle tree. Each entry corresponds to a
/// [`SparseMerkleTree::insert_value()`] call.
new_pairs : BTree Map< K , V > ,
new_pairs : Unordered Map< K , V > ,
/// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with
/// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call.
new_root : RpoDigest ,
}
impl < const DEPTH : u8 , K , V > MutationSet < DEPTH , K , V > {
impl < const DEPTH : u8 , K : Eq + Hash , V > MutationSet < DEPTH , K , V > {
/// Returns the SMT root that was calculated during `SparseMerkleTree::compute_mutations()`. See
/// that method for more information.
pub fn root ( & self ) -> RpoDigest {
@ -686,13 +695,13 @@ impl MutationSet {
}
/// Returns the set of inner nodes that need to be removed or added.
pub fn node_mutations ( & self ) -> & BTreeMap < NodeIndex , NodeMutation > {
pub fn node_mutations ( & self ) -> & NodeMutations {
& self . node_mutations
}
/// Returns the set of top-level key-value pairs that need to be added, updated or deleted
/// (i.e. set to `EMPTY_WORD`).
pub fn new_pairs ( & self ) -> & BTree Map< K , V > {
pub fn new_pairs ( & self ) -> & Unordered Map< K , V > {
& self . new_pairs
}
}
@ -702,8 +711,8 @@ impl MutationSet {
impl Serializable for InnerNode {
fn write_into < W : ByteWriter > ( & self , target : & mut W ) {
self . left . write_into ( targe t) ;
self . right . write_into ( targe t) ;
target . write ( self . lef t) ;
target . write ( self . righ t) ;
}
}
@ -739,23 +748,57 @@ impl Deserializable for NodeMutation {
}
}
impl < const DEPTH : u8 , K : Serializable , V : Serializable > Serializable for MutationSet < DEPTH , K , V > {
impl < const DEPTH : u8 , K : Serializable + Eq + Hash , V : Serializable > Serializable
for MutationSet < DEPTH , K , V >
{
fn write_into < W : ByteWriter > ( & self , target : & mut W ) {
target . write ( self . old_root ) ;
target . write ( self . new_root ) ;
self . node_mutations . write_into ( target ) ;
self . new_pairs . write_into ( target ) ;
let inner_removals : Vec < _ > = self
. node_mutations
. iter ( )
. filter ( | ( _ , value ) | matches ! ( value , NodeMutation ::Removal ) )
. map ( | ( key , _ ) | key )
. collect ( ) ;
let inner_additions : Vec < _ > = self
. node_mutations
. iter ( )
. filter_map ( | ( key , value ) | match value {
NodeMutation ::Addition ( node ) = > Some ( ( key , node ) ) ,
_ = > None ,
} )
. collect ( ) ;
target . write ( inner_removals ) ;
target . write ( inner_additions ) ;
target . write_usize ( self . new_pairs . len ( ) ) ;
target . write_many ( & self . new_pairs ) ;
}
}
impl < const DEPTH : u8 , K : Deserializable + Ord , V : Deserializable > Deserializable
impl < const DEPTH : u8 , K : Deserializable + Ord + Eq + Hash , V : Deserializable > Deserializable
for MutationSet < DEPTH , K , V >
{
fn read_from < R : ByteReader > ( source : & mut R ) -> Result < Self , DeserializationError > {
let old_root = source . read ( ) ? ;
let new_root = source . read ( ) ? ;
let node_mutations = source . read ( ) ? ;
let new_pairs = source . read ( ) ? ;
let inner_removals : Vec < NodeIndex > = source . read ( ) ? ;
let inner_additions : Vec < ( NodeIndex , InnerNode ) > = source . read ( ) ? ;
let node_mutations = NodeMutations ::from_iter (
inner_removals . into_iter ( ) . map ( | index | ( index , NodeMutation ::Removal ) ) . chain (
inner_additions
. into_iter ( )
. map ( | ( index , node ) | ( index , NodeMutation ::Addition ( node ) ) ) ,
) ,
) ;
let num_new_pairs = source . read_usize ( ) ? ;
let new_pairs = source . read_many ( num_new_pairs ) ? ;
let new_pairs = UnorderedMap ::from_iter ( new_pairs ) ;
Ok ( Self {
old_root ,
@ -768,6 +811,7 @@ impl Deserializable
// SUBTREES
// ================================================================================================
/// A subtree is of depth 8.
const SUBTREE_DEPTH : u8 = 8 ;
@ -787,10 +831,10 @@ pub struct SubtreeLeaf {
}
/// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`].
#[ derive(Debug, Clone, PartialEq, Eq ) ]
#[ derive(Debug, Clone) ]
pub ( crate ) struct PairComputations < K , L > {
/// Literal leaves to be added to the sparse Merkle tree's internal mapping.
pub nodes : BTree Map< K , L > ,
pub nodes : Unordered Map< K , L > ,
/// "Conceptual" leaves that will be used for computations.
pub leaves : Vec < Vec < SubtreeLeaf > > ,
}
@ -818,7 +862,7 @@ impl<'s> SubtreeLeavesIter<'s> {
Self { leaves : leaves . drain ( . . ) . peekable ( ) }
}
}
impl core ::iter ::Iterator for SubtreeLeavesIter < '_ > {
impl Iterator for SubtreeLeavesIter < '_ > {
type Item = Vec < SubtreeLeaf > ;
/// Each `next()` collects an entire subtree.