Browse Source

feature: add conditional support for serde

al-gkr-basic-workflow
Augusto F. Hack 1 year ago
parent
commit
8cf5e9fd2c
No known key found for this signature in database GPG Key ID: 3F3584B7FB1DFB76
18 changed files with 204 additions and 14 deletions
  1. +2
    -0
      Cargo.toml
  2. +20
    -1
      src/hash/blake/mod.rs
  3. +87
    -12
      src/hash/rpo/digest.rs
  4. +3
    -0
      src/merkle/delta.rs
  5. +1
    -0
      src/merkle/index.rs
  6. +1
    -0
      src/merkle/merkle_tree.rs
  7. +1
    -0
      src/merkle/mmr/accumulator.rs
  8. +1
    -0
      src/merkle/mmr/full.rs
  9. +1
    -0
      src/merkle/mmr/proof.rs
  10. +1
    -0
      src/merkle/node.rs
  11. +1
    -0
      src/merkle/partial_mt/mod.rs
  12. +1
    -0
      src/merkle/path.rs
  13. +2
    -0
      src/merkle/simple_smt/mod.rs
  14. +2
    -0
      src/merkle/store/mod.rs
  15. +1
    -0
      src/merkle/tiered_smt/mod.rs
  16. +1
    -0
      src/merkle/tiered_smt/nodes.rs
  17. +2
    -0
      src/merkle/tiered_smt/values.rs
  18. +76
    -1
      src/utils/mod.rs

+ 2
- 0
Cargo.toml

@ -27,12 +27,14 @@ harness = false
[features]
default = ["blake3/default", "std", "winter_crypto/default", "winter_math/default", "winter_utils/default"]
std = ["blake3/std", "winter_crypto/std", "winter_math/std", "winter_utils/std"]
serde = ["winter_math/serde", "dep:serde", "serde/alloc"]
[dependencies]
blake3 = { version = "1.4", default-features = false }
winter_crypto = { version = "0.6", package = "winter-crypto", default-features = false }
winter_math = { version = "0.6", package = "winter-math", default-features = false }
winter_utils = { version = "0.6", package = "winter-utils", default-features = false }
serde = { version = "1.0", features = [ "derive" ], optional = true, default-features = false }
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }

+ 20
- 1
src/hash/blake/mod.rs

