use core::fmt::Display; use super::{Felt, MerkleError, RpoDigest}; use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; // NODE INDEX // ================================================================================================ /// Address to an arbitrary node in a binary tree using level order form. /// /// The position is represented by the pair `(depth, pos)`, where for a given depth `d` elements /// are numbered from $0..(2^d)-1$. Example: /// /// ```ignore /// depth /// 0 0 /// 1 0 1 /// 2 0 1 2 3 /// 3 0 1 2 3 4 5 6 7 /// ``` /// /// The root is represented by the pair $(0, 0)$, its left child is $(1, 0)$ and its right child /// $(1, 1)$. #[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct NodeIndex { depth: u8, value: u64, } impl NodeIndex { // CONSTRUCTORS // -------------------------------------------------------------------------------------------- /// Creates a new node index. /// /// # Errors /// Returns an error if the `value` is greater than or equal to 2^{depth}. pub const fn new(depth: u8, value: u64) -> Result { if (64 - value.leading_zeros()) > depth as u32 { Err(MerkleError::InvalidNodeIndex { depth, value }) } else { Ok(Self { depth, value }) } } /// Creates a new node index without checking its validity. pub const fn new_unchecked(depth: u8, value: u64) -> Self { debug_assert!((64 - value.leading_zeros()) <= depth as u32); Self { depth, value } } /// Creates a new node index for testing purposes. /// /// # Panics /// Panics if the `value` is greater than or equal to 2^{depth}. #[cfg(test)] pub fn make(depth: u8, value: u64) -> Self { Self::new(depth, value).unwrap() } /// Creates a node index from a pair of field elements representing the depth and value. /// /// # Errors /// Returns an error if: /// - `depth` doesn't fit in a `u8`. /// - `value` is greater than or equal to 2^{depth}. pub fn from_elements(depth: &Felt, value: &Felt) -> Result { let depth = depth.as_int(); let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?; let value = value.as_int(); Self::new(depth, value) } /// Creates a new node index pointing to the root of the tree. pub const fn root() -> Self { Self { depth: 0, value: 0 } } /// Computes sibling index of the current node. pub const fn sibling(mut self) -> Self { self.value ^= 1; self } /// Returns left child index of the current node. pub const fn left_child(mut self) -> Self { self.depth += 1; self.value <<= 1; self } /// Returns right child index of the current node. pub const fn right_child(mut self) -> Self { self.depth += 1; self.value = (self.value << 1) + 1; self } // PROVIDERS // -------------------------------------------------------------------------------------------- /// Builds a node to be used as input of a hash function when computing a Merkle path. /// /// Will evaluate the parity of the current instance to define the result. pub const fn build_node(&self, slf: RpoDigest, sibling: RpoDigest) -> [RpoDigest; 2] { if self.is_value_odd() { [sibling, slf] } else { [slf, sibling] } } /// Returns the scalar representation of the depth/value pair. /// /// It is computed as `2^depth + value`. pub const fn to_scalar_index(&self) -> u64 { (1 << self.depth as u64) + self.value } /// Returns the depth of the current instance. pub const fn depth(&self) -> u8 { self.depth } /// Returns the value of this index. pub const fn value(&self) -> u64 { self.value } /// Returns true if the current instance points to a right sibling node. pub const fn is_value_odd(&self) -> bool { (self.value & 1) == 1 } /// Returns `true` if the depth is `0`. pub const fn is_root(&self) -> bool { self.depth == 0 } // STATE MUTATORS // -------------------------------------------------------------------------------------------- /// Traverses one level towards the root, decrementing the depth by `1`. pub fn move_up(&mut self) { self.depth = self.depth.saturating_sub(1); self.value >>= 1; } /// Traverses towards the root until the specified depth is reached. /// /// Assumes that the specified depth is smaller than the current depth. pub fn move_up_to(&mut self, depth: u8) { debug_assert!(depth < self.depth); let delta = self.depth.saturating_sub(depth); self.depth = self.depth.saturating_sub(delta); self.value >>= delta as u32; } } impl Display for NodeIndex { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "depth={}, value={}", self.depth, self.value) } } impl Serializable for NodeIndex { fn write_into(&self, target: &mut W) { target.write_u8(self.depth); target.write_u64(self.value); } } impl Deserializable for NodeIndex { fn read_from(source: &mut R) -> Result { let depth = source.read_u8()?; let value = source.read_u64()?; NodeIndex::new(depth, value) .map_err(|_| DeserializationError::InvalidValue("Invalid index".into())) } } #[cfg(test)] mod tests { use assert_matches::assert_matches; use proptest::prelude::*; use super::*; #[test] fn test_node_index_value_too_high() { assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 }); let err = NodeIndex::new(0, 1).unwrap_err(); assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 }); assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 }); let err = NodeIndex::new(1, 2).unwrap_err(); assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 }); assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 }); let err = NodeIndex::new(2, 4).unwrap_err(); assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 }); assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 }); let err = NodeIndex::new(3, 8).unwrap_err(); assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 }); } #[test] fn test_node_index_can_represent_depth_64() { assert!(NodeIndex::new(64, u64::MAX).is_ok()); } prop_compose! { fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex { // unwrap never panics because the range of depth is 0..u64::BITS let mut depth = value.ilog2() as u8; if value > (1 << depth) { // round up depth += 1; } NodeIndex::new(depth, value).unwrap() } } proptest! { #[test] fn arbitrary_index_wont_panic_on_move_up( mut index in node_index(), count in prop::num::u8::ANY, ) { for _ in 0..count { index.move_up(); } } } }