use core::fmt::Display;
|
|
|
|
use super::{Felt, MerkleError, RpoDigest};
|
|
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
|
|
|
// NODE INDEX
|
|
// ================================================================================================
|
|
|
|
/// Address to an arbitrary node in a binary tree using level order form.
|
|
///
|
|
/// The position is represented by the pair `(depth, pos)`, where for a given depth `d` elements
|
|
/// are numbered from $0..(2^d)-1$. Example:
|
|
///
|
|
/// ```ignore
|
|
/// depth
|
|
/// 0 0
|
|
/// 1 0 1
|
|
/// 2 0 1 2 3
|
|
/// 3 0 1 2 3 4 5 6 7
|
|
/// ```
|
|
///
|
|
/// The root is represented by the pair $(0, 0)$, its left child is $(1, 0)$ and its right child
|
|
/// $(1, 1)$.
|
|
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
|
|
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
|
pub struct NodeIndex {
|
|
depth: u8,
|
|
value: u64,
|
|
}
|
|
|
|
impl NodeIndex {
|
|
// CONSTRUCTORS
|
|
// --------------------------------------------------------------------------------------------
|
|
|
|
/// Creates a new node index.
|
|
///
|
|
/// # Errors
|
|
/// Returns an error if the `value` is greater than or equal to 2^{depth}.
|
|
pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
|
|
if (64 - value.leading_zeros()) > depth as u32 {
|
|
Err(MerkleError::InvalidNodeIndex { depth, value })
|
|
} else {
|
|
Ok(Self { depth, value })
|
|
}
|
|
}
|
|
|
|
/// Creates a new node index without checking its validity.
|
|
pub const fn new_unchecked(depth: u8, value: u64) -> Self {
|
|
debug_assert!((64 - value.leading_zeros()) <= depth as u32);
|
|
Self { depth, value }
|
|
}
|
|
|
|
/// Creates a new node index for testing purposes.
|
|
///
|
|
/// # Panics
|
|
/// Panics if the `value` is greater than or equal to 2^{depth}.
|
|
#[cfg(test)]
|
|
pub fn make(depth: u8, value: u64) -> Self {
|
|
Self::new(depth, value).unwrap()
|
|
}
|
|
|
|
/// Creates a node index from a pair of field elements representing the depth and value.
|
|
///
|
|
/// # Errors
|
|
/// Returns an error if:
|
|
/// - `depth` doesn't fit in a `u8`.
|
|
/// - `value` is greater than or equal to 2^{depth}.
|
|
pub fn from_elements(depth: &Felt, value: &Felt) -> Result<Self, MerkleError> {
|
|
let depth = depth.as_int();
|
|
let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
|
|
let value = value.as_int();
|
|
Self::new(depth, value)
|
|
}
|
|
|
|
/// Creates a new node index pointing to the root of the tree.
|
|
pub const fn root() -> Self {
|
|
Self { depth: 0, value: 0 }
|
|
}
|
|
|
|
/// Computes sibling index of the current node.
|
|
pub const fn sibling(mut self) -> Self {
|
|
self.value ^= 1;
|
|
self
|
|
}
|
|
|
|
/// Returns left child index of the current node.
|
|
pub const fn left_child(mut self) -> Self {
|
|
self.depth += 1;
|
|
self.value <<= 1;
|
|
self
|
|
}
|
|
|
|
/// Returns right child index of the current node.
|
|
pub const fn right_child(mut self) -> Self {
|
|
self.depth += 1;
|
|
self.value = (self.value << 1) + 1;
|
|
self
|
|
}
|
|
|
|
// PROVIDERS
|
|
// --------------------------------------------------------------------------------------------
|
|
|
|
/// Builds a node to be used as input of a hash function when computing a Merkle path.
|
|
///
|
|
/// Will evaluate the parity of the current instance to define the result.
|
|
pub const fn build_node(&self, slf: RpoDigest, sibling: RpoDigest) -> [RpoDigest; 2] {
|
|
if self.is_value_odd() {
|
|
[sibling, slf]
|
|
} else {
|
|
[slf, sibling]
|
|
}
|
|
}
|
|
|
|
/// Returns the scalar representation of the depth/value pair.
|
|
///
|
|
/// It is computed as `2^depth + value`.
|
|
pub const fn to_scalar_index(&self) -> u64 {
|
|
(1 << self.depth as u64) + self.value
|
|
}
|
|
|
|
/// Returns the depth of the current instance.
|
|
pub const fn depth(&self) -> u8 {
|
|
self.depth
|
|
}
|
|
|
|
/// Returns the value of this index.
|
|
pub const fn value(&self) -> u64 {
|
|
self.value
|
|
}
|
|
|
|
/// Returns true if the current instance points to a right sibling node.
|
|
pub const fn is_value_odd(&self) -> bool {
|
|
(self.value & 1) == 1
|
|
}
|
|
|
|
/// Returns `true` if the depth is `0`.
|
|
pub const fn is_root(&self) -> bool {
|
|
self.depth == 0
|
|
}
|
|
|
|
// STATE MUTATORS
|
|
// --------------------------------------------------------------------------------------------
|
|
|
|
/// Traverses one level towards the root, decrementing the depth by `1`.
|
|
pub fn move_up(&mut self) {
|
|
self.depth = self.depth.saturating_sub(1);
|
|
self.value >>= 1;
|
|
}
|
|
|
|
/// Traverses towards the root until the specified depth is reached.
|
|
///
|
|
/// Assumes that the specified depth is smaller than the current depth.
|
|
pub fn move_up_to(&mut self, depth: u8) {
|
|
debug_assert!(depth < self.depth);
|
|
let delta = self.depth.saturating_sub(depth);
|
|
self.depth = self.depth.saturating_sub(delta);
|
|
self.value >>= delta as u32;
|
|
}
|
|
}
|
|
|
|
impl Display for NodeIndex {
|
|
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
|
write!(f, "depth={}, value={}", self.depth, self.value)
|
|
}
|
|
}
|
|
|
|
impl Serializable for NodeIndex {
|
|
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
|
target.write_u8(self.depth);
|
|
target.write_u64(self.value);
|
|
}
|
|
}
|
|
|
|
impl Deserializable for NodeIndex {
|
|
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
|
let depth = source.read_u8()?;
|
|
let value = source.read_u64()?;
|
|
NodeIndex::new(depth, value)
|
|
.map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use assert_matches::assert_matches;
|
|
use proptest::prelude::*;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_node_index_value_too_high() {
|
|
assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
|
|
let err = NodeIndex::new(0, 1).unwrap_err();
|
|
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
|
|
|
|
assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
|
|
let err = NodeIndex::new(1, 2).unwrap_err();
|
|
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
|
|
|
|
assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
|
|
let err = NodeIndex::new(2, 4).unwrap_err();
|
|
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
|
|
|
|
assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
|
|
let err = NodeIndex::new(3, 8).unwrap_err();
|
|
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
|
|
}
|
|
|
|
#[test]
|
|
fn test_node_index_can_represent_depth_64() {
|
|
assert!(NodeIndex::new(64, u64::MAX).is_ok());
|
|
}
|
|
|
|
prop_compose! {
|
|
fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
|
|
// unwrap never panics because the range of depth is 0..u64::BITS
|
|
let mut depth = value.ilog2() as u8;
|
|
if value > (1 << depth) { // round up
|
|
depth += 1;
|
|
}
|
|
NodeIndex::new(depth, value).unwrap()
|
|
}
|
|
}
|
|
|
|
proptest! {
|
|
#[test]
|
|
fn arbitrary_index_wont_panic_on_move_up(
|
|
mut index in node_index(),
|
|
count in prop::num::u8::ANY,
|
|
) {
|
|
for _ in 0..count {
|
|
index.move_up();
|
|
}
|
|
}
|
|
}
|
|
}
|