refactor: optimize code, fix bugs

This commit is contained in:
Andrey Khmuro
2023-06-05 18:02:16 +03:00
parent 43f1a4cb64
commit 2708a23649
3 changed files with 93 additions and 131 deletions

View File

@@ -22,6 +22,7 @@ const EMPTY_DIGEST: RpoDigest = RpoDigest::new(EMPTY_WORD);
/// Tree allows to create Merkle Tree by providing Merkle paths of different lengths.
///
/// The root of the tree is recomputed on each new leaf update.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PartialMerkleTree {
max_depth: u8,
nodes: BTreeMap<NodeIndex, RpoDigest>,
@@ -112,12 +113,12 @@ impl PartialMerkleTree {
/// Returns a vector of paths from every leaf to the root.
pub fn paths(&self) -> Vec<(NodeIndex, ValuePath)> {
let mut paths = Vec::new();
self.leaves.iter().for_each(|leaf| {
self.leaves.iter().for_each(|&leaf| {
paths.push((
*leaf,
leaf,
ValuePath {
value: *self.get_node(*leaf).expect("Failed to get leaf node"),
path: self.get_path(*leaf).expect("Failed to get path"),
value: *self.get_node(leaf).expect("Failed to get leaf node"),
path: self.get_path(leaf).expect("Failed to get path"),
},
));
});
@@ -160,10 +161,10 @@ impl PartialMerkleTree {
/// Returns an iterator over the leaves of this [PartialMerkleTree].
pub fn leaves(&self) -> impl Iterator<Item = (NodeIndex, RpoDigest)> + '_ {
self.leaves.iter().map(|leaf| {
self.leaves.iter().map(|&leaf| {
(
*leaf,
self.get_node(*leaf).unwrap_or_else(|_| {
leaf,
self.get_node(leaf).unwrap_or_else(|_| {
panic!(
"Leaf with node index ({}, {}) is not in the nodes map",
leaf.depth(),
@@ -214,19 +215,25 @@ impl PartialMerkleTree {
self.nodes.insert(index_value, node);
// if the calculated node was a leaf, remove it from leaves set.
if self.leaves.contains(&index_value) {
self.leaves.remove(&index_value);
}
self.leaves.remove(&index_value);
let sibling_node = index_value.sibling();
// node became a leaf only if it is a new node (it wasn't in nodes map)
if !self.nodes.contains_key(&sibling_node) {
// Insert node from Merkle path to the nodes map. This sibling node becomes a leaf only
// if it is a new node (it wasn't in nodes map).
// Node can be in 3 states: internal node, leaf of the tree and not a node at all.
// - Internal node can only stay in this state -- addition of a new path can't make it
// a leaf or remove it from the tree.
// - Leaf node can stay in the same state (remain a leaf) or can become an internal
// node. In the first case we don't need to do anything, and the second case is handled
// in the line 219.
// - New node can be a calculated node or a "sibling" node from a Merkle Path:
// --- Calculated node, obviously, never can be a leaf.
// --- Sibling node can be only a leaf, because otherwise it is not a new node.
if self.nodes.insert(sibling_node, hash.into()).is_none() {
self.leaves.insert(sibling_node);
}
// insert node from Merkle path to the nodes map
self.nodes.insert(sibling_node, hash.into());
Rpo256::merge(&index_value.build_node(node, hash.into()))
});
@@ -238,8 +245,6 @@ impl PartialMerkleTree {
return Err(MerkleError::ConflictingRoots([*self.root(), *root].to_vec()));
}
// self.update_leaves()?;
Ok(())
}
@@ -250,7 +255,7 @@ impl PartialMerkleTree {
&mut self,
node_index: NodeIndex,
value: RpoDigest,
) -> Result<RpoDigest, MerkleError> {
) -> Result<Option<RpoDigest>, MerkleError> {
// check correctness of the depth and update it
Self::check_depth(node_index.depth())?;
self.update_depth(node_index.depth());
@@ -259,38 +264,19 @@ impl PartialMerkleTree {
self.leaves.insert(node_index);
// add node value to the nodes Map
let old_value = self.nodes.insert(node_index, value).unwrap_or(EMPTY_DIGEST);
let old_value = self.nodes.insert(node_index, value);
// if the old value and new value are the same, there is nothing to update
if value == old_value {
return Ok(value);
if old_value.is_some() && value == old_value.unwrap() {
return Ok(old_value);
}
let mut node_index = node_index;
let mut value = value;
for _ in 0..node_index.depth() {
let is_right = node_index.is_value_odd();
let (left, right) = if is_right {
let left_index = NodeIndex::new(node_index.depth(), node_index.value() - 1)?;
(
self.nodes
.get(&left_index)
.cloned()
.ok_or(MerkleError::NodeNotInSet(left_index))?,
value,
)
} else {
let right_index = NodeIndex::new(node_index.depth(), node_index.value() + 1)?;
(
value,
self.nodes
.get(&right_index)
.cloned()
.ok_or(MerkleError::NodeNotInSet(right_index))?,
)
};
let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist");
value = Rpo256::merge(&node_index.build_node(value, *sibling));
node_index.move_up();
value = Rpo256::merge(&[left, right]);
self.nodes.insert(node_index, value);
}