refactor: refactor crypto APIs to use RpoDigest instead of Word

This commit is contained in:
tohrnii
2023-06-09 21:18:13 +01:00
parent 59f7723221
commit fe9aa8c28c
16 changed files with 590 additions and 376 deletions

View File

@@ -1,10 +1,5 @@
use super::{
Felt, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word,
};
use crate::{
utils::{string::String, uninit_vector, word_to_hex},
FieldElement,
};
use super::{InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word};
use crate::utils::{string::String, uninit_vector, word_to_hex};
use core::{fmt, slice};
use winter_math::log2;
@@ -14,7 +9,7 @@ use winter_math::log2;
/// A fully-balanced binary Merkle tree (i.e., a tree where the number of leaves is a power of two).
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MerkleTree {
nodes: Vec<Word>,
nodes: Vec<RpoDigest>,
}
impl MerkleTree {
@@ -34,10 +29,12 @@ impl MerkleTree {
// create un-initialized vector to hold all tree nodes
let mut nodes = unsafe { uninit_vector(2 * n) };
nodes[0] = [Felt::ZERO; 4];
nodes[0] = RpoDigest::default();
// copy leaves into the second part of the nodes vector
nodes[n..].copy_from_slice(&leaves);
nodes[n..].iter_mut().zip(leaves).for_each(|(node, leaf)| {
*node = RpoDigest::from(leaf);
});
// re-interpret nodes as an array of two nodes fused together
// Safety: `nodes` will never move here as it is not bound to an external lifetime (i.e.
@@ -47,7 +44,7 @@ impl MerkleTree {
// calculate all internal tree nodes
for i in (1..n).rev() {
nodes[i] = Rpo256::merge(&pairs[i]).into();
nodes[i] = Rpo256::merge(&pairs[i]);
}
Ok(Self { nodes })
@@ -57,7 +54,7 @@ impl MerkleTree {
// --------------------------------------------------------------------------------------------
/// Returns the root of this Merkle tree.
pub fn root(&self) -> Word {
pub fn root(&self) -> RpoDigest {
self.nodes[1]
}
@@ -74,7 +71,7 @@ impl MerkleTree {
/// Returns an error if:
/// * The specified depth is greater than the depth of the tree.
/// * The specified index is not valid for the specified depth.
pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
if index.is_root() {
return Err(MerkleError::DepthTooSmall(index.depth()));
} else if index.depth() > self.depth() {
@@ -120,7 +117,7 @@ impl MerkleTree {
/// Returns an iterator over the leaves of this [MerkleTree].
pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
let leaves_start = self.nodes.len() / 2;
self.nodes.iter().skip(leaves_start).enumerate().map(|(i, v)| (i as u64, v))
self.nodes.iter().skip(leaves_start).enumerate().map(|(i, v)| (i as u64, &**v))
}
/// Returns n iterator over every inner node of this [MerkleTree].
@@ -159,13 +156,13 @@ impl MerkleTree {
// update the current node
let pos = index.to_scalar_index() as usize;
self.nodes[pos] = value;
self.nodes[pos] = value.into();
// traverse to the root, updating each node with the merged values of its parents
for _ in 0..index.depth() {
index.move_up();
let pos = index.to_scalar_index() as usize;
let value = Rpo256::merge(&pairs[pos]).into();
let value = Rpo256::merge(&pairs[pos]);
self.nodes[pos] = value;
}
@@ -180,7 +177,7 @@ impl MerkleTree {
///
/// Use this to extract the data of the tree, there is no guarantee on the order of the elements.
pub struct InnerNodeIterator<'a> {
nodes: &'a Vec<Word>,
nodes: &'a Vec<RpoDigest>,
index: usize,
}
@@ -258,21 +255,25 @@ pub fn path_to_text(path: &MerklePath) -> Result<String, fmt::Error> {
#[cfg(test)]
mod tests {
use super::*;
use crate::merkle::{int_to_node, InnerNodeInfo};
use crate::{
merkle::{int_to_leaf, InnerNodeInfo},
Felt, Word, WORD_SIZE,
};
use core::mem::size_of;
use proptest::prelude::*;
const LEAVES4: [Word; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
const LEAVES4: [Word; WORD_SIZE] =
[int_to_leaf(1), int_to_leaf(2), int_to_leaf(3), int_to_leaf(4)];
const LEAVES8: [Word; 8] = [
int_to_node(1),
int_to_node(2),
int_to_node(3),
int_to_node(4),
int_to_node(5),
int_to_node(6),
int_to_node(7),
int_to_node(8),
int_to_leaf(1),
int_to_leaf(2),
int_to_leaf(3),
int_to_leaf(4),
int_to_leaf(5),
int_to_leaf(6),
int_to_leaf(7),
int_to_leaf(8),
];
#[test]
@@ -282,7 +283,7 @@ mod tests {
// leaves were copied correctly
for (a, b) in tree.nodes.iter().skip(4).zip(LEAVES4.iter()) {
assert_eq!(a, b);
assert_eq!(*a, RpoDigest::from(*b));
}
let (root, node2, node3) = compute_internal_nodes();
@@ -299,10 +300,10 @@ mod tests {
let tree = super::MerkleTree::new(LEAVES4.to_vec()).unwrap();
// check depth 2
assert_eq!(LEAVES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
assert_eq!(LEAVES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
assert_eq!(LEAVES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
assert_eq!(LEAVES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
assert_eq!(RpoDigest::from(LEAVES4[0]), tree.get_node(NodeIndex::make(2, 0)).unwrap());
assert_eq!(RpoDigest::from(LEAVES4[1]), tree.get_node(NodeIndex::make(2, 1)).unwrap());
assert_eq!(RpoDigest::from(LEAVES4[2]), tree.get_node(NodeIndex::make(2, 2)).unwrap());
assert_eq!(RpoDigest::from(LEAVES4[3]), tree.get_node(NodeIndex::make(2, 3)).unwrap());
// check depth 1
let (_, node2, node3) = compute_internal_nodes();
@@ -318,10 +319,22 @@ mod tests {
let (_, node2, node3) = compute_internal_nodes();
// check depth 2
assert_eq!(vec![LEAVES4[1], node3], *tree.get_path(NodeIndex::make(2, 0)).unwrap());
assert_eq!(vec![LEAVES4[0], node3], *tree.get_path(NodeIndex::make(2, 1)).unwrap());
assert_eq!(vec![LEAVES4[3], node2], *tree.get_path(NodeIndex::make(2, 2)).unwrap());
assert_eq!(vec![LEAVES4[2], node2], *tree.get_path(NodeIndex::make(2, 3)).unwrap());
assert_eq!(
vec![RpoDigest::from(LEAVES4[1]), node3],
*tree.get_path(NodeIndex::make(2, 0)).unwrap()
);
assert_eq!(
vec![RpoDigest::from(LEAVES4[0]), node3],
*tree.get_path(NodeIndex::make(2, 1)).unwrap()
);
assert_eq!(
vec![RpoDigest::from(LEAVES4[3]), node2],
*tree.get_path(NodeIndex::make(2, 2)).unwrap()
);
assert_eq!(
vec![RpoDigest::from(LEAVES4[2]), node2],
*tree.get_path(NodeIndex::make(2, 3)).unwrap()
);
// check depth 1
assert_eq!(vec![node3], *tree.get_path(NodeIndex::make(1, 0)).unwrap());
@@ -334,7 +347,7 @@ mod tests {
// update one leaf
let value = 3;
let new_node = int_to_node(9);
let new_node = int_to_leaf(9);
let mut expected_leaves = LEAVES8.to_vec();
expected_leaves[value as usize] = new_node;
let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
@@ -344,7 +357,7 @@ mod tests {
// update another leaf
let value = 6;
let new_node = int_to_node(10);
let new_node = int_to_leaf(10);
expected_leaves[value as usize] = new_node;
let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
@@ -417,11 +430,13 @@ mod tests {
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn compute_internal_nodes() -> (Word, Word, Word) {
let node2 = Rpo256::hash_elements(&[LEAVES4[0], LEAVES4[1]].concat());
let node3 = Rpo256::hash_elements(&[LEAVES4[2], LEAVES4[3]].concat());
fn compute_internal_nodes() -> (RpoDigest, RpoDigest, RpoDigest) {
let node2 =
Rpo256::hash_elements(&[Word::from(LEAVES4[0]), Word::from(LEAVES4[1])].concat());
let node3 =
Rpo256::hash_elements(&[Word::from(LEAVES4[2]), Word::from(LEAVES4[3])].concat());
let root = Rpo256::merge(&[node2, node3]);
(root.into(), node2.into(), node3.into())
(root, node2, node3)
}
}