@ -1,5 +1,8 @@
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField};
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use crate::utils::{
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
DeserializationError, HexParseError, Serializable,
};
use core::{
mem::{size_of, transmute, transmute_copy},
ops::Deref,
@ -24,6 +27,8 @@ const DIGEST20_BYTES: usize = 20;
/// Note: `N` can't be greater than `32` because [`Digest::as_bytes`] currently supports only 32
/// bytes.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
pub struct Blake3Digest<const N: usize>([u8; N]);
impl<const N: usize> Default for Blake3Digest<N> {
@ -52,6 +57,20 @@ impl From<[u8; N]> for Blake3Digest {
}
}
impl<const N: usize> From<Blake3Digest<N>> for String {
fn from(value: Blake3Digest<N>) -> Self {
bytes_to_hex_string(value.as_bytes())
}
}
impl<const N: usize> TryFrom<&str> for Blake3Digest<N> {
type Error = HexParseError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
hex_to_bytes(value).map(|v| v.into())
}
}
impl<const N: usize> Serializable for Blake3Digest<N> {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_bytes(&self.0);

+ 87
- 12
src/hash/rpo/digest.rs

@ -1,13 +1,19 @@
use super::{Digest, Felt, StarkField, DIGEST_SIZE, ZERO};
use crate::utils::{
string::String, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
DeserializationError, HexParseError, Serializable,
};
use core::{cmp::Ordering, fmt::Display, ops::Deref};
/// The number of bytes needed to encoded a digest
pub const DIGEST_BYTES: usize = 32;
// DIGEST TRAIT IMPLEMENTATIONS
// ================================================================================================
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
pub struct RpoDigest([Felt; DIGEST_SIZE]);
impl RpoDigest {
@ -19,7 +25,7 @@ impl RpoDigest {
self.as_ref()
}
pub fn as_bytes(&self) -> [u8; 32] {
pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
<Self as Digest>::as_bytes(self)
}
@ -32,8 +38,8 @@ impl RpoDigest {
}
impl Digest for RpoDigest {
fn as_bytes(&self) -> [u8; 32] {
let mut result = [0; 32];
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
let mut result = [0; DIGEST_BYTES];
result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes());
result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes());
@ -107,18 +113,73 @@ impl From for [u64; DIGEST_SIZE] {
}
}
impl From<&RpoDigest> for [u8; 32] {
impl From<&RpoDigest> for [u8; DIGEST_BYTES] {
fn from(value: &RpoDigest) -> Self {
value.as_bytes()
}
}
impl From<RpoDigest> for [u8; 32] {
impl From<RpoDigest> for [u8; DIGEST_BYTES] {
fn from(value: RpoDigest) -> Self {
value.as_bytes()
}
}
impl From<RpoDigest> for String {
fn from(value: RpoDigest) -> Self {
bytes_to_hex_string(value.as_bytes())
}
}
impl From<&RpoDigest> for String {
fn from(value: &RpoDigest) -> Self {
(*value).into()
}
}
impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest {
type Error = HexParseError;
fn try_from(value: [u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
// Note: the input length is known, the conversion from slice to array must succeed so the
// `unwrap`s below are safe
let a = u64::from_le_bytes(value[0..8].try_into().unwrap());
let b = u64::from_le_bytes(value[8..16].try_into().unwrap());
let c = u64::from_le_bytes(value[16..24].try_into().unwrap());
let d = u64::from_le_bytes(value[24..32].try_into().unwrap());
if [a, b, c, d].iter().any(|v| *v >= Felt::MODULUS) {
return Err(HexParseError::OutOfRange);
}
Ok(RpoDigest([Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]))
}
}
impl TryFrom<&str> for RpoDigest {
type Error = HexParseError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
hex_to_bytes(value).and_then(|v| v.try_into())
}
}
impl TryFrom<String> for RpoDigest {
type Error = HexParseError;
fn try_from(value: String) -> Result<Self, Self::Error> {
value.as_str().try_into()
}
}
impl TryFrom<&String> for RpoDigest {
type Error = HexParseError;
fn try_from(value: &String) -> Result<Self, Self::Error> {
value.as_str().try_into()
}
}
impl Deref for RpoDigest {
type Target = [Felt; DIGEST_SIZE];
@ -158,9 +219,8 @@ impl PartialOrd for RpoDigest {
impl Display for RpoDigest {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
for byte in self.as_bytes() {
write!(f, "{byte:02x}")?;
}
let encoded: String = self.into();
write!(f, "{}", encoded)?;
Ok(())
}
}
@ -170,8 +230,7 @@ impl Display for RpoDigest {
#[cfg(test)]
mod tests {
use super::{Deserializable, Felt, RpoDigest, Serializable};
use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES};
use crate::utils::SliceReader;
use rand_utils::rand_value;
@ -186,11 +245,27 @@ mod tests {
let mut bytes = vec![];
d1.write_into(&mut bytes);
assert_eq!(32, bytes.len());
assert_eq!(DIGEST_BYTES, bytes.len());
let mut reader = SliceReader::new(&bytes);
let d2 = RpoDigest::read_from(&mut reader).unwrap();
assert_eq!(d1, d2);
}
#[cfg(feature = "std")]
#[test]
fn digest_encoding() {
let digest = RpoDigest([
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
]);
let string: String = digest.into();
let round_trip: RpoDigest = string.try_into().expect("decoding failed");
assert_eq!(digest, round_trip);
}
}

+ 3
- 0
src/merkle/delta.rs

@ -13,6 +13,7 @@ use super::{empty_roots::EMPTY_WORD, Felt, SimpleSmt};
/// [RpoDigest] represents the root of the Merkle tree and [MerkleTreeDelta] represents the
/// differences between the initial and final Merkle tree states.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>);
// MERKLE TREE DELTA
@ -26,6 +27,7 @@ pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>);
/// - updated_slots: index-value pairs of slots where values were set to non [ZERO; 4] values.
#[cfg(not(test))]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MerkleTreeDelta {
depth: u8,
cleared_slots: Vec<u64>,
@ -107,6 +109,7 @@ pub fn merkle_tree_delta>(
// --------------------------------------------------------------------------------------------
#[cfg(test)]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MerkleTreeDelta {
pub depth: u8,
pub cleared_slots: Vec<u64>,

+ 1
- 0
src/merkle/index.rs

@ -21,6 +21,7 @@ use core::fmt::Display;
/// The root is represented by the pair $(0, 0)$, its left child is $(1, 0)$ and its right child
/// $(1, 1)$.
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct NodeIndex {
depth: u8,
value: u64,

+ 1
- 0
src/merkle/merkle_tree.rs

@ -8,6 +8,7 @@ use winter_math::log2;
/// A fully-balanced binary Merkle tree (i.e., a tree where the number of leaves is a power of two).
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MerkleTree {
nodes: Vec<RpoDigest>,
}

+ 1
- 0
src/merkle/mmr/accumulator.rs

@ -4,6 +4,7 @@ use super::{
};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MmrPeaks {
/// The number of leaves is used to differentiate accumulators that have the same number of
/// peaks. This happens because the number of peaks goes up-and-down as the structure is used

+ 1
- 0
src/merkle/mmr/full.rs

@ -29,6 +29,7 @@ use std::error::Error;
/// Since this is a full representation of the MMR, elements are never removed and the MMR will
/// grow roughly `O(2n)` in number of leaf elements.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Mmr {
/// Refer to the `forest` method documentation for details of the semantics of this value.
pub(super) forest: usize,

+ 1
- 0
src/merkle/mmr/proof.rs

@ -3,6 +3,7 @@ use super::super::MerklePath;
use super::full::{high_bitmask, leaf_to_corresponding_tree};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MmrProof {
/// The state of the MMR when the MmrProof was created.
pub forest: usize,

+ 1
- 0
src/merkle/node.rs

@ -2,6 +2,7 @@ use crate::hash::rpo::RpoDigest;
/// Representation of a node with two children used for iterating over containers.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct InnerNodeInfo {
pub value: RpoDigest,
pub left: RpoDigest,

+ 1
- 0
src/merkle/partial_mt/mod.rs

@ -28,6 +28,7 @@ const EMPTY_DIGEST: RpoDigest = RpoDigest::new([ZERO; 4]);
///
/// The root of the tree is recomputed on each new leaf update.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct PartialMerkleTree {
max_depth: u8,
nodes: BTreeMap<NodeIndex, RpoDigest>,

+ 1
- 0
src/merkle/path.rs

@ -6,6 +6,7 @@ use core::ops::{Deref, DerefMut};
/// A merkle path container, composed of a sequence of nodes of a Merkle tree.
#[derive(Clone, Debug, Default, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MerklePath {
nodes: Vec<RpoDigest>,
}

+ 2
- 0
src/merkle/simple_smt/mod.rs

@ -13,6 +13,7 @@ mod tests;
///
/// The root of the tree is recomputed on each new leaf update.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct SimpleSmt {
depth: u8,
root: RpoDigest,
@ -265,6 +266,7 @@ impl SimpleSmt {
// ================================================================================================
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
struct BranchNode {
left: RpoDigest,
right: RpoDigest,

+ 2
- 0
src/merkle/store/mod.rs

@ -19,6 +19,7 @@ pub type DefaultMerkleStore = MerkleStore>;
pub type RecordingMerkleStore = MerkleStore<RecordingMap<RpoDigest, StoreNode>>;
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct StoreNode {
left: RpoDigest,
right: RpoDigest,
@ -87,6 +88,7 @@ pub struct StoreNode {
/// assert_eq!(store.num_internal_nodes() - 255, 10);
/// ```
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MerkleStore<T: KvMap<RpoDigest, StoreNode> = BTreeMap<RpoDigest, StoreNode>> {
nodes: T,
}

+ 1
- 0
src/merkle/tiered_smt/mod.rs

@ -43,6 +43,7 @@ mod tests;
/// - Leaf node at depths 16, 32, or 64: hash(key, value, domain=depth).
/// - Leaf node at depth 64: hash([key_0, value_0, ..., key_n, value_n], domain=64).
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct TieredSmt {
root: RpoDigest,
nodes: NodeStore,

+ 1
- 0
src/merkle/tiered_smt/nodes.rs

@ -24,6 +24,7 @@ const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
/// represent leaf nodes in a Tiered Sparse Merkle tree. In the current implementation, [BTreeSet]s
/// are used to determine the position of the leaves in the tree.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct NodeStore {
nodes: BTreeMap<NodeIndex, RpoDigest>,
upper_leaves: BTreeSet<NodeIndex>,

+ 2
- 0
src/merkle/tiered_smt/values.rs

@ -26,6 +26,7 @@ const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
/// The store supports lookup by the full key (i.e. [RpoDigest]) as well as by the 64-bit key
/// prefix.
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct ValueStore {
values: BTreeMap<u64, StoreEntry>,
}
@ -173,6 +174,7 @@ impl ValueStore {
/// An entry can contain either a single key-value pair or a vector of key-value pairs sorted by
/// key.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum StoreEntry {
Single((RpoDigest, Word)),
List(Vec<(RpoDigest, Word)>),

+ 76
- 1
src/utils/mod.rs

@ -1,5 +1,5 @@
use super::{utils::string::String, Word};
use core::fmt::{self, Write};
use core::fmt::{self, Display, Write};
#[cfg(not(feature = "std"))]
pub use alloc::{format, vec};
@ -36,3 +36,78 @@ pub fn word_to_hex(w: &Word) -> Result {
Ok(s)
}
/// Renders an array of bytes as hex into a String.
pub fn bytes_to_hex_string<const N: usize>(data: [u8; N]) -> String {
let mut s = String::with_capacity(N + 2);
s.push_str("0x");
for byte in data.iter() {
write!(s, "{byte:02x}").expect("formatting hex failed");
}
s
}
#[derive(Debug)]
pub enum HexParseError {
InvalidLength { expected: usize, got: usize },
MissingPrefix,
InvalidChar,
OutOfRange,
}
impl Display for HexParseError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
HexParseError::InvalidLength { expected, got } => {
write!(f, "Hex encoded RpoDigest must have length 66, including the 0x prefix. expected {expected} got {got}")
}
HexParseError::MissingPrefix => {
write!(f, "Hex encoded RpoDigest must start with 0x prefix")
}
HexParseError::InvalidChar => {
write!(f, "Hex encoded RpoDigest must contain characters [a-zA-Z0-9]")
}
HexParseError::OutOfRange => {
write!(f, "Hex encoded values of an RpoDigest must be inside the field modulus")
}
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for HexParseError {}
/// Parses a hex string into an array of bytes of known size.
pub fn hex_to_bytes<const N: usize>(value: &str) -> Result<[u8; N], HexParseError> {
let expected: usize = (N * 2) + 2;
if value.len() != expected {
return Err(HexParseError::InvalidLength {
expected,
got: value.len(),
});
}
if !value.starts_with("0x") {
return Err(HexParseError::MissingPrefix);
}
let mut data = value.bytes().skip(2).map(|v| match v {
b'0'..=b'9' => Ok(v - b'0'),
b'a'..=b'f' => Ok(v - b'a' + 10),
b'A'..=b'F' => Ok(v - b'A' + 10),
_ => Err(HexParseError::InvalidChar),
});
let mut decoded = [0u8; N];
#[allow(clippy::needless_range_loop)]
for pos in 0..N {
// These `unwrap` calls are okay because the length was checked above
let high: u8 = data.next().unwrap()?;
let low: u8 = data.next().unwrap()?;
decoded[pos] = (high << 4) + low;
}
Ok(decoded)
}

Loading…
Cancel
Save