diff --git a/src/merkle/store/mod.rs b/src/merkle/store/mod.rs index d78be59..f77d558 100644 --- a/src/merkle/store/mod.rs +++ b/src/merkle/store/mod.rs @@ -3,7 +3,10 @@ use super::{ MerklePathSet, MerkleTree, NodeIndex, RecordingMap, RootPath, Rpo256, RpoDigest, SimpleSmt, TieredSmt, ValuePath, Vec, }; -use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +use crate::utils::{ + collections::{ApplyDiff, Diff, KvMapDiff}, + ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, +}; use core::borrow::Borrow; #[cfg(test)] @@ -474,6 +477,24 @@ impl> Extend for MerkleStore { } } +// DiffT & ApplyDiffT TRAIT IMPLEMENTATION +// ================================================================================================ +impl> Diff for MerkleStore { + type DiffType = KvMapDiff; + + fn diff(&self, other: &Self) -> Self::DiffType { + self.nodes.diff(&other.nodes) + } +} + +impl> ApplyDiff for MerkleStore { + type DiffType = KvMapDiff; + + fn apply(&mut self, diff: Self::DiffType) { + self.nodes.apply(diff); + } +} + // SERIALIZATION // ================================================================================================ diff --git a/src/utils/diff.rs b/src/utils/diff.rs new file mode 100644 index 0000000..48c80b6 --- /dev/null +++ b/src/utils/diff.rs @@ -0,0 +1,16 @@ +/// A trait for computing the difference between two objects. +pub trait Diff { + type DiffType; + + /// Returns a `Self::DiffType` object that represents the difference between this object and + /// other. + fn diff(&self, other: &Self) -> Self::DiffType; +} + +/// A trait for applying the difference between two objects. +pub trait ApplyDiff { + type DiffType; + + /// Applies the provided changes described by [DiffType] to the object implementing this trait. + fn apply(&mut self, diff: Self::DiffType); +} diff --git a/src/utils/kv_map.rs b/src/utils/kv_map.rs index d9b453d..063a0a0 100644 --- a/src/utils/kv_map.rs +++ b/src/utils/kv_map.rs @@ -1,3 +1,4 @@ +use super::{collections::ApplyDiff, diff::Diff}; use core::cell::RefCell; use winter_utils::{ collections::{btree_map::IntoIter, BTreeMap, BTreeSet}, @@ -18,6 +19,7 @@ pub trait KvMap: self.len() == 0 } fn insert(&mut self, key: K, value: V) -> Option; + fn remove(&mut self, key: &K) -> Option; fn iter(&self) -> Box + '_>; } @@ -42,6 +44,10 @@ impl KvMap for BTreeMap { self.insert(key, value) } + fn remove(&mut self, key: &K) -> Option { + self.remove(key) + } + fn iter(&self) -> Box + '_> { Box::new(self.iter()) } @@ -56,8 +62,9 @@ impl KvMap for BTreeMap { /// /// The [RecordingMap] is composed of three parts: /// - `data`: which contains the current set of key-value pairs in the map. -/// - `updates`: which tracks keys for which values have been since the map was instantiated. -/// updates include both insertions and updates of values under existing keys. +/// - `updates`: which tracks keys for which values have been changed since the map was +/// instantiated. updates include both insertions, removals and updates of values under existing +/// keys. /// - `trace`: which contains the key-value pairs from the original data which have been accesses /// since the map was instantiated. #[derive(Debug, Default, Clone, Eq, PartialEq)] @@ -80,6 +87,13 @@ impl RecordingMap { } } + // PUBLIC ACCESSORS + // -------------------------------------------------------------------------------------------- + + pub fn inner(&self) -> &BTreeMap { + &self.data + } + // FINALIZER // -------------------------------------------------------------------------------------------- @@ -148,6 +162,19 @@ impl KvMap for RecordingMap { }) } + /// Removes a key-value pair from the data set. + /// + /// If the key exists in the data set, the old value is returned. + fn remove(&mut self, key: &K) -> Option { + self.data.remove(key).map(|old_value| { + let new_update = self.updates.insert(key.clone()); + if new_update { + self.trace.borrow_mut().insert(key.clone(), old_value.clone()); + } + old_value + }) + } + // ITERATION // -------------------------------------------------------------------------------------------- @@ -180,6 +207,74 @@ impl IntoIterator for RecordingMap { } } +// KV MAP DIFF +// ================================================================================================ +/// [KvMapDiff] stores the difference between two key-value maps. +/// +/// The [KvMapDiff] is composed of two parts: +/// - `updates` - a map of key-value pairs that were updated in the second map compared to the +/// first map. This includes new key-value pairs. +/// - `removed` - a set of keys that were removed from the second map compared to the first map. +#[derive(Debug, Clone)] +pub struct KvMapDiff { + updated: BTreeMap, + removed: BTreeSet, +} + +impl KvMapDiff { + // CONSTRUCTOR + // -------------------------------------------------------------------------------------------- + /// Creates a new [KvMapDiff] instance. + pub fn new() -> Self { + KvMapDiff { + updated: BTreeMap::new(), + removed: BTreeSet::new(), + } + } +} + +impl Default for KvMapDiff { + fn default() -> Self { + Self::new() + } +} + +impl> Diff for T { + type DiffType = KvMapDiff; + + fn diff(&self, other: &T) -> Self::DiffType { + let mut diff = KvMapDiff::default(); + for (k, v) in self.iter() { + if let Some(other_value) = other.get(k) { + if v != other_value { + diff.updated.insert(k.clone(), other_value.clone()); + } + } else { + diff.removed.insert(k.clone()); + } + } + for (k, v) in other.iter() { + if self.get(k).is_none() { + diff.updated.insert(k.clone(), v.clone()); + } + } + diff + } +} + +impl> ApplyDiff for T { + type DiffType = KvMapDiff; + + fn apply(&mut self, diff: Self::DiffType) { + for (k, v) in diff.updated { + self.insert(k, v); + } + for k in diff.removed { + self.remove(&k); + } + } +} + // TESTS // ================================================================================================ @@ -321,4 +416,87 @@ mod tests { let map = RecordingMap::new(ITEMS.to_vec()); assert!(!map.is_empty()); } + + #[test] + fn test_remove() { + let mut map = RecordingMap::new(ITEMS.to_vec()); + + // remove an item that exists + let key = 0; + let value = map.remove(&key).unwrap(); + assert_eq!(value, ITEMS[0].1); + assert_eq!(map.len(), ITEMS.len() - 1); + assert_eq!(map.trace_len(), 1); + assert_eq!(map.updates_len(), 1); + + // add the item back and then remove it again + let key = 0; + let value = 0; + map.insert(key, value); + let value = map.remove(&key).unwrap(); + assert_eq!(value, 0); + assert_eq!(map.len(), ITEMS.len() - 1); + assert_eq!(map.trace_len(), 1); + assert_eq!(map.updates_len(), 1); + + // remove an item that does not exist + let key = 100; + let value = map.remove(&key); + assert_eq!(value, None); + assert_eq!(map.len(), ITEMS.len() - 1); + assert_eq!(map.trace_len(), 1); + assert_eq!(map.updates_len(), 1); + + // insert a new item and then remove it + let key = 100; + let value = 100; + map.insert(key, value); + let value = map.remove(&key).unwrap(); + assert_eq!(value, 100); + assert_eq!(map.len(), ITEMS.len() - 1); + assert_eq!(map.trace_len(), 1); + assert_eq!(map.updates_len(), 2); + + // convert the map into a proof + let proof = map.into_proof(); + + // check that the proof contains the expected values + for (key, value) in ITEMS.iter() { + match key { + 0 => assert_eq!(proof.get(key), Some(value)), + _ => assert_eq!(proof.get(key), None), + } + } + } + + #[test] + fn test_kv_map_diff() { + let mut initial_state = ITEMS.into_iter().collect::>(); + let mut map = RecordingMap::new(initial_state.clone()); + + // remove an item that exists + let key = 0; + let _value = map.remove(&key).unwrap(); + + // add a new item + let key = 100; + let value = 100; + map.insert(key, value); + + // update an existing item + let key = 1; + let value = 100; + map.insert(key, value); + + // compute a diff + let diff = initial_state.diff(map.inner()); + assert!(diff.updated.len() == 2); + assert!(diff.updated.iter().all(|(k, v)| [(100, 100), (1, 100)].contains(&(*k, *v)))); + assert!(diff.removed.len() == 1); + assert!(diff.removed.first() == Some(&0)); + + // apply the diff to the initial state and assert the contents are the same as the map + initial_state.apply(diff); + assert!(initial_state.iter().eq(map.iter())); + } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 8059d26..d71cd33 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -7,6 +7,7 @@ pub use alloc::{format, vec}; #[cfg(feature = "std")] pub use std::{format, vec}; +mod diff; mod kv_map; // RE-EXPORTS @@ -17,6 +18,7 @@ pub use winter_utils::{ }; pub mod collections { + pub use super::diff::*; pub use super::kv_map::*; pub use winter_utils::collections::*; }