You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

236 lines
7.5 KiB

use core::fmt::Display;
use super::{Felt, MerkleError, RpoDigest};
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
// NODE INDEX
// ================================================================================================
/// Address to an arbitrary node in a binary tree using level order form.
///
/// The position is represented by the pair `(depth, pos)`, where for a given depth `d` elements
/// are numbered from $0..(2^d)-1$. Example:
///
/// ```ignore
/// depth
/// 0 0
/// 1 0 1
/// 2 0 1 2 3
/// 3 0 1 2 3 4 5 6 7
/// ```
///
/// The root is represented by the pair $(0, 0)$, its left child is $(1, 0)$ and its right child
/// $(1, 1)$.
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct NodeIndex {
depth: u8,
value: u64,
}
impl NodeIndex {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Creates a new node index.
///
/// # Errors
/// Returns an error if the `value` is greater than or equal to 2^{depth}.
pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
if (64 - value.leading_zeros()) > depth as u32 {
Err(MerkleError::InvalidNodeIndex { depth, value })
} else {
Ok(Self { depth, value })
}
}
/// Creates a new node index without checking its validity.
pub const fn new_unchecked(depth: u8, value: u64) -> Self {
debug_assert!((64 - value.leading_zeros()) <= depth as u32);
Self { depth, value }
}
/// Creates a new node index for testing purposes.
///
/// # Panics
/// Panics if the `value` is greater than or equal to 2^{depth}.
#[cfg(test)]
pub fn make(depth: u8, value: u64) -> Self {
Self::new(depth, value).unwrap()
}
/// Creates a node index from a pair of field elements representing the depth and value.
///
/// # Errors
/// Returns an error if:
/// - `depth` doesn't fit in a `u8`.
/// - `value` is greater than or equal to 2^{depth}.
pub fn from_elements(depth: &Felt, value: &Felt) -> Result<Self, MerkleError> {
let depth = depth.as_int();
let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
let value = value.as_int();
Self::new(depth, value)
}
/// Creates a new node index pointing to the root of the tree.
pub const fn root() -> Self {
Self { depth: 0, value: 0 }
}
/// Computes sibling index of the current node.
pub const fn sibling(mut self) -> Self {
self.value ^= 1;
self
}
/// Returns left child index of the current node.
pub const fn left_child(mut self) -> Self {
self.depth += 1;
self.value <<= 1;
self
}
/// Returns right child index of the current node.
pub const fn right_child(mut self) -> Self {
self.depth += 1;
self.value = (self.value << 1) + 1;
self
}
// PROVIDERS
// --------------------------------------------------------------------------------------------
/// Builds a node to be used as input of a hash function when computing a Merkle path.
///
/// Will evaluate the parity of the current instance to define the result.
pub const fn build_node(&self, slf: RpoDigest, sibling: RpoDigest) -> [RpoDigest; 2] {
if self.is_value_odd() {
[sibling, slf]
} else {
[slf, sibling]
}
}
/// Returns the scalar representation of the depth/value pair.
///
/// It is computed as `2^depth + value`.
pub const fn to_scalar_index(&self) -> u64 {
(1 << self.depth as u64) + self.value
}
/// Returns the depth of the current instance.
pub const fn depth(&self) -> u8 {
self.depth
}
/// Returns the value of this index.
pub const fn value(&self) -> u64 {
self.value
}
/// Returns true if the current instance points to a right sibling node.
pub const fn is_value_odd(&self) -> bool {
(self.value & 1) == 1
}
/// Returns `true` if the depth is `0`.
pub const fn is_root(&self) -> bool {
self.depth == 0
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Traverses one level towards the root, decrementing the depth by `1`.
pub fn move_up(&mut self) {
self.depth = self.depth.saturating_sub(1);
self.value >>= 1;
}
/// Traverses towards the root until the specified depth is reached.
///
/// Assumes that the specified depth is smaller than the current depth.
pub fn move_up_to(&mut self, depth: u8) {
debug_assert!(depth < self.depth);
let delta = self.depth.saturating_sub(depth);
self.depth = self.depth.saturating_sub(delta);
self.value >>= delta as u32;
}
}
impl Display for NodeIndex {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "depth={}, value={}", self.depth, self.value)
}
}
impl Serializable for NodeIndex {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_u8(self.depth);
target.write_u64(self.value);
}
}
impl Deserializable for NodeIndex {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let depth = source.read_u8()?;
let value = source.read_u64()?;
NodeIndex::new(depth, value)
.map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use proptest::prelude::*;
use super::*;
#[test]
fn test_node_index_value_too_high() {
assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
let err = NodeIndex::new(0, 1).unwrap_err();
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
let err = NodeIndex::new(1, 2).unwrap_err();
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
let err = NodeIndex::new(2, 4).unwrap_err();
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
let err = NodeIndex::new(3, 8).unwrap_err();
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
}
#[test]
fn test_node_index_can_represent_depth_64() {
assert!(NodeIndex::new(64, u64::MAX).is_ok());
}
prop_compose! {
fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
// unwrap never panics because the range of depth is 0..u64::BITS
let mut depth = value.ilog2() as u8;
if value > (1 << depth) { // round up
depth += 1;
}
NodeIndex::new(depth, value).unwrap()
}
}
proptest! {
#[test]
fn arbitrary_index_wont_panic_on_move_up(
mut index in node_index(),
count in prop::num::u8::ANY,
) {
for _ in 0..count {
index.move_up();
}
}
}
}