mirror of
https://github.com/arnaucube/miden-crypto.git
synced 2026-01-10 16:11:30 +01:00
Compare commits
40 Commits
v0.10.1
...
al-gkr-bas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f06fa30a9 | ||
|
|
0b074a795d | ||
|
|
862ccf54dd | ||
|
|
88bcdfd576 | ||
|
|
290894f497 | ||
|
|
4aac00884c | ||
|
|
2ef6f79656 | ||
|
|
5142e2fd31 | ||
|
|
9fb41337ec | ||
|
|
0296e05ccd | ||
|
|
499f97046d | ||
|
|
600feafe53 | ||
|
|
9d854f1fcb | ||
|
|
af76cb10d0 | ||
|
|
4758e0672f | ||
|
|
8bb080a91d | ||
|
|
e5f3b28645 | ||
|
|
29e0d07129 | ||
|
|
81a94ecbe7 | ||
|
|
223fbf887d | ||
|
|
9e77a7c9b7 | ||
|
|
894e20fe0c | ||
|
|
7ec7b06574 | ||
|
|
2499a8a2dd | ||
|
|
800994c69b | ||
|
|
26560605bf | ||
|
|
672340d0c2 | ||
|
|
8083b02aef | ||
|
|
ecb8719d45 | ||
|
|
4144f98560 | ||
|
|
c726050957 | ||
|
|
9239340888 | ||
|
|
97ee9298a4 | ||
|
|
bfae06e128 | ||
|
|
b4e2d63c10 | ||
|
|
9679329746 | ||
|
|
2bbea37dbe | ||
|
|
83000940da | ||
|
|
f44175e7a9 | ||
|
|
4cf8eebff5 |
20
.editorconfig
Normal file
20
.editorconfig
Normal file
@@ -0,0 +1,20 @@
|
||||
# Documentation available at editorconfig.org
|
||||
|
||||
root=true
|
||||
|
||||
[*]
|
||||
ident_style = space
|
||||
ident_size = 4
|
||||
end_of_line = lf
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
[*.rs]
|
||||
max_line_length = 100
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
||||
|
||||
[*.yml]
|
||||
ident_size = 2
|
||||
@@ -1,3 +1,11 @@
|
||||
## 0.8.0 (TBD)
|
||||
|
||||
* Implemented the `PartialMmr` data structure (#195).
|
||||
* Updated Winterfell dependency to v0.7 (#200)
|
||||
* Implemented RPX hash function (#201).
|
||||
* Added `FeltRng` and `RpoRandomCoin` (#237).
|
||||
* Added `inner_nodes()` method to `PartialMmr` (#238).
|
||||
|
||||
## 0.7.1 (2023-10-10)
|
||||
|
||||
* Fixed RPO Falcon signature build on Windows.
|
||||
@@ -12,7 +20,6 @@
|
||||
* Implemented benchmarking for `TieredSmt` (#182).
|
||||
* Added more leaf traversal methods for `MerkleStore` (#185).
|
||||
* Added SVE acceleration for RPO hash function (#189).
|
||||
* Implemented the `PartialMmr` datastructure (#195).
|
||||
|
||||
## 0.6.0 (2023-06-25)
|
||||
|
||||
|
||||
17
Cargo.toml
17
Cargo.toml
@@ -1,12 +1,12 @@
|
||||
[package]
|
||||
name = "miden-crypto"
|
||||
version = "0.7.1"
|
||||
version = "0.8.0"
|
||||
description = "Miden Cryptographic primitives"
|
||||
authors = ["miden contributors"]
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
repository = "https://github.com/0xPolygonMiden/crypto"
|
||||
documentation = "https://docs.rs/miden-crypto/0.7.1"
|
||||
documentation = "https://docs.rs/miden-crypto/0.8.0"
|
||||
categories = ["cryptography", "no-std"]
|
||||
keywords = ["miden", "crypto", "hash", "merkle"]
|
||||
edition = "2021"
|
||||
@@ -42,16 +42,19 @@ sve = ["std"]
|
||||
blake3 = { version = "1.5", default-features = false }
|
||||
clap = { version = "4.4", features = ["derive"], optional = true }
|
||||
libc = { version = "0.2", default-features = false, optional = true }
|
||||
rand_utils = { version = "0.6", package = "winter-rand-utils", optional = true }
|
||||
rand_utils = { version = "0.7", package = "winter-rand-utils", optional = true }
|
||||
serde = { version = "1.0", features = [ "derive" ], default-features = false, optional = true }
|
||||
winter_crypto = { version = "0.6", package = "winter-crypto", default-features = false }
|
||||
winter_math = { version = "0.6", package = "winter-math", default-features = false }
|
||||
winter_utils = { version = "0.6", package = "winter-utils", default-features = false }
|
||||
winter_crypto = { version = "0.7", package = "winter-crypto", default-features = false }
|
||||
winter_math = { version = "0.7", package = "winter-math", default-features = false }
|
||||
winter_utils = { version = "0.7", package = "winter-utils", default-features = false }
|
||||
rayon = "1.8.0"
|
||||
rand = "0.8.4"
|
||||
rand_core = { version = "0.5", default-features = false }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
proptest = "1.3"
|
||||
rand_utils = { version = "0.6", package = "winter-rand-utils" }
|
||||
rand_utils = { version = "0.7", package = "winter-rand-utils" }
|
||||
|
||||
[build-dependencies]
|
||||
cc = { version = "1.0", features = ["parallel"], optional = true }
|
||||
|
||||
10
README.md
10
README.md
@@ -6,6 +6,7 @@ This crate contains cryptographic primitives used in Polygon Miden.
|
||||
|
||||
* [BLAKE3](https://github.com/BLAKE3-team/BLAKE3) hash function with 256-bit, 192-bit, or 160-bit output. The 192-bit and 160-bit outputs are obtained by truncating the 256-bit output of the standard BLAKE3.
|
||||
* [RPO](https://eprint.iacr.org/2022/1577) hash function with 256-bit output. This hash function is an algebraic hash function suitable for recursive STARKs.
|
||||
* [RPX](https://eprint.iacr.org/2023/1045) hash function with 256-bit output. Similar to RPO, this hash function is suitable for recursive STARKs but it is about 2x faster as compared to RPO.
|
||||
|
||||
For performance benchmarks of these hash functions and their comparison to other popular hash functions please see [here](./benches/).
|
||||
|
||||
@@ -16,18 +17,25 @@ For performance benchmarks of these hash functions and their comparison to other
|
||||
* `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64.
|
||||
* `Mmr`: a Merkle mountain range structure designed to function as an append-only log.
|
||||
* `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64.
|
||||
* `PartialMmr`: a partial view of a Merkle mountain range structure.
|
||||
* `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values.
|
||||
* `TieredSmt`: a Sparse Merkle tree (with compaction), mapping 4-element keys to 4-element values.
|
||||
|
||||
The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state.
|
||||
|
||||
## Signatures
|
||||
[DAS module](./src/dsa) provides a set of digital signature schemes supported by default in the Miden VM. Currently, these schemes are:
|
||||
[DSA module](./src/dsa) provides a set of digital signature schemes supported by default in the Miden VM. Currently, these schemes are:
|
||||
|
||||
* `RPO Falcon512`: a variant of the [Falcon](https://falcon-sign.info/) signature scheme. This variant differs from the standard in that instead of using SHAKE256 hash function in the *hash-to-point* algorithm we use RPO256. This makes the signature more efficient to verify in Miden VM.
|
||||
|
||||
For the above signatures, key generation and signing is available only in the `std` context (see [crate features](#crate-features) below), while signature verification is available in `no_std` context as well.
|
||||
|
||||
## Pseudo-Random Element Generator
|
||||
[Pseudo random element generator module](./src/rand/) provides a set of traits and data structures that facilitate generating pseudo-random elements in the context of Miden VM and Miden rollup. The module currently includes:
|
||||
|
||||
* `FeltRng`: a trait for generating random field elements and random 4 field elements.
|
||||
* `RpoRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait.
|
||||
|
||||
## Crate features
|
||||
This crate can be compiled with the following features:
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ In the Miden VM, we make use of different hash functions. Some of these are "tra
|
||||
* **Poseidon** as specified [here](https://eprint.iacr.org/2019/458.pdf) and implemented [here](https://github.com/mir-protocol/plonky2/blob/806b88d7d6e69a30dc0b4775f7ba275c45e8b63b/plonky2/src/hash/poseidon_goldilocks.rs) (but in pure Rust, without vectorized instructions).
|
||||
* **Rescue Prime (RP)** as specified [here](https://eprint.iacr.org/2020/1143) and implemented [here](https://github.com/novifinancial/winterfell/blob/46dce1adf0/crypto/src/hash/rescue/rp64_256/mod.rs).
|
||||
* **Rescue Prime Optimized (RPO)** as specified [here](https://eprint.iacr.org/2022/1577) and implemented in this crate.
|
||||
* **Rescue Prime Extended (RPX)** a variant of the [xHash](https://eprint.iacr.org/2023/1045) hash function as implemented in this crate.
|
||||
|
||||
## Comparison and Instructions
|
||||
|
||||
@@ -15,28 +16,28 @@ The second scenario is that of sequential hashing where we take a sequence of le
|
||||
|
||||
#### Scenario 1: 2-to-1 hashing `h(a,b)`
|
||||
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 |
|
||||
| ------------------- | ------ | --------| --------- | --------- | ------- |
|
||||
| Apple M1 Pro | 80 ns | 245 ns | 1.5 us | 9.1 us | 5.4 us |
|
||||
| Apple M2 | 76 ns | 233 ns | 1.3 us | 7.9 us | 5.0 us |
|
||||
| Amazon Graviton 3 | 108 ns | | | | 5.3 us |
|
||||
| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 us | 9.1 us | 5.5 us |
|
||||
| Intel Core i5-8279U | 80 ns | | | | 8.7 us |
|
||||
| Intel Xeon 8375C | 67 ns | | | | 8.2 us |
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
|
||||
| ------------------- | ------ | ------- | --------- | --------- | ------- | ------- |
|
||||
| Apple M1 Pro | 76 ns | 245 ns | 1.5 µs | 9.1 µs | 5.2 µs | 2.7 µs |
|
||||
| Apple M2 Max | 71 ns | 233 ns | 1.3 µs | 7.9 µs | 4.6 µs | 2.4 µs |
|
||||
| Amazon Graviton 3 | 108 ns | | | | 5.3 µs | 3.1 µs |
|
||||
| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 µs | 9.1 µs | 5.5 µs | |
|
||||
| Intel Core i5-8279U | 68 ns | 536 ns | 2.0 µs | 13.6 µs | 8.5 µs | 4.4 µs |
|
||||
| Intel Xeon 8375C | 67 ns | | | | 8.2 µs | |
|
||||
|
||||
#### Scenario 2: Sequential hashing of 100 elements `h([a_0,...,a_99])`
|
||||
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 |
|
||||
| ------------------- | -------| ------- | --------- | --------- | ------- |
|
||||
| Apple M1 Pro | 1.0 us | 1.5 us | 19.4 us | 118 us | 70 us |
|
||||
| Apple M2 | 1.0 us | 1.5 us | 17.4 us | 103 us | 65 us |
|
||||
| Amazon Graviton 3 | 1.4 us | | | | 69 us |
|
||||
| AMD Ryzen 9 5950X | 0.8 us | 1.7 us | 15.7 us | 120 us | 72 us |
|
||||
| Intel Core i5-8279U | 1.0 us | | | | 116 us |
|
||||
| Intel Xeon 8375C | 0.8 ns | | | | 110 us |
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
|
||||
| ------------------- | -------| ------- | --------- | --------- | ------- | ------- |
|
||||
| Apple M1 Pro | 1.0 µs | 1.5 µs | 19.4 µs | 118 µs | 69 µs | 35 µs |
|
||||
| Apple M2 Max | 0.9 µs | 1.5 µs | 17.4 µs | 103 µs | 60 µs | 31 µs |
|
||||
| Amazon Graviton 3 | 1.4 µs | | | | 69 µs | 41 µs |
|
||||
| AMD Ryzen 9 5950X | 0.8 µs | 1.7 µs | 15.7 µs | 120 µs | 72 µs | |
|
||||
| Intel Core i5-8279U | 0.9 µs | | | | 107 µs | 56 µs |
|
||||
| Intel Xeon 8375C | 0.8 µs | | | | 110 µs | |
|
||||
|
||||
Notes:
|
||||
- On Graviton 3, RPO256 is run with SVE acceleration enabled.
|
||||
- On Graviton 3, RPO256 and RPX256 are run with SVE acceleration enabled.
|
||||
|
||||
### Instructions
|
||||
Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks for RPO and BLAKE3, clone the current repository, and from the root directory of the repo run the following:
|
||||
|
||||
@@ -3,6 +3,7 @@ use miden_crypto::{
|
||||
hash::{
|
||||
blake::Blake3_256,
|
||||
rpo::{Rpo256, RpoDigest},
|
||||
rpx::{Rpx256, RpxDigest},
|
||||
},
|
||||
Felt,
|
||||
};
|
||||
@@ -57,6 +58,54 @@ fn rpo256_sequential(c: &mut Criterion) {
|
||||
});
|
||||
}
|
||||
|
||||
fn rpx256_2to1(c: &mut Criterion) {
|
||||
let v: [RpxDigest; 2] = [Rpx256::hash(&[1_u8]), Rpx256::hash(&[2_u8])];
|
||||
c.bench_function("RPX256 2-to-1 hashing (cached)", |bench| {
|
||||
bench.iter(|| Rpx256::merge(black_box(&v)))
|
||||
});
|
||||
|
||||
c.bench_function("RPX256 2-to-1 hashing (random)", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
[
|
||||
Rpx256::hash(&rand_value::<u64>().to_le_bytes()),
|
||||
Rpx256::hash(&rand_value::<u64>().to_le_bytes()),
|
||||
]
|
||||
},
|
||||
|state| Rpx256::merge(&state),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn rpx256_sequential(c: &mut Criterion) {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.into_iter()
|
||||
.map(Felt::new)
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
.expect("should not fail");
|
||||
c.bench_function("RPX256 sequential hashing (cached)", |bench| {
|
||||
bench.iter(|| Rpx256::hash_elements(black_box(&v)))
|
||||
});
|
||||
|
||||
c.bench_function("RPX256 sequential hashing (random)", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.into_iter()
|
||||
.map(|_| Felt::new(rand_value()))
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
.expect("should not fail");
|
||||
v
|
||||
},
|
||||
|state| Rpx256::hash_elements(&state),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn blake3_2to1(c: &mut Criterion) {
|
||||
let v: [<Blake3_256 as Hasher>::Digest; 2] =
|
||||
[Blake3_256::hash(&[1_u8]), Blake3_256::hash(&[2_u8])];
|
||||
@@ -106,5 +155,13 @@ fn blake3_sequential(c: &mut Criterion) {
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(hash_group, rpo256_2to1, rpo256_sequential, blake3_2to1, blake3_sequential);
|
||||
criterion_group!(
|
||||
hash_group,
|
||||
rpx256_2to1,
|
||||
rpx256_sequential,
|
||||
rpo256_2to1,
|
||||
rpo256_sequential,
|
||||
blake3_2to1,
|
||||
blake3_sequential
|
||||
);
|
||||
criterion_main!(hash_group);
|
||||
|
||||
@@ -147,7 +147,12 @@ impl KeyPair {
|
||||
};
|
||||
|
||||
if res == 0 {
|
||||
Ok(Signature { sig, pk: self.public_key })
|
||||
Ok(Signature {
|
||||
sig,
|
||||
pk: self.public_key,
|
||||
pk_polynomial: Default::default(),
|
||||
sig_polynomial: Default::default(),
|
||||
})
|
||||
} else {
|
||||
Err(FalconError::SigGenerationFailed)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use super::{
|
||||
SIG_L2_BOUND, ZERO,
|
||||
};
|
||||
use crate::utils::string::ToString;
|
||||
use core::cell::OnceCell;
|
||||
|
||||
// FALCON SIGNATURE
|
||||
// ================================================================================================
|
||||
@@ -43,6 +44,10 @@ use crate::utils::string::ToString;
|
||||
pub struct Signature {
|
||||
pub(super) pk: PublicKeyBytes,
|
||||
pub(super) sig: SignatureBytes,
|
||||
|
||||
// Cached polynomial decoding for public key and signatures
|
||||
pub(super) pk_polynomial: OnceCell<Polynomial>,
|
||||
pub(super) sig_polynomial: OnceCell<Polynomial>,
|
||||
}
|
||||
|
||||
impl Signature {
|
||||
@@ -51,10 +56,11 @@ impl Signature {
|
||||
|
||||
/// Returns the public key polynomial h.
|
||||
pub fn pub_key_poly(&self) -> Polynomial {
|
||||
// TODO: memoize
|
||||
// we assume that the signature was constructed with a valid public key, and thus
|
||||
// expect() is OK here.
|
||||
Polynomial::from_pub_key(&self.pk).expect("invalid public key")
|
||||
*self.pk_polynomial.get_or_init(|| {
|
||||
// we assume that the signature was constructed with a valid public key, and thus
|
||||
// expect() is OK here.
|
||||
Polynomial::from_pub_key(&self.pk).expect("invalid public key")
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the nonce component of the signature represented as field elements.
|
||||
@@ -70,10 +76,11 @@ impl Signature {
|
||||
|
||||
// Returns the polynomial representation of the signature in Z_p[x]/(phi).
|
||||
pub fn sig_poly(&self) -> Polynomial {
|
||||
// TODO: memoize
|
||||
// we assume that the signature was constructed with a valid signature, and thus
|
||||
// expect() is OK here.
|
||||
Polynomial::from_signature(&self.sig).expect("invalid signature")
|
||||
*self.sig_polynomial.get_or_init(|| {
|
||||
// we assume that the signature was constructed with a valid signature, and thus
|
||||
// expect() is OK here.
|
||||
Polynomial::from_signature(&self.sig).expect("invalid signature")
|
||||
})
|
||||
}
|
||||
|
||||
// HASH-TO-POINT
|
||||
@@ -123,12 +130,14 @@ impl Deserializable for Signature {
|
||||
let sig: SignatureBytes = source.read_array()?;
|
||||
|
||||
// make sure public key and signature can be decoded correctly
|
||||
Polynomial::from_pub_key(&pk)
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?;
|
||||
Polynomial::from_signature(&sig[41..])
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?;
|
||||
let pk_polynomial = Polynomial::from_pub_key(&pk)
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?
|
||||
.into();
|
||||
let sig_polynomial = Polynomial::from_signature(&sig[41..])
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?
|
||||
.into();
|
||||
|
||||
Ok(Self { pk, sig })
|
||||
Ok(Self { pk, sig, pk_polynomial, sig_polynomial })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
977
src/gkr/circuit/mod.rs
Normal file
977
src/gkr/circuit/mod.rs
Normal file
@@ -0,0 +1,977 @@
|
||||
use alloc::sync::Arc;
|
||||
use winter_crypto::{ElementHasher, RandomCoin};
|
||||
use winter_math::fields::f64::BaseElement;
|
||||
use winter_math::FieldElement;
|
||||
|
||||
use crate::gkr::multivariate::{
|
||||
ComposedMultiLinearsOracle, EqPolynomial, GkrCompositionVanilla, MultiLinearOracle,
|
||||
};
|
||||
use crate::gkr::sumcheck::{sum_check_verify, Claim};
|
||||
|
||||
use super::multivariate::{
|
||||
gen_plain_gkr_oracle, gkr_composition_from_composition_polys, ComposedMultiLinears,
|
||||
CompositionPolynomial, MultiLinear,
|
||||
};
|
||||
use super::sumcheck::{
|
||||
sum_check_prove, sum_check_verify_and_reduce, FinalEvaluationClaim,
|
||||
PartialProof as SumcheckInstanceProof, RoundProof as SumCheckRoundProof, Witness,
|
||||
};
|
||||
|
||||
/// Layered circuit for computing a sum of fractions.
|
||||
///
|
||||
/// The circuit computes a sum of fractions based on the formula a / c + b / d = (a * d + b * c) / (c * d)
|
||||
/// which defines a "gate" ((a, b), (c, d)) --> (a * d + b * c, c * d) upon which the `FractionalSumCircuit`
|
||||
/// is built.
|
||||
/// TODO: Swap 1 and 0
|
||||
#[derive(Debug)]
|
||||
pub struct FractionalSumCircuit<E: FieldElement> {
|
||||
p_1_vec: Vec<MultiLinear<E>>,
|
||||
p_0_vec: Vec<MultiLinear<E>>,
|
||||
q_1_vec: Vec<MultiLinear<E>>,
|
||||
q_0_vec: Vec<MultiLinear<E>>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> FractionalSumCircuit<E> {
|
||||
/// Computes The values of the gates outputs for each of the layers of the fractional sum circuit.
|
||||
pub fn new_(num_den: &Vec<MultiLinear<E>>) -> Self {
|
||||
let mut p_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut p_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut q_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut q_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
|
||||
let num_layers = num_den[0].len().ilog2() as usize;
|
||||
|
||||
p_1_vec.push(num_den[0].to_owned());
|
||||
p_0_vec.push(num_den[1].to_owned());
|
||||
q_1_vec.push(num_den[2].to_owned());
|
||||
q_0_vec.push(num_den[3].to_owned());
|
||||
|
||||
for i in 0..num_layers {
|
||||
let (output_p_1, output_p_0, output_q_1, output_q_0) =
|
||||
FractionalSumCircuit::compute_layer(
|
||||
&p_1_vec[i],
|
||||
&p_0_vec[i],
|
||||
&q_1_vec[i],
|
||||
&q_0_vec[i],
|
||||
);
|
||||
p_1_vec.push(output_p_1);
|
||||
p_0_vec.push(output_p_0);
|
||||
q_1_vec.push(output_q_1);
|
||||
q_0_vec.push(output_q_0);
|
||||
}
|
||||
|
||||
FractionalSumCircuit { p_1_vec, p_0_vec, q_1_vec, q_0_vec }
|
||||
}
|
||||
|
||||
/// Compute the output values of the layer given a set of input values
|
||||
fn compute_layer(
|
||||
inp_p_1: &MultiLinear<E>,
|
||||
inp_p_0: &MultiLinear<E>,
|
||||
inp_q_1: &MultiLinear<E>,
|
||||
inp_q_0: &MultiLinear<E>,
|
||||
) -> (MultiLinear<E>, MultiLinear<E>, MultiLinear<E>, MultiLinear<E>) {
|
||||
let len = inp_q_1.len();
|
||||
let outp_p_1 = (0..len / 2)
|
||||
.map(|i| inp_p_1[i] * inp_q_0[i] + inp_p_0[i] * inp_q_1[i])
|
||||
.collect::<Vec<E>>();
|
||||
let outp_p_0 = (len / 2..len)
|
||||
.map(|i| inp_p_1[i] * inp_q_0[i] + inp_p_0[i] * inp_q_1[i])
|
||||
.collect::<Vec<E>>();
|
||||
let outp_q_1 = (0..len / 2).map(|i| inp_q_1[i] * inp_q_0[i]).collect::<Vec<E>>();
|
||||
let outp_q_0 = (len / 2..len).map(|i| inp_q_1[i] * inp_q_0[i]).collect::<Vec<E>>();
|
||||
|
||||
(
|
||||
MultiLinear::new(outp_p_1),
|
||||
MultiLinear::new(outp_p_0),
|
||||
MultiLinear::new(outp_q_1),
|
||||
MultiLinear::new(outp_q_0),
|
||||
)
|
||||
}
|
||||
|
||||
/// Computes The values of the gates outputs for each of the layers of the fractional sum circuit.
|
||||
pub fn new(poly: &MultiLinear<E>) -> Self {
|
||||
let mut p_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut p_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut q_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut q_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
|
||||
let num_layers = poly.len().ilog2() as usize - 1;
|
||||
let (output_p, output_q) = poly.split(poly.len() / 2);
|
||||
let (output_p_1, output_p_0) = output_p.split(output_p.len() / 2);
|
||||
let (output_q_1, output_q_0) = output_q.split(output_q.len() / 2);
|
||||
|
||||
p_1_vec.push(output_p_1);
|
||||
p_0_vec.push(output_p_0);
|
||||
q_1_vec.push(output_q_1);
|
||||
q_0_vec.push(output_q_0);
|
||||
|
||||
for i in 0..num_layers - 1 {
|
||||
let (output_p_1, output_p_0, output_q_1, output_q_0) =
|
||||
FractionalSumCircuit::compute_layer(
|
||||
&p_1_vec[i],
|
||||
&p_0_vec[i],
|
||||
&q_1_vec[i],
|
||||
&q_0_vec[i],
|
||||
);
|
||||
p_1_vec.push(output_p_1);
|
||||
p_0_vec.push(output_p_0);
|
||||
q_1_vec.push(output_q_1);
|
||||
q_0_vec.push(output_q_0);
|
||||
}
|
||||
|
||||
FractionalSumCircuit { p_1_vec, p_0_vec, q_1_vec, q_0_vec }
|
||||
}
|
||||
|
||||
/// Given a value r, computes the evaluation of the last layer at r when interpreted as (two)
|
||||
/// multilinear polynomials.
|
||||
pub fn evaluate(&self, r: E) -> (E, E) {
|
||||
let len = self.p_1_vec.len();
|
||||
assert_eq!(self.p_1_vec[len - 1].num_variables(), 0);
|
||||
assert_eq!(self.p_0_vec[len - 1].num_variables(), 0);
|
||||
assert_eq!(self.q_1_vec[len - 1].num_variables(), 0);
|
||||
assert_eq!(self.q_0_vec[len - 1].num_variables(), 0);
|
||||
|
||||
let mut p_1 = self.p_1_vec[len - 1].clone();
|
||||
p_1.extend(&self.p_0_vec[len - 1]);
|
||||
let mut q_1 = self.q_1_vec[len - 1].clone();
|
||||
q_1.extend(&self.q_0_vec[len - 1]);
|
||||
|
||||
(p_1.evaluate(&[r]), q_1.evaluate(&[r]))
|
||||
}
|
||||
}
|
||||
|
||||
/// A proof for reducing a claim on the correctness of the output of a layer to that of:
|
||||
///
|
||||
/// 1. Correctness of a sumcheck proof on the claimed output.
|
||||
/// 2. Correctness of the evaluation of the input (to the said layer) at a random point when
|
||||
/// interpreted as multilinear polynomial.
|
||||
///
|
||||
/// The verifier will then have to work backward and:
|
||||
///
|
||||
/// 1. Verify that the sumcheck proof is valid.
|
||||
/// 2. Recurse on the (claimed evaluations) using the same approach as above.
|
||||
///
|
||||
/// Note that the following struct batches proofs for many circuits of the same type that
|
||||
/// are independent i.e., parallel.
|
||||
#[derive(Debug)]
|
||||
pub struct LayerProof<E: FieldElement> {
|
||||
pub proof: SumcheckInstanceProof<E>,
|
||||
pub claims_sum_p1: E,
|
||||
pub claims_sum_p0: E,
|
||||
pub claims_sum_q1: E,
|
||||
pub claims_sum_q0: E,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl<E: FieldElement<BaseField = BaseElement> + 'static> LayerProof<E> {
|
||||
/// Checks the validity of a `LayerProof`.
|
||||
///
|
||||
/// It first reduces the 2 claims to 1 claim using randomness and then checks that the sumcheck
|
||||
/// protocol was correctly executed.
|
||||
///
|
||||
/// The method outputs:
|
||||
///
|
||||
/// 1. A vector containing the randomness sent by the verifier throughout the course of the
|
||||
/// sum-check protocol.
|
||||
/// 2. The (claimed) evaluation of the inner polynomial (i.e., the one being summed) at the this random vector.
|
||||
/// 3. The random value used in the 2-to-1 reduction of the 2 sumchecks.
|
||||
pub fn verify_sum_check_before_last<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
&self,
|
||||
claim: (E, E),
|
||||
num_rounds: usize,
|
||||
transcript: &mut C,
|
||||
) -> ((E, Vec<E>), E) {
|
||||
// Absorb the claims
|
||||
let data = vec![claim.0, claim.1];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Squeeze challenge to reduce two sumchecks to one
|
||||
let r_sum_check: E = transcript.draw().unwrap();
|
||||
|
||||
// Run the sumcheck protocol
|
||||
|
||||
// Given r_sum_check and claim, we create a Claim with the GKR composer and then call the generic sum-check verifier
|
||||
let reduced_claim = claim.0 + claim.1 * r_sum_check;
|
||||
|
||||
// Create vanilla oracle
|
||||
let oracle = gen_plain_gkr_oracle(num_rounds, r_sum_check);
|
||||
|
||||
// Create sum-check claim
|
||||
let transformed_claim = Claim {
|
||||
sum_value: reduced_claim,
|
||||
polynomial: oracle,
|
||||
};
|
||||
|
||||
let reduced_gkr_claim =
|
||||
sum_check_verify_and_reduce(&transformed_claim, self.proof.clone(), transcript);
|
||||
|
||||
(reduced_gkr_claim, r_sum_check)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GkrClaim<E: FieldElement + 'static> {
|
||||
evaluation_point: Vec<E>,
|
||||
claimed_evaluation: (E, E),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CircuitProof<E: FieldElement + 'static> {
|
||||
pub proof: Vec<LayerProof<E>>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement<BaseField = BaseElement> + 'static> CircuitProof<E> {
|
||||
pub fn prove<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
circuit: &mut FractionalSumCircuit<E>,
|
||||
transcript: &mut C,
|
||||
) -> (Self, Vec<E>, Vec<Vec<E>>) {
|
||||
let mut proof_layers: Vec<LayerProof<E>> = Vec::new();
|
||||
let num_layers = circuit.p_0_vec.len();
|
||||
|
||||
let data = vec![
|
||||
circuit.p_1_vec[num_layers - 1][0],
|
||||
circuit.p_0_vec[num_layers - 1][0],
|
||||
circuit.q_1_vec[num_layers - 1][0],
|
||||
circuit.q_0_vec[num_layers - 1][0],
|
||||
];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Challenge to reduce p1, p0, q1, q0 to pr, qr
|
||||
let r_cord = transcript.draw().unwrap();
|
||||
|
||||
// Compute the (2-to-1 folded) claim
|
||||
let mut claim = circuit.evaluate(r_cord);
|
||||
let mut all_rand = Vec::new();
|
||||
|
||||
let mut rand = Vec::new();
|
||||
rand.push(r_cord);
|
||||
for layer_id in (0..num_layers - 1).rev() {
|
||||
let len = circuit.p_0_vec[layer_id].len();
|
||||
|
||||
// Construct the Lagrange kernel evaluated at previous GKR round randomness.
|
||||
// TODO: Treat the direction of doing sum-check more robustly.
|
||||
let mut rand_reversed = rand.clone();
|
||||
rand_reversed.reverse();
|
||||
let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations();
|
||||
let mut poly_x = MultiLinear::from_values(&eq_evals);
|
||||
assert_eq!(poly_x.len(), len);
|
||||
|
||||
let num_rounds = poly_x.len().ilog2() as usize;
|
||||
|
||||
// 1. A is a polynomial containing the evaluations `p_1`.
|
||||
// 2. B is a polynomial containing the evaluations `p_0`.
|
||||
// 3. C is a polynomial containing the evaluations `q_1`.
|
||||
// 4. D is a polynomial containing the evaluations `q_0`.
|
||||
let poly_a: &mut MultiLinear<E>;
|
||||
let poly_b: &mut MultiLinear<E>;
|
||||
let poly_c: &mut MultiLinear<E>;
|
||||
let poly_d: &mut MultiLinear<E>;
|
||||
poly_a = &mut circuit.p_1_vec[layer_id];
|
||||
poly_b = &mut circuit.p_0_vec[layer_id];
|
||||
poly_c = &mut circuit.q_1_vec[layer_id];
|
||||
poly_d = &mut circuit.q_0_vec[layer_id];
|
||||
|
||||
let poly_vec_par = (poly_a, poly_b, poly_c, poly_d, &mut poly_x);
|
||||
|
||||
// The (non-linear) polynomial combining the multilinear polynomials
|
||||
let comb_func = |a: &E, b: &E, c: &E, d: &E, x: &E, rho: &E| -> E {
|
||||
(*a * *d + *b * *c + *rho * *c * *d) * *x
|
||||
};
|
||||
|
||||
// Run the sumcheck protocol
|
||||
let (proof, rand_sumcheck, claims_sum) = sum_check_prover_gkr_before_last::<E, _, _>(
|
||||
claim,
|
||||
num_rounds,
|
||||
poly_vec_par,
|
||||
comb_func,
|
||||
transcript,
|
||||
);
|
||||
|
||||
let (claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0, _claims_eq) =
|
||||
claims_sum;
|
||||
|
||||
let data = vec![claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Produce a random challenge to condense claims into a single claim
|
||||
let r_layer = transcript.draw().unwrap();
|
||||
|
||||
claim = (
|
||||
claims_sum_p1 + r_layer * (claims_sum_p0 - claims_sum_p1),
|
||||
claims_sum_q1 + r_layer * (claims_sum_q0 - claims_sum_q1),
|
||||
);
|
||||
|
||||
// Collect the randomness used for the current layer in order to construct the random
|
||||
// point where the input multilinear polynomials were evaluated.
|
||||
let mut ext = rand_sumcheck;
|
||||
ext.push(r_layer);
|
||||
all_rand.push(rand);
|
||||
rand = ext;
|
||||
|
||||
proof_layers.push(LayerProof {
|
||||
proof,
|
||||
claims_sum_p1,
|
||||
claims_sum_p0,
|
||||
claims_sum_q1,
|
||||
claims_sum_q0,
|
||||
});
|
||||
}
|
||||
|
||||
(CircuitProof { proof: proof_layers }, rand, all_rand)
|
||||
}
|
||||
|
||||
pub fn prove_virtual_bus<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
composition_polys: Vec<Vec<Arc<dyn CompositionPolynomial<E>>>>,
|
||||
mls: &mut Vec<MultiLinear<E>>,
|
||||
transcript: &mut C,
|
||||
) -> (Vec<E>, Self, super::sumcheck::FullProof<E>) {
|
||||
let num_evaluations = 1 << mls[0].num_variables();
|
||||
|
||||
// I) Evaluate the numerators and denominators over the boolean hyper-cube
|
||||
let mut num_den: Vec<Vec<E>> = vec![vec![]; 4];
|
||||
for i in 0..num_evaluations {
|
||||
for j in 0..4 {
|
||||
let query: Vec<E> = mls.iter().map(|ml| ml[i]).collect();
|
||||
|
||||
composition_polys[j].iter().for_each(|c| {
|
||||
let evaluation = c.as_ref().evaluate(&query);
|
||||
num_den[j].push(evaluation);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// II) Evaluate the GKR fractional sum circuit
|
||||
let input: Vec<MultiLinear<E>> =
|
||||
(0..4).map(|i| MultiLinear::from_values(&num_den[i])).collect();
|
||||
let mut circuit = FractionalSumCircuit::new_(&input);
|
||||
|
||||
// III) Run the GKR prover for all layers except the last one
|
||||
let (gkr_proofs, GkrClaim { evaluation_point, claimed_evaluation }) =
|
||||
CircuitProof::prove_before_final(&mut circuit, transcript);
|
||||
|
||||
// IV) Run the sum-check prover for the last GKR layer counting backwards i.e., first layer
|
||||
// in the circuit.
|
||||
|
||||
// 1) Build the EQ polynomial (Lagrange kernel) at the randomness sampled during the previous
|
||||
// sum-check protocol run
|
||||
let mut rand_reversed = evaluation_point.clone();
|
||||
rand_reversed.reverse();
|
||||
let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations();
|
||||
let poly_x = MultiLinear::from_values(&eq_evals);
|
||||
|
||||
// 2) Add the Lagrange kernel to the list of MLs
|
||||
mls.push(poly_x);
|
||||
|
||||
// 3) Absorb the final sum-check claims and generate randomness for 2-to-1 sum-check reduction
|
||||
let data = vec![claimed_evaluation.0, claimed_evaluation.1];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
// Squeeze challenge to reduce two sumchecks to one
|
||||
let r_sum_check = transcript.draw().unwrap();
|
||||
let reduced_claim = claimed_evaluation.0 + claimed_evaluation.1 * r_sum_check;
|
||||
|
||||
// 4) Create the composed ML representing the numerators and denominators of the topmost GKR layer
|
||||
let gkr_final_composed_ml = gkr_composition_from_composition_polys(
|
||||
&composition_polys,
|
||||
r_sum_check,
|
||||
1 << mls[0].num_variables,
|
||||
);
|
||||
let composed_ml =
|
||||
ComposedMultiLinears::new(Arc::new(gkr_final_composed_ml.clone()), mls.to_vec());
|
||||
|
||||
// 5) Create the composed ML oracle. This will be used for verifying the FinalEvaluationClaim downstream
|
||||
// TODO: This should be an input to the current function.
|
||||
// TODO: Make MultiLinearOracle a variant in an enum so that it is possible to capture other types of oracles.
|
||||
// For example, shifts of polynomials, Lagrange kernels at a random point or periodic (transparent) polynomials.
|
||||
let left_num_oracle = MultiLinearOracle { id: 0 };
|
||||
let right_num_oracle = MultiLinearOracle { id: 1 };
|
||||
let left_denom_oracle = MultiLinearOracle { id: 2 };
|
||||
let right_denom_oracle = MultiLinearOracle { id: 3 };
|
||||
let eq_oracle = MultiLinearOracle { id: 4 };
|
||||
let composed_ml_oracle = ComposedMultiLinearsOracle {
|
||||
composer: (Arc::new(gkr_final_composed_ml.clone())),
|
||||
multi_linears: vec![
|
||||
eq_oracle,
|
||||
left_num_oracle,
|
||||
right_num_oracle,
|
||||
left_denom_oracle,
|
||||
right_denom_oracle,
|
||||
],
|
||||
};
|
||||
|
||||
// 6) Create the claim for the final sum-check protocol.
|
||||
let claim = Claim {
|
||||
sum_value: reduced_claim,
|
||||
polynomial: composed_ml_oracle.clone(),
|
||||
};
|
||||
|
||||
// 7) Create the witness for the sum-check claim.
|
||||
let witness = Witness { polynomial: composed_ml };
|
||||
let output = sum_check_prove(&claim, composed_ml_oracle, witness, transcript);
|
||||
|
||||
// 8) Create the claimed output of the circuit.
|
||||
let circuit_outputs = vec![
|
||||
circuit.p_1_vec.last().unwrap()[0],
|
||||
circuit.p_0_vec.last().unwrap()[0],
|
||||
circuit.q_1_vec.last().unwrap()[0],
|
||||
circuit.q_0_vec.last().unwrap()[0],
|
||||
];
|
||||
|
||||
// 9) Return:
|
||||
// 1. The claimed circuit outputs.
|
||||
// 2. GKR proofs of all circuit layers except the initial layer.
|
||||
// 3. Output of the final sum-check protocol.
|
||||
(circuit_outputs, gkr_proofs, output)
|
||||
}
|
||||
|
||||
pub fn prove_before_final<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
sum_circuits: &mut FractionalSumCircuit<E>,
|
||||
transcript: &mut C,
|
||||
) -> (Self, GkrClaim<E>) {
|
||||
let mut proof_layers: Vec<LayerProof<E>> = Vec::new();
|
||||
let num_layers = sum_circuits.p_0_vec.len();
|
||||
|
||||
let data = vec![
|
||||
sum_circuits.p_1_vec[num_layers - 1][0],
|
||||
sum_circuits.p_0_vec[num_layers - 1][0],
|
||||
sum_circuits.q_1_vec[num_layers - 1][0],
|
||||
sum_circuits.q_0_vec[num_layers - 1][0],
|
||||
];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Challenge to reduce p1, p0, q1, q0 to pr, qr
|
||||
let r_cord = transcript.draw().unwrap();
|
||||
|
||||
// Compute the (2-to-1 folded) claim
|
||||
let mut claims_to_verify = sum_circuits.evaluate(r_cord);
|
||||
let mut all_rand = Vec::new();
|
||||
|
||||
let mut rand = Vec::new();
|
||||
rand.push(r_cord);
|
||||
for layer_id in (1..num_layers - 1).rev() {
|
||||
let len = sum_circuits.p_0_vec[layer_id].len();
|
||||
|
||||
// Construct the Lagrange kernel evaluated at previous GKR round randomness.
|
||||
// TODO: Treat the direction of doing sum-check more robustly.
|
||||
let mut rand_reversed = rand.clone();
|
||||
rand_reversed.reverse();
|
||||
let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations();
|
||||
let mut poly_x = MultiLinear::from_values(&eq_evals);
|
||||
assert_eq!(poly_x.len(), len);
|
||||
|
||||
let num_rounds = poly_x.len().ilog2() as usize;
|
||||
|
||||
// 1. A is a polynomial containing the evaluations `p_1`.
|
||||
// 2. B is a polynomial containing the evaluations `p_0`.
|
||||
// 3. C is a polynomial containing the evaluations `q_1`.
|
||||
// 4. D is a polynomial containing the evaluations `q_0`.
|
||||
let poly_a: &mut MultiLinear<E>;
|
||||
let poly_b: &mut MultiLinear<E>;
|
||||
let poly_c: &mut MultiLinear<E>;
|
||||
let poly_d: &mut MultiLinear<E>;
|
||||
poly_a = &mut sum_circuits.p_1_vec[layer_id];
|
||||
poly_b = &mut sum_circuits.p_0_vec[layer_id];
|
||||
poly_c = &mut sum_circuits.q_1_vec[layer_id];
|
||||
poly_d = &mut sum_circuits.q_0_vec[layer_id];
|
||||
|
||||
let poly_vec = (poly_a, poly_b, poly_c, poly_d, &mut poly_x);
|
||||
|
||||
let claim = claims_to_verify;
|
||||
|
||||
// The (non-linear) polynomial combining the multilinear polynomials
|
||||
let comb_func = |a: &E, b: &E, c: &E, d: &E, x: &E, rho: &E| -> E {
|
||||
(*a * *d + *b * *c + *rho * *c * *d) * *x
|
||||
};
|
||||
|
||||
// Run the sumcheck protocol
|
||||
let (proof, rand_sumcheck, claims_sum) = sum_check_prover_gkr_before_last::<E, _, _>(
|
||||
claim, num_rounds, poly_vec, comb_func, transcript,
|
||||
);
|
||||
|
||||
let (claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0, _claims_eq) =
|
||||
claims_sum;
|
||||
|
||||
let data = vec![claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Produce a random challenge to condense claims into a single claim
|
||||
let r_layer = transcript.draw().unwrap();
|
||||
|
||||
claims_to_verify = (
|
||||
claims_sum_p1 + r_layer * (claims_sum_p0 - claims_sum_p1),
|
||||
claims_sum_q1 + r_layer * (claims_sum_q0 - claims_sum_q1),
|
||||
);
|
||||
|
||||
// Collect the randomness used for the current layer in order to construct the random
|
||||
// point where the input multilinear polynomials were evaluated.
|
||||
let mut ext = rand_sumcheck;
|
||||
ext.push(r_layer);
|
||||
all_rand.push(rand);
|
||||
rand = ext;
|
||||
|
||||
proof_layers.push(LayerProof {
|
||||
proof,
|
||||
claims_sum_p1,
|
||||
claims_sum_p0,
|
||||
claims_sum_q1,
|
||||
claims_sum_q0,
|
||||
});
|
||||
}
|
||||
let gkr_claim = GkrClaim {
|
||||
evaluation_point: rand.clone(),
|
||||
claimed_evaluation: claims_to_verify,
|
||||
};
|
||||
|
||||
(CircuitProof { proof: proof_layers }, gkr_claim)
|
||||
}
|
||||
|
||||
pub fn verify<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
&self,
|
||||
claims_sum_vec: &[E],
|
||||
transcript: &mut C,
|
||||
) -> ((E, E), Vec<E>) {
|
||||
let num_layers = self.proof.len() as usize - 1;
|
||||
let mut rand: Vec<E> = Vec::new();
|
||||
|
||||
let data = claims_sum_vec;
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
let r_cord = transcript.draw().unwrap();
|
||||
|
||||
let p_poly_coef = vec![claims_sum_vec[0], claims_sum_vec[1]];
|
||||
let q_poly_coef = vec![claims_sum_vec[2], claims_sum_vec[3]];
|
||||
|
||||
let p_poly = MultiLinear::new(p_poly_coef);
|
||||
let q_poly = MultiLinear::new(q_poly_coef);
|
||||
let p_eval = p_poly.evaluate(&[r_cord]);
|
||||
let q_eval = q_poly.evaluate(&[r_cord]);
|
||||
|
||||
let mut reduced_claim = (p_eval, q_eval);
|
||||
|
||||
rand.push(r_cord);
|
||||
for (num_rounds, i) in (0..num_layers).enumerate() {
|
||||
let ((claim_last, rand_sumcheck), r_two_sumchecks) = self.proof[i]
|
||||
.verify_sum_check_before_last::<_, _>(reduced_claim, num_rounds + 1, transcript);
|
||||
|
||||
let claims_sum_p1 = &self.proof[i].claims_sum_p1;
|
||||
let claims_sum_p0 = &self.proof[i].claims_sum_p0;
|
||||
let claims_sum_q1 = &self.proof[i].claims_sum_q1;
|
||||
let claims_sum_q0 = &self.proof[i].claims_sum_q0;
|
||||
|
||||
let data = vec![
|
||||
claims_sum_p1.clone(),
|
||||
claims_sum_p0.clone(),
|
||||
claims_sum_q1.clone(),
|
||||
claims_sum_q0.clone(),
|
||||
];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
assert_eq!(rand.len(), rand_sumcheck.len());
|
||||
|
||||
let eq: E = (0..rand.len())
|
||||
.map(|i| {
|
||||
rand[i] * rand_sumcheck[i] + (E::ONE - rand[i]) * (E::ONE - rand_sumcheck[i])
|
||||
})
|
||||
.fold(E::ONE, |acc, term| acc * term);
|
||||
|
||||
let claim_expected: E = (*claims_sum_p1 * *claims_sum_q0
|
||||
+ *claims_sum_p0 * *claims_sum_q1
|
||||
+ r_two_sumchecks * *claims_sum_q1 * *claims_sum_q0)
|
||||
* eq;
|
||||
|
||||
assert_eq!(claim_expected, claim_last);
|
||||
|
||||
// Produce a random challenge to condense claims into a single claim
|
||||
let r_layer = transcript.draw().unwrap();
|
||||
|
||||
reduced_claim = (
|
||||
*claims_sum_p1 + r_layer * (*claims_sum_p0 - *claims_sum_p1),
|
||||
*claims_sum_q1 + r_layer * (*claims_sum_q0 - *claims_sum_q1),
|
||||
);
|
||||
|
||||
// Collect the randomness' used for the current layer in order to construct the random
|
||||
// point where the input multilinear polynomials were evaluated.
|
||||
let mut ext = rand_sumcheck;
|
||||
ext.push(r_layer);
|
||||
rand = ext;
|
||||
}
|
||||
(reduced_claim, rand)
|
||||
}
|
||||
|
||||
pub fn verify_virtual_bus<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
&self,
|
||||
composition_polys: Vec<Vec<Arc<dyn CompositionPolynomial<E>>>>,
|
||||
final_layer_proof: super::sumcheck::FullProof<E>,
|
||||
claims_sum_vec: &[E],
|
||||
transcript: &mut C,
|
||||
) -> (FinalEvaluationClaim<E>, Vec<E>) {
|
||||
let num_layers = self.proof.len() as usize;
|
||||
let mut rand: Vec<E> = Vec::new();
|
||||
|
||||
// Check that a/b + d/e is equal to 0
|
||||
assert_ne!(claims_sum_vec[2], E::ZERO);
|
||||
assert_ne!(claims_sum_vec[3], E::ZERO);
|
||||
assert_eq!(
|
||||
claims_sum_vec[0] * claims_sum_vec[3] + claims_sum_vec[1] * claims_sum_vec[2],
|
||||
E::ZERO
|
||||
);
|
||||
|
||||
let data = claims_sum_vec;
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
let r_cord = transcript.draw().unwrap();
|
||||
|
||||
let p_poly_coef = vec![claims_sum_vec[0], claims_sum_vec[1]];
|
||||
let q_poly_coef = vec![claims_sum_vec[2], claims_sum_vec[3]];
|
||||
|
||||
let p_poly = MultiLinear::new(p_poly_coef);
|
||||
let q_poly = MultiLinear::new(q_poly_coef);
|
||||
let p_eval = p_poly.evaluate(&[r_cord]);
|
||||
let q_eval = q_poly.evaluate(&[r_cord]);
|
||||
|
||||
let mut reduced_claim = (p_eval, q_eval);
|
||||
|
||||
// I) Verify all GKR layers but for the last one counting backwards.
|
||||
rand.push(r_cord);
|
||||
for (num_rounds, i) in (0..num_layers).enumerate() {
|
||||
let ((claim_last, rand_sumcheck), r_two_sumchecks) = self.proof[i]
|
||||
.verify_sum_check_before_last::<_, _>(reduced_claim, num_rounds + 1, transcript);
|
||||
|
||||
let claims_sum_p1 = &self.proof[i].claims_sum_p1;
|
||||
let claims_sum_p0 = &self.proof[i].claims_sum_p0;
|
||||
let claims_sum_q1 = &self.proof[i].claims_sum_q1;
|
||||
let claims_sum_q0 = &self.proof[i].claims_sum_q0;
|
||||
|
||||
let data = vec![
|
||||
claims_sum_p1.clone(),
|
||||
claims_sum_p0.clone(),
|
||||
claims_sum_q1.clone(),
|
||||
claims_sum_q0.clone(),
|
||||
];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
assert_eq!(rand.len(), rand_sumcheck.len());
|
||||
|
||||
let eq: E = (0..rand.len())
|
||||
.map(|i| {
|
||||
rand[i] * rand_sumcheck[i] + (E::ONE - rand[i]) * (E::ONE - rand_sumcheck[i])
|
||||
})
|
||||
.fold(E::ONE, |acc, term| acc * term);
|
||||
|
||||
let claim_expected: E = (*claims_sum_p1 * *claims_sum_q0
|
||||
+ *claims_sum_p0 * *claims_sum_q1
|
||||
+ r_two_sumchecks * *claims_sum_q1 * *claims_sum_q0)
|
||||
* eq;
|
||||
|
||||
assert_eq!(claim_expected, claim_last);
|
||||
|
||||
// Produce a random challenge to condense claims into a single claim
|
||||
let r_layer = transcript.draw().unwrap();
|
||||
|
||||
reduced_claim = (
|
||||
*claims_sum_p1 + r_layer * (*claims_sum_p0 - *claims_sum_p1),
|
||||
*claims_sum_q1 + r_layer * (*claims_sum_q0 - *claims_sum_q1),
|
||||
);
|
||||
|
||||
let mut ext = rand_sumcheck;
|
||||
ext.push(r_layer);
|
||||
rand = ext;
|
||||
}
|
||||
|
||||
// II) Verify the final GKR layer counting backwards.
|
||||
|
||||
// Absorb the claims
|
||||
let data = vec![reduced_claim.0, reduced_claim.1];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Squeeze challenge to reduce two sumchecks to one
|
||||
let r_sum_check = transcript.draw().unwrap();
|
||||
let reduced_claim = reduced_claim.0 + reduced_claim.1 * r_sum_check;
|
||||
|
||||
let gkr_final_composed_ml = gkr_composition_from_composition_polys(
|
||||
&composition_polys,
|
||||
r_sum_check,
|
||||
1 << (num_layers + 1),
|
||||
);
|
||||
|
||||
// TODO: refactor
|
||||
let composed_ml_oracle = {
|
||||
let left_num_oracle = MultiLinearOracle { id: 0 };
|
||||
let right_num_oracle = MultiLinearOracle { id: 1 };
|
||||
let left_denom_oracle = MultiLinearOracle { id: 2 };
|
||||
let right_denom_oracle = MultiLinearOracle { id: 3 };
|
||||
let eq_oracle = MultiLinearOracle { id: 4 };
|
||||
ComposedMultiLinearsOracle {
|
||||
composer: (Arc::new(gkr_final_composed_ml.clone())),
|
||||
multi_linears: vec![
|
||||
eq_oracle,
|
||||
left_num_oracle,
|
||||
right_num_oracle,
|
||||
left_denom_oracle,
|
||||
right_denom_oracle,
|
||||
],
|
||||
}
|
||||
};
|
||||
|
||||
let claim = Claim {
|
||||
sum_value: reduced_claim,
|
||||
polynomial: composed_ml_oracle.clone(),
|
||||
};
|
||||
|
||||
let final_eval_claim = sum_check_verify(&claim, final_layer_proof, transcript);
|
||||
|
||||
(final_eval_claim, rand)
|
||||
}
|
||||
}
|
||||
|
||||
fn sum_check_prover_gkr_before_last<
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
claim: (E, E),
|
||||
num_rounds: usize,
|
||||
ml_polys: (
|
||||
&mut MultiLinear<E>,
|
||||
&mut MultiLinear<E>,
|
||||
&mut MultiLinear<E>,
|
||||
&mut MultiLinear<E>,
|
||||
&mut MultiLinear<E>,
|
||||
),
|
||||
comb_func: impl Fn(&E, &E, &E, &E, &E, &E) -> E,
|
||||
transcript: &mut C,
|
||||
) -> (SumcheckInstanceProof<E>, Vec<E>, (E, E, E, E, E)) {
|
||||
// Absorb the claims
|
||||
let data = vec![claim.0, claim.1];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Squeeze challenge to reduce two sumchecks to one
|
||||
let r_sum_check = transcript.draw().unwrap();
|
||||
|
||||
let (poly_a, poly_b, poly_c, poly_d, poly_x) = ml_polys;
|
||||
|
||||
let mut e = claim.0 + claim.1 * r_sum_check;
|
||||
|
||||
let mut r: Vec<E> = Vec::new();
|
||||
let mut round_proofs: Vec<SumCheckRoundProof<E>> = Vec::new();
|
||||
|
||||
for _j in 0..num_rounds {
|
||||
let evals: (E, E, E) = {
|
||||
let mut eval_point_0 = E::ZERO;
|
||||
let mut eval_point_2 = E::ZERO;
|
||||
let mut eval_point_3 = E::ZERO;
|
||||
|
||||
let len = poly_a.len() / 2;
|
||||
for i in 0..len {
|
||||
// The interpolation formula for a linear function is:
|
||||
// z * A(x) + (1 - z) * A (y)
|
||||
// z * A(1) + (1 - z) * A(0)
|
||||
|
||||
// eval at z = 0: A(1)
|
||||
eval_point_0 += comb_func(
|
||||
&poly_a[i << 1],
|
||||
&poly_b[i << 1],
|
||||
&poly_c[i << 1],
|
||||
&poly_d[i << 1],
|
||||
&poly_x[i << 1],
|
||||
&r_sum_check,
|
||||
);
|
||||
|
||||
let poly_a_u = poly_a[(i << 1) + 1];
|
||||
let poly_a_v = poly_a[i << 1];
|
||||
let poly_b_u = poly_b[(i << 1) + 1];
|
||||
let poly_b_v = poly_b[i << 1];
|
||||
let poly_c_u = poly_c[(i << 1) + 1];
|
||||
let poly_c_v = poly_c[i << 1];
|
||||
let poly_d_u = poly_d[(i << 1) + 1];
|
||||
let poly_d_v = poly_d[i << 1];
|
||||
let poly_x_u = poly_x[(i << 1) + 1];
|
||||
let poly_x_v = poly_x[i << 1];
|
||||
|
||||
// eval at z = 2: 2 * A(1) - A(0)
|
||||
let poly_a_extrapolated_point = poly_a_u + poly_a_u - poly_a_v;
|
||||
let poly_b_extrapolated_point = poly_b_u + poly_b_u - poly_b_v;
|
||||
let poly_c_extrapolated_point = poly_c_u + poly_c_u - poly_c_v;
|
||||
let poly_d_extrapolated_point = poly_d_u + poly_d_u - poly_d_v;
|
||||
let poly_x_extrapolated_point = poly_x_u + poly_x_u - poly_x_v;
|
||||
eval_point_2 += comb_func(
|
||||
&poly_a_extrapolated_point,
|
||||
&poly_b_extrapolated_point,
|
||||
&poly_c_extrapolated_point,
|
||||
&poly_d_extrapolated_point,
|
||||
&poly_x_extrapolated_point,
|
||||
&r_sum_check,
|
||||
);
|
||||
|
||||
// eval at z = 3: 3 * A(1) - 2 * A(0) = 2 * A(1) - A(0) + A(1) - A(0)
|
||||
// hence we can compute the evaluation at z + 1 from that of z for z > 1
|
||||
let poly_a_extrapolated_point = poly_a_extrapolated_point + poly_a_u - poly_a_v;
|
||||
let poly_b_extrapolated_point = poly_b_extrapolated_point + poly_b_u - poly_b_v;
|
||||
let poly_c_extrapolated_point = poly_c_extrapolated_point + poly_c_u - poly_c_v;
|
||||
let poly_d_extrapolated_point = poly_d_extrapolated_point + poly_d_u - poly_d_v;
|
||||
let poly_x_extrapolated_point = poly_x_extrapolated_point + poly_x_u - poly_x_v;
|
||||
|
||||
eval_point_3 += comb_func(
|
||||
&poly_a_extrapolated_point,
|
||||
&poly_b_extrapolated_point,
|
||||
&poly_c_extrapolated_point,
|
||||
&poly_d_extrapolated_point,
|
||||
&poly_x_extrapolated_point,
|
||||
&r_sum_check,
|
||||
);
|
||||
}
|
||||
|
||||
(eval_point_0, eval_point_2, eval_point_3)
|
||||
};
|
||||
|
||||
let eval_0 = evals.0;
|
||||
let eval_2 = evals.1;
|
||||
let eval_3 = evals.2;
|
||||
|
||||
let evals = vec![e - eval_0, eval_2, eval_3];
|
||||
let compressed_poly = SumCheckRoundProof { poly_evals: evals };
|
||||
|
||||
// append the prover's message to the transcript
|
||||
transcript.reseed(H::hash_elements(&compressed_poly.poly_evals));
|
||||
|
||||
// derive the verifier's challenge for the next round
|
||||
let r_j = transcript.draw().unwrap();
|
||||
r.push(r_j);
|
||||
|
||||
poly_a.bind_assign(r_j);
|
||||
poly_b.bind_assign(r_j);
|
||||
poly_c.bind_assign(r_j);
|
||||
poly_d.bind_assign(r_j);
|
||||
|
||||
poly_x.bind_assign(r_j);
|
||||
|
||||
e = compressed_poly.evaluate(e, r_j);
|
||||
|
||||
round_proofs.push(compressed_poly);
|
||||
}
|
||||
let claims_sum = (poly_a[0], poly_b[0], poly_c[0], poly_d[0], poly_x[0]);
|
||||
|
||||
(SumcheckInstanceProof { round_proofs }, r, claims_sum)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod sum_circuit_tests {
|
||||
use crate::rand::RpoRandomCoin;
|
||||
|
||||
use super::*;
|
||||
use rand::Rng;
|
||||
use rand_utils::rand_value;
|
||||
use BaseElement as Felt;
|
||||
|
||||
/// The following tests the fractional sum circuit to check that \sum_{i = 0}^{log(m)-1} m / 2^{i} = 2 * (m - 1)
|
||||
#[test]
|
||||
fn sum_circuit_example() {
|
||||
let n = 4; // n := log(m)
|
||||
let mut inp: Vec<Felt> = (0..n).map(|_| Felt::from(1_u64 << n)).collect();
|
||||
let inp_: Vec<Felt> = (0..n).map(|i| Felt::from(1_u64 << i)).collect();
|
||||
inp.extend(inp_.iter());
|
||||
|
||||
let summation = MultiLinear::new(inp);
|
||||
|
||||
let expected_output = Felt::from(2 * ((1_u64 << n) - 1));
|
||||
|
||||
let mut circuit = FractionalSumCircuit::new(&summation);
|
||||
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
|
||||
let (proof, _evals, _) = CircuitProof::prove(&mut circuit, &mut transcript);
|
||||
|
||||
let (p1, q1) = circuit.evaluate(Felt::from(1_u8));
|
||||
let (p0, q0) = circuit.evaluate(Felt::from(0_u8));
|
||||
assert_eq!(expected_output, (p1 * q0 + q1 * p0) / (q1 * q0));
|
||||
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
let claims = vec![p0, p1, q0, q1];
|
||||
proof.verify(&claims, &mut transcript);
|
||||
}
|
||||
|
||||
// Test the fractional sum GKR in the context of LogUp.
|
||||
#[test]
|
||||
fn log_up() {
|
||||
use rand::distributions::Slice;
|
||||
|
||||
let n: usize = 16;
|
||||
let num_w: usize = 31; // This should be of the form 2^k - 1
|
||||
let rng = rand::thread_rng();
|
||||
|
||||
let t_table: Vec<u32> = (0..(1 << n)).collect();
|
||||
let mut m_table: Vec<u32> = (0..(1 << n)).map(|_| 0).collect();
|
||||
|
||||
let t_table_slice = Slice::new(&t_table).unwrap();
|
||||
|
||||
// Construct the witness columns. Uses sampling with replacement in order to have multiplicities
|
||||
// different from 1.
|
||||
let mut w_tables = Vec::new();
|
||||
for _ in 0..num_w {
|
||||
let wi_table: Vec<u32> =
|
||||
rng.clone().sample_iter(&t_table_slice).cloned().take(1 << n).collect();
|
||||
|
||||
// Construct the multiplicities
|
||||
wi_table.iter().for_each(|w| {
|
||||
m_table[*w as usize] += 1;
|
||||
});
|
||||
w_tables.push(wi_table)
|
||||
}
|
||||
|
||||
// The numerators
|
||||
let mut p: Vec<Felt> = m_table.iter().map(|m| Felt::from(*m as u32)).collect();
|
||||
p.extend((0..(num_w * (1 << n))).map(|_| Felt::from(1_u32)).collect::<Vec<Felt>>());
|
||||
|
||||
// Sample the challenge alpha to construct the denominators.
|
||||
let alpha = rand_value();
|
||||
|
||||
// Construct the denominators
|
||||
let mut q: Vec<Felt> = t_table.iter().map(|t| Felt::from(*t) - alpha).collect();
|
||||
for w_table in w_tables {
|
||||
q.extend(w_table.iter().map(|w| alpha - Felt::from(*w)).collect::<Vec<Felt>>());
|
||||
}
|
||||
|
||||
// Build the input to the fractional sum GKR circuit
|
||||
p.extend(q);
|
||||
let input = p;
|
||||
|
||||
let summation = MultiLinear::new(input);
|
||||
|
||||
let expected_output = Felt::from(0_u8);
|
||||
|
||||
let mut circuit = FractionalSumCircuit::new(&summation);
|
||||
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
|
||||
let (proof, _evals, _) = CircuitProof::prove(&mut circuit, &mut transcript);
|
||||
|
||||
let (p1, q1) = circuit.evaluate(Felt::from(1_u8));
|
||||
let (p0, q0) = circuit.evaluate(Felt::from(0_u8));
|
||||
assert_eq!(expected_output, (p1 * q0 + q1 * p0) / (q1 * q0)); // This check should be part of verification
|
||||
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
let claims = vec![p0, p1, q0, q1];
|
||||
proof.verify(&claims, &mut transcript);
|
||||
}
|
||||
}
|
||||
7
src/gkr/mod.rs
Normal file
7
src/gkr/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
#![allow(unused_imports)]
|
||||
#![allow(dead_code)]
|
||||
|
||||
mod sumcheck;
|
||||
mod multivariate;
|
||||
mod utils;
|
||||
mod circuit;
|
||||
34
src/gkr/multivariate/eq_poly.rs
Normal file
34
src/gkr/multivariate/eq_poly.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use super::FieldElement;
|
||||
|
||||
pub struct EqPolynomial<E> {
|
||||
r: Vec<E>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> EqPolynomial<E> {
|
||||
pub fn new(r: Vec<E>) -> Self {
|
||||
EqPolynomial { r }
|
||||
}
|
||||
|
||||
pub fn evaluate(&self, rho: &[E]) -> E {
|
||||
assert_eq!(self.r.len(), rho.len());
|
||||
(0..rho.len())
|
||||
.map(|i| self.r[i] * rho[i] + (E::ONE - self.r[i]) * (E::ONE - rho[i]))
|
||||
.fold(E::ONE, |acc, term| acc * term)
|
||||
}
|
||||
|
||||
pub fn evaluations(&self) -> Vec<E> {
|
||||
let nu = self.r.len();
|
||||
|
||||
let mut evals: Vec<E> = vec![E::ONE; 1 << nu];
|
||||
let mut size = 1;
|
||||
for j in 0..nu {
|
||||
size *= 2;
|
||||
for i in (0..size).rev().step_by(2) {
|
||||
let scalar = evals[i / 2];
|
||||
evals[i] = scalar * self.r[j];
|
||||
evals[i - 1] = scalar - evals[i];
|
||||
}
|
||||
}
|
||||
evals
|
||||
}
|
||||
}
|
||||
543
src/gkr/multivariate/mod.rs
Normal file
543
src/gkr/multivariate/mod.rs
Normal file
@@ -0,0 +1,543 @@
|
||||
use core::ops::Index;
|
||||
|
||||
use alloc::sync::Arc;
|
||||
use winter_math::{fields::f64::BaseElement, log2, FieldElement, StarkField};
|
||||
|
||||
mod eq_poly;
|
||||
pub use eq_poly::EqPolynomial;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MultiLinear<E: FieldElement> {
|
||||
pub num_variables: usize,
|
||||
pub evaluations: Vec<E>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> MultiLinear<E> {
|
||||
pub fn new(values: Vec<E>) -> Self {
|
||||
Self {
|
||||
num_variables: log2(values.len()) as usize,
|
||||
evaluations: values,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_values(values: &[E]) -> Self {
|
||||
Self {
|
||||
num_variables: log2(values.len()) as usize,
|
||||
evaluations: values.to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_variables(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
pub fn evaluations(&self) -> &[E] {
|
||||
&self.evaluations
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.evaluations.len()
|
||||
}
|
||||
|
||||
pub fn evaluate(&self, query: &[E]) -> E {
|
||||
let tensored_query = tensorize(query);
|
||||
inner_product(&self.evaluations, &tensored_query)
|
||||
}
|
||||
|
||||
pub fn bind(&self, round_challenge: E) -> Self {
|
||||
let mut result = vec![E::ZERO; 1 << (self.num_variables() - 1)];
|
||||
for i in 0..(1 << (self.num_variables() - 1)) {
|
||||
result[i] = self.evaluations[i << 1]
|
||||
+ round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]);
|
||||
}
|
||||
Self::from_values(&result)
|
||||
}
|
||||
|
||||
pub fn bind_assign(&mut self, round_challenge: E) {
|
||||
let mut result = vec![E::ZERO; 1 << (self.num_variables() - 1)];
|
||||
for i in 0..(1 << (self.num_variables() - 1)) {
|
||||
result[i] = self.evaluations[i << 1]
|
||||
+ round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]);
|
||||
}
|
||||
*self = Self::from_values(&result);
|
||||
}
|
||||
|
||||
pub fn split(&self, at: usize) -> (Self, Self) {
|
||||
assert!(at < self.len());
|
||||
(
|
||||
Self::new(self.evaluations[..at].to_vec()),
|
||||
Self::new(self.evaluations[at..2 * at].to_vec()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn extend(&mut self, other: &MultiLinear<E>) {
|
||||
let other_vec = other.evaluations.to_vec();
|
||||
assert_eq!(other_vec.len(), self.len());
|
||||
self.evaluations.extend(other_vec);
|
||||
self.num_variables += 1;
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: FieldElement> Index<usize> for MultiLinear<E> {
|
||||
type Output = E;
|
||||
|
||||
fn index(&self, index: usize) -> &E {
|
||||
&(self.evaluations[index])
|
||||
}
|
||||
}
|
||||
|
||||
/// A multi-variate polynomial for composing individual multi-linear polynomials
|
||||
pub trait CompositionPolynomial<E: FieldElement>: Sync + Send {
|
||||
/// The number of variables when interpreted as a multi-variate polynomial.
|
||||
fn num_variables(&self) -> usize;
|
||||
|
||||
/// Maximum degree in all variables.
|
||||
fn max_degree(&self) -> usize;
|
||||
|
||||
/// Given a query, of length equal the number of variables, evaluate [Self] at this query.
|
||||
fn evaluate(&self, query: &[E]) -> E;
|
||||
}
|
||||
|
||||
pub struct ComposedMultiLinears<E: FieldElement> {
|
||||
pub composer: Arc<dyn CompositionPolynomial<E>>,
|
||||
pub multi_linears: Vec<MultiLinear<E>>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> ComposedMultiLinears<E> {
|
||||
pub fn new(
|
||||
composer: Arc<dyn CompositionPolynomial<E>>,
|
||||
multi_linears: Vec<MultiLinear<E>>,
|
||||
) -> Self {
|
||||
Self { composer, multi_linears }
|
||||
}
|
||||
|
||||
pub fn num_ml(&self) -> usize {
|
||||
self.multi_linears.len()
|
||||
}
|
||||
|
||||
pub fn num_variables(&self) -> usize {
|
||||
self.composer.num_variables()
|
||||
}
|
||||
|
||||
pub fn num_variables_ml(&self) -> usize {
|
||||
self.multi_linears[0].num_variables
|
||||
}
|
||||
|
||||
pub fn degree(&self) -> usize {
|
||||
self.composer.max_degree()
|
||||
}
|
||||
|
||||
pub fn bind(&self, round_challenge: E) -> ComposedMultiLinears<E> {
|
||||
let result: Vec<MultiLinear<E>> =
|
||||
self.multi_linears.iter().map(|f| f.bind(round_challenge)).collect();
|
||||
|
||||
Self {
|
||||
composer: self.composer.clone(),
|
||||
multi_linears: result,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ComposedMultiLinearsOracle<E: FieldElement> {
|
||||
pub composer: Arc<dyn CompositionPolynomial<E>>,
|
||||
pub multi_linears: Vec<MultiLinearOracle>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultiLinearOracle {
|
||||
pub id: usize,
|
||||
}
|
||||
|
||||
// Composition polynomials
|
||||
|
||||
pub struct IdentityComposition {
|
||||
num_variables: usize,
|
||||
}
|
||||
|
||||
impl IdentityComposition {
|
||||
pub fn new() -> Self {
|
||||
Self { num_variables: 1 }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for IdentityComposition
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
assert_eq!(query.len(), 1);
|
||||
query[0]
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ProjectionComposition {
|
||||
coordinate: usize,
|
||||
}
|
||||
|
||||
impl ProjectionComposition {
|
||||
pub fn new(coordinate: usize) -> Self {
|
||||
Self { coordinate }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for ProjectionComposition
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
query[self.coordinate]
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LogUpDenominatorTableComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
projection_coordinate: usize,
|
||||
alpha: E,
|
||||
}
|
||||
|
||||
impl<E> LogUpDenominatorTableComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
pub fn new(projection_coordinate: usize, alpha: E) -> Self {
|
||||
Self { projection_coordinate, alpha }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for LogUpDenominatorTableComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
query[self.projection_coordinate] + self.alpha
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LogUpDenominatorWitnessComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
projection_coordinate: usize,
|
||||
alpha: E,
|
||||
}
|
||||
|
||||
impl<E> LogUpDenominatorWitnessComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
pub fn new(projection_coordinate: usize, alpha: E) -> Self {
|
||||
Self { projection_coordinate, alpha }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for LogUpDenominatorWitnessComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
-(query[self.projection_coordinate] + self.alpha)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ProductComposition {
|
||||
num_variables: usize,
|
||||
}
|
||||
|
||||
impl ProductComposition {
|
||||
pub fn new(num_variables: usize) -> Self {
|
||||
Self { num_variables }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for ProductComposition
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
query.iter().fold(E::ONE, |acc, x| acc * *x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SumComposition {
|
||||
num_variables: usize,
|
||||
}
|
||||
|
||||
impl SumComposition {
|
||||
pub fn new(num_variables: usize) -> Self {
|
||||
Self { num_variables }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for SumComposition
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
query.iter().fold(E::ZERO, |acc, x| acc + *x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GkrCompositionVanilla<E: 'static>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
num_variables_ml: usize,
|
||||
num_variables_merge: usize,
|
||||
combining_randomness: E,
|
||||
gkr_randomness: Vec<E>,
|
||||
}
|
||||
|
||||
impl<E> GkrCompositionVanilla<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
pub fn new(
|
||||
num_variables_ml: usize,
|
||||
num_variables_merge: usize,
|
||||
combining_randomness: E,
|
||||
gkr_randomness: Vec<E>,
|
||||
) -> Self {
|
||||
Self {
|
||||
num_variables_ml,
|
||||
num_variables_merge,
|
||||
combining_randomness,
|
||||
gkr_randomness,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for GkrCompositionVanilla<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
self.num_variables_ml // + TODO
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
self.num_variables_ml //TODO
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
let eval_left_numerator = query[0];
|
||||
let eval_right_numerator = query[1];
|
||||
let eval_left_denominator = query[2];
|
||||
let eval_right_denominator = query[3];
|
||||
let eq_eval = query[4];
|
||||
|
||||
eq_eval
|
||||
* ((eval_left_numerator * eval_right_denominator
|
||||
+ eval_right_numerator * eval_left_denominator)
|
||||
+ eval_left_denominator * eval_right_denominator * self.combining_randomness)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct GkrComposition<E>
|
||||
where
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
{
|
||||
pub num_variables_ml: usize,
|
||||
pub combining_randomness: E,
|
||||
|
||||
eq_composer: Arc<dyn CompositionPolynomial<E>>,
|
||||
right_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
left_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
right_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
left_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
}
|
||||
|
||||
impl<E> GkrComposition<E>
|
||||
where
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
{
|
||||
pub fn new(
|
||||
num_variables_ml: usize,
|
||||
combining_randomness: E,
|
||||
eq_composer: Arc<dyn CompositionPolynomial<E>>,
|
||||
right_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
left_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
right_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
left_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
num_variables_ml,
|
||||
combining_randomness,
|
||||
eq_composer,
|
||||
right_numerator_composer,
|
||||
left_numerator_composer,
|
||||
right_denominator_composer,
|
||||
left_denominator_composer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for GkrComposition<E>
|
||||
where
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
self.num_variables_ml // + TODO
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
3 // TODO
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
let eval_right_numerator = self.right_numerator_composer[0].evaluate(query);
|
||||
let eval_left_numerator = self.left_numerator_composer[0].evaluate(query);
|
||||
let eval_right_denominator = self.right_denominator_composer[0].evaluate(query);
|
||||
let eval_left_denominator = self.left_denominator_composer[0].evaluate(query);
|
||||
let eq_eval = self.eq_composer.evaluate(query);
|
||||
|
||||
let res = eq_eval
|
||||
* ((eval_left_numerator * eval_right_denominator
|
||||
+ eval_right_numerator * eval_left_denominator)
|
||||
+ eval_left_denominator * eval_right_denominator * self.combining_randomness);
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a composed ML polynomial for the initial GKR layer from a vector of composition
|
||||
/// polynomials.
|
||||
/// The composition polynomials are divided into LeftNumerator, RightNumerator, LeftDenominator
|
||||
/// and RightDenominator.
|
||||
/// TODO: Generalize this to the case where each numerator/denominator contains more than one
|
||||
/// composition polynomial i.e., a merged composed ML polynomial.
|
||||
pub fn gkr_composition_from_composition_polys<
|
||||
E: FieldElement<BaseField = BaseElement> + 'static,
|
||||
>(
|
||||
composition_polys: &Vec<Vec<Arc<dyn CompositionPolynomial<E>>>>,
|
||||
combining_randomness: E,
|
||||
num_variables: usize,
|
||||
) -> GkrComposition<E> {
|
||||
let eq_composer = Arc::new(ProjectionComposition::new(4));
|
||||
let left_numerator = composition_polys[0].to_owned();
|
||||
let right_numerator = composition_polys[1].to_owned();
|
||||
let left_denominator = composition_polys[2].to_owned();
|
||||
let right_denominator = composition_polys[3].to_owned();
|
||||
GkrComposition::new(
|
||||
num_variables,
|
||||
combining_randomness,
|
||||
eq_composer,
|
||||
right_numerator,
|
||||
left_numerator,
|
||||
right_denominator,
|
||||
left_denominator,
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates a plain oracle for the sum-check protocol except the final one.
|
||||
pub fn gen_plain_gkr_oracle<E: FieldElement<BaseField = BaseElement> + 'static>(
|
||||
num_rounds: usize,
|
||||
r_sum_check: E,
|
||||
) -> ComposedMultiLinearsOracle<E> {
|
||||
let gkr_composer = Arc::new(GkrCompositionVanilla::new(num_rounds, 0, r_sum_check, vec![]));
|
||||
|
||||
let ml_oracles = vec![
|
||||
MultiLinearOracle { id: 0 },
|
||||
MultiLinearOracle { id: 1 },
|
||||
MultiLinearOracle { id: 2 },
|
||||
MultiLinearOracle { id: 3 },
|
||||
MultiLinearOracle { id: 4 },
|
||||
];
|
||||
|
||||
let oracle = ComposedMultiLinearsOracle {
|
||||
composer: gkr_composer,
|
||||
multi_linears: ml_oracles,
|
||||
};
|
||||
oracle
|
||||
}
|
||||
|
||||
fn to_index<E: FieldElement<BaseField = BaseElement>>(index: &[E]) -> usize {
|
||||
let res = index.iter().fold(E::ZERO, |acc, term| acc * E::ONE.double() + (*term));
|
||||
let res = res.base_element(0);
|
||||
res.as_int() as usize
|
||||
}
|
||||
|
||||
fn inner_product<E: FieldElement>(evaluations: &[E], tensored_query: &[E]) -> E {
|
||||
assert_eq!(evaluations.len(), tensored_query.len());
|
||||
evaluations
|
||||
.iter()
|
||||
.zip(tensored_query.iter())
|
||||
.fold(E::ZERO, |acc, (x_i, y_i)| acc + *x_i * *y_i)
|
||||
}
|
||||
|
||||
pub fn tensorize<E: FieldElement>(query: &[E]) -> Vec<E> {
|
||||
let nu = query.len();
|
||||
let n = 1 << nu;
|
||||
|
||||
(0..n).map(|i| lagrange_basis_eval(query, i)).collect()
|
||||
}
|
||||
|
||||
fn lagrange_basis_eval<E: FieldElement>(query: &[E], i: usize) -> E {
|
||||
query
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(j, x_j)| if i & (1 << j) == 0 { E::ONE - *x_j } else { *x_j })
|
||||
.fold(E::ONE, |acc, v| acc * v)
|
||||
}
|
||||
|
||||
pub fn compute_claim<E: FieldElement>(poly: &ComposedMultiLinears<E>) -> E {
|
||||
let cube_size = 1 << poly.num_variables_ml();
|
||||
let mut res = E::ZERO;
|
||||
|
||||
for i in 0..cube_size {
|
||||
let eval_point: Vec<E> =
|
||||
poly.multi_linears.iter().map(|poly| poly.evaluations[i]).collect();
|
||||
res += poly.composer.evaluate(&eval_point);
|
||||
}
|
||||
res
|
||||
}
|
||||
108
src/gkr/sumcheck/mod.rs
Normal file
108
src/gkr/sumcheck/mod.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use super::{
|
||||
multivariate::{ComposedMultiLinears, ComposedMultiLinearsOracle},
|
||||
utils::{barycentric_weights, evaluate_barycentric},
|
||||
};
|
||||
use winter_math::FieldElement;
|
||||
|
||||
mod prover;
|
||||
pub use prover::sum_check_prove;
|
||||
mod verifier;
|
||||
pub use verifier::{sum_check_verify, sum_check_verify_and_reduce};
|
||||
mod tests;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RoundProof<E> {
|
||||
pub poly_evals: Vec<E>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> RoundProof<E> {
|
||||
pub fn to_evals(&self, claim: E) -> Vec<E> {
|
||||
let mut result = vec![];
|
||||
|
||||
// s(0) + s(1) = claim
|
||||
let c0 = claim - self.poly_evals[0];
|
||||
|
||||
result.push(c0);
|
||||
result.extend_from_slice(&self.poly_evals);
|
||||
result
|
||||
}
|
||||
|
||||
// TODO: refactor once we move to coefficient form
|
||||
pub(crate) fn evaluate(&self, claim: E, r: E) -> E {
|
||||
let poly_evals = self.to_evals(claim);
|
||||
|
||||
let points: Vec<E> = (0..poly_evals.len()).map(|i| E::from(i as u8)).collect();
|
||||
let evalss: Vec<(E, E)> =
|
||||
points.iter().zip(poly_evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||
let weights = barycentric_weights(&evalss);
|
||||
let new_claim = evaluate_barycentric(&evalss, r, &weights);
|
||||
new_claim
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PartialProof<E> {
|
||||
pub round_proofs: Vec<RoundProof<E>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FinalEvaluationClaim<E: FieldElement> {
|
||||
pub evaluation_point: Vec<E>,
|
||||
pub claimed_evaluation: E,
|
||||
pub polynomial: ComposedMultiLinearsOracle<E>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FullProof<E: FieldElement> {
|
||||
pub sum_check_proof: PartialProof<E>,
|
||||
pub final_evaluation_claim: FinalEvaluationClaim<E>,
|
||||
}
|
||||
|
||||
pub struct Claim<E: FieldElement> {
|
||||
pub sum_value: E,
|
||||
pub polynomial: ComposedMultiLinearsOracle<E>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RoundClaim<E: FieldElement> {
|
||||
pub partial_eval_point: Vec<E>,
|
||||
pub current_claim: E,
|
||||
}
|
||||
|
||||
pub struct RoundOutput<E: FieldElement> {
|
||||
proof: PartialProof<E>,
|
||||
witness: Witness<E>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> From<Claim<E>> for RoundClaim<E> {
|
||||
fn from(value: Claim<E>) -> Self {
|
||||
Self {
|
||||
partial_eval_point: vec![],
|
||||
current_claim: value.sum_value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Witness<E: FieldElement> {
|
||||
pub(crate) polynomial: ComposedMultiLinears<E>,
|
||||
}
|
||||
|
||||
pub fn reduce_claim<E: FieldElement>(
|
||||
current_poly: RoundProof<E>,
|
||||
current_round_claim: RoundClaim<E>,
|
||||
round_challenge: E,
|
||||
) -> RoundClaim<E> {
|
||||
let poly_evals = current_poly.to_evals(current_round_claim.current_claim);
|
||||
let points: Vec<E> = (0..poly_evals.len()).map(|i| E::from(i as u8)).collect();
|
||||
let evalss: Vec<(E, E)> = points.iter().zip(poly_evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||
let weights = barycentric_weights(&evalss);
|
||||
let new_claim = evaluate_barycentric(&evalss, round_challenge, &weights);
|
||||
|
||||
let mut new_partial_eval_point = current_round_claim.partial_eval_point;
|
||||
new_partial_eval_point.push(round_challenge);
|
||||
|
||||
RoundClaim {
|
||||
partial_eval_point: new_partial_eval_point,
|
||||
current_claim: new_claim,
|
||||
}
|
||||
}
|
||||
109
src/gkr/sumcheck/prover.rs
Normal file
109
src/gkr/sumcheck/prover.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
use super::{Claim, FullProof, RoundProof, Witness};
|
||||
use crate::gkr::{
|
||||
multivariate::{ComposedMultiLinears, ComposedMultiLinearsOracle},
|
||||
sumcheck::{reduce_claim, FinalEvaluationClaim, PartialProof, RoundClaim, RoundOutput},
|
||||
};
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
||||
use winter_crypto::{ElementHasher, RandomCoin};
|
||||
use winter_math::{fields::f64::BaseElement, FieldElement};
|
||||
|
||||
pub fn sum_check_prove<
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
claim: &Claim<E>,
|
||||
oracle: ComposedMultiLinearsOracle<E>,
|
||||
witness: Witness<E>,
|
||||
coin: &mut C,
|
||||
) -> FullProof<E> {
|
||||
// Setup first round
|
||||
let mut prev_claim = RoundClaim {
|
||||
partial_eval_point: vec![],
|
||||
current_claim: claim.sum_value.clone(),
|
||||
};
|
||||
let prev_proof = PartialProof { round_proofs: vec![] };
|
||||
let num_vars = witness.polynomial.num_variables_ml();
|
||||
let prev_output = RoundOutput { proof: prev_proof, witness };
|
||||
|
||||
let mut output = sumcheck_round(prev_output);
|
||||
let poly_evals = &output.proof.round_proofs[0].poly_evals;
|
||||
coin.reseed(H::hash_elements(&poly_evals));
|
||||
|
||||
for i in 1..num_vars {
|
||||
let round_challenge = coin.draw().unwrap();
|
||||
let new_claim = reduce_claim(
|
||||
output.proof.round_proofs.last().unwrap().clone(),
|
||||
prev_claim,
|
||||
round_challenge,
|
||||
);
|
||||
output.witness.polynomial = output.witness.polynomial.bind(round_challenge);
|
||||
|
||||
output = sumcheck_round(output);
|
||||
prev_claim = new_claim;
|
||||
|
||||
let poly_evals = &output.proof.round_proofs[i].poly_evals;
|
||||
coin.reseed(H::hash_elements(&poly_evals));
|
||||
}
|
||||
|
||||
let round_challenge = coin.draw().unwrap();
|
||||
let RoundClaim { partial_eval_point, current_claim } = reduce_claim(
|
||||
output.proof.round_proofs.last().unwrap().clone(),
|
||||
prev_claim,
|
||||
round_challenge,
|
||||
);
|
||||
let final_eval_claim = FinalEvaluationClaim {
|
||||
evaluation_point: partial_eval_point,
|
||||
claimed_evaluation: current_claim,
|
||||
polynomial: oracle,
|
||||
};
|
||||
|
||||
FullProof {
|
||||
sum_check_proof: output.proof,
|
||||
final_evaluation_claim: final_eval_claim,
|
||||
}
|
||||
}
|
||||
|
||||
fn sumcheck_round<E: FieldElement>(prev_proof: RoundOutput<E>) -> RoundOutput<E> {
|
||||
let RoundOutput { mut proof, witness } = prev_proof;
|
||||
|
||||
let polynomial = witness.polynomial;
|
||||
let num_ml = polynomial.num_ml();
|
||||
let num_vars = polynomial.num_variables_ml();
|
||||
let num_rounds = num_vars - 1;
|
||||
|
||||
let mut evals_zero = vec![E::ZERO; num_ml];
|
||||
let mut evals_one = vec![E::ZERO; num_ml];
|
||||
let mut deltas = vec![E::ZERO; num_ml];
|
||||
let mut evals_x = vec![E::ZERO; num_ml];
|
||||
|
||||
let total_evals = (0..1 << num_rounds).into_iter().map(|i| {
|
||||
for (j, ml) in polynomial.multi_linears.iter().enumerate() {
|
||||
evals_zero[j] = ml.evaluations[(i << 1) as usize];
|
||||
evals_one[j] = ml.evaluations[(i << 1) + 1];
|
||||
}
|
||||
let mut total_evals = vec![E::ZERO; polynomial.degree()];
|
||||
total_evals[0] = polynomial.composer.evaluate(&evals_one);
|
||||
evals_zero
|
||||
.iter()
|
||||
.zip(evals_one.iter().zip(deltas.iter_mut().zip(evals_x.iter_mut())))
|
||||
.for_each(|(a0, (a1, (delta, evx)))| {
|
||||
*delta = *a1 - *a0;
|
||||
*evx = *a1;
|
||||
});
|
||||
total_evals.iter_mut().skip(1).for_each(|e| {
|
||||
evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| {
|
||||
*evx += *delta;
|
||||
});
|
||||
*e = polynomial.composer.evaluate(&evals_x);
|
||||
});
|
||||
total_evals
|
||||
});
|
||||
let evaluations = total_evals.fold(vec![E::ZERO; polynomial.degree()], |mut acc, evals| {
|
||||
acc.iter_mut().zip(evals.iter()).for_each(|(a, ev)| *a += *ev);
|
||||
acc
|
||||
});
|
||||
let proof_update = RoundProof { poly_evals: evaluations };
|
||||
proof.round_proofs.push(proof_update);
|
||||
RoundOutput { proof, witness: Witness { polynomial } }
|
||||
}
|
||||
199
src/gkr/sumcheck/tests.rs
Normal file
199
src/gkr/sumcheck/tests.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
use alloc::sync::Arc;
|
||||
use rand::{distributions::Uniform, SeedableRng};
|
||||
use winter_crypto::RandomCoin;
|
||||
use winter_math::{fields::f64::BaseElement, FieldElement};
|
||||
|
||||
use crate::{
|
||||
gkr::{
|
||||
circuit::{CircuitProof, FractionalSumCircuit},
|
||||
multivariate::{
|
||||
compute_claim, gkr_composition_from_composition_polys, ComposedMultiLinears,
|
||||
ComposedMultiLinearsOracle, CompositionPolynomial, EqPolynomial, GkrComposition,
|
||||
GkrCompositionVanilla, LogUpDenominatorTableComposition,
|
||||
LogUpDenominatorWitnessComposition, MultiLinear, MultiLinearOracle,
|
||||
ProjectionComposition, SumComposition,
|
||||
},
|
||||
sumcheck::{
|
||||
prover::sum_check_prove, verifier::sum_check_verify, Claim, FinalEvaluationClaim,
|
||||
FullProof, Witness,
|
||||
},
|
||||
},
|
||||
hash::rpo::Rpo256,
|
||||
rand::RpoRandomCoin,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn gkr_workflow() {
|
||||
// generate the data witness for the LogUp argument
|
||||
let mut mls = generate_logup_witness::<BaseElement>(3);
|
||||
|
||||
// the is sampled after receiving the main trace commitment
|
||||
let alpha = rand_utils::rand_value();
|
||||
|
||||
// the composition polynomials defining the numerators/denominators
|
||||
let composition_polys: Vec<Vec<Arc<dyn CompositionPolynomial<BaseElement>>>> = vec![
|
||||
// left num
|
||||
vec![Arc::new(ProjectionComposition::new(0))],
|
||||
// right num
|
||||
vec![Arc::new(ProjectionComposition::new(1))],
|
||||
// left den
|
||||
vec![Arc::new(LogUpDenominatorTableComposition::new(2, alpha))],
|
||||
// right den
|
||||
vec![Arc::new(LogUpDenominatorWitnessComposition::new(3, alpha))],
|
||||
];
|
||||
|
||||
// run the GKR prover to obtain:
|
||||
// 1. The fractional sum circuit output.
|
||||
// 2. GKR proofs up to the last circuit layer counting backwards.
|
||||
// 3. GKR proof (i.e., a sum-check proof) for the last circuit layer counting backwards.
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
let (circuit_outputs, gkr_before_last_proof, final_layer_proof) =
|
||||
CircuitProof::prove_virtual_bus(composition_polys.clone(), &mut mls, &mut transcript);
|
||||
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
|
||||
// run the GKR verifier to obtain:
|
||||
// 1. A final evaluation claim.
|
||||
// 2. Randomness defining the Lagrange kernel in the final sum-check protocol. Note that this
|
||||
// Lagrange kernel is different from the one used by the STARK (outer) prover to open the MLs
|
||||
// at the evaluation point.
|
||||
let (final_eval_claim, gkr_lagrange_kernel_rand) = gkr_before_last_proof.verify_virtual_bus(
|
||||
composition_polys.clone(),
|
||||
final_layer_proof,
|
||||
&circuit_outputs,
|
||||
&mut transcript,
|
||||
);
|
||||
|
||||
// the final verification step is composed of:
|
||||
// 1. Querying the oracles for the openings at the evaluation point. This will be done by the
|
||||
// (outer) STARK prover using:
|
||||
// a. The Lagrange kernel (auxiliary) column at the evaluation point.
|
||||
// b. An extra (auxiliary) column to compute an inner product between two vectors. The first
|
||||
// being the Lagrange kernel and the second being (\sum_{j=0}^3 mls[j][i] * \lambda_i)_{i\in\{0,..,n\}}
|
||||
// 2. Evaluating the composition polynomial at the previous openings and checking equality with
|
||||
// the claimed evaluation.
|
||||
|
||||
// 1. Querying the oracles
|
||||
|
||||
let FinalEvaluationClaim {
|
||||
evaluation_point,
|
||||
claimed_evaluation,
|
||||
polynomial,
|
||||
} = final_eval_claim;
|
||||
|
||||
// The evaluation of the EQ polynomial can be done by the verifier directly
|
||||
let eq = (0..gkr_lagrange_kernel_rand.len())
|
||||
.map(|i| {
|
||||
gkr_lagrange_kernel_rand[i] * evaluation_point[i]
|
||||
+ (BaseElement::ONE - gkr_lagrange_kernel_rand[i])
|
||||
* (BaseElement::ONE - evaluation_point[i])
|
||||
})
|
||||
.fold(BaseElement::ONE, |acc, term| acc * term);
|
||||
|
||||
// These are the queries to the oracles.
|
||||
// They should be provided by the prover non-deterministically
|
||||
let left_num_eval = mls[0].evaluate(&evaluation_point);
|
||||
let right_num_eval = mls[1].evaluate(&evaluation_point);
|
||||
let left_den_eval = mls[2].evaluate(&evaluation_point);
|
||||
let right_den_eval = mls[3].evaluate(&evaluation_point);
|
||||
|
||||
// The verifier absorbs the claimed openings and generates batching randomness lambda
|
||||
let mut query = vec![left_num_eval, right_num_eval, left_den_eval, right_den_eval];
|
||||
transcript.reseed(Rpo256::hash_elements(&query));
|
||||
let lambdas: Vec<BaseElement> = vec![
|
||||
transcript.draw().unwrap(),
|
||||
transcript.draw().unwrap(),
|
||||
transcript.draw().unwrap(),
|
||||
];
|
||||
let batched_query =
|
||||
query[0] + query[1] * lambdas[0] + query[2] * lambdas[1] + query[3] * lambdas[2];
|
||||
|
||||
// The prover generates the Lagrange kernel as an auxiliary column
|
||||
let mut rev_evaluation_point = evaluation_point;
|
||||
rev_evaluation_point.reverse();
|
||||
let lagrange_kernel = EqPolynomial::new(rev_evaluation_point).evaluations();
|
||||
|
||||
// The prover generates the additional auxiliary column for the inner product
|
||||
let tmp_col: Vec<BaseElement> = (0..mls[0].len())
|
||||
.map(|i| {
|
||||
mls[0][i] + mls[1][i] * lambdas[0] + mls[2][i] * lambdas[1] + mls[3][i] * lambdas[2]
|
||||
})
|
||||
.collect();
|
||||
let mut running_sum_col = vec![BaseElement::ZERO; tmp_col.len() + 1];
|
||||
running_sum_col[0] = BaseElement::ZERO;
|
||||
for i in 1..(tmp_col.len() + 1) {
|
||||
running_sum_col[i] = running_sum_col[i - 1] + tmp_col[i - 1] * lagrange_kernel[i - 1];
|
||||
}
|
||||
|
||||
// Boundary constraint to check correctness of openings
|
||||
assert_eq!(batched_query, *running_sum_col.last().unwrap());
|
||||
|
||||
// 2) Final evaluation and check
|
||||
query.push(eq);
|
||||
let verifier_computed = polynomial.composer.evaluate(&query);
|
||||
|
||||
assert_eq!(verifier_computed, claimed_evaluation);
|
||||
}
|
||||
|
||||
pub fn generate_logup_witness<E: FieldElement>(trace_len: usize) -> Vec<MultiLinear<E>> {
|
||||
let num_variables_ml = trace_len;
|
||||
let num_evaluations = 1 << num_variables_ml;
|
||||
let num_witnesses = 1;
|
||||
let (p, q) = generate_logup_data::<E>(num_variables_ml, num_witnesses);
|
||||
let numerators: Vec<Vec<E>> = p.chunks(num_evaluations).map(|x| x.into()).collect();
|
||||
let denominators: Vec<Vec<E>> = q.chunks(num_evaluations).map(|x| x.into()).collect();
|
||||
|
||||
let mut mls = vec![];
|
||||
for i in 0..2 {
|
||||
let ml = MultiLinear::from_values(&numerators[i]);
|
||||
mls.push(ml);
|
||||
}
|
||||
for i in 0..2 {
|
||||
let ml = MultiLinear::from_values(&denominators[i]);
|
||||
mls.push(ml);
|
||||
}
|
||||
mls
|
||||
}
|
||||
|
||||
pub fn generate_logup_data<E: FieldElement>(
|
||||
trace_len: usize,
|
||||
num_witnesses: usize,
|
||||
) -> (Vec<E>, Vec<E>) {
|
||||
use rand::distributions::Slice;
|
||||
use rand::Rng;
|
||||
let n: usize = trace_len;
|
||||
let num_w: usize = num_witnesses; // This should be of the form 2^k - 1
|
||||
let rng = rand::rngs::StdRng::seed_from_u64(0);
|
||||
|
||||
let t_table: Vec<u32> = (0..(1 << n)).collect();
|
||||
let mut m_table: Vec<u32> = (0..(1 << n)).map(|_| 0).collect();
|
||||
|
||||
let t_table_slice = Slice::new(&t_table).unwrap();
|
||||
|
||||
// Construct the witness columns. Uses sampling with replacement in order to have multiplicities
|
||||
// different from 1.
|
||||
let mut w_tables = Vec::new();
|
||||
for _ in 0..num_w {
|
||||
let wi_table: Vec<u32> =
|
||||
rng.clone().sample_iter(&t_table_slice).cloned().take(1 << n).collect();
|
||||
|
||||
// Construct the multiplicities
|
||||
wi_table.iter().for_each(|w| {
|
||||
m_table[*w as usize] += 1;
|
||||
});
|
||||
w_tables.push(wi_table)
|
||||
}
|
||||
|
||||
// The numerators
|
||||
let mut p: Vec<E> = m_table.iter().map(|m| E::from(*m as u32)).collect();
|
||||
p.extend((0..(num_w * (1 << n))).map(|_| E::from(1_u32)).collect::<Vec<E>>());
|
||||
|
||||
// Construct the denominators
|
||||
let mut q: Vec<E> = t_table.iter().map(|t| E::from(*t)).collect();
|
||||
for w_table in w_tables {
|
||||
q.extend(w_table.iter().map(|w| E::from(*w)).collect::<Vec<E>>());
|
||||
}
|
||||
(p, q)
|
||||
}
|
||||
71
src/gkr/sumcheck/verifier.rs
Normal file
71
src/gkr/sumcheck/verifier.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use winter_crypto::{ElementHasher, RandomCoin};
|
||||
use winter_math::{fields::f64::BaseElement, FieldElement};
|
||||
|
||||
use crate::gkr::utils::{barycentric_weights, evaluate_barycentric};
|
||||
|
||||
use super::{Claim, FinalEvaluationClaim, FullProof, PartialProof};
|
||||
|
||||
pub fn sum_check_verify_and_reduce<
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
claim: &Claim<E>,
|
||||
proofs: PartialProof<E>,
|
||||
coin: &mut C,
|
||||
) -> (E, Vec<E>) {
|
||||
let degree = 3;
|
||||
let points: Vec<E> = (0..degree + 1).map(|x| E::from(x as u8)).collect();
|
||||
let mut sum_value = claim.sum_value.clone();
|
||||
let mut randomness = vec![];
|
||||
|
||||
for proof in proofs.round_proofs {
|
||||
let partial_evals = proof.poly_evals.clone();
|
||||
coin.reseed(H::hash_elements(&partial_evals));
|
||||
|
||||
// get r
|
||||
let r: E = coin.draw().unwrap();
|
||||
randomness.push(r);
|
||||
let evals = proof.to_evals(sum_value);
|
||||
|
||||
let point_evals: Vec<_> = points.iter().zip(evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||
let weights = barycentric_weights(&point_evals);
|
||||
sum_value = evaluate_barycentric(&point_evals, r, &weights);
|
||||
}
|
||||
(sum_value, randomness)
|
||||
}
|
||||
|
||||
pub fn sum_check_verify<
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
claim: &Claim<E>,
|
||||
proofs: FullProof<E>,
|
||||
coin: &mut C,
|
||||
) -> FinalEvaluationClaim<E> {
|
||||
let FullProof {
|
||||
sum_check_proof: proofs,
|
||||
final_evaluation_claim,
|
||||
} = proofs;
|
||||
let Claim { mut sum_value, polynomial } = claim;
|
||||
let degree = polynomial.composer.max_degree();
|
||||
let points: Vec<E> = (0..degree + 1).map(|x| E::from(x as u8)).collect();
|
||||
|
||||
for proof in proofs.round_proofs {
|
||||
let partial_evals = proof.poly_evals.clone();
|
||||
coin.reseed(H::hash_elements(&partial_evals));
|
||||
|
||||
// get r
|
||||
let r: E = coin.draw().unwrap();
|
||||
let evals = proof.to_evals(sum_value);
|
||||
|
||||
let point_evals: Vec<_> = points.iter().zip(evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||
let weights = barycentric_weights(&point_evals);
|
||||
sum_value = evaluate_barycentric(&point_evals, r, &weights);
|
||||
}
|
||||
|
||||
assert_eq!(final_evaluation_claim.claimed_evaluation, sum_value);
|
||||
|
||||
final_evaluation_claim
|
||||
}
|
||||
33
src/gkr/utils/mod.rs
Normal file
33
src/gkr/utils/mod.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use winter_math::{FieldElement, batch_inversion};
|
||||
|
||||
|
||||
pub fn barycentric_weights<E: FieldElement>(points: &[(E, E)]) -> Vec<E> {
|
||||
let n = points.len();
|
||||
let tmp = (0..n)
|
||||
.map(|i| (0..n).filter(|&j| j != i).fold(E::ONE, |acc, j| acc * (points[i].0 - points[j].0)))
|
||||
.collect::<Vec<_>>();
|
||||
batch_inversion(&tmp)
|
||||
}
|
||||
|
||||
pub fn evaluate_barycentric<E: FieldElement>(
|
||||
points: &[(E, E)],
|
||||
x: E,
|
||||
barycentric_weights: &[E],
|
||||
) -> E {
|
||||
for &(x_i, y_i) in points {
|
||||
if x_i == x {
|
||||
return y_i;
|
||||
}
|
||||
}
|
||||
|
||||
let l_x: E = points.iter().fold(E::ONE, |acc, &(x_i, _y_i)| acc * (x - x_i));
|
||||
|
||||
let sum = (0..points.len()).fold(E::ZERO, |acc, i| {
|
||||
let x_i = points[i].0;
|
||||
let y_i = points[i].1;
|
||||
let w_i = barycentric_weights[i];
|
||||
acc + (w_i / (x - x_i) * y_i)
|
||||
});
|
||||
|
||||
l_x * sum
|
||||
}
|
||||
@@ -1,9 +1,17 @@
|
||||
//! Cryptographic hash functions used by the Miden VM and the Miden rollup.
|
||||
|
||||
use super::{Felt, FieldElement, StarkField, ONE, ZERO};
|
||||
use super::{CubeExtension, Felt, FieldElement, StarkField, ONE, ZERO};
|
||||
|
||||
pub mod blake;
|
||||
pub mod rpo;
|
||||
|
||||
mod rescue;
|
||||
pub mod rpo {
|
||||
pub use super::rescue::{Rpo256, RpoDigest};
|
||||
}
|
||||
|
||||
pub mod rpx {
|
||||
pub use super::rescue::{Rpx256, RpxDigest};
|
||||
}
|
||||
|
||||
// RE-EXPORTS
|
||||
// ================================================================================================
|
||||
|
||||
101
src/hash/rescue/arch/mod.rs
Normal file
101
src/hash/rescue/arch/mod.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
pub mod optimized {
|
||||
use crate::hash::rescue::STATE_WIDTH;
|
||||
use crate::Felt;
|
||||
|
||||
mod ffi {
|
||||
#[link(name = "rpo_sve", kind = "static")]
|
||||
extern "C" {
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
state: *mut std::ffi::c_ulong,
|
||||
constants: *const std::ffi::c_ulong,
|
||||
) -> bool;
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
state: *mut std::ffi::c_ulong,
|
||||
constants: *const std::ffi::c_ulong,
|
||||
) -> bool;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
unsafe {
|
||||
ffi::add_constants_and_apply_sbox(
|
||||
state.as_mut_ptr() as *mut u64,
|
||||
ark.as_ptr() as *const u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
unsafe {
|
||||
ffi::add_constants_and_apply_inv_sbox(
|
||||
state.as_mut_ptr() as *mut u64,
|
||||
ark.as_ptr() as *const u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "avx2")]
|
||||
mod x86_64_avx2;
|
||||
|
||||
#[cfg(target_feature = "avx2")]
|
||||
pub mod optimized {
|
||||
use super::x86_64_avx2::{apply_inv_sbox, apply_sbox};
|
||||
use crate::hash::rescue::{add_constants, STATE_WIDTH};
|
||||
use crate::Felt;
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
add_constants(state, ark);
|
||||
unsafe {
|
||||
apply_sbox(std::mem::transmute(state));
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
add_constants(state, ark);
|
||||
unsafe {
|
||||
apply_inv_sbox(std::mem::transmute(state));
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(target_feature = "avx2", all(target_feature = "sve", feature = "sve"))))]
|
||||
pub mod optimized {
|
||||
use crate::hash::rescue::STATE_WIDTH;
|
||||
use crate::Felt;
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
_state: &mut [Felt; STATE_WIDTH],
|
||||
_ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
_state: &mut [Felt; STATE_WIDTH],
|
||||
_ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
325
src/hash/rescue/arch/x86_64_avx2.rs
Normal file
325
src/hash/rescue/arch/x86_64_avx2.rs
Normal file
@@ -0,0 +1,325 @@
|
||||
use core::arch::x86_64::*;
|
||||
|
||||
// The following AVX2 implementation has been copied from plonky2:
|
||||
// https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs
|
||||
|
||||
// Preliminary notes:
|
||||
// 1. AVX does not support addition with carry but 128-bit (2-word) addition can be easily
|
||||
// emulated. The method recognizes that for a + b overflowed iff (a + b) < a:
|
||||
// i. res_lo = a_lo + b_lo
|
||||
// ii. carry_mask = res_lo < a_lo
|
||||
// iii. res_hi = a_hi + b_hi - carry_mask
|
||||
// Notice that carry_mask is subtracted, not added. This is because AVX comparison instructions
|
||||
// return -1 (all bits 1) for true and 0 for false.
|
||||
//
|
||||
// 2. AVX does not have unsigned 64-bit comparisons. Those can be emulated with signed comparisons
|
||||
// by recognizing that a <u b iff a + (1 << 63) <s b + (1 << 63), where the addition wraps around
|
||||
// and the comparisons are unsigned and signed respectively. The shift function adds/subtracts
|
||||
// 1 << 63 to enable this trick.
|
||||
// Example: addition with carry.
|
||||
// i. a_lo_s = shift(a_lo)
|
||||
// ii. res_lo_s = a_lo_s + b_lo
|
||||
// iii. carry_mask = res_lo_s <s a_lo_s
|
||||
// iv. res_lo = shift(res_lo_s)
|
||||
// v. res_hi = a_hi + b_hi - carry_mask
|
||||
// The suffix _s denotes a value that has been shifted by 1 << 63. The result of addition is
|
||||
// shifted if exactly one of the operands is shifted, as is the case on line ii. Line iii.
|
||||
// performs a signed comparison res_lo_s <s a_lo_s on shifted values to emulate unsigned
|
||||
// comparison res_lo <u a_lo on unshifted values. Finally, line iv. reverses the shift so the
|
||||
// result can be returned.
|
||||
// When performing a chain of calculations, we can often save instructions by letting the shift
|
||||
// propagate through and only undoing it when necessary. For example, to compute the addition of
|
||||
// three two-word (128-bit) numbers we can do:
|
||||
// i. a_lo_s = shift(a_lo)
|
||||
// ii. tmp_lo_s = a_lo_s + b_lo
|
||||
// iii. tmp_carry_mask = tmp_lo_s <s a_lo_s
|
||||
// iv. tmp_hi = a_hi + b_hi - tmp_carry_mask
|
||||
// v. res_lo_s = tmp_lo_s + c_lo
|
||||
// vi. res_carry_mask = res_lo_s <s tmp_lo_s
|
||||
// vii. res_lo = shift(res_lo_s)
|
||||
// viii. res_hi = tmp_hi + c_hi - res_carry_mask
|
||||
// Notice that the above 3-value addition still only requires two calls to shift, just like our
|
||||
// 2-value addition.
|
||||
|
||||
#[inline(always)]
|
||||
pub fn branch_hint() {
|
||||
// NOTE: These are the currently supported assembly architectures. See the
|
||||
// [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
|
||||
// the most up-to-date list.
|
||||
#[cfg(any(
|
||||
target_arch = "aarch64",
|
||||
target_arch = "arm",
|
||||
target_arch = "riscv32",
|
||||
target_arch = "riscv64",
|
||||
target_arch = "x86",
|
||||
target_arch = "x86_64",
|
||||
))]
|
||||
unsafe {
|
||||
core::arch::asm!("", options(nomem, nostack, preserves_flags));
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! map3 {
|
||||
($f:ident::<$l:literal>, $v:ident) => {
|
||||
($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2))
|
||||
};
|
||||
($f:ident::<$l:literal>, $v1:ident, $v2:ident) => {
|
||||
($f::<$l>($v1.0, $v2.0), $f::<$l>($v1.1, $v2.1), $f::<$l>($v1.2, $v2.2))
|
||||
};
|
||||
($f:ident, $v:ident) => {
|
||||
($f($v.0), $f($v.1), $f($v.2))
|
||||
};
|
||||
($f:ident, $v0:ident, $v1:ident) => {
|
||||
($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2))
|
||||
};
|
||||
($f:ident, rep $v0:ident, $v1:ident) => {
|
||||
($f($v0, $v1.0), $f($v0, $v1.1), $f($v0, $v1.2))
|
||||
};
|
||||
|
||||
($f:ident, $v0:ident, rep $v1:ident) => {
|
||||
($f($v0.0, $v1), $f($v0.1, $v1), $f($v0.2, $v1))
|
||||
};
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn square3(
|
||||
x: (__m256i, __m256i, __m256i),
|
||||
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
|
||||
let x_hi = {
|
||||
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
|
||||
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
|
||||
// This is safe and free.
|
||||
let x_ps = map3!(_mm256_castsi256_ps, x);
|
||||
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
|
||||
map3!(_mm256_castps_si256, x_hi_ps)
|
||||
};
|
||||
|
||||
// All pairwise multiplications.
|
||||
let mul_ll = map3!(_mm256_mul_epu32, x, x);
|
||||
let mul_lh = map3!(_mm256_mul_epu32, x, x_hi);
|
||||
let mul_hh = map3!(_mm256_mul_epu32, x_hi, x_hi);
|
||||
|
||||
// Bignum addition, but mul_lh is shifted by 33 bits (not 32).
|
||||
let mul_ll_hi = map3!(_mm256_srli_epi64::<33>, mul_ll);
|
||||
let t0 = map3!(_mm256_add_epi64, mul_lh, mul_ll_hi);
|
||||
let t0_hi = map3!(_mm256_srli_epi64::<31>, t0);
|
||||
let res_hi = map3!(_mm256_add_epi64, mul_hh, t0_hi);
|
||||
|
||||
// Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high
|
||||
// position).
|
||||
let mul_lh_lo = map3!(_mm256_slli_epi64::<33>, mul_lh);
|
||||
let res_lo = map3!(_mm256_add_epi64, mul_ll, mul_lh_lo);
|
||||
|
||||
(res_lo, res_hi)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mul3(
|
||||
x: (__m256i, __m256i, __m256i),
|
||||
y: (__m256i, __m256i, __m256i),
|
||||
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
|
||||
let epsilon = _mm256_set1_epi64x(0xffffffff);
|
||||
let x_hi = {
|
||||
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
|
||||
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
|
||||
// This is safe and free.
|
||||
let x_ps = map3!(_mm256_castsi256_ps, x);
|
||||
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
|
||||
map3!(_mm256_castps_si256, x_hi_ps)
|
||||
};
|
||||
let y_hi = {
|
||||
let y_ps = map3!(_mm256_castsi256_ps, y);
|
||||
let y_hi_ps = map3!(_mm256_movehdup_ps, y_ps);
|
||||
map3!(_mm256_castps_si256, y_hi_ps)
|
||||
};
|
||||
|
||||
// All four pairwise multiplications
|
||||
let mul_ll = map3!(_mm256_mul_epu32, x, y);
|
||||
let mul_lh = map3!(_mm256_mul_epu32, x, y_hi);
|
||||
let mul_hl = map3!(_mm256_mul_epu32, x_hi, y);
|
||||
let mul_hh = map3!(_mm256_mul_epu32, x_hi, y_hi);
|
||||
|
||||
// Bignum addition
|
||||
// Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow.
|
||||
let mul_ll_hi = map3!(_mm256_srli_epi64::<32>, mul_ll);
|
||||
let t0 = map3!(_mm256_add_epi64, mul_hl, mul_ll_hi);
|
||||
// Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow.
|
||||
// Also, extract high 32 bits of t0 and add to mul_hh.
|
||||
let t0_lo = map3!(_mm256_and_si256, t0, rep epsilon);
|
||||
let t0_hi = map3!(_mm256_srli_epi64::<32>, t0);
|
||||
let t1 = map3!(_mm256_add_epi64, mul_lh, t0_lo);
|
||||
let t2 = map3!(_mm256_add_epi64, mul_hh, t0_hi);
|
||||
// Lastly, extract the high 32 bits of t1 and add to t2.
|
||||
let t1_hi = map3!(_mm256_srli_epi64::<32>, t1);
|
||||
let res_hi = map3!(_mm256_add_epi64, t2, t1_hi);
|
||||
|
||||
// Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high
|
||||
// position).
|
||||
let t1_lo = {
|
||||
let t1_ps = map3!(_mm256_castsi256_ps, t1);
|
||||
let t1_lo_ps = map3!(_mm256_moveldup_ps, t1_ps);
|
||||
map3!(_mm256_castps_si256, t1_lo_ps)
|
||||
};
|
||||
let res_lo = map3!(_mm256_blend_epi32::<0xaa>, mul_ll, t1_lo);
|
||||
|
||||
(res_lo, res_hi)
|
||||
}
|
||||
|
||||
/// Addition, where the second operand is `0 <= y < 0xffffffff00000001`.
|
||||
#[inline(always)]
|
||||
unsafe fn add_small(
|
||||
x_s: (__m256i, __m256i, __m256i),
|
||||
y: (__m256i, __m256i, __m256i),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let res_wrapped_s = map3!(_mm256_add_epi64, x_s, y);
|
||||
let mask = map3!(_mm256_cmpgt_epi32, x_s, res_wrapped_s);
|
||||
let wrapback_amt = map3!(_mm256_srli_epi64::<32>, mask); // EPSILON if overflowed else 0.
|
||||
let res_s = map3!(_mm256_add_epi64, res_wrapped_s, wrapback_amt);
|
||||
res_s
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn maybe_adj_sub(res_wrapped_s: __m256i, mask: __m256i) -> __m256i {
|
||||
// The subtraction is very unlikely to overflow so we're best off branching.
|
||||
// The even u32s in `mask` are meaningless, so we want to ignore them. `_mm256_testz_pd`
|
||||
// branches depending on the sign bit of double-precision (64-bit) floats. Bit cast `mask` to
|
||||
// floating-point (this is free).
|
||||
let mask_pd = _mm256_castsi256_pd(mask);
|
||||
// `_mm256_testz_pd(mask_pd, mask_pd) == 1` iff all sign bits are 0, meaning that underflow
|
||||
// did not occur for any of the vector elements.
|
||||
if _mm256_testz_pd(mask_pd, mask_pd) == 1 {
|
||||
res_wrapped_s
|
||||
} else {
|
||||
branch_hint();
|
||||
// Highly unlikely: underflow did occur. Find adjustment per element and apply it.
|
||||
let adj_amount = _mm256_srli_epi64::<32>(mask); // EPSILON if underflow.
|
||||
_mm256_sub_epi64(res_wrapped_s, adj_amount)
|
||||
}
|
||||
}
|
||||
|
||||
/// Addition, where the second operand is much smaller than `0xffffffff00000001`.
|
||||
#[inline(always)]
|
||||
unsafe fn sub_tiny(
|
||||
x_s: (__m256i, __m256i, __m256i),
|
||||
y: (__m256i, __m256i, __m256i),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let res_wrapped_s = map3!(_mm256_sub_epi64, x_s, y);
|
||||
let mask = map3!(_mm256_cmpgt_epi32, res_wrapped_s, x_s);
|
||||
let res_s = map3!(maybe_adj_sub, res_wrapped_s, mask);
|
||||
res_s
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn reduce3(
|
||||
(lo0, hi0): ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let sign_bit = _mm256_set1_epi64x(i64::MIN);
|
||||
let epsilon = _mm256_set1_epi64x(0xffffffff);
|
||||
let lo0_s = map3!(_mm256_xor_si256, lo0, rep sign_bit);
|
||||
let hi_hi0 = map3!(_mm256_srli_epi64::<32>, hi0);
|
||||
let lo1_s = sub_tiny(lo0_s, hi_hi0);
|
||||
let t1 = map3!(_mm256_mul_epu32, hi0, rep epsilon);
|
||||
let lo2_s = add_small(lo1_s, t1);
|
||||
let lo2 = map3!(_mm256_xor_si256, lo2_s, rep sign_bit);
|
||||
lo2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mul_reduce(
|
||||
a: (__m256i, __m256i, __m256i),
|
||||
b: (__m256i, __m256i, __m256i),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
reduce3(mul3(a, b))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn square_reduce(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||
reduce3(square3(state))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn exp_acc(
|
||||
high: (__m256i, __m256i, __m256i),
|
||||
low: (__m256i, __m256i, __m256i),
|
||||
exp: usize,
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let mut result = high;
|
||||
for _ in 0..exp {
|
||||
result = square_reduce(result);
|
||||
}
|
||||
mul_reduce(result, low)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn do_apply_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||
let state2 = square_reduce(state);
|
||||
let state4_unreduced = square3(state2);
|
||||
let state3_unreduced = mul3(state2, state);
|
||||
let state4 = reduce3(state4_unreduced);
|
||||
let state3 = reduce3(state3_unreduced);
|
||||
let state7_unreduced = mul3(state3, state4);
|
||||
let state7 = reduce3(state7_unreduced);
|
||||
state7
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn do_apply_inv_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||
// compute base^10540996611094048183 using 72 multiplications per array element
|
||||
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
|
||||
|
||||
// compute base^10
|
||||
let t1 = square_reduce(state);
|
||||
|
||||
// compute base^100
|
||||
let t2 = square_reduce(t1);
|
||||
|
||||
// compute base^100100
|
||||
let t3 = exp_acc(t2, t2, 3);
|
||||
|
||||
// compute base^100100100100
|
||||
let t4 = exp_acc(t3, t3, 6);
|
||||
|
||||
// compute base^100100100100100100100100
|
||||
let t5 = exp_acc(t4, t4, 12);
|
||||
|
||||
// compute base^100100100100100100100100100100
|
||||
let t6 = exp_acc(t5, t3, 6);
|
||||
|
||||
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||
let t7 = exp_acc(t6, t6, 31);
|
||||
|
||||
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||
let a = square_reduce(square_reduce(mul_reduce(square_reduce(t7), t6)));
|
||||
let b = mul_reduce(t1, mul_reduce(t2, state));
|
||||
mul_reduce(a, b)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn avx2_load(state: &[u64; 12]) -> (__m256i, __m256i, __m256i) {
|
||||
(
|
||||
_mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()),
|
||||
_mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()),
|
||||
_mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn avx2_store(buf: &mut [u64; 12], state: (__m256i, __m256i, __m256i)) {
|
||||
_mm256_storeu_si256((&mut buf[0..4]).as_mut_ptr().cast::<__m256i>(), state.0);
|
||||
_mm256_storeu_si256((&mut buf[4..8]).as_mut_ptr().cast::<__m256i>(), state.1);
|
||||
_mm256_storeu_si256((&mut buf[8..12]).as_mut_ptr().cast::<__m256i>(), state.2);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub unsafe fn apply_sbox(buffer: &mut [u64; 12]) {
|
||||
let mut state = avx2_load(&buffer);
|
||||
state = do_apply_sbox(state);
|
||||
avx2_store(buffer, state);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub unsafe fn apply_inv_sbox(buffer: &mut [u64; 12]) {
|
||||
let mut state = avx2_load(&buffer);
|
||||
state = do_apply_inv_sbox(state);
|
||||
avx2_store(buffer, state);
|
||||
}
|
||||
@@ -11,7 +11,8 @@
|
||||
/// divisions by 2 and repeated modular reductions. This is because of our explicit choice of
|
||||
/// an MDS matrix that has small powers of 2 entries in frequency domain.
|
||||
/// The following implementation has benefited greatly from the discussions and insights of
|
||||
/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero.
|
||||
/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is base on Nabaglo's Plonky2
|
||||
/// implementation.
|
||||
|
||||
// Rescue MDS matrix in frequency domain.
|
||||
// More precisely, this is the output of the three 4-point (real) FFTs of the first column of
|
||||
@@ -26,7 +27,7 @@ const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1];
|
||||
|
||||
// We use split 3 x 4 FFT transform in order to transform our vectors into the frequency domain.
|
||||
#[inline(always)]
|
||||
pub(crate) const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
|
||||
pub const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
|
||||
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = state;
|
||||
|
||||
let (u0, u1, u2) = fft4_real([s0, s3, s6, s9]);
|
||||
@@ -156,7 +157,7 @@ const fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::{Felt, Rpo256, MDS, ZERO};
|
||||
use super::super::{apply_mds, Felt, MDS, ZERO};
|
||||
use proptest::prelude::*;
|
||||
|
||||
const STATE_WIDTH: usize = 12;
|
||||
@@ -185,7 +186,7 @@ mod tests {
|
||||
v2 = v1;
|
||||
|
||||
apply_mds_naive(&mut v1);
|
||||
Rpo256::apply_mds(&mut v2);
|
||||
apply_mds(&mut v2);
|
||||
|
||||
prop_assert_eq!(v1, v2);
|
||||
}
|
||||
214
src/hash/rescue/mds/mod.rs
Normal file
214
src/hash/rescue/mds/mod.rs
Normal file
@@ -0,0 +1,214 @@
|
||||
use super::{Felt, STATE_WIDTH, ZERO};
|
||||
|
||||
mod freq;
|
||||
pub use freq::mds_multiply_freq;
|
||||
|
||||
// MDS MULTIPLICATION
|
||||
// ================================================================================================
|
||||
|
||||
#[inline(always)]
|
||||
pub fn apply_mds(state: &mut [Felt; STATE_WIDTH]) {
|
||||
let mut result = [ZERO; STATE_WIDTH];
|
||||
|
||||
// Using the linearity of the operations we can split the state into a low||high decomposition
|
||||
// and operate on each with no overflow and then combine/reduce the result to a field element.
|
||||
// The no overflow is guaranteed by the fact that the MDS matrix is a small powers of two in
|
||||
// frequency domain.
|
||||
let mut state_l = [0u64; STATE_WIDTH];
|
||||
let mut state_h = [0u64; STATE_WIDTH];
|
||||
|
||||
for r in 0..STATE_WIDTH {
|
||||
let s = state[r].inner();
|
||||
state_h[r] = s >> 32;
|
||||
state_l[r] = (s as u32) as u64;
|
||||
}
|
||||
|
||||
let state_h = mds_multiply_freq(state_h);
|
||||
let state_l = mds_multiply_freq(state_l);
|
||||
|
||||
for r in 0..STATE_WIDTH {
|
||||
let s = state_l[r] as u128 + ((state_h[r] as u128) << 32);
|
||||
let s_hi = (s >> 64) as u64;
|
||||
let s_lo = s as u64;
|
||||
let z = (s_hi << 32) - s_hi;
|
||||
let (res, over) = s_lo.overflowing_add(z);
|
||||
|
||||
result[r] = Felt::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64));
|
||||
}
|
||||
*state = result;
|
||||
}
|
||||
|
||||
// MDS MATRIX
|
||||
// ================================================================================================
|
||||
|
||||
/// RPO MDS matrix
|
||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = [
|
||||
[
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
],
|
||||
[
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
],
|
||||
[
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
],
|
||||
[
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
],
|
||||
[
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
],
|
||||
[
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
],
|
||||
[
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
],
|
||||
[
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
],
|
||||
[
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
],
|
||||
[
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
],
|
||||
[
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
],
|
||||
[
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
],
|
||||
];
|
||||
348
src/hash/rescue/mod.rs
Normal file
348
src/hash/rescue/mod.rs
Normal file
@@ -0,0 +1,348 @@
|
||||
use super::{
|
||||
CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO,
|
||||
};
|
||||
use core::ops::Range;
|
||||
|
||||
mod arch;
|
||||
pub use arch::optimized::{add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox};
|
||||
|
||||
mod mds;
|
||||
use mds::{apply_mds, MDS};
|
||||
|
||||
mod rpo;
|
||||
pub use rpo::{Rpo256, RpoDigest};
|
||||
|
||||
mod rpx;
|
||||
pub use rpx::{Rpx256, RpxDigest};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// The number of rounds is set to 7. For the RPO hash functions all rounds are uniform. For the
|
||||
/// RPX hash function, there are 3 different types of rounds.
|
||||
const NUM_ROUNDS: usize = 7;
|
||||
|
||||
/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
const STATE_WIDTH: usize = 12;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11.
|
||||
const RATE_RANGE: Range<usize> = 4..12;
|
||||
const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
|
||||
|
||||
const INPUT1_RANGE: Range<usize> = 4..8;
|
||||
const INPUT2_RANGE: Range<usize> = 8..12;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
const CAPACITY_RANGE: Range<usize> = 0..4;
|
||||
|
||||
/// The output of the hash function is a digest which consists of 4 field elements or 32 bytes.
|
||||
///
|
||||
/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the
|
||||
/// rate portion).
|
||||
const DIGEST_RANGE: Range<usize> = 4..8;
|
||||
const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start;
|
||||
|
||||
/// The number of bytes needed to encoded a digest
|
||||
const DIGEST_BYTES: usize = 32;
|
||||
|
||||
/// The number of byte chunks defining a field element when hashing a sequence of bytes
|
||||
const BINARY_CHUNK_SIZE: usize = 7;
|
||||
|
||||
/// S-Box and Inverse S-Box powers;
|
||||
///
|
||||
/// The constants are defined for tests only because the exponentiations in the code are unrolled
|
||||
/// for efficiency reasons.
|
||||
#[cfg(test)]
|
||||
const ALPHA: u64 = 7;
|
||||
#[cfg(test)]
|
||||
const INV_ALPHA: u64 = 10540996611094048183;
|
||||
|
||||
// SBOX FUNCTION
|
||||
// ================================================================================================
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||
state[0] = state[0].exp7();
|
||||
state[1] = state[1].exp7();
|
||||
state[2] = state[2].exp7();
|
||||
state[3] = state[3].exp7();
|
||||
state[4] = state[4].exp7();
|
||||
state[5] = state[5].exp7();
|
||||
state[6] = state[6].exp7();
|
||||
state[7] = state[7].exp7();
|
||||
state[8] = state[8].exp7();
|
||||
state[9] = state[9].exp7();
|
||||
state[10] = state[10].exp7();
|
||||
state[11] = state[11].exp7();
|
||||
}
|
||||
|
||||
// INVERSE SBOX FUNCTION
|
||||
// ================================================================================================
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||
// compute base^10540996611094048183 using 72 multiplications per array element
|
||||
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
|
||||
|
||||
// compute base^10
|
||||
let mut t1 = *state;
|
||||
t1.iter_mut().for_each(|t| *t = t.square());
|
||||
|
||||
// compute base^100
|
||||
let mut t2 = t1;
|
||||
t2.iter_mut().for_each(|t| *t = t.square());
|
||||
|
||||
// compute base^100100
|
||||
let t3 = exp_acc::<Felt, STATE_WIDTH, 3>(t2, t2);
|
||||
|
||||
// compute base^100100100100
|
||||
let t4 = exp_acc::<Felt, STATE_WIDTH, 6>(t3, t3);
|
||||
|
||||
// compute base^100100100100100100100100
|
||||
let t5 = exp_acc::<Felt, STATE_WIDTH, 12>(t4, t4);
|
||||
|
||||
// compute base^100100100100100100100100100100
|
||||
let t6 = exp_acc::<Felt, STATE_WIDTH, 6>(t5, t3);
|
||||
|
||||
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||
let t7 = exp_acc::<Felt, STATE_WIDTH, 31>(t6, t6);
|
||||
|
||||
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||
for (i, s) in state.iter_mut().enumerate() {
|
||||
let a = (t7[i].square() * t6[i]).square().square();
|
||||
let b = t1[i] * t2[i] * *s;
|
||||
*s = a * b;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn exp_acc<B: StarkField, const N: usize, const M: usize>(
|
||||
base: [B; N],
|
||||
tail: [B; N],
|
||||
) -> [B; N] {
|
||||
let mut result = base;
|
||||
for _ in 0..M {
|
||||
result.iter_mut().for_each(|r| *r = r.square());
|
||||
}
|
||||
result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t);
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) {
|
||||
state.iter_mut().zip(ark).for_each(|(s, &k)| *s += k);
|
||||
}
|
||||
|
||||
// ROUND CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Rescue round constants;
|
||||
/// computed as in [specifications](https://github.com/ASDiscreteMathematics/rpo)
|
||||
///
|
||||
/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the
|
||||
/// first half of RPO round, and ARK2 contains constants for the second half of RPO round.
|
||||
const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(5789762306288267392),
|
||||
Felt::new(6522564764413701783),
|
||||
Felt::new(17809893479458208203),
|
||||
Felt::new(107145243989736508),
|
||||
Felt::new(6388978042437517382),
|
||||
Felt::new(15844067734406016715),
|
||||
Felt::new(9975000513555218239),
|
||||
Felt::new(3344984123768313364),
|
||||
Felt::new(9959189626657347191),
|
||||
Felt::new(12960773468763563665),
|
||||
Felt::new(9602914297752488475),
|
||||
Felt::new(16657542370200465908),
|
||||
],
|
||||
[
|
||||
Felt::new(12987190162843096997),
|
||||
Felt::new(653957632802705281),
|
||||
Felt::new(4441654670647621225),
|
||||
Felt::new(4038207883745915761),
|
||||
Felt::new(5613464648874830118),
|
||||
Felt::new(13222989726778338773),
|
||||
Felt::new(3037761201230264149),
|
||||
Felt::new(16683759727265180203),
|
||||
Felt::new(8337364536491240715),
|
||||
Felt::new(3227397518293416448),
|
||||
Felt::new(8110510111539674682),
|
||||
Felt::new(2872078294163232137),
|
||||
],
|
||||
[
|
||||
Felt::new(18072785500942327487),
|
||||
Felt::new(6200974112677013481),
|
||||
Felt::new(17682092219085884187),
|
||||
Felt::new(10599526828986756440),
|
||||
Felt::new(975003873302957338),
|
||||
Felt::new(8264241093196931281),
|
||||
Felt::new(10065763900435475170),
|
||||
Felt::new(2181131744534710197),
|
||||
Felt::new(6317303992309418647),
|
||||
Felt::new(1401440938888741532),
|
||||
Felt::new(8884468225181997494),
|
||||
Felt::new(13066900325715521532),
|
||||
],
|
||||
[
|
||||
Felt::new(5674685213610121970),
|
||||
Felt::new(5759084860419474071),
|
||||
Felt::new(13943282657648897737),
|
||||
Felt::new(1352748651966375394),
|
||||
Felt::new(17110913224029905221),
|
||||
Felt::new(1003883795902368422),
|
||||
Felt::new(4141870621881018291),
|
||||
Felt::new(8121410972417424656),
|
||||
Felt::new(14300518605864919529),
|
||||
Felt::new(13712227150607670181),
|
||||
Felt::new(17021852944633065291),
|
||||
Felt::new(6252096473787587650),
|
||||
],
|
||||
[
|
||||
Felt::new(4887609836208846458),
|
||||
Felt::new(3027115137917284492),
|
||||
Felt::new(9595098600469470675),
|
||||
Felt::new(10528569829048484079),
|
||||
Felt::new(7864689113198939815),
|
||||
Felt::new(17533723827845969040),
|
||||
Felt::new(5781638039037710951),
|
||||
Felt::new(17024078752430719006),
|
||||
Felt::new(109659393484013511),
|
||||
Felt::new(7158933660534805869),
|
||||
Felt::new(2955076958026921730),
|
||||
Felt::new(7433723648458773977),
|
||||
],
|
||||
[
|
||||
Felt::new(16308865189192447297),
|
||||
Felt::new(11977192855656444890),
|
||||
Felt::new(12532242556065780287),
|
||||
Felt::new(14594890931430968898),
|
||||
Felt::new(7291784239689209784),
|
||||
Felt::new(5514718540551361949),
|
||||
Felt::new(10025733853830934803),
|
||||
Felt::new(7293794580341021693),
|
||||
Felt::new(6728552937464861756),
|
||||
Felt::new(6332385040983343262),
|
||||
Felt::new(13277683694236792804),
|
||||
Felt::new(2600778905124452676),
|
||||
],
|
||||
[
|
||||
Felt::new(7123075680859040534),
|
||||
Felt::new(1034205548717903090),
|
||||
Felt::new(7717824418247931797),
|
||||
Felt::new(3019070937878604058),
|
||||
Felt::new(11403792746066867460),
|
||||
Felt::new(10280580802233112374),
|
||||
Felt::new(337153209462421218),
|
||||
Felt::new(13333398568519923717),
|
||||
Felt::new(3596153696935337464),
|
||||
Felt::new(8104208463525993784),
|
||||
Felt::new(14345062289456085693),
|
||||
Felt::new(17036731477169661256),
|
||||
],
|
||||
];
|
||||
|
||||
const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(6077062762357204287),
|
||||
Felt::new(15277620170502011191),
|
||||
Felt::new(5358738125714196705),
|
||||
Felt::new(14233283787297595718),
|
||||
Felt::new(13792579614346651365),
|
||||
Felt::new(11614812331536767105),
|
||||
Felt::new(14871063686742261166),
|
||||
Felt::new(10148237148793043499),
|
||||
Felt::new(4457428952329675767),
|
||||
Felt::new(15590786458219172475),
|
||||
Felt::new(10063319113072092615),
|
||||
Felt::new(14200078843431360086),
|
||||
],
|
||||
[
|
||||
Felt::new(6202948458916099932),
|
||||
Felt::new(17690140365333231091),
|
||||
Felt::new(3595001575307484651),
|
||||
Felt::new(373995945117666487),
|
||||
Felt::new(1235734395091296013),
|
||||
Felt::new(14172757457833931602),
|
||||
Felt::new(707573103686350224),
|
||||
Felt::new(15453217512188187135),
|
||||
Felt::new(219777875004506018),
|
||||
Felt::new(17876696346199469008),
|
||||
Felt::new(17731621626449383378),
|
||||
Felt::new(2897136237748376248),
|
||||
],
|
||||
[
|
||||
Felt::new(8023374565629191455),
|
||||
Felt::new(15013690343205953430),
|
||||
Felt::new(4485500052507912973),
|
||||
Felt::new(12489737547229155153),
|
||||
Felt::new(9500452585969030576),
|
||||
Felt::new(2054001340201038870),
|
||||
Felt::new(12420704059284934186),
|
||||
Felt::new(355990932618543755),
|
||||
Felt::new(9071225051243523860),
|
||||
Felt::new(12766199826003448536),
|
||||
Felt::new(9045979173463556963),
|
||||
Felt::new(12934431667190679898),
|
||||
],
|
||||
[
|
||||
Felt::new(18389244934624494276),
|
||||
Felt::new(16731736864863925227),
|
||||
Felt::new(4440209734760478192),
|
||||
Felt::new(17208448209698888938),
|
||||
Felt::new(8739495587021565984),
|
||||
Felt::new(17000774922218161967),
|
||||
Felt::new(13533282547195532087),
|
||||
Felt::new(525402848358706231),
|
||||
Felt::new(16987541523062161972),
|
||||
Felt::new(5466806524462797102),
|
||||
Felt::new(14512769585918244983),
|
||||
Felt::new(10973956031244051118),
|
||||
],
|
||||
[
|
||||
Felt::new(6982293561042362913),
|
||||
Felt::new(14065426295947720331),
|
||||
Felt::new(16451845770444974180),
|
||||
Felt::new(7139138592091306727),
|
||||
Felt::new(9012006439959783127),
|
||||
Felt::new(14619614108529063361),
|
||||
Felt::new(1394813199588124371),
|
||||
Felt::new(4635111139507788575),
|
||||
Felt::new(16217473952264203365),
|
||||
Felt::new(10782018226466330683),
|
||||
Felt::new(6844229992533662050),
|
||||
Felt::new(7446486531695178711),
|
||||
],
|
||||
[
|
||||
Felt::new(3736792340494631448),
|
||||
Felt::new(577852220195055341),
|
||||
Felt::new(6689998335515779805),
|
||||
Felt::new(13886063479078013492),
|
||||
Felt::new(14358505101923202168),
|
||||
Felt::new(7744142531772274164),
|
||||
Felt::new(16135070735728404443),
|
||||
Felt::new(12290902521256031137),
|
||||
Felt::new(12059913662657709804),
|
||||
Felt::new(16456018495793751911),
|
||||
Felt::new(4571485474751953524),
|
||||
Felt::new(17200392109565783176),
|
||||
],
|
||||
[
|
||||
Felt::new(17130398059294018733),
|
||||
Felt::new(519782857322261988),
|
||||
Felt::new(9625384390925085478),
|
||||
Felt::new(1664893052631119222),
|
||||
Felt::new(7629576092524553570),
|
||||
Felt::new(3485239601103661425),
|
||||
Felt::new(9755891797164033838),
|
||||
Felt::new(15218148195153269027),
|
||||
Felt::new(16460604813734957368),
|
||||
Felt::new(9643968136937729763),
|
||||
Felt::new(3611348709641382851),
|
||||
Felt::new(18256379591337759196),
|
||||
],
|
||||
];
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::{Digest, Felt, StarkField, DIGEST_SIZE, ZERO};
|
||||
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
|
||||
use crate::utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
@@ -6,9 +6,6 @@ use crate::utils::{
|
||||
use core::{cmp::Ordering, fmt::Display, ops::Deref};
|
||||
use winter_utils::Randomizable;
|
||||
|
||||
/// The number of bytes needed to encoded a digest
|
||||
pub const DIGEST_BYTES: usize = 32;
|
||||
|
||||
// DIGEST TRAIT IMPLEMENTATIONS
|
||||
// ================================================================================================
|
||||
|
||||
@@ -172,9 +169,21 @@ impl From<&RpoDigest> for String {
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: TO DIGEST
|
||||
// CONVERSIONS: TO RPO DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum RpoDigestError {
|
||||
/// The provided u64 integer does not fit in the field's moduli.
|
||||
InvalidInteger,
|
||||
}
|
||||
|
||||
impl From<&[Felt; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: &[Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(*value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[Felt; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
@@ -200,6 +209,46 @@ impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8; DIGEST_BYTES]> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u64; DIGEST_SIZE]> for RpoDigest {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
|
||||
if value[0] >= Felt::MODULUS
|
||||
|| value[1] >= Felt::MODULUS
|
||||
|| value[2] >= Felt::MODULUS
|
||||
|| value[3] >= Felt::MODULUS
|
||||
{
|
||||
return Err(RpoDigestError::InvalidInteger);
|
||||
}
|
||||
|
||||
Ok(Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()]))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u64; DIGEST_SIZE]> for RpoDigest {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
@@ -253,13 +302,24 @@ impl Deserializable for RpoDigest {
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
impl IntoIterator for RpoDigest {
|
||||
type Item = Felt;
|
||||
type IntoIter = <[Felt; 4] as IntoIterator>::IntoIter;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.0.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES};
|
||||
use crate::utils::SliceReader;
|
||||
use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
|
||||
use crate::utils::{string::String, SliceReader};
|
||||
use rand_utils::rand_value;
|
||||
|
||||
#[test]
|
||||
@@ -281,7 +341,6 @@ mod tests {
|
||||
assert_eq!(d1, d2);
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn digest_encoding() {
|
||||
let digest = RpoDigest([
|
||||
@@ -296,4 +355,54 @@ mod tests {
|
||||
|
||||
assert_eq!(digest, round_trip);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversions() {
|
||||
let digest = RpoDigest([
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = digest.into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = (&digest).into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpoDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpoDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
}
|
||||
}
|
||||
323
src/hash/rescue/rpo/mod.rs
Normal file
323
src/hash/rescue/rpo/mod.rs
Normal file
@@ -0,0 +1,323 @@
|
||||
use super::{
|
||||
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
|
||||
apply_mds, apply_sbox, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ARK1,
|
||||
ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, DIGEST_SIZE, INPUT1_RANGE,
|
||||
INPUT2_RANGE, MDS, NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO,
|
||||
};
|
||||
use core::{convert::TryInto, ops::Range};
|
||||
|
||||
mod digest;
|
||||
pub use digest::RpoDigest;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// HASHER IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
/// Implementation of the Rescue Prime Optimized hash function with 256-bit output.
|
||||
///
|
||||
/// The hash function is implemented according to the Rescue Prime Optimized
|
||||
/// [specifications](https://eprint.iacr.org/2022/1577)
|
||||
///
|
||||
/// The parameters used to instantiate the function are:
|
||||
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
|
||||
/// * State width: 12 field elements.
|
||||
/// * Capacity size: 4 field elements.
|
||||
/// * Number of founds: 7.
|
||||
/// * S-Box degree: 7.
|
||||
///
|
||||
/// The above parameters target 128-bit security level. The digest consists of four field elements
|
||||
/// and it can be serialized into 32 bytes (256 bits).
|
||||
///
|
||||
/// ## Hash output consistency
|
||||
/// Functions [hash_elements()](Rpo256::hash_elements), [merge()](Rpo256::merge), and
|
||||
/// [merge_with_int()](Rpo256::merge_with_int) are internally consistent. That is, computing
|
||||
/// a hash for the same set of elements using these functions will always produce the same
|
||||
/// result. For example, merging two digests using [merge()](Rpo256::merge) will produce the
|
||||
/// same result as hashing 8 elements which make up these digests using
|
||||
/// [hash_elements()](Rpo256::hash_elements) function.
|
||||
///
|
||||
/// However, [hash()](Rpo256::hash) function is not consistent with functions mentioned above.
|
||||
/// For example, if we take two field elements, serialize them to bytes and hash them using
|
||||
/// [hash()](Rpo256::hash), the result will differ from the result obtained by hashing these
|
||||
/// elements directly using [hash_elements()](Rpo256::hash_elements) function. The reason for
|
||||
/// this difference is that [hash()](Rpo256::hash) function needs to be able to handle
|
||||
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
|
||||
/// deserialization procedure used by this function is different from the procedure used to
|
||||
/// deserialize valid field elements.
|
||||
///
|
||||
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
||||
/// to deserialize them into field elements and then hash them using
|
||||
/// [hash_elements()](Rpo256::hash_elements) function rather then hashing the serialized bytes
|
||||
/// using [hash()](Rpo256::hash) function.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Rpo256();
|
||||
|
||||
impl Hasher for Rpo256 {
|
||||
/// Rpo256 collision resistance is the same as the security level, that is 128-bits.
|
||||
///
|
||||
/// #### Collision resistance
|
||||
///
|
||||
/// However, our setup of the capacity registers might drop it to 126.
|
||||
///
|
||||
/// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69)
|
||||
const COLLISION_RESISTANCE: u32 = 128;
|
||||
|
||||
type Digest = RpoDigest;
|
||||
|
||||
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||
// initialize the state with zeroes
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
// set the capacity (first element) to a flag on whether or not the input length is evenly
|
||||
// divided by the rate. this will prevent collisions between padded and non-padded inputs,
|
||||
// and will rule out the need to perform an extra permutation in case of evenly divided
|
||||
// inputs.
|
||||
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0;
|
||||
if !is_rate_multiple {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// initialize a buffer to receive the little-endian elements.
|
||||
let mut buf = [0_u8; 8];
|
||||
|
||||
// iterate the chunks of bytes, creating a field element from each chunk and copying it
|
||||
// into the state.
|
||||
//
|
||||
// every time the rate range is filled, a permutation is performed. if the final value of
|
||||
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
|
||||
// additional permutation must be performed.
|
||||
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
|
||||
// the last element of the iteration may or may not be a full chunk. if it's not, then
|
||||
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`.
|
||||
// this will avoid collisions.
|
||||
if chunk.len() == BINARY_CHUNK_SIZE {
|
||||
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||
} else {
|
||||
buf.fill(0);
|
||||
buf[..chunk.len()].copy_from_slice(chunk);
|
||||
buf[chunk.len()] = 1;
|
||||
}
|
||||
|
||||
// set the current rate element to the input. since we take at most 7 bytes, we are
|
||||
// guaranteed that the inputs data will fit into a single field element.
|
||||
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf));
|
||||
|
||||
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
||||
// counter to the beginning of the range.
|
||||
if i == RATE_WIDTH - 1 {
|
||||
Self::apply_permutation(&mut state);
|
||||
0
|
||||
} else {
|
||||
i + 1
|
||||
}
|
||||
});
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we
|
||||
// don't need to apply any extra padding because the first capacity element contains a
|
||||
// flag indicating whether the input is evenly divisible by the rate.
|
||||
if i != 0 {
|
||||
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the rate as hash result.
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = Self::Digest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
// initialize the state as follows:
|
||||
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element
|
||||
// and set the sixth rate element to 1.
|
||||
// - if the value doesn't fit into a single field element, split it into two field
|
||||
// elements, copy them into rate elements 5 and 6, and set the seventh rate element
|
||||
// to 1.
|
||||
// - set the first capacity element to 1
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||
if value < Felt::MODULUS {
|
||||
state[INPUT2_RANGE.start + 1] = ONE;
|
||||
} else {
|
||||
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
||||
state[INPUT2_RANGE.start + 2] = ONE;
|
||||
}
|
||||
|
||||
// common padding for both cases
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementHasher for Rpo256 {
|
||||
type BaseField = Felt;
|
||||
|
||||
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
|
||||
// convert the elements into a list of base field elements
|
||||
let elements = E::slice_as_base_elements(elements);
|
||||
|
||||
// initialize state to all zeros, except for the first element of the capacity part, which
|
||||
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
if elements.len() % RATE_WIDTH != 0 {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// absorb elements into the state one by one until the rate portion of the state is filled
|
||||
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
||||
// elements have been absorbed
|
||||
let mut i = 0;
|
||||
for &element in elements.iter() {
|
||||
state[RATE_RANGE.start + i] = element;
|
||||
i += 1;
|
||||
if i % RATE_WIDTH == 0 {
|
||||
Self::apply_permutation(&mut state);
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after
|
||||
// padding by appending a 1 followed by as many 0 as necessary to make the input length a
|
||||
// multiple of the RATE_WIDTH.
|
||||
if i > 0 {
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
i += 1;
|
||||
while i != RATE_WIDTH {
|
||||
state[RATE_RANGE.start + i] = ZERO;
|
||||
i += 1;
|
||||
}
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the state as hash result
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
// HASH FUNCTION IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Rpo256 {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// The number of rounds is set to 7 to target 128-bit security level.
|
||||
pub const NUM_ROUNDS: usize = NUM_ROUNDS;
|
||||
|
||||
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
pub const STATE_WIDTH: usize = STATE_WIDTH;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
|
||||
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
|
||||
|
||||
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
|
||||
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
|
||||
|
||||
/// MDS matrix used for computing the linear layer in a RPO round.
|
||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
|
||||
|
||||
/// Round constants added to the hasher state in the first half of the RPO round.
|
||||
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
|
||||
|
||||
/// Round constants added to the hasher state in the second half of the RPO round.
|
||||
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
|
||||
|
||||
// TRAIT PASS-THROUGH FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of the provided sequence of bytes.
|
||||
#[inline(always)]
|
||||
pub fn hash(bytes: &[u8]) -> RpoDigest {
|
||||
<Self as Hasher>::hash(bytes)
|
||||
}
|
||||
|
||||
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||
/// Merkle trees and verification of Merkle paths.
|
||||
#[inline(always)]
|
||||
pub fn merge(values: &[RpoDigest; 2]) -> RpoDigest {
|
||||
<Self as Hasher>::merge(values)
|
||||
}
|
||||
|
||||
/// Returns a hash of the provided field elements.
|
||||
#[inline(always)]
|
||||
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpoDigest {
|
||||
<Self as ElementHasher>::hash_elements(elements)
|
||||
}
|
||||
|
||||
// DOMAIN IDENTIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of two digests and a domain identifier.
|
||||
pub fn merge_in_domain(values: &[RpoDigest; 2], domain: Felt) -> RpoDigest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = RpoDigest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// set the second capacity element to the domain value. The first capacity element is used
|
||||
// for padding purposes.
|
||||
state[CAPACITY_RANGE.start + 1] = domain;
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
// RESCUE PERMUTATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Applies RPO permutation to the provided state.
|
||||
#[inline(always)]
|
||||
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
|
||||
for i in 0..NUM_ROUNDS {
|
||||
Self::apply_round(state, i);
|
||||
}
|
||||
}
|
||||
|
||||
/// RPO round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
// apply first half of RPO round
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_sbox(state, &ARK1[round]) {
|
||||
add_constants(state, &ARK1[round]);
|
||||
apply_sbox(state);
|
||||
}
|
||||
|
||||
// apply second half of RPO round
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
|
||||
add_constants(state, &ARK2[round]);
|
||||
apply_inv_sbox(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::{
|
||||
Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, ALPHA, INV_ALPHA, ONE, STATE_WIDTH,
|
||||
ZERO,
|
||||
super::{apply_inv_sbox, apply_sbox, ALPHA, INV_ALPHA},
|
||||
Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, ONE, STATE_WIDTH, ZERO,
|
||||
};
|
||||
use crate::{
|
||||
utils::collections::{BTreeSet, Vec},
|
||||
@@ -10,13 +10,6 @@ use core::convert::TryInto;
|
||||
use proptest::prelude::*;
|
||||
use rand_utils::rand_value;
|
||||
|
||||
#[test]
|
||||
fn test_alphas() {
|
||||
let e: Felt = Felt::new(rand_value());
|
||||
let e_exp = e.exp(ALPHA);
|
||||
assert_eq!(e, e_exp.exp(INV_ALPHA));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sbox() {
|
||||
let state = [Felt::new(rand_value()); STATE_WIDTH];
|
||||
@@ -25,7 +18,7 @@ fn test_sbox() {
|
||||
expected.iter_mut().for_each(|v| *v = v.exp(ALPHA));
|
||||
|
||||
let mut actual = state;
|
||||
Rpo256::apply_sbox(&mut actual);
|
||||
apply_sbox(&mut actual);
|
||||
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
@@ -38,7 +31,7 @@ fn test_inv_sbox() {
|
||||
expected.iter_mut().for_each(|v| *v = v.exp(INV_ALPHA));
|
||||
|
||||
let mut actual = state;
|
||||
Rpo256::apply_inv_sbox(&mut actual);
|
||||
apply_inv_sbox(&mut actual);
|
||||
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
398
src/hash/rescue/rpx/digest.rs
Normal file
398
src/hash/rescue/rpx/digest.rs
Normal file
@@ -0,0 +1,398 @@
|
||||
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
|
||||
use crate::utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
};
|
||||
use core::{cmp::Ordering, fmt::Display, ops::Deref};
|
||||
use winter_utils::Randomizable;
|
||||
|
||||
// DIGEST TRAIT IMPLEMENTATIONS
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
|
||||
pub struct RpxDigest([Felt; DIGEST_SIZE]);
|
||||
|
||||
impl RpxDigest {
|
||||
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
|
||||
pub fn as_elements(&self) -> &[Felt] {
|
||||
self.as_ref()
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
|
||||
<Self as Digest>::as_bytes(self)
|
||||
}
|
||||
|
||||
pub fn digests_as_elements<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
|
||||
where
|
||||
I: Iterator<Item = &'a Self>,
|
||||
{
|
||||
digests.flat_map(|d| d.0.iter())
|
||||
}
|
||||
}
|
||||
|
||||
impl Digest for RpxDigest {
|
||||
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
|
||||
let mut result = [0; DIGEST_BYTES];
|
||||
|
||||
result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes());
|
||||
result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes());
|
||||
result[16..24].copy_from_slice(&self.0[2].as_int().to_le_bytes());
|
||||
result[24..].copy_from_slice(&self.0[3].as_int().to_le_bytes());
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for RpxDigest {
|
||||
type Target = [Felt; DIGEST_SIZE];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for RpxDigest {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
// compare the inner u64 of both elements.
|
||||
//
|
||||
// it will iterate the elements and will return the first computation different than
|
||||
// `Equal`. Otherwise, the ordering is equal.
|
||||
//
|
||||
// the endianness is irrelevant here because since, this being a cryptographically secure
|
||||
// hash computation, the digest shouldn't have any ordered property of its input.
|
||||
//
|
||||
// finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a
|
||||
// montgomery reduction for every limb. that is safe because every inner element of the
|
||||
// digest is guaranteed to be in its canonical form (that is, `x in [0,p)`).
|
||||
self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold(
|
||||
Ordering::Equal,
|
||||
|ord, (a, b)| match ord {
|
||||
Ordering::Equal => a.cmp(&b),
|
||||
_ => ord,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for RpxDigest {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for RpxDigest {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
let encoded: String = self.into();
|
||||
write!(f, "{}", encoded)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Randomizable for RpxDigest {
|
||||
const VALUE_SIZE: usize = DIGEST_BYTES;
|
||||
|
||||
fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
|
||||
let bytes_array: Option<[u8; 32]> = bytes.try_into().ok();
|
||||
if let Some(bytes_array) = bytes_array {
|
||||
Self::try_from(bytes_array).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: FROM RPX DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
impl From<&RpxDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for [u64; DIGEST_SIZE] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
[
|
||||
value.0[0].as_int(),
|
||||
value.0[1].as_int(),
|
||||
value.0[2].as_int(),
|
||||
value.0[3].as_int(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for [u64; DIGEST_SIZE] {
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
[
|
||||
value.0[0].as_int(),
|
||||
value.0[1].as_int(),
|
||||
value.0[2].as_int(),
|
||||
value.0[3].as_int(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for [u8; DIGEST_BYTES] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
value.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for [u8; DIGEST_BYTES] {
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
value.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
bytes_to_hex_string(value.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: TO RPX DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum RpxDigestError {
|
||||
/// The provided u64 integer does not fit in the field's moduli.
|
||||
InvalidInteger,
|
||||
}
|
||||
|
||||
impl From<&[Felt; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: &[Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(*value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[Felt; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u8; DIGEST_BYTES]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: [u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
// Note: the input length is known, the conversion from slice to array must succeed so the
|
||||
// `unwrap`s below are safe
|
||||
let a = u64::from_le_bytes(value[0..8].try_into().unwrap());
|
||||
let b = u64::from_le_bytes(value[8..16].try_into().unwrap());
|
||||
let c = u64::from_le_bytes(value[16..24].try_into().unwrap());
|
||||
let d = u64::from_le_bytes(value[24..32].try_into().unwrap());
|
||||
|
||||
if [a, b, c, d].iter().any(|v| *v >= Felt::MODULUS) {
|
||||
return Err(HexParseError::OutOfRange);
|
||||
}
|
||||
|
||||
Ok(RpxDigest([Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8; DIGEST_BYTES]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u64; DIGEST_SIZE]> for RpxDigest {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
|
||||
if value[0] >= Felt::MODULUS
|
||||
|| value[1] >= Felt::MODULUS
|
||||
|| value[2] >= Felt::MODULUS
|
||||
|| value[3] >= Felt::MODULUS
|
||||
{
|
||||
return Err(RpxDigestError::InvalidInteger);
|
||||
}
|
||||
|
||||
Ok(Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()]))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u64; DIGEST_SIZE]> for RpxDigest {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
hex_to_bytes(value).and_then(|v| v.try_into())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<String> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: String) -> Result<Self, Self::Error> {
|
||||
value.as_str().try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&String> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: &String) -> Result<Self, Self::Error> {
|
||||
value.as_str().try_into()
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for RpxDigest {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.as_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RpxDigest {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE];
|
||||
for inner in inner.iter_mut() {
|
||||
let e = source.read_u64()?;
|
||||
if e >= Felt::MODULUS {
|
||||
return Err(DeserializationError::InvalidValue(String::from(
|
||||
"Value not in the appropriate range",
|
||||
)));
|
||||
}
|
||||
*inner = Felt::new(e);
|
||||
}
|
||||
|
||||
Ok(Self(inner))
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Deserializable, Felt, RpxDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
|
||||
use crate::utils::{string::String, SliceReader};
|
||||
use rand_utils::rand_value;
|
||||
|
||||
#[test]
|
||||
fn digest_serialization() {
|
||||
let e1 = Felt::new(rand_value());
|
||||
let e2 = Felt::new(rand_value());
|
||||
let e3 = Felt::new(rand_value());
|
||||
let e4 = Felt::new(rand_value());
|
||||
|
||||
let d1 = RpxDigest([e1, e2, e3, e4]);
|
||||
|
||||
let mut bytes = vec![];
|
||||
d1.write_into(&mut bytes);
|
||||
assert_eq!(DIGEST_BYTES, bytes.len());
|
||||
|
||||
let mut reader = SliceReader::new(&bytes);
|
||||
let d2 = RpxDigest::read_from(&mut reader).unwrap();
|
||||
|
||||
assert_eq!(d1, d2);
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn digest_encoding() {
|
||||
let digest = RpxDigest([
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
let string: String = digest.into();
|
||||
let round_trip: RpxDigest = string.try_into().expect("decoding failed");
|
||||
|
||||
assert_eq!(digest, round_trip);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversions() {
|
||||
let digest = RpxDigest([
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = digest.into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = (&digest).into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpxDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpxDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
}
|
||||
}
|
||||
366
src/hash/rescue/rpx/mod.rs
Normal file
366
src/hash/rescue/rpx/mod.rs
Normal file
@@ -0,0 +1,366 @@
|
||||
use super::{
|
||||
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
|
||||
apply_mds, apply_sbox, CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher,
|
||||
StarkField, ARK1, ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE,
|
||||
DIGEST_SIZE, INPUT1_RANGE, INPUT2_RANGE, MDS, NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH,
|
||||
STATE_WIDTH, ZERO,
|
||||
};
|
||||
use core::{convert::TryInto, ops::Range};
|
||||
|
||||
mod digest;
|
||||
pub use digest::RpxDigest;
|
||||
|
||||
pub type CubicExtElement = CubeExtension<Felt>;
|
||||
|
||||
// HASHER IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
/// Implementation of the Rescue Prime eXtension hash function with 256-bit output.
|
||||
///
|
||||
/// The hash function is based on the XHash12 construction in [specifications](https://eprint.iacr.org/2023/1045)
|
||||
///
|
||||
/// The parameters used to instantiate the function are:
|
||||
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
|
||||
/// * State width: 12 field elements.
|
||||
/// * Capacity size: 4 field elements.
|
||||
/// * S-Box degree: 7.
|
||||
/// * Rounds: There are 3 different types of rounds:
|
||||
/// - (FB): `apply_mds` → `add_constants` → `apply_sbox` → `apply_mds` → `add_constants` → `apply_inv_sbox`.
|
||||
/// - (E): `add_constants` → `ext_sbox` (which is raising to power 7 in the degree 3 extension field).
|
||||
/// - (M): `apply_mds` → `add_constants`.
|
||||
/// * Permutation: (FB) (E) (FB) (E) (FB) (E) (M).
|
||||
///
|
||||
/// The above parameters target 128-bit security level. The digest consists of four field elements
|
||||
/// and it can be serialized into 32 bytes (256 bits).
|
||||
///
|
||||
/// ## Hash output consistency
|
||||
/// Functions [hash_elements()](Rpx256::hash_elements), [merge()](Rpx256::merge), and
|
||||
/// [merge_with_int()](Rpx256::merge_with_int) are internally consistent. That is, computing
|
||||
/// a hash for the same set of elements using these functions will always produce the same
|
||||
/// result. For example, merging two digests using [merge()](Rpx256::merge) will produce the
|
||||
/// same result as hashing 8 elements which make up these digests using
|
||||
/// [hash_elements()](Rpx256::hash_elements) function.
|
||||
///
|
||||
/// However, [hash()](Rpx256::hash) function is not consistent with functions mentioned above.
|
||||
/// For example, if we take two field elements, serialize them to bytes and hash them using
|
||||
/// [hash()](Rpx256::hash), the result will differ from the result obtained by hashing these
|
||||
/// elements directly using [hash_elements()](Rpx256::hash_elements) function. The reason for
|
||||
/// this difference is that [hash()](Rpx256::hash) function needs to be able to handle
|
||||
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
|
||||
/// deserialization procedure used by this function is different from the procedure used to
|
||||
/// deserialize valid field elements.
|
||||
///
|
||||
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
||||
/// to deserialize them into field elements and then hash them using
|
||||
/// [hash_elements()](Rpx256::hash_elements) function rather then hashing the serialized bytes
|
||||
/// using [hash()](Rpx256::hash) function.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Rpx256();
|
||||
|
||||
impl Hasher for Rpx256 {
|
||||
/// Rpx256 collision resistance is the same as the security level, that is 128-bits.
|
||||
///
|
||||
/// #### Collision resistance
|
||||
///
|
||||
/// However, our setup of the capacity registers might drop it to 126.
|
||||
///
|
||||
/// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69)
|
||||
const COLLISION_RESISTANCE: u32 = 128;
|
||||
|
||||
type Digest = RpxDigest;
|
||||
|
||||
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||
// initialize the state with zeroes
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
// set the capacity (first element) to a flag on whether or not the input length is evenly
|
||||
// divided by the rate. this will prevent collisions between padded and non-padded inputs,
|
||||
// and will rule out the need to perform an extra permutation in case of evenly divided
|
||||
// inputs.
|
||||
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0;
|
||||
if !is_rate_multiple {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// initialize a buffer to receive the little-endian elements.
|
||||
let mut buf = [0_u8; 8];
|
||||
|
||||
// iterate the chunks of bytes, creating a field element from each chunk and copying it
|
||||
// into the state.
|
||||
//
|
||||
// every time the rate range is filled, a permutation is performed. if the final value of
|
||||
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
|
||||
// additional permutation must be performed.
|
||||
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
|
||||
// the last element of the iteration may or may not be a full chunk. if it's not, then
|
||||
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`.
|
||||
// this will avoid collisions.
|
||||
if chunk.len() == BINARY_CHUNK_SIZE {
|
||||
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||
} else {
|
||||
buf.fill(0);
|
||||
buf[..chunk.len()].copy_from_slice(chunk);
|
||||
buf[chunk.len()] = 1;
|
||||
}
|
||||
|
||||
// set the current rate element to the input. since we take at most 7 bytes, we are
|
||||
// guaranteed that the inputs data will fit into a single field element.
|
||||
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf));
|
||||
|
||||
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
||||
// counter to the beginning of the range.
|
||||
if i == RATE_WIDTH - 1 {
|
||||
Self::apply_permutation(&mut state);
|
||||
0
|
||||
} else {
|
||||
i + 1
|
||||
}
|
||||
});
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation. we
|
||||
// don't need to apply any extra padding because the first capacity element contains a
|
||||
// flag indicating whether the input is evenly divisible by the rate.
|
||||
if i != 0 {
|
||||
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the rate as hash result.
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = Self::Digest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// apply the RPX permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
// initialize the state as follows:
|
||||
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element
|
||||
// and set the sixth rate element to 1.
|
||||
// - if the value doesn't fit into a single field element, split it into two field
|
||||
// elements, copy them into rate elements 5 and 6, and set the seventh rate element
|
||||
// to 1.
|
||||
// - set the first capacity element to 1
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||
if value < Felt::MODULUS {
|
||||
state[INPUT2_RANGE.start + 1] = ONE;
|
||||
} else {
|
||||
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
||||
state[INPUT2_RANGE.start + 2] = ONE;
|
||||
}
|
||||
|
||||
// common padding for both cases
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
|
||||
// apply the RPX permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementHasher for Rpx256 {
|
||||
type BaseField = Felt;
|
||||
|
||||
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
|
||||
// convert the elements into a list of base field elements
|
||||
let elements = E::slice_as_base_elements(elements);
|
||||
|
||||
// initialize state to all zeros, except for the first element of the capacity part, which
|
||||
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
if elements.len() % RATE_WIDTH != 0 {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// absorb elements into the state one by one until the rate portion of the state is filled
|
||||
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
||||
// elements have been absorbed
|
||||
let mut i = 0;
|
||||
for &element in elements.iter() {
|
||||
state[RATE_RANGE.start + i] = element;
|
||||
i += 1;
|
||||
if i % RATE_WIDTH == 0 {
|
||||
Self::apply_permutation(&mut state);
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation after
|
||||
// padding by appending a 1 followed by as many 0 as necessary to make the input length a
|
||||
// multiple of the RATE_WIDTH.
|
||||
if i > 0 {
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
i += 1;
|
||||
while i != RATE_WIDTH {
|
||||
state[RATE_RANGE.start + i] = ZERO;
|
||||
i += 1;
|
||||
}
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the state as hash result
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
// HASH FUNCTION IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Rpx256 {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
pub const STATE_WIDTH: usize = STATE_WIDTH;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
|
||||
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
|
||||
|
||||
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
|
||||
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
|
||||
|
||||
/// MDS matrix used for computing the linear layer in the (FB) and (E) rounds.
|
||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
|
||||
|
||||
/// Round constants added to the hasher state in the first half of the round.
|
||||
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
|
||||
|
||||
/// Round constants added to the hasher state in the second half of the round.
|
||||
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
|
||||
|
||||
// TRAIT PASS-THROUGH FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of the provided sequence of bytes.
|
||||
#[inline(always)]
|
||||
pub fn hash(bytes: &[u8]) -> RpxDigest {
|
||||
<Self as Hasher>::hash(bytes)
|
||||
}
|
||||
|
||||
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||
/// Merkle trees and verification of Merkle paths.
|
||||
#[inline(always)]
|
||||
pub fn merge(values: &[RpxDigest; 2]) -> RpxDigest {
|
||||
<Self as Hasher>::merge(values)
|
||||
}
|
||||
|
||||
/// Returns a hash of the provided field elements.
|
||||
#[inline(always)]
|
||||
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpxDigest {
|
||||
<Self as ElementHasher>::hash_elements(elements)
|
||||
}
|
||||
|
||||
// DOMAIN IDENTIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of two digests and a domain identifier.
|
||||
pub fn merge_in_domain(values: &[RpxDigest; 2], domain: Felt) -> RpxDigest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = RpxDigest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// set the second capacity element to the domain value. The first capacity element is used
|
||||
// for padding purposes.
|
||||
state[CAPACITY_RANGE.start + 1] = domain;
|
||||
|
||||
// apply the RPX permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
// RPX PERMUTATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Applies RPX permutation to the provided state.
|
||||
#[inline(always)]
|
||||
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
|
||||
Self::apply_fb_round(state, 0);
|
||||
Self::apply_ext_round(state, 1);
|
||||
Self::apply_fb_round(state, 2);
|
||||
Self::apply_ext_round(state, 3);
|
||||
Self::apply_fb_round(state, 4);
|
||||
Self::apply_ext_round(state, 5);
|
||||
Self::apply_final_round(state, 6);
|
||||
}
|
||||
|
||||
// RPX PERMUTATION ROUND FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// (FB) round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_fb_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_sbox(state, &ARK1[round]) {
|
||||
add_constants(state, &ARK1[round]);
|
||||
apply_sbox(state);
|
||||
}
|
||||
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
|
||||
add_constants(state, &ARK2[round]);
|
||||
apply_inv_sbox(state);
|
||||
}
|
||||
}
|
||||
|
||||
/// (E) round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_ext_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
// add constants
|
||||
add_constants(state, &ARK1[round]);
|
||||
|
||||
// decompose the state into 4 elements in the cubic extension field and apply the power 7
|
||||
// map to each of the elements
|
||||
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = *state;
|
||||
let ext0 = Self::exp7(CubicExtElement::new(s0, s1, s2));
|
||||
let ext1 = Self::exp7(CubicExtElement::new(s3, s4, s5));
|
||||
let ext2 = Self::exp7(CubicExtElement::new(s6, s7, s8));
|
||||
let ext3 = Self::exp7(CubicExtElement::new(s9, s10, s11));
|
||||
|
||||
// decompose the state back into 12 base field elements
|
||||
let arr_ext = [ext0, ext1, ext2, ext3];
|
||||
*state = CubicExtElement::slice_as_base_elements(&arr_ext)
|
||||
.try_into()
|
||||
.expect("shouldn't fail");
|
||||
}
|
||||
|
||||
/// (M) round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_final_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
apply_mds(state);
|
||||
add_constants(state, &ARK1[round]);
|
||||
}
|
||||
|
||||
/// Computes an exponentiation to the power 7 in cubic extension field
|
||||
#[inline(always)]
|
||||
pub fn exp7(x: CubeExtension<Felt>) -> CubeExtension<Felt> {
|
||||
let x2 = x.square();
|
||||
let x4 = x2.square();
|
||||
|
||||
let x3 = x2 * x;
|
||||
x3 * x4
|
||||
}
|
||||
}
|
||||
9
src/hash/rescue/tests.rs
Normal file
9
src/hash/rescue/tests.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
use super::{Felt, FieldElement, ALPHA, INV_ALPHA};
|
||||
use rand_utils::rand_value;
|
||||
|
||||
#[test]
|
||||
fn test_alphas() {
|
||||
let e: Felt = Felt::new(rand_value());
|
||||
let e_exp = e.exp(ALPHA);
|
||||
assert_eq!(e, e_exp.exp(INV_ALPHA));
|
||||
}
|
||||
@@ -1,905 +0,0 @@
|
||||
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO};
|
||||
use core::{convert::TryInto, ops::Range};
|
||||
|
||||
mod digest;
|
||||
pub use digest::RpoDigest;
|
||||
|
||||
mod mds_freq;
|
||||
use mds_freq::mds_multiply_freq;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
#[link(name = "rpo_sve", kind = "static")]
|
||||
extern "C" {
|
||||
fn add_constants_and_apply_sbox(
|
||||
state: *mut std::ffi::c_ulong,
|
||||
constants: *const std::ffi::c_ulong,
|
||||
) -> bool;
|
||||
fn add_constants_and_apply_inv_sbox(
|
||||
state: *mut std::ffi::c_ulong,
|
||||
constants: *const std::ffi::c_ulong,
|
||||
) -> bool;
|
||||
}
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
const STATE_WIDTH: usize = 12;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11.
|
||||
const RATE_RANGE: Range<usize> = 4..12;
|
||||
const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
|
||||
|
||||
const INPUT1_RANGE: Range<usize> = 4..8;
|
||||
const INPUT2_RANGE: Range<usize> = 8..12;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
const CAPACITY_RANGE: Range<usize> = 0..4;
|
||||
|
||||
/// The output of the hash function is a digest which consists of 4 field elements or 32 bytes.
|
||||
///
|
||||
/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the
|
||||
/// rate portion).
|
||||
const DIGEST_RANGE: Range<usize> = 4..8;
|
||||
const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start;
|
||||
|
||||
/// The number of rounds is set to 7 to target 128-bit security level
|
||||
const NUM_ROUNDS: usize = 7;
|
||||
|
||||
/// The number of byte chunks defining a field element when hashing a sequence of bytes
|
||||
const BINARY_CHUNK_SIZE: usize = 7;
|
||||
|
||||
/// S-Box and Inverse S-Box powers;
|
||||
///
|
||||
/// The constants are defined for tests only because the exponentiations in the code are unrolled
|
||||
/// for efficiency reasons.
|
||||
#[cfg(test)]
|
||||
const ALPHA: u64 = 7;
|
||||
#[cfg(test)]
|
||||
const INV_ALPHA: u64 = 10540996611094048183;
|
||||
|
||||
// HASHER IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
/// Implementation of the Rescue Prime Optimized hash function with 256-bit output.
|
||||
///
|
||||
/// The hash function is implemented according to the Rescue Prime Optimized
|
||||
/// [specifications](https://eprint.iacr.org/2022/1577)
|
||||
///
|
||||
/// The parameters used to instantiate the function are:
|
||||
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
|
||||
/// * State width: 12 field elements.
|
||||
/// * Capacity size: 4 field elements.
|
||||
/// * Number of founds: 7.
|
||||
/// * S-Box degree: 7.
|
||||
///
|
||||
/// The above parameters target 128-bit security level. The digest consists of four field elements
|
||||
/// and it can be serialized into 32 bytes (256 bits).
|
||||
///
|
||||
/// ## Hash output consistency
|
||||
/// Functions [hash_elements()](Rpo256::hash_elements), [merge()](Rpo256::merge), and
|
||||
/// [merge_with_int()](Rpo256::merge_with_int) are internally consistent. That is, computing
|
||||
/// a hash for the same set of elements using these functions will always produce the same
|
||||
/// result. For example, merging two digests using [merge()](Rpo256::merge) will produce the
|
||||
/// same result as hashing 8 elements which make up these digests using
|
||||
/// [hash_elements()](Rpo256::hash_elements) function.
|
||||
///
|
||||
/// However, [hash()](Rpo256::hash) function is not consistent with functions mentioned above.
|
||||
/// For example, if we take two field elements, serialize them to bytes and hash them using
|
||||
/// [hash()](Rpo256::hash), the result will differ from the result obtained by hashing these
|
||||
/// elements directly using [hash_elements()](Rpo256::hash_elements) function. The reason for
|
||||
/// this difference is that [hash()](Rpo256::hash) function needs to be able to handle
|
||||
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
|
||||
/// deserialization procedure used by this function is different from the procedure used to
|
||||
/// deserialize valid field elements.
|
||||
///
|
||||
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
||||
/// to deserialize them into field elements and then hash them using
|
||||
/// [hash_elements()](Rpo256::hash_elements) function rather then hashing the serialized bytes
|
||||
/// using [hash()](Rpo256::hash) function.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Rpo256();
|
||||
|
||||
impl Hasher for Rpo256 {
|
||||
/// Rpo256 collision resistance is the same as the security level, that is 128-bits.
|
||||
///
|
||||
/// #### Collision resistance
|
||||
///
|
||||
/// However, our setup of the capacity registers might drop it to 126.
|
||||
///
|
||||
/// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69)
|
||||
const COLLISION_RESISTANCE: u32 = 128;
|
||||
|
||||
type Digest = RpoDigest;
|
||||
|
||||
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||
// initialize the state with zeroes
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
// set the capacity (first element) to a flag on whether or not the input length is evenly
|
||||
// divided by the rate. this will prevent collisions between padded and non-padded inputs,
|
||||
// and will rule out the need to perform an extra permutation in case of evenly divided
|
||||
// inputs.
|
||||
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0;
|
||||
if !is_rate_multiple {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// initialize a buffer to receive the little-endian elements.
|
||||
let mut buf = [0_u8; 8];
|
||||
|
||||
// iterate the chunks of bytes, creating a field element from each chunk and copying it
|
||||
// into the state.
|
||||
//
|
||||
// every time the rate range is filled, a permutation is performed. if the final value of
|
||||
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
|
||||
// additional permutation must be performed.
|
||||
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
|
||||
// the last element of the iteration may or may not be a full chunk. if it's not, then
|
||||
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`.
|
||||
// this will avoid collisions.
|
||||
if chunk.len() == BINARY_CHUNK_SIZE {
|
||||
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||
} else {
|
||||
buf.fill(0);
|
||||
buf[..chunk.len()].copy_from_slice(chunk);
|
||||
buf[chunk.len()] = 1;
|
||||
}
|
||||
|
||||
// set the current rate element to the input. since we take at most 7 bytes, we are
|
||||
// guaranteed that the inputs data will fit into a single field element.
|
||||
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf));
|
||||
|
||||
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
||||
// counter to the beginning of the range.
|
||||
if i == RATE_WIDTH - 1 {
|
||||
Self::apply_permutation(&mut state);
|
||||
0
|
||||
} else {
|
||||
i + 1
|
||||
}
|
||||
});
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we
|
||||
// don't need to apply any extra padding because the first capacity element contains a
|
||||
// flag indicating whether the input is evenly divisible by the rate.
|
||||
if i != 0 {
|
||||
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the rate as hash result.
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = Self::Digest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
// initialize the state as follows:
|
||||
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element
|
||||
// and set the sixth rate element to 1.
|
||||
// - if the value doesn't fit into a single field element, split it into two field
|
||||
// elements, copy them into rate elements 5 and 6, and set the seventh rate element
|
||||
// to 1.
|
||||
// - set the first capacity element to 1
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||
if value < Felt::MODULUS {
|
||||
state[INPUT2_RANGE.start + 1] = ONE;
|
||||
} else {
|
||||
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
||||
state[INPUT2_RANGE.start + 2] = ONE;
|
||||
}
|
||||
|
||||
// common padding for both cases
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementHasher for Rpo256 {
|
||||
type BaseField = Felt;
|
||||
|
||||
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
|
||||
// convert the elements into a list of base field elements
|
||||
let elements = E::slice_as_base_elements(elements);
|
||||
|
||||
// initialize state to all zeros, except for the first element of the capacity part, which
|
||||
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
if elements.len() % RATE_WIDTH != 0 {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// absorb elements into the state one by one until the rate portion of the state is filled
|
||||
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
||||
// elements have been absorbed
|
||||
let mut i = 0;
|
||||
for &element in elements.iter() {
|
||||
state[RATE_RANGE.start + i] = element;
|
||||
i += 1;
|
||||
if i % RATE_WIDTH == 0 {
|
||||
Self::apply_permutation(&mut state);
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after
|
||||
// padding by appending a 1 followed by as many 0 as necessary to make the input length a
|
||||
// multiple of the RATE_WIDTH.
|
||||
if i > 0 {
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
i += 1;
|
||||
while i != RATE_WIDTH {
|
||||
state[RATE_RANGE.start + i] = ZERO;
|
||||
i += 1;
|
||||
}
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the state as hash result
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
// HASH FUNCTION IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Rpo256 {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// The number of rounds is set to 7 to target 128-bit security level.
|
||||
pub const NUM_ROUNDS: usize = NUM_ROUNDS;
|
||||
|
||||
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
pub const STATE_WIDTH: usize = STATE_WIDTH;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
|
||||
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
|
||||
|
||||
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
|
||||
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
|
||||
|
||||
/// MDS matrix used for computing the linear layer in a RPO round.
|
||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
|
||||
|
||||
/// Round constants added to the hasher state in the first half of the RPO round.
|
||||
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
|
||||
|
||||
/// Round constants added to the hasher state in the second half of the RPO round.
|
||||
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
|
||||
|
||||
// TRAIT PASS-THROUGH FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of the provided sequence of bytes.
|
||||
#[inline(always)]
|
||||
pub fn hash(bytes: &[u8]) -> RpoDigest {
|
||||
<Self as Hasher>::hash(bytes)
|
||||
}
|
||||
|
||||
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||
/// Merkle trees and verification of Merkle paths.
|
||||
#[inline(always)]
|
||||
pub fn merge(values: &[RpoDigest; 2]) -> RpoDigest {
|
||||
<Self as Hasher>::merge(values)
|
||||
}
|
||||
|
||||
/// Returns a hash of the provided field elements.
|
||||
#[inline(always)]
|
||||
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpoDigest {
|
||||
<Self as ElementHasher>::hash_elements(elements)
|
||||
}
|
||||
|
||||
// DOMAIN IDENTIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of two digests and a domain identifier.
|
||||
pub fn merge_in_domain(values: &[RpoDigest; 2], domain: Felt) -> RpoDigest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = RpoDigest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// set the second capacity element to the domain value. The first capacity element is used
|
||||
// for padding purposes.
|
||||
state[CAPACITY_RANGE.start + 1] = domain;
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
// RESCUE PERMUTATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Applies RPO permutation to the provided state.
|
||||
#[inline(always)]
|
||||
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
|
||||
for i in 0..NUM_ROUNDS {
|
||||
Self::apply_round(state, i);
|
||||
}
|
||||
}
|
||||
|
||||
/// RPO round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
// apply first half of RPO round
|
||||
Self::apply_mds(state);
|
||||
if !Self::optimized_add_constants_and_apply_sbox(state, &ARK1[round]) {
|
||||
Self::add_constants(state, &ARK1[round]);
|
||||
Self::apply_sbox(state);
|
||||
}
|
||||
|
||||
// apply second half of RPO round
|
||||
Self::apply_mds(state);
|
||||
if !Self::optimized_add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
|
||||
Self::add_constants(state, &ARK2[round]);
|
||||
Self::apply_inv_sbox(state);
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
#[inline(always)]
|
||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
fn optimized_add_constants_and_apply_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
unsafe {
|
||||
add_constants_and_apply_sbox(state.as_mut_ptr() as *mut u64, ark.as_ptr() as *const u64)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[cfg(not(all(target_feature = "sve", feature = "sve")))]
|
||||
fn optimized_add_constants_and_apply_sbox(
|
||||
_state: &mut [Felt; STATE_WIDTH],
|
||||
_ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
fn optimized_add_constants_and_apply_inv_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
unsafe {
|
||||
add_constants_and_apply_inv_sbox(
|
||||
state.as_mut_ptr() as *mut u64,
|
||||
ark.as_ptr() as *const u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[cfg(not(all(target_feature = "sve", feature = "sve")))]
|
||||
fn optimized_add_constants_and_apply_inv_sbox(
|
||||
_state: &mut [Felt; STATE_WIDTH],
|
||||
_ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_mds(state: &mut [Felt; STATE_WIDTH]) {
|
||||
let mut result = [ZERO; STATE_WIDTH];
|
||||
|
||||
// Using the linearity of the operations we can split the state into a low||high decomposition
|
||||
// and operate on each with no overflow and then combine/reduce the result to a field element.
|
||||
// The no overflow is guaranteed by the fact that the MDS matrix is a small powers of two in
|
||||
// frequency domain.
|
||||
let mut state_l = [0u64; STATE_WIDTH];
|
||||
let mut state_h = [0u64; STATE_WIDTH];
|
||||
|
||||
for r in 0..STATE_WIDTH {
|
||||
let s = state[r].inner();
|
||||
state_h[r] = s >> 32;
|
||||
state_l[r] = (s as u32) as u64;
|
||||
}
|
||||
|
||||
let state_h = mds_multiply_freq(state_h);
|
||||
let state_l = mds_multiply_freq(state_l);
|
||||
|
||||
for r in 0..STATE_WIDTH {
|
||||
let s = state_l[r] as u128 + ((state_h[r] as u128) << 32);
|
||||
let s_hi = (s >> 64) as u64;
|
||||
let s_lo = s as u64;
|
||||
let z = (s_hi << 32) - s_hi;
|
||||
let (res, over) = s_lo.overflowing_add(z);
|
||||
|
||||
result[r] = Felt::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64));
|
||||
}
|
||||
*state = result;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) {
|
||||
state.iter_mut().zip(ark).for_each(|(s, &k)| *s += k);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||
state[0] = state[0].exp7();
|
||||
state[1] = state[1].exp7();
|
||||
state[2] = state[2].exp7();
|
||||
state[3] = state[3].exp7();
|
||||
state[4] = state[4].exp7();
|
||||
state[5] = state[5].exp7();
|
||||
state[6] = state[6].exp7();
|
||||
state[7] = state[7].exp7();
|
||||
state[8] = state[8].exp7();
|
||||
state[9] = state[9].exp7();
|
||||
state[10] = state[10].exp7();
|
||||
state[11] = state[11].exp7();
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||
// compute base^10540996611094048183 using 72 multiplications per array element
|
||||
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
|
||||
|
||||
// compute base^10
|
||||
let mut t1 = *state;
|
||||
t1.iter_mut().for_each(|t| *t = t.square());
|
||||
|
||||
// compute base^100
|
||||
let mut t2 = t1;
|
||||
t2.iter_mut().for_each(|t| *t = t.square());
|
||||
|
||||
// compute base^100100
|
||||
let t3 = Self::exp_acc::<Felt, STATE_WIDTH, 3>(t2, t2);
|
||||
|
||||
// compute base^100100100100
|
||||
let t4 = Self::exp_acc::<Felt, STATE_WIDTH, 6>(t3, t3);
|
||||
|
||||
// compute base^100100100100100100100100
|
||||
let t5 = Self::exp_acc::<Felt, STATE_WIDTH, 12>(t4, t4);
|
||||
|
||||
// compute base^100100100100100100100100100100
|
||||
let t6 = Self::exp_acc::<Felt, STATE_WIDTH, 6>(t5, t3);
|
||||
|
||||
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||
let t7 = Self::exp_acc::<Felt, STATE_WIDTH, 31>(t6, t6);
|
||||
|
||||
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||
for (i, s) in state.iter_mut().enumerate() {
|
||||
let a = (t7[i].square() * t6[i]).square().square();
|
||||
let b = t1[i] * t2[i] * *s;
|
||||
*s = a * b;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn exp_acc<B: StarkField, const N: usize, const M: usize>(
|
||||
base: [B; N],
|
||||
tail: [B; N],
|
||||
) -> [B; N] {
|
||||
let mut result = base;
|
||||
for _ in 0..M {
|
||||
result.iter_mut().for_each(|r| *r = r.square());
|
||||
}
|
||||
result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t);
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
// MDS
|
||||
// ================================================================================================
|
||||
/// RPO MDS matrix
|
||||
const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = [
|
||||
[
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
],
|
||||
[
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
],
|
||||
[
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
],
|
||||
[
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
],
|
||||
[
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
],
|
||||
[
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
],
|
||||
[
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
],
|
||||
[
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
],
|
||||
[
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
],
|
||||
[
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
],
|
||||
[
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
],
|
||||
[
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
],
|
||||
];
|
||||
|
||||
// ROUND CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Rescue round constants;
|
||||
/// computed as in [specifications](https://github.com/ASDiscreteMathematics/rpo)
|
||||
///
|
||||
/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the
|
||||
/// first half of RPO round, and ARK2 contains constants for the second half of RPO round.
|
||||
const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(5789762306288267392),
|
||||
Felt::new(6522564764413701783),
|
||||
Felt::new(17809893479458208203),
|
||||
Felt::new(107145243989736508),
|
||||
Felt::new(6388978042437517382),
|
||||
Felt::new(15844067734406016715),
|
||||
Felt::new(9975000513555218239),
|
||||
Felt::new(3344984123768313364),
|
||||
Felt::new(9959189626657347191),
|
||||
Felt::new(12960773468763563665),
|
||||
Felt::new(9602914297752488475),
|
||||
Felt::new(16657542370200465908),
|
||||
],
|
||||
[
|
||||
Felt::new(12987190162843096997),
|
||||
Felt::new(653957632802705281),
|
||||
Felt::new(4441654670647621225),
|
||||
Felt::new(4038207883745915761),
|
||||
Felt::new(5613464648874830118),
|
||||
Felt::new(13222989726778338773),
|
||||
Felt::new(3037761201230264149),
|
||||
Felt::new(16683759727265180203),
|
||||
Felt::new(8337364536491240715),
|
||||
Felt::new(3227397518293416448),
|
||||
Felt::new(8110510111539674682),
|
||||
Felt::new(2872078294163232137),
|
||||
],
|
||||
[
|
||||
Felt::new(18072785500942327487),
|
||||
Felt::new(6200974112677013481),
|
||||
Felt::new(17682092219085884187),
|
||||
Felt::new(10599526828986756440),
|
||||
Felt::new(975003873302957338),
|
||||
Felt::new(8264241093196931281),
|
||||
Felt::new(10065763900435475170),
|
||||
Felt::new(2181131744534710197),
|
||||
Felt::new(6317303992309418647),
|
||||
Felt::new(1401440938888741532),
|
||||
Felt::new(8884468225181997494),
|
||||
Felt::new(13066900325715521532),
|
||||
],
|
||||
[
|
||||
Felt::new(5674685213610121970),
|
||||
Felt::new(5759084860419474071),
|
||||
Felt::new(13943282657648897737),
|
||||
Felt::new(1352748651966375394),
|
||||
Felt::new(17110913224029905221),
|
||||
Felt::new(1003883795902368422),
|
||||
Felt::new(4141870621881018291),
|
||||
Felt::new(8121410972417424656),
|
||||
Felt::new(14300518605864919529),
|
||||
Felt::new(13712227150607670181),
|
||||
Felt::new(17021852944633065291),
|
||||
Felt::new(6252096473787587650),
|
||||
],
|
||||
[
|
||||
Felt::new(4887609836208846458),
|
||||
Felt::new(3027115137917284492),
|
||||
Felt::new(9595098600469470675),
|
||||
Felt::new(10528569829048484079),
|
||||
Felt::new(7864689113198939815),
|
||||
Felt::new(17533723827845969040),
|
||||
Felt::new(5781638039037710951),
|
||||
Felt::new(17024078752430719006),
|
||||
Felt::new(109659393484013511),
|
||||
Felt::new(7158933660534805869),
|
||||
Felt::new(2955076958026921730),
|
||||
Felt::new(7433723648458773977),
|
||||
],
|
||||
[
|
||||
Felt::new(16308865189192447297),
|
||||
Felt::new(11977192855656444890),
|
||||
Felt::new(12532242556065780287),
|
||||
Felt::new(14594890931430968898),
|
||||
Felt::new(7291784239689209784),
|
||||
Felt::new(5514718540551361949),
|
||||
Felt::new(10025733853830934803),
|
||||
Felt::new(7293794580341021693),
|
||||
Felt::new(6728552937464861756),
|
||||
Felt::new(6332385040983343262),
|
||||
Felt::new(13277683694236792804),
|
||||
Felt::new(2600778905124452676),
|
||||
],
|
||||
[
|
||||
Felt::new(7123075680859040534),
|
||||
Felt::new(1034205548717903090),
|
||||
Felt::new(7717824418247931797),
|
||||
Felt::new(3019070937878604058),
|
||||
Felt::new(11403792746066867460),
|
||||
Felt::new(10280580802233112374),
|
||||
Felt::new(337153209462421218),
|
||||
Felt::new(13333398568519923717),
|
||||
Felt::new(3596153696935337464),
|
||||
Felt::new(8104208463525993784),
|
||||
Felt::new(14345062289456085693),
|
||||
Felt::new(17036731477169661256),
|
||||
],
|
||||
];
|
||||
|
||||
const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(6077062762357204287),
|
||||
Felt::new(15277620170502011191),
|
||||
Felt::new(5358738125714196705),
|
||||
Felt::new(14233283787297595718),
|
||||
Felt::new(13792579614346651365),
|
||||
Felt::new(11614812331536767105),
|
||||
Felt::new(14871063686742261166),
|
||||
Felt::new(10148237148793043499),
|
||||
Felt::new(4457428952329675767),
|
||||
Felt::new(15590786458219172475),
|
||||
Felt::new(10063319113072092615),
|
||||
Felt::new(14200078843431360086),
|
||||
],
|
||||
[
|
||||
Felt::new(6202948458916099932),
|
||||
Felt::new(17690140365333231091),
|
||||
Felt::new(3595001575307484651),
|
||||
Felt::new(373995945117666487),
|
||||
Felt::new(1235734395091296013),
|
||||
Felt::new(14172757457833931602),
|
||||
Felt::new(707573103686350224),
|
||||
Felt::new(15453217512188187135),
|
||||
Felt::new(219777875004506018),
|
||||
Felt::new(17876696346199469008),
|
||||
Felt::new(17731621626449383378),
|
||||
Felt::new(2897136237748376248),
|
||||
],
|
||||
[
|
||||
Felt::new(8023374565629191455),
|
||||
Felt::new(15013690343205953430),
|
||||
Felt::new(4485500052507912973),
|
||||
Felt::new(12489737547229155153),
|
||||
Felt::new(9500452585969030576),
|
||||
Felt::new(2054001340201038870),
|
||||
Felt::new(12420704059284934186),
|
||||
Felt::new(355990932618543755),
|
||||
Felt::new(9071225051243523860),
|
||||
Felt::new(12766199826003448536),
|
||||
Felt::new(9045979173463556963),
|
||||
Felt::new(12934431667190679898),
|
||||
],
|
||||
[
|
||||
Felt::new(18389244934624494276),
|
||||
Felt::new(16731736864863925227),
|
||||
Felt::new(4440209734760478192),
|
||||
Felt::new(17208448209698888938),
|
||||
Felt::new(8739495587021565984),
|
||||
Felt::new(17000774922218161967),
|
||||
Felt::new(13533282547195532087),
|
||||
Felt::new(525402848358706231),
|
||||
Felt::new(16987541523062161972),
|
||||
Felt::new(5466806524462797102),
|
||||
Felt::new(14512769585918244983),
|
||||
Felt::new(10973956031244051118),
|
||||
],
|
||||
[
|
||||
Felt::new(6982293561042362913),
|
||||
Felt::new(14065426295947720331),
|
||||
Felt::new(16451845770444974180),
|
||||
Felt::new(7139138592091306727),
|
||||
Felt::new(9012006439959783127),
|
||||
Felt::new(14619614108529063361),
|
||||
Felt::new(1394813199588124371),
|
||||
Felt::new(4635111139507788575),
|
||||
Felt::new(16217473952264203365),
|
||||
Felt::new(10782018226466330683),
|
||||
Felt::new(6844229992533662050),
|
||||
Felt::new(7446486531695178711),
|
||||
],
|
||||
[
|
||||
Felt::new(3736792340494631448),
|
||||
Felt::new(577852220195055341),
|
||||
Felt::new(6689998335515779805),
|
||||
Felt::new(13886063479078013492),
|
||||
Felt::new(14358505101923202168),
|
||||
Felt::new(7744142531772274164),
|
||||
Felt::new(16135070735728404443),
|
||||
Felt::new(12290902521256031137),
|
||||
Felt::new(12059913662657709804),
|
||||
Felt::new(16456018495793751911),
|
||||
Felt::new(4571485474751953524),
|
||||
Felt::new(17200392109565783176),
|
||||
],
|
||||
[
|
||||
Felt::new(17130398059294018733),
|
||||
Felt::new(519782857322261988),
|
||||
Felt::new(9625384390925085478),
|
||||
Felt::new(1664893052631119222),
|
||||
Felt::new(7629576092524553570),
|
||||
Felt::new(3485239601103661425),
|
||||
Felt::new(9755891797164033838),
|
||||
Felt::new(15218148195153269027),
|
||||
Felt::new(16460604813734957368),
|
||||
Felt::new(9643968136937729763),
|
||||
Felt::new(3611348709641382851),
|
||||
Felt::new(18256379591337759196),
|
||||
],
|
||||
];
|
||||
10
src/lib.rs
10
src/lib.rs
@@ -1,7 +1,7 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[cfg_attr(test, macro_use)]
|
||||
//#[cfg(not(feature = "std"))]
|
||||
//#[cfg_attr(test, macro_use)]
|
||||
extern crate alloc;
|
||||
|
||||
pub mod dsa;
|
||||
@@ -9,11 +9,15 @@ pub mod hash;
|
||||
pub mod merkle;
|
||||
pub mod rand;
|
||||
pub mod utils;
|
||||
pub mod gkr;
|
||||
|
||||
// RE-EXPORTS
|
||||
// ================================================================================================
|
||||
|
||||
pub use winter_math::{fields::f64::BaseElement as Felt, FieldElement, StarkField};
|
||||
pub use winter_math::{
|
||||
fields::{f64::BaseElement as Felt, CubeExtension, QuadExtension},
|
||||
FieldElement, StarkField,
|
||||
};
|
||||
|
||||
// TYPE ALIASES
|
||||
// ================================================================================================
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
use clap::Parser;
|
||||
use miden_crypto::{
|
||||
hash::rpo::RpoDigest,
|
||||
merkle::MerkleError,
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
merkle::{MerkleError, TieredSmt},
|
||||
Felt, Word, ONE,
|
||||
{hash::rpo::Rpo256, merkle::TieredSmt},
|
||||
};
|
||||
use rand_utils::rand_value;
|
||||
use std::time::Instant;
|
||||
|
||||
@@ -10,12 +10,19 @@ pub struct EmptySubtreeRoots;
|
||||
impl EmptySubtreeRoots {
|
||||
/// Returns a static slice with roots of empty subtrees of a Merkle tree starting at the
|
||||
/// specified depth.
|
||||
pub const fn empty_hashes(depth: u8) -> &'static [RpoDigest] {
|
||||
let ptr = &EMPTY_SUBTREES[255 - depth as usize] as *const RpoDigest;
|
||||
pub const fn empty_hashes(tree_depth: u8) -> &'static [RpoDigest] {
|
||||
let ptr = &EMPTY_SUBTREES[255 - tree_depth as usize] as *const RpoDigest;
|
||||
// Safety: this is a static/constant array, so it will never be outlived. If we attempt to
|
||||
// use regular slices, this wouldn't be a `const` function, meaning we won't be able to use
|
||||
// the returned value for static/constant definitions.
|
||||
unsafe { slice::from_raw_parts(ptr, depth as usize + 1) }
|
||||
unsafe { slice::from_raw_parts(ptr, tree_depth as usize + 1) }
|
||||
}
|
||||
|
||||
/// Returns the node's digest for a sub-tree with all its leaves set to the empty word.
|
||||
pub const fn entry(tree_depth: u8, node_depth: u8) -> &'static RpoDigest {
|
||||
assert!(node_depth <= tree_depth);
|
||||
let pos = 255 - tree_depth + node_depth;
|
||||
&EMPTY_SUBTREES[pos as usize]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1583,3 +1590,16 @@ fn all_depths_opens_to_zero() {
|
||||
.for_each(|(x, computed)| assert_eq!(x, computed));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_entry() {
|
||||
// check the leaf is always the empty work
|
||||
for depth in 0..255 {
|
||||
assert_eq!(EmptySubtreeRoots::entry(depth, depth), &RpoDigest::new(EMPTY_WORD));
|
||||
}
|
||||
|
||||
// check the root matches the first element of empty_hashes
|
||||
for depth in 0..255 {
|
||||
assert_eq!(EmptySubtreeRoots::entry(depth, 0), &EmptySubtreeRoots::empty_hashes(depth)[0]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,8 +13,9 @@ pub enum MerkleError {
|
||||
DuplicateValuesForKey(RpoDigest),
|
||||
InvalidIndex { depth: u8, value: u64 },
|
||||
InvalidDepth { expected: u8, provided: u8 },
|
||||
InvalidSubtreeDepth { subtree_depth: u8, tree_depth: u8 },
|
||||
InvalidPath(MerklePath),
|
||||
InvalidNumEntries(usize, usize),
|
||||
InvalidNumEntries(usize),
|
||||
NodeNotInSet(NodeIndex),
|
||||
NodeNotInStore(RpoDigest, NodeIndex),
|
||||
NumLeavesNotPowerOfTwo(usize),
|
||||
@@ -30,18 +31,21 @@ impl fmt::Display for MerkleError {
|
||||
DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"),
|
||||
DuplicateValuesForIndex(key) => write!(f, "multiple values provided for key {key}"),
|
||||
DuplicateValuesForKey(key) => write!(f, "multiple values provided for key {key}"),
|
||||
InvalidIndex{ depth, value} => write!(
|
||||
f,
|
||||
"the index value {value} is not valid for the depth {depth}"
|
||||
),
|
||||
InvalidDepth { expected, provided } => write!(
|
||||
f,
|
||||
"the provided depth {provided} is not valid for {expected}"
|
||||
),
|
||||
InvalidIndex { depth, value } => {
|
||||
write!(f, "the index value {value} is not valid for the depth {depth}")
|
||||
}
|
||||
InvalidDepth { expected, provided } => {
|
||||
write!(f, "the provided depth {provided} is not valid for {expected}")
|
||||
}
|
||||
InvalidSubtreeDepth { subtree_depth, tree_depth } => {
|
||||
write!(f, "tried inserting a subtree of depth {subtree_depth} into a tree of depth {tree_depth}")
|
||||
}
|
||||
InvalidPath(_path) => write!(f, "the provided path is not valid"),
|
||||
InvalidNumEntries(max, provided) => write!(f, "the provided number of entries is {provided}, but the maximum for the given depth is {max}"),
|
||||
InvalidNumEntries(max) => write!(f, "number of entries exceeded the maximum: {max}"),
|
||||
NodeNotInSet(index) => write!(f, "the node with index ({index}) is not in the set"),
|
||||
NodeNotInStore(hash, index) => write!(f, "the node {hash:?} with index ({index}) is not in the store"),
|
||||
NodeNotInStore(hash, index) => {
|
||||
write!(f, "the node {hash:?} with index ({index}) is not in the store")
|
||||
}
|
||||
NumLeavesNotPowerOfTwo(leaves) => {
|
||||
write!(f, "the leaves count {leaves} is not a power of 2")
|
||||
}
|
||||
|
||||
@@ -187,13 +187,20 @@ mod tests {
|
||||
#[test]
|
||||
fn test_node_index_value_too_high() {
|
||||
assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
|
||||
match NodeIndex::new(0, 1) {
|
||||
Err(MerkleError::InvalidIndex { depth, value }) => {
|
||||
assert_eq!(depth, 0);
|
||||
assert_eq!(value, 1);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let err = NodeIndex::new(0, 1).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 0, value: 1 });
|
||||
|
||||
assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
|
||||
let err = NodeIndex::new(1, 2).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 1, value: 2 });
|
||||
|
||||
assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
|
||||
let err = NodeIndex::new(2, 4).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 2, value: 4 });
|
||||
|
||||
assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
|
||||
let err = NodeIndex::new(3, 8).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 3, value: 8 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -9,11 +9,12 @@
|
||||
//! least number of leaves. The structure preserves the invariant that each tree has different
|
||||
//! depths, i.e. as part of adding adding a new element to the forest the trees with same depth are
|
||||
//! merged, creating a new tree with depth d+1, this process is continued until the property is
|
||||
//! restabilished.
|
||||
//! reestablished.
|
||||
use super::{
|
||||
super::{InnerNodeInfo, MerklePath, RpoDigest, Vec},
|
||||
super::{InnerNodeInfo, MerklePath, Vec},
|
||||
bit::TrueBitPositionIterator,
|
||||
leaf_to_corresponding_tree, nodes_in_forest, MmrDelta, MmrError, MmrPeaks, MmrProof, Rpo256,
|
||||
RpoDigest,
|
||||
};
|
||||
|
||||
// MMR
|
||||
@@ -76,13 +77,13 @@ impl Mmr {
|
||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||
/// has position 0, the second position 1, and so on.
|
||||
pub fn open(&self, pos: usize) -> Result<MmrProof, MmrError> {
|
||||
pub fn open(&self, pos: usize, target_forest: usize) -> Result<MmrProof, MmrError> {
|
||||
// find the target tree responsible for the MMR position
|
||||
let tree_bit =
|
||||
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?;
|
||||
leaf_to_corresponding_tree(pos, target_forest).ok_or(MmrError::InvalidPosition(pos))?;
|
||||
|
||||
// isolate the trees before the target
|
||||
let forest_before = self.forest & high_bitmask(tree_bit + 1);
|
||||
let forest_before = target_forest & high_bitmask(tree_bit + 1);
|
||||
let index_offset = nodes_in_forest(forest_before);
|
||||
|
||||
// update the value position from global to the target tree
|
||||
@@ -92,7 +93,7 @@ impl Mmr {
|
||||
let (_, path) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);
|
||||
|
||||
Ok(MmrProof {
|
||||
forest: self.forest,
|
||||
forest: target_forest,
|
||||
position: pos,
|
||||
merkle_path: MerklePath::new(path),
|
||||
})
|
||||
@@ -143,9 +144,13 @@ impl Mmr {
|
||||
self.forest += 1;
|
||||
}
|
||||
|
||||
/// Returns an accumulator representing the current state of the MMR.
|
||||
pub fn accumulator(&self) -> MmrPeaks {
|
||||
let peaks: Vec<RpoDigest> = TrueBitPositionIterator::new(self.forest)
|
||||
/// Returns an peaks of the MMR for the version specified by `forest`.
|
||||
pub fn peaks(&self, forest: usize) -> Result<MmrPeaks, MmrError> {
|
||||
if forest > self.forest {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
}
|
||||
|
||||
let peaks: Vec<RpoDigest> = TrueBitPositionIterator::new(forest)
|
||||
.rev()
|
||||
.map(|bit| nodes_in_forest(1 << bit))
|
||||
.scan(0, |offset, el| {
|
||||
@@ -156,39 +161,41 @@ impl Mmr {
|
||||
.collect();
|
||||
|
||||
// Safety: the invariant is maintained by the [Mmr]
|
||||
MmrPeaks::new(self.forest, peaks).unwrap()
|
||||
let peaks = MmrPeaks::new(forest, peaks).unwrap();
|
||||
|
||||
Ok(peaks)
|
||||
}
|
||||
|
||||
/// Compute the required update to `original_forest`.
|
||||
///
|
||||
/// The result is a packed sequence of the authentication elements required to update the trees
|
||||
/// that have been merged together, followed by the new peaks of the [Mmr].
|
||||
pub fn get_delta(&self, original_forest: usize) -> Result<MmrDelta, MmrError> {
|
||||
if original_forest > self.forest {
|
||||
pub fn get_delta(&self, from_forest: usize, to_forest: usize) -> Result<MmrDelta, MmrError> {
|
||||
if to_forest > self.forest || from_forest > to_forest {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
}
|
||||
|
||||
if original_forest == self.forest {
|
||||
return Ok(MmrDelta { forest: self.forest, data: Vec::new() });
|
||||
if from_forest == to_forest {
|
||||
return Ok(MmrDelta { forest: to_forest, data: Vec::new() });
|
||||
}
|
||||
|
||||
let mut result = Vec::new();
|
||||
|
||||
// Find the largest tree in this [Mmr] which is new to `original_forest`.
|
||||
let candidate_trees = self.forest ^ original_forest;
|
||||
// Find the largest tree in this [Mmr] which is new to `from_forest`.
|
||||
let candidate_trees = to_forest ^ from_forest;
|
||||
let mut new_high = 1 << candidate_trees.ilog2();
|
||||
|
||||
// Collect authentication nodes used for tree merges
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// Find the trees from `original_forest` that have been merged into `new_high`.
|
||||
let mut merges = original_forest & (new_high - 1);
|
||||
// Find the trees from `from_forest` that have been merged into `new_high`.
|
||||
let mut merges = from_forest & (new_high - 1);
|
||||
|
||||
// Find the peaks that are common to `original_forest` and this [Mmr]
|
||||
let common_trees = original_forest ^ merges;
|
||||
// Find the peaks that are common to `from_forest` and this [Mmr]
|
||||
let common_trees = from_forest ^ merges;
|
||||
|
||||
if merges != 0 {
|
||||
// Skip the smallest trees unknown to `original_forest`.
|
||||
// Skip the smallest trees unknown to `from_forest`.
|
||||
let mut target = 1 << merges.trailing_zeros();
|
||||
|
||||
// Collect siblings required to computed the merged tree's peak
|
||||
@@ -213,15 +220,15 @@ impl Mmr {
|
||||
}
|
||||
} else {
|
||||
// The new high tree may not be the result of any merges, if it is smaller than all the
|
||||
// trees of `original_forest`.
|
||||
// trees of `from_forest`.
|
||||
new_high = 0;
|
||||
}
|
||||
|
||||
// Collect the new [Mmr] peaks
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
let mut new_peaks = self.forest ^ common_trees ^ new_high;
|
||||
let old_peaks = self.forest ^ new_peaks;
|
||||
let mut new_peaks = to_forest ^ common_trees ^ new_high;
|
||||
let old_peaks = to_forest ^ new_peaks;
|
||||
let mut offset = nodes_in_forest(old_peaks);
|
||||
while new_peaks != 0 {
|
||||
let target = 1 << new_peaks.ilog2();
|
||||
@@ -230,7 +237,7 @@ impl Mmr {
|
||||
new_peaks ^= target;
|
||||
}
|
||||
|
||||
Ok(MmrDelta { forest: self.forest, data: result })
|
||||
Ok(MmrDelta { forest: to_forest, data: result })
|
||||
}
|
||||
|
||||
/// An iterator over inner nodes in the MMR. The order of iteration is unspecified.
|
||||
|
||||
@@ -6,6 +6,9 @@
|
||||
//! leaves count.
|
||||
use core::num::NonZeroUsize;
|
||||
|
||||
// IN-ORDER INDEX
|
||||
// ================================================================================================
|
||||
|
||||
/// Index of nodes in a perfectly balanced binary tree based on an in-order tree walk.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct InOrderIndex {
|
||||
@@ -13,15 +16,17 @@ pub struct InOrderIndex {
|
||||
}
|
||||
|
||||
impl InOrderIndex {
|
||||
/// Constructor for a new [InOrderIndex].
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [InOrderIndex] instantiated from the provided value.
|
||||
pub fn new(idx: NonZeroUsize) -> InOrderIndex {
|
||||
InOrderIndex { idx: idx.get() }
|
||||
}
|
||||
|
||||
/// Constructs an index from a leaf position.
|
||||
///
|
||||
/// Panics:
|
||||
/// Return a new [InOrderIndex] instantiated from the specified leaf position.
|
||||
///
|
||||
/// # Panics:
|
||||
/// If `leaf` is higher than or equal to `usize::MAX / 2`.
|
||||
pub fn from_leaf_pos(leaf: usize) -> InOrderIndex {
|
||||
// Convert the position from 0-indexed to 1-indexed, since the bit manipulation in this
|
||||
@@ -30,6 +35,9 @@ impl InOrderIndex {
|
||||
InOrderIndex { idx: pos * 2 - 1 }
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// True if the index is pointing at a leaf.
|
||||
///
|
||||
/// Every odd number represents a leaf.
|
||||
@@ -37,6 +45,11 @@ impl InOrderIndex {
|
||||
self.idx & 1 == 1
|
||||
}
|
||||
|
||||
/// Returns true if this note is a left child of its parent.
|
||||
pub fn is_left_child(&self) -> bool {
|
||||
self.parent().left_child() == *self
|
||||
}
|
||||
|
||||
/// Returns the level of the index.
|
||||
///
|
||||
/// Starts at level zero for leaves and increases by one for each parent.
|
||||
@@ -46,8 +59,7 @@ impl InOrderIndex {
|
||||
|
||||
/// Returns the index of the left child.
|
||||
///
|
||||
/// Panics:
|
||||
///
|
||||
/// # Panics:
|
||||
/// If the index corresponds to a leaf.
|
||||
pub fn left_child(&self) -> InOrderIndex {
|
||||
// The left child is itself a parent, with an index that splits its left/right subtrees. To
|
||||
@@ -59,8 +71,7 @@ impl InOrderIndex {
|
||||
|
||||
/// Returns the index of the right child.
|
||||
///
|
||||
/// Panics:
|
||||
///
|
||||
/// # Panics:
|
||||
/// If the index corresponds to a leaf.
|
||||
pub fn right_child(&self) -> InOrderIndex {
|
||||
// To compute the index of the parent of the right subtree it is sufficient to add the size
|
||||
@@ -94,8 +105,25 @@ impl InOrderIndex {
|
||||
parent.right_child()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the inner value of this [InOrderIndex].
|
||||
pub fn inner(&self) -> u64 {
|
||||
self.idx as u64
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS FROM IN-ORDER INDEX
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl From<InOrderIndex> for u64 {
|
||||
fn from(index: InOrderIndex) -> Self {
|
||||
index.idx as u64
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::InOrderIndex;
|
||||
|
||||
@@ -10,7 +10,7 @@ mod proof;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use super::{Felt, Rpo256, Word};
|
||||
use super::{Felt, Rpo256, RpoDigest, Word};
|
||||
|
||||
// REEXPORTS
|
||||
// ================================================================================================
|
||||
@@ -40,10 +40,10 @@ const fn leaf_to_corresponding_tree(pos: usize, forest: usize) -> Option<u32> {
|
||||
// - each bit in the forest is a unique tree and the bit position its power-of-two size
|
||||
// - each tree owns a consecutive range of positions equal to its size from left-to-right
|
||||
// - this means the first tree owns from `0` up to the `2^k_0` first positions, where `k_0`
|
||||
// is the highest true bit position, the second tree from `2^k_0 + 1` up to `2^k_1` where
|
||||
// `k_1` is the second highest bit, so on.
|
||||
// is the highest true bit position, the second tree from `2^k_0 + 1` up to `2^k_1` where
|
||||
// `k_1` is the second highest bit, so on.
|
||||
// - this means the highest bits work as a category marker, and the position is owned by
|
||||
// the first tree which doesn't share a high bit with the position
|
||||
// the first tree which doesn't share a high bit with the position
|
||||
let before = forest & pos;
|
||||
let after = forest ^ before;
|
||||
let tree = after.ilog2();
|
||||
|
||||
@@ -1,57 +1,60 @@
|
||||
use super::{MmrDelta, MmrProof, Rpo256, RpoDigest};
|
||||
use crate::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
merkle::{
|
||||
mmr::{leaf_to_corresponding_tree, nodes_in_forest},
|
||||
InOrderIndex, MerklePath, MmrError, MmrPeaks,
|
||||
InOrderIndex, InnerNodeInfo, MerklePath, MmrError, MmrPeaks,
|
||||
},
|
||||
utils::{
|
||||
collections::{BTreeMap, BTreeSet, Vec},
|
||||
vec,
|
||||
},
|
||||
utils::collections::{BTreeMap, Vec},
|
||||
};
|
||||
|
||||
use super::{MmrDelta, MmrProof};
|
||||
|
||||
/// Partially materialized [Mmr], used to efficiently store and update the authentication paths for
|
||||
/// a subset of the elements in a full [Mmr].
|
||||
// PARTIAL MERKLE MOUNTAIN RANGE
|
||||
// ================================================================================================
|
||||
/// Partially materialized Merkle Mountain Range (MMR), used to efficiently store and update the
|
||||
/// authentication paths for a subset of the elements in a full MMR.
|
||||
///
|
||||
/// This structure store only the authentication path for a value, the value itself is stored
|
||||
/// separately.
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PartialMmr {
|
||||
/// The version of the [Mmr].
|
||||
/// The version of the MMR.
|
||||
///
|
||||
/// This value serves the following purposes:
|
||||
///
|
||||
/// - The forest is a counter for the total number of elements in the [Mmr].
|
||||
/// - Since the [Mmr] is an append-only structure, every change to it causes a change to the
|
||||
/// - The forest is a counter for the total number of elements in the MMR.
|
||||
/// - Since the MMR is an append-only structure, every change to it causes a change to the
|
||||
/// `forest`, so this value has a dual purpose as a version tag.
|
||||
/// - The bits in the forest also corresponds to the count and size of every perfect binary
|
||||
/// tree that composes the [Mmr] structure, which server to compute indexes and perform
|
||||
/// tree that composes the MMR structure, which server to compute indexes and perform
|
||||
/// validation.
|
||||
pub(crate) forest: usize,
|
||||
|
||||
/// The [Mmr] peaks.
|
||||
/// The MMR peaks.
|
||||
///
|
||||
/// The peaks are used for two reasons:
|
||||
///
|
||||
/// 1. It authenticates the addition of an element to the [PartialMmr], ensuring only valid
|
||||
/// elements are tracked.
|
||||
/// 2. During a [Mmr] update peaks can be merged by hashing the left and right hand sides. The
|
||||
/// 2. During a MMR update peaks can be merged by hashing the left and right hand sides. The
|
||||
/// peaks are used as the left hand.
|
||||
///
|
||||
/// All the peaks of every tree in the [Mmr] forest. The peaks are always ordered by number of
|
||||
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
|
||||
/// leaves, starting from the peak with most children, to the one with least.
|
||||
pub(crate) peaks: Vec<RpoDigest>,
|
||||
|
||||
/// Authentication nodes used to construct merkle paths for a subset of the [Mmr]'s leaves.
|
||||
/// Authentication nodes used to construct merkle paths for a subset of the MMR's leaves.
|
||||
///
|
||||
/// This does not include the [Mmr]'s peaks nor the tracked nodes, only the elements required
|
||||
/// This does not include the MMR's peaks nor the tracked nodes, only the elements required
|
||||
/// to construct their authentication paths. This property is used to detect when elements can
|
||||
/// be safely removed from, because they are no longer required to authenticate any element in
|
||||
/// the [PartialMmr].
|
||||
///
|
||||
/// The elements in the [Mmr] are referenced using a in-order tree index. This indexing scheme
|
||||
/// The elements in the MMR are referenced using a in-order tree index. This indexing scheme
|
||||
/// permits for easy computation of the relative nodes (left/right children, sibling, parent),
|
||||
/// which is useful for traversal. The indexing is also stable, meaning that merges to the
|
||||
/// trees in the [Mmr] can be represented without rewrites of the indexes.
|
||||
/// trees in the MMR can be represented without rewrites of the indexes.
|
||||
pub(crate) nodes: BTreeMap<InOrderIndex, RpoDigest>,
|
||||
|
||||
/// Flag indicating if the odd element should be tracked.
|
||||
@@ -66,33 +69,42 @@ impl PartialMmr {
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Constructs a [PartialMmr] from the given [MmrPeaks].
|
||||
pub fn from_peaks(accumulator: MmrPeaks) -> Self {
|
||||
let forest = accumulator.num_leaves();
|
||||
let peaks = accumulator.peaks().to_vec();
|
||||
pub fn from_peaks(peaks: MmrPeaks) -> Self {
|
||||
let forest = peaks.num_leaves();
|
||||
let peaks = peaks.peaks().to_vec();
|
||||
let nodes = BTreeMap::new();
|
||||
let track_latest = false;
|
||||
|
||||
Self { forest, peaks, nodes, track_latest }
|
||||
}
|
||||
|
||||
// ACCESSORS
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
// Gets the current `forest`.
|
||||
//
|
||||
// This value corresponds to the version of the [PartialMmr] and the number of leaves in it.
|
||||
/// Returns the current `forest` of this [PartialMmr].
|
||||
///
|
||||
/// This value corresponds to the version of the [PartialMmr] and the number of leaves in the
|
||||
/// underlying MMR.
|
||||
pub fn forest(&self) -> usize {
|
||||
self.forest
|
||||
}
|
||||
|
||||
// Returns a reference to the current peaks in the [PartialMmr]
|
||||
pub fn peaks(&self) -> &[RpoDigest] {
|
||||
&self.peaks
|
||||
/// Returns the number of leaves in the underlying MMR for this [PartialMmr].
|
||||
pub fn num_leaves(&self) -> usize {
|
||||
self.forest
|
||||
}
|
||||
|
||||
/// Given a leaf position, returns the Merkle path to its corresponding peak. If the position
|
||||
/// is greater-or-equal than the tree size an error is returned. If the requested value is not
|
||||
/// tracked returns `None`.
|
||||
/// Returns the peaks of the MMR for this [PartialMmr].
|
||||
pub fn peaks(&self) -> MmrPeaks {
|
||||
// expect() is OK here because the constructor ensures that MMR peaks can be constructed
|
||||
// correctly
|
||||
MmrPeaks::new(self.forest, self.peaks.clone()).expect("invalid MMR peaks")
|
||||
}
|
||||
|
||||
/// Given a leaf position, returns the Merkle path to its corresponding peak.
|
||||
///
|
||||
/// If the position is greater-or-equal than the tree size an error is returned. If the
|
||||
/// requested value is not tracked returns `None`.
|
||||
///
|
||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||
@@ -125,14 +137,45 @@ impl PartialMmr {
|
||||
}
|
||||
}
|
||||
|
||||
// MODIFIERS
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator nodes of all authentication paths of this [PartialMmr].
|
||||
pub fn nodes(&self) -> impl Iterator<Item = (&InOrderIndex, &RpoDigest)> {
|
||||
self.nodes.iter()
|
||||
}
|
||||
|
||||
/// Returns an iterator over inner nodes of this [PartialMmr] for the specified leaves.
|
||||
///
|
||||
/// The order of iteration is not defined. If a leaf is not presented in this partial MMR it
|
||||
/// is silently ignored.
|
||||
pub fn inner_nodes<'a, I: Iterator<Item = &'a (usize, RpoDigest)> + 'a>(
|
||||
&'a self,
|
||||
mut leaves: I,
|
||||
) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
let stack = if let Some((pos, leaf)) = leaves.next() {
|
||||
let idx = InOrderIndex::from_leaf_pos(*pos);
|
||||
vec![(idx, *leaf)]
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
InnerNodeIterator {
|
||||
nodes: &self.nodes,
|
||||
leaves,
|
||||
stack,
|
||||
seen_nodes: BTreeSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Add the authentication path represented by [MerklePath] if it is valid.
|
||||
///
|
||||
/// The `index` refers to the global position of the leaf in the [Mmr], these are 0-indexed
|
||||
/// values assigned in a strictly monotonic fashion as elements are inserted into the [Mmr],
|
||||
/// this value corresponds to the values used in the [Mmr] structure.
|
||||
/// The `index` refers to the global position of the leaf in the MMR, these are 0-indexed
|
||||
/// values assigned in a strictly monotonic fashion as elements are inserted into the MMR,
|
||||
/// this value corresponds to the values used in the MMR structure.
|
||||
///
|
||||
/// The `node` corresponds to the value at `index`, and `path` is the authentication path for
|
||||
/// that element up to its corresponding Mmr peak. The `node` is only used to compute the root
|
||||
@@ -185,7 +228,7 @@ impl PartialMmr {
|
||||
|
||||
/// Remove a leaf of the [PartialMmr] and the unused nodes from the authentication path.
|
||||
///
|
||||
/// Note: `leaf_pos` corresponds to the position the [Mmr] and not on an individual tree.
|
||||
/// Note: `leaf_pos` corresponds to the position in the MMR and not on an individual tree.
|
||||
pub fn remove(&mut self, leaf_pos: usize) {
|
||||
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
|
||||
|
||||
@@ -202,18 +245,21 @@ impl PartialMmr {
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies updates to the [PartialMmr].
|
||||
pub fn apply(&mut self, delta: MmrDelta) -> Result<(), MmrError> {
|
||||
/// Applies updates to this [PartialMmr] and returns a vector of new authentication nodes
|
||||
/// inserted into the partial MMR.
|
||||
pub fn apply(&mut self, delta: MmrDelta) -> Result<Vec<(InOrderIndex, RpoDigest)>, MmrError> {
|
||||
if delta.forest < self.forest {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
}
|
||||
|
||||
let mut inserted_nodes = Vec::new();
|
||||
|
||||
if delta.forest == self.forest {
|
||||
if !delta.data.is_empty() {
|
||||
return Err(MmrError::InvalidUpdate);
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
return Ok(inserted_nodes);
|
||||
}
|
||||
|
||||
// find the tree merges
|
||||
@@ -268,16 +314,21 @@ impl PartialMmr {
|
||||
// check if either the left or right subtrees have saved for authentication paths.
|
||||
// If so, turn tracking on to update those paths.
|
||||
if target != 1 && !track {
|
||||
let left_child = peak_idx.left_child();
|
||||
let right_child = peak_idx.right_child();
|
||||
track = self.nodes.contains_key(&left_child)
|
||||
| self.nodes.contains_key(&right_child);
|
||||
track = self.is_tracked_node(&peak_idx);
|
||||
}
|
||||
|
||||
// update data only contains the nodes from the right subtrees, left nodes are
|
||||
// either previously known peaks or computed values
|
||||
let (left, right) = if target & merges != 0 {
|
||||
let peak = self.peaks[peak_count];
|
||||
let sibling_idx = peak_idx.sibling();
|
||||
|
||||
// if the sibling peak is tracked, add this peaks to the set of
|
||||
// authentication nodes
|
||||
if self.is_tracked_node(&sibling_idx) {
|
||||
self.nodes.insert(peak_idx, new);
|
||||
inserted_nodes.push((peak_idx, new));
|
||||
}
|
||||
peak_count += 1;
|
||||
(peak, new)
|
||||
} else {
|
||||
@@ -287,7 +338,14 @@ impl PartialMmr {
|
||||
};
|
||||
|
||||
if track {
|
||||
self.nodes.insert(peak_idx.sibling(), right);
|
||||
let sibling_idx = peak_idx.sibling();
|
||||
if peak_idx.is_left_child() {
|
||||
self.nodes.insert(sibling_idx, right);
|
||||
inserted_nodes.push((sibling_idx, right));
|
||||
} else {
|
||||
self.nodes.insert(sibling_idx, left);
|
||||
inserted_nodes.push((sibling_idx, left));
|
||||
}
|
||||
}
|
||||
|
||||
peak_idx = peak_idx.parent();
|
||||
@@ -313,7 +371,22 @@ impl PartialMmr {
|
||||
|
||||
debug_assert!(self.peaks.len() == (self.forest.count_ones() as usize));
|
||||
|
||||
Ok(())
|
||||
Ok(inserted_nodes)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if this [PartialMmr] tracks authentication path for the node at the specified
|
||||
/// index.
|
||||
fn is_tracked_node(&self, node_index: &InOrderIndex) -> bool {
|
||||
if node_index.is_leaf() {
|
||||
self.nodes.contains_key(&node_index.sibling())
|
||||
} else {
|
||||
let left_child = node_index.left_child();
|
||||
let right_child = node_index.right_child();
|
||||
self.nodes.contains_key(&left_child) | self.nodes.contains_key(&right_child)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -348,12 +421,59 @@ impl From<&PartialMmr> for MmrPeaks {
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
|
||||
/// An iterator over every inner node of the [PartialMmr].
|
||||
pub struct InnerNodeIterator<'a, I: Iterator<Item = &'a (usize, RpoDigest)>> {
|
||||
nodes: &'a BTreeMap<InOrderIndex, RpoDigest>,
|
||||
leaves: I,
|
||||
stack: Vec<(InOrderIndex, RpoDigest)>,
|
||||
seen_nodes: BTreeSet<InOrderIndex>,
|
||||
}
|
||||
|
||||
impl<'a, I: Iterator<Item = &'a (usize, RpoDigest)>> Iterator for InnerNodeIterator<'a, I> {
|
||||
type Item = InnerNodeInfo;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while let Some((idx, node)) = self.stack.pop() {
|
||||
let parent_idx = idx.parent();
|
||||
let new_node = self.seen_nodes.insert(parent_idx);
|
||||
|
||||
// if we haven't seen this node's parent before, and the node has a sibling, return
|
||||
// the inner node defined by the parent of this node, and move up the branch
|
||||
if new_node {
|
||||
if let Some(sibling) = self.nodes.get(&idx.sibling()) {
|
||||
let (left, right) = if parent_idx.left_child() == idx {
|
||||
(node, *sibling)
|
||||
} else {
|
||||
(*sibling, node)
|
||||
};
|
||||
let parent = Rpo256::merge(&[left, right]);
|
||||
let inner_node = InnerNodeInfo { value: parent, left, right };
|
||||
|
||||
self.stack.push((parent_idx, parent));
|
||||
return Some(inner_node);
|
||||
}
|
||||
}
|
||||
|
||||
// the previous leaf has been processed, try to process the next leaf
|
||||
if let Some((pos, leaf)) = self.leaves.next() {
|
||||
let idx = InOrderIndex::from_leaf_pos(*pos);
|
||||
self.stack.push((idx, *leaf));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// UTILS
|
||||
// ================================================================================================
|
||||
|
||||
/// Given the description of a `forest`, returns the index of the root element of the smallest tree
|
||||
/// in it.
|
||||
pub fn forest_to_root_index(forest: usize) -> InOrderIndex {
|
||||
fn forest_to_root_index(forest: usize) -> InOrderIndex {
|
||||
// Count total size of all trees in the forest.
|
||||
let nodes = nodes_in_forest(forest);
|
||||
|
||||
@@ -370,10 +490,23 @@ pub fn forest_to_root_index(forest: usize) -> InOrderIndex {
|
||||
InOrderIndex::new(idx.try_into().unwrap())
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::forest_to_root_index;
|
||||
use crate::merkle::InOrderIndex;
|
||||
mod tests {
|
||||
use super::{forest_to_root_index, BTreeSet, InOrderIndex, PartialMmr, RpoDigest, Vec};
|
||||
use crate::merkle::{int_to_node, MerkleStore, Mmr, NodeIndex};
|
||||
|
||||
const LEAVES: [RpoDigest; 7] = [
|
||||
int_to_node(0),
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
int_to_node(3),
|
||||
int_to_node(4),
|
||||
int_to_node(5),
|
||||
int_to_node(6),
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn test_forest_to_root_index() {
|
||||
@@ -400,4 +533,171 @@ mod test {
|
||||
assert_eq!(forest_to_root_index(0b1100), idx(20));
|
||||
assert_eq!(forest_to_root_index(0b1110), idx(26));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_apply_delta() {
|
||||
// build an MMR with 10 nodes (2 peaks) and a partial MMR based on it
|
||||
let mut mmr = Mmr::default();
|
||||
(0..10).for_each(|i| mmr.add(int_to_node(i)));
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
|
||||
// add authentication path for position 1 and 8
|
||||
{
|
||||
let node = mmr.get(1).unwrap();
|
||||
let proof = mmr.open(1, mmr.forest()).unwrap();
|
||||
partial_mmr.add(1, node, &proof.merkle_path).unwrap();
|
||||
}
|
||||
|
||||
{
|
||||
let node = mmr.get(8).unwrap();
|
||||
let proof = mmr.open(8, mmr.forest()).unwrap();
|
||||
partial_mmr.add(8, node, &proof.merkle_path).unwrap();
|
||||
}
|
||||
|
||||
// add 2 more nodes into the MMR and validate apply_delta()
|
||||
(10..12).for_each(|i| mmr.add(int_to_node(i)));
|
||||
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||
|
||||
// add 1 more node to the MMR, validate apply_delta() and start tracking the node
|
||||
mmr.add(int_to_node(12));
|
||||
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||
{
|
||||
let node = mmr.get(12).unwrap();
|
||||
let proof = mmr.open(12, mmr.forest()).unwrap();
|
||||
partial_mmr.add(12, node, &proof.merkle_path).unwrap();
|
||||
assert!(partial_mmr.track_latest);
|
||||
}
|
||||
|
||||
// by this point we are tracking authentication paths for positions: 1, 8, and 12
|
||||
|
||||
// add 3 more nodes to the MMR (collapses to 1 peak) and validate apply_delta()
|
||||
(13..16).for_each(|i| mmr.add(int_to_node(i)));
|
||||
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||
}
|
||||
|
||||
fn validate_apply_delta(mmr: &Mmr, partial: &mut PartialMmr) {
|
||||
let tracked_leaves = partial
|
||||
.nodes
|
||||
.iter()
|
||||
.filter_map(|(index, _)| if index.is_leaf() { Some(index.sibling()) } else { None })
|
||||
.collect::<Vec<_>>();
|
||||
let nodes_before = partial.nodes.clone();
|
||||
|
||||
// compute and apply delta
|
||||
let delta = mmr.get_delta(partial.forest(), mmr.forest()).unwrap();
|
||||
let nodes_delta = partial.apply(delta).unwrap();
|
||||
|
||||
// new peaks were computed correctly
|
||||
assert_eq!(mmr.peaks(mmr.forest()).unwrap(), partial.peaks());
|
||||
|
||||
let mut expected_nodes = nodes_before;
|
||||
for (key, value) in nodes_delta {
|
||||
// nodes should not be duplicated
|
||||
assert!(expected_nodes.insert(key, value).is_none());
|
||||
}
|
||||
|
||||
// new nodes should be a combination of original nodes and delta
|
||||
assert_eq!(expected_nodes, partial.nodes);
|
||||
|
||||
// make sure tracked leaves open to the same proofs as in the underlying MMR
|
||||
for index in tracked_leaves {
|
||||
let index_value: u64 = index.into();
|
||||
let pos = index_value / 2;
|
||||
let proof1 = partial.open(pos as usize).unwrap().unwrap();
|
||||
let proof2 = mmr.open(pos as usize, mmr.forest()).unwrap();
|
||||
assert_eq!(proof1, proof2);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_inner_nodes_iterator() {
|
||||
// build the MMR
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let first_peak = mmr.peaks(mmr.forest).unwrap().peaks()[0];
|
||||
|
||||
// -- test single tree ----------------------------
|
||||
|
||||
// get path and node for position 1
|
||||
let node1 = mmr.get(1).unwrap();
|
||||
let proof1 = mmr.open(1, mmr.forest()).unwrap();
|
||||
|
||||
// create partial MMR and add authentication path to node at position 1
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
partial_mmr.add(1, node1, &proof1.merkle_path).unwrap();
|
||||
|
||||
// empty iterator should have no nodes
|
||||
assert_eq!(partial_mmr.inner_nodes([].iter()).next(), None);
|
||||
|
||||
// build Merkle store from authentication paths in partial MMR
|
||||
let mut store: MerkleStore = MerkleStore::new();
|
||||
store.extend(partial_mmr.inner_nodes([(1, node1)].iter()));
|
||||
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||
|
||||
assert_eq!(path1, proof1.merkle_path);
|
||||
|
||||
// -- test no duplicates --------------------------
|
||||
|
||||
// build the partial MMR
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
|
||||
let node0 = mmr.get(0).unwrap();
|
||||
let proof0 = mmr.open(0, mmr.forest()).unwrap();
|
||||
|
||||
let node2 = mmr.get(2).unwrap();
|
||||
let proof2 = mmr.open(2, mmr.forest()).unwrap();
|
||||
|
||||
partial_mmr.add(0, node0, &proof0.merkle_path).unwrap();
|
||||
partial_mmr.add(1, node1, &proof1.merkle_path).unwrap();
|
||||
partial_mmr.add(2, node2, &proof2.merkle_path).unwrap();
|
||||
|
||||
// make sure there are no duplicates
|
||||
let leaves = [(0, node0), (1, node1), (2, node2)];
|
||||
let mut nodes = BTreeSet::new();
|
||||
for node in partial_mmr.inner_nodes(leaves.iter()) {
|
||||
assert!(nodes.insert(node.value));
|
||||
}
|
||||
|
||||
// and also that the store is still be built correctly
|
||||
store.extend(partial_mmr.inner_nodes(leaves.iter()));
|
||||
|
||||
let index0 = NodeIndex::new(2, 0).unwrap();
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
let index2 = NodeIndex::new(2, 2).unwrap();
|
||||
|
||||
let path0 = store.get_path(first_peak, index0).unwrap().path;
|
||||
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||
let path2 = store.get_path(first_peak, index2).unwrap().path;
|
||||
|
||||
assert_eq!(path0, proof0.merkle_path);
|
||||
assert_eq!(path1, proof1.merkle_path);
|
||||
assert_eq!(path2, proof2.merkle_path);
|
||||
|
||||
// -- test multiple trees -------------------------
|
||||
|
||||
// build the partial MMR
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
|
||||
let node5 = mmr.get(5).unwrap();
|
||||
let proof5 = mmr.open(5, mmr.forest()).unwrap();
|
||||
|
||||
partial_mmr.add(1, node1, &proof1.merkle_path).unwrap();
|
||||
partial_mmr.add(5, node5, &proof5.merkle_path).unwrap();
|
||||
|
||||
// build Merkle store from authentication paths in partial MMR
|
||||
let mut store: MerkleStore = MerkleStore::new();
|
||||
store.extend(partial_mmr.inner_nodes([(1, node1), (5, node5)].iter()));
|
||||
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
let index5 = NodeIndex::new(1, 1).unwrap();
|
||||
|
||||
let second_peak = mmr.peaks(mmr.forest).unwrap().peaks()[1];
|
||||
|
||||
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||
let path5 = store.get_path(second_peak, index5).unwrap().path;
|
||||
|
||||
assert_eq!(path1, proof1.merkle_path);
|
||||
assert_eq!(path5, proof5.merkle_path);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,17 +3,20 @@ use super::{
|
||||
Felt, MmrError, MmrProof, Rpo256, Word,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
// MMR PEAKS
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MmrPeaks {
|
||||
/// The number of leaves is used to differentiate accumulators that have the same number of
|
||||
/// peaks. This happens because the number of peaks goes up-and-down as the structure is used
|
||||
/// causing existing trees to be merged and new ones to be created. As an example, every time
|
||||
/// the [Mmr] has a power-of-two number of leaves there is a single peak.
|
||||
/// The number of leaves is used to differentiate MMRs that have the same number of peaks. This
|
||||
/// happens because the number of peaks goes up-and-down as the structure is used causing
|
||||
/// existing trees to be merged and new ones to be created. As an example, every time the MMR
|
||||
/// has a power-of-two number of leaves there is a single peak.
|
||||
///
|
||||
/// Every tree in the [Mmr] forest has a distinct power-of-two size, this means only the right
|
||||
/// most tree can have an odd number of elements (e.g. `1`). Additionally this means that the bits in
|
||||
/// `num_leaves` conveniently encode the size of each individual tree.
|
||||
/// Every tree in the MMR forest has a distinct power-of-two size, this means only the right-
|
||||
/// most tree can have an odd number of elements (e.g. `1`). Additionally this means that the
|
||||
/// bits in `num_leaves` conveniently encode the size of each individual tree.
|
||||
///
|
||||
/// Examples:
|
||||
///
|
||||
@@ -25,7 +28,7 @@ pub struct MmrPeaks {
|
||||
/// leftmost tree has `2**3=8` elements, and the right most has `2**2=4` elements.
|
||||
num_leaves: usize,
|
||||
|
||||
/// All the peaks of every tree in the [Mmr] forest. The peaks are always ordered by number of
|
||||
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
|
||||
/// leaves, starting from the peak with most children, to the one with least.
|
||||
///
|
||||
/// Invariant: The length of `peaks` must be equal to the number of true bits in `num_leaves`.
|
||||
@@ -33,6 +36,14 @@ pub struct MmrPeaks {
|
||||
}
|
||||
|
||||
impl MmrPeaks {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns new [MmrPeaks] instantiated from the provided vector of peaks and the number of
|
||||
/// leaves in the underlying MMR.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the number of leaves and the number of peaks are inconsistent.
|
||||
pub fn new(num_leaves: usize, peaks: Vec<RpoDigest>) -> Result<Self, MmrError> {
|
||||
if num_leaves.count_ones() as usize != peaks.len() {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
@@ -44,23 +55,34 @@ impl MmrPeaks {
|
||||
// ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a count of the [Mmr]'s leaves.
|
||||
/// Returns a count of leaves in the underlying MMR.
|
||||
pub fn num_leaves(&self) -> usize {
|
||||
self.num_leaves
|
||||
}
|
||||
|
||||
/// Returns the current peaks of the [Mmr].
|
||||
/// Returns the number of peaks of the underlying MMR.
|
||||
pub fn num_peaks(&self) -> usize {
|
||||
self.peaks.len()
|
||||
}
|
||||
|
||||
/// Returns the list of peaks of the underlying MMR.
|
||||
pub fn peaks(&self) -> &[RpoDigest] {
|
||||
&self.peaks
|
||||
}
|
||||
|
||||
/// Converts this [MmrPeaks] into its components: number of leaves and a vector of peaks of
|
||||
/// the underlying MMR.
|
||||
pub fn into_parts(self) -> (usize, Vec<RpoDigest>) {
|
||||
(self.num_leaves, self.peaks)
|
||||
}
|
||||
|
||||
/// Hashes the peaks.
|
||||
///
|
||||
/// The procedure will:
|
||||
/// - Flatten and pad the peaks to a vector of Felts.
|
||||
/// - Hash the vector of Felts.
|
||||
pub fn hash_peaks(&self) -> Word {
|
||||
Rpo256::hash_elements(&self.flatten_and_pad_peaks()).into()
|
||||
pub fn hash_peaks(&self) -> RpoDigest {
|
||||
Rpo256::hash_elements(&self.flatten_and_pad_peaks())
|
||||
}
|
||||
|
||||
pub fn verify(&self, value: RpoDigest, opening: MmrProof) -> bool {
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
use super::{
|
||||
super::{InnerNodeInfo, Vec},
|
||||
super::{InnerNodeInfo, Rpo256, RpoDigest, Vec},
|
||||
bit::TrueBitPositionIterator,
|
||||
full::high_bitmask,
|
||||
leaf_to_corresponding_tree, nodes_in_forest, Mmr, MmrPeaks, PartialMmr, Rpo256,
|
||||
leaf_to_corresponding_tree, nodes_in_forest, Mmr, MmrPeaks, PartialMmr,
|
||||
};
|
||||
use crate::{
|
||||
hash::rpo::RpoDigest,
|
||||
merkle::{int_to_node, InOrderIndex, MerklePath, MerkleTree, MmrProof, NodeIndex},
|
||||
Felt, Word,
|
||||
};
|
||||
@@ -137,7 +136,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 1);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 1);
|
||||
assert_eq!(acc.peaks(), &[postorder[0]]);
|
||||
|
||||
@@ -146,7 +145,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 3);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 2);
|
||||
assert_eq!(acc.peaks(), &[postorder[2]]);
|
||||
|
||||
@@ -155,7 +154,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 4);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 3);
|
||||
assert_eq!(acc.peaks(), &[postorder[2], postorder[3]]);
|
||||
|
||||
@@ -164,7 +163,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 7);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 4);
|
||||
assert_eq!(acc.peaks(), &[postorder[6]]);
|
||||
|
||||
@@ -173,7 +172,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 8);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 5);
|
||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[7]]);
|
||||
|
||||
@@ -182,7 +181,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 10);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 6);
|
||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[9]]);
|
||||
|
||||
@@ -191,7 +190,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 11);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 7);
|
||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[9], postorder[10]]);
|
||||
}
|
||||
@@ -203,96 +202,139 @@ fn test_mmr_open() {
|
||||
let h23 = merge(LEAVES[2], LEAVES[3]);
|
||||
|
||||
// node at pos 7 is the root
|
||||
assert!(mmr.open(7).is_err(), "Element 7 is not in the tree, result should be None");
|
||||
assert!(
|
||||
mmr.open(7, mmr.forest()).is_err(),
|
||||
"Element 7 is not in the tree, result should be None"
|
||||
);
|
||||
|
||||
// node at pos 6 is the root
|
||||
let empty: MerklePath = MerklePath::new(vec![]);
|
||||
let opening = mmr
|
||||
.open(6)
|
||||
.open(6, mmr.forest())
|
||||
.expect("Element 6 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, empty);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 6);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[6], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[6], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
// nodes 4,5 are depth 1
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[4]]);
|
||||
let opening = mmr
|
||||
.open(5)
|
||||
.open(5, mmr.forest())
|
||||
.expect("Element 5 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 5);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[5], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[5], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[5]]);
|
||||
let opening = mmr
|
||||
.open(4)
|
||||
.open(4, mmr.forest())
|
||||
.expect("Element 4 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 4);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[4], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[4], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
// nodes 0,1,2,3 are detph 2
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[2], h01]);
|
||||
let opening = mmr
|
||||
.open(3)
|
||||
.open(3, mmr.forest())
|
||||
.expect("Element 3 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 3);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[3], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[3], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[3], h01]);
|
||||
let opening = mmr
|
||||
.open(2)
|
||||
.open(2, mmr.forest())
|
||||
.expect("Element 2 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 2);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[2], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[2], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[0], h23]);
|
||||
let opening = mmr
|
||||
.open(1)
|
||||
.open(1, mmr.forest())
|
||||
.expect("Element 1 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 1);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[1], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[1], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[1], h23]);
|
||||
let opening = mmr
|
||||
.open(0)
|
||||
.open(0, mmr.forest())
|
||||
.expect("Element 0 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 0);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[0], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[0], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_open_older_version() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
fn is_even(v: &usize) -> bool {
|
||||
v & 1 == 0
|
||||
}
|
||||
|
||||
// merkle path of a node is empty if there are no elements to pair with it
|
||||
for pos in (0..mmr.forest()).filter(is_even) {
|
||||
let forest = pos + 1;
|
||||
let proof = mmr.open(pos, forest).unwrap();
|
||||
assert_eq!(proof.forest, forest);
|
||||
assert_eq!(proof.merkle_path.nodes(), []);
|
||||
assert_eq!(proof.position, pos);
|
||||
}
|
||||
|
||||
// openings match that of a merkle tree
|
||||
let mtree: MerkleTree = LEAVES[..4].try_into().unwrap();
|
||||
for forest in 4..=LEAVES.len() {
|
||||
for pos in 0..4 {
|
||||
let idx = NodeIndex::new(2, pos).unwrap();
|
||||
let path = mtree.get_path(idx).unwrap();
|
||||
let proof = mmr.open(pos as usize, forest).unwrap();
|
||||
assert_eq!(path, proof.merkle_path);
|
||||
}
|
||||
}
|
||||
let mtree: MerkleTree = LEAVES[4..6].try_into().unwrap();
|
||||
for forest in 6..=LEAVES.len() {
|
||||
for pos in 0..2 {
|
||||
let idx = NodeIndex::new(1, pos).unwrap();
|
||||
let path = mtree.get_path(idx).unwrap();
|
||||
// account for the bigger tree with 4 elements
|
||||
let mmr_pos = (pos + 4) as usize;
|
||||
let proof = mmr.open(mmr_pos, forest).unwrap();
|
||||
assert_eq!(path, proof.merkle_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tests the openings of a simple Mmr with a single tree of depth 8.
|
||||
#[test]
|
||||
fn test_mmr_open_eight() {
|
||||
@@ -313,49 +355,49 @@ fn test_mmr_open_eight() {
|
||||
let root = mtree.root();
|
||||
|
||||
let position = 0;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 1;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 2;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 3;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 4;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 5;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 6;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 7;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
@@ -371,47 +413,47 @@ fn test_mmr_open_seven() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
let position = 0;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[0]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 1;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[1]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 2;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(2, LEAVES[2]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 3;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(3, LEAVES[3]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 4;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 0u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[4]).unwrap(), mtree2.root());
|
||||
|
||||
let position = 5;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 1u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[5]).unwrap(), mtree2.root());
|
||||
|
||||
let position = 6;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath = [].as_ref().into();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[6]).unwrap(), LEAVES[6]);
|
||||
@@ -435,7 +477,7 @@ fn test_mmr_invariants() {
|
||||
let mut mmr = Mmr::new();
|
||||
for v in 1..=1028 {
|
||||
mmr.add(int_to_node(v));
|
||||
let accumulator = mmr.accumulator();
|
||||
let accumulator = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(v as usize, mmr.forest(), "MMR leaf count must increase by one on every add");
|
||||
assert_eq!(
|
||||
v as usize,
|
||||
@@ -516,10 +558,50 @@ fn test_mmr_inner_nodes() {
|
||||
assert_eq!(postorder, nodes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_peaks() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
let forest = 0b0001;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[0]]);
|
||||
|
||||
let forest = 0b0010;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[2]]);
|
||||
|
||||
let forest = 0b0011;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[2], mmr.nodes[3]]);
|
||||
|
||||
let forest = 0b0100;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6]]);
|
||||
|
||||
let forest = 0b0101;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[7]]);
|
||||
|
||||
let forest = 0b0110;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9]]);
|
||||
|
||||
let forest = 0b0111;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9], mmr.nodes[10]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_hash_peaks() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let peaks = mmr.accumulator();
|
||||
let peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||
|
||||
let first_peak = Rpo256::merge(&[
|
||||
Rpo256::merge(&[LEAVES[0], LEAVES[1]]),
|
||||
@@ -531,10 +613,7 @@ fn test_mmr_hash_peaks() {
|
||||
// minimum length is 16
|
||||
let mut expected_peaks = [first_peak, second_peak, third_peak].to_vec();
|
||||
expected_peaks.resize(16, RpoDigest::default());
|
||||
assert_eq!(
|
||||
peaks.hash_peaks(),
|
||||
*Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||
);
|
||||
assert_eq!(peaks.hash_peaks(), Rpo256::hash_elements(&digests_to_elements(&expected_peaks)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -552,7 +631,7 @@ fn test_mmr_peaks_hash_less_than_16() {
|
||||
expected_peaks.resize(16, RpoDigest::default());
|
||||
assert_eq!(
|
||||
accumulator.hash_peaks(),
|
||||
*Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||
Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -569,47 +648,47 @@ fn test_mmr_peaks_hash_odd() {
|
||||
expected_peaks.resize(18, RpoDigest::default());
|
||||
assert_eq!(
|
||||
accumulator.hash_peaks(),
|
||||
*Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||
Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_updates() {
|
||||
fn test_mmr_delta() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
|
||||
// original_forest can't have more elements
|
||||
assert!(
|
||||
mmr.get_delta(LEAVES.len() + 1).is_err(),
|
||||
mmr.get_delta(LEAVES.len() + 1, mmr.forest()).is_err(),
|
||||
"Can not provide updates for a newer Mmr"
|
||||
);
|
||||
|
||||
// if the number of elements is the same there is no change
|
||||
assert!(
|
||||
mmr.get_delta(LEAVES.len()).unwrap().data.is_empty(),
|
||||
mmr.get_delta(LEAVES.len(), mmr.forest()).unwrap().data.is_empty(),
|
||||
"There are no updates for the same Mmr version"
|
||||
);
|
||||
|
||||
// missing the last element added, which is itself a tree peak
|
||||
assert_eq!(mmr.get_delta(6).unwrap().data, vec![acc.peaks()[2]], "one peak");
|
||||
assert_eq!(mmr.get_delta(6, mmr.forest()).unwrap().data, vec![acc.peaks()[2]], "one peak");
|
||||
|
||||
// missing the sibling to complete the tree of depth 2, and the last element
|
||||
assert_eq!(
|
||||
mmr.get_delta(5).unwrap().data,
|
||||
mmr.get_delta(5, mmr.forest()).unwrap().data,
|
||||
vec![LEAVES[5], acc.peaks()[2]],
|
||||
"one sibling, one peak"
|
||||
);
|
||||
|
||||
// missing the whole last two trees, only send the peaks
|
||||
assert_eq!(
|
||||
mmr.get_delta(4).unwrap().data,
|
||||
mmr.get_delta(4, mmr.forest()).unwrap().data,
|
||||
vec![acc.peaks()[1], acc.peaks()[2]],
|
||||
"two peaks"
|
||||
);
|
||||
|
||||
// missing the sibling to complete the first tree, and the two last trees
|
||||
assert_eq!(
|
||||
mmr.get_delta(3).unwrap().data,
|
||||
mmr.get_delta(3, mmr.forest()).unwrap().data,
|
||||
vec![LEAVES[3], acc.peaks()[1], acc.peaks()[2]],
|
||||
"one sibling, two peaks"
|
||||
);
|
||||
@@ -617,35 +696,77 @@ fn test_mmr_updates() {
|
||||
// missing half of the first tree, only send the computed element (not the leaves), and the new
|
||||
// peaks
|
||||
assert_eq!(
|
||||
mmr.get_delta(2).unwrap().data,
|
||||
mmr.get_delta(2, mmr.forest()).unwrap().data,
|
||||
vec![mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
|
||||
"one sibling, two peaks"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
mmr.get_delta(1).unwrap().data,
|
||||
mmr.get_delta(1, mmr.forest()).unwrap().data,
|
||||
vec![LEAVES[1], mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
|
||||
"one sibling, two peaks"
|
||||
);
|
||||
|
||||
assert_eq!(&mmr.get_delta(0).unwrap().data, acc.peaks(), "all peaks");
|
||||
assert_eq!(&mmr.get_delta(0, mmr.forest()).unwrap().data, acc.peaks(), "all peaks");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_delta_old_forest() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
// from_forest must be smaller-or-equal to to_forest
|
||||
for version in 1..=mmr.forest() {
|
||||
assert!(mmr.get_delta(version + 1, version).is_err());
|
||||
}
|
||||
|
||||
// when from_forest and to_forest are equal, there are no updates
|
||||
for version in 1..=mmr.forest() {
|
||||
let delta = mmr.get_delta(version, version).unwrap();
|
||||
assert!(delta.data.is_empty());
|
||||
assert_eq!(delta.forest, version);
|
||||
}
|
||||
|
||||
// test update which merges the odd peak to the right
|
||||
for count in 0..(mmr.forest() / 2) {
|
||||
// *2 because every iteration tests a pair
|
||||
// +1 because the Mmr is 1-indexed
|
||||
let from_forest = (count * 2) + 1;
|
||||
let to_forest = (count * 2) + 2;
|
||||
let delta = mmr.get_delta(from_forest, to_forest).unwrap();
|
||||
|
||||
// *2 because every iteration tests a pair
|
||||
// +1 because sibling is the odd element
|
||||
let sibling = (count * 2) + 1;
|
||||
assert_eq!(delta.data, [LEAVES[sibling]]);
|
||||
assert_eq!(delta.forest, to_forest);
|
||||
}
|
||||
|
||||
let version = 4;
|
||||
let delta = mmr.get_delta(1, version).unwrap();
|
||||
assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5]]);
|
||||
assert_eq!(delta.forest, version);
|
||||
|
||||
let version = 5;
|
||||
let delta = mmr.get_delta(1, version).unwrap();
|
||||
assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5], mmr.nodes[7]]);
|
||||
assert_eq!(delta.forest, version);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_simple() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let acc = mmr.accumulator();
|
||||
let mut partial: PartialMmr = acc.clone().into();
|
||||
let peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||
let mut partial: PartialMmr = peaks.clone().into();
|
||||
|
||||
// check initial state of the partial mmr
|
||||
assert_eq!(partial.peaks(), acc.peaks());
|
||||
assert_eq!(partial.forest(), acc.num_leaves());
|
||||
assert_eq!(partial.peaks(), peaks);
|
||||
assert_eq!(partial.forest(), peaks.num_leaves());
|
||||
assert_eq!(partial.forest(), LEAVES.len());
|
||||
assert_eq!(partial.peaks().len(), 3);
|
||||
assert_eq!(partial.peaks().num_peaks(), 3);
|
||||
assert_eq!(partial.nodes.len(), 0);
|
||||
|
||||
// check state after adding tracking one element
|
||||
let proof1 = mmr.open(0).unwrap();
|
||||
let proof1 = mmr.open(0, mmr.forest()).unwrap();
|
||||
let el1 = mmr.get(proof1.position).unwrap();
|
||||
partial.add(proof1.position, el1, &proof1.merkle_path).unwrap();
|
||||
|
||||
@@ -657,7 +778,7 @@ fn test_partial_mmr_simple() {
|
||||
let idx = idx.parent();
|
||||
assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[1]);
|
||||
|
||||
let proof2 = mmr.open(1).unwrap();
|
||||
let proof2 = mmr.open(1, mmr.forest()).unwrap();
|
||||
let el2 = mmr.get(proof2.position).unwrap();
|
||||
partial.add(proof2.position, el2, &proof2.merkle_path).unwrap();
|
||||
|
||||
@@ -675,21 +796,21 @@ fn test_partial_mmr_update_single() {
|
||||
let mut full = Mmr::new();
|
||||
let zero = int_to_node(0);
|
||||
full.add(zero);
|
||||
let mut partial: PartialMmr = full.accumulator().into();
|
||||
let mut partial: PartialMmr = full.peaks(full.forest()).unwrap().into();
|
||||
|
||||
let proof = full.open(0).unwrap();
|
||||
let proof = full.open(0, full.forest()).unwrap();
|
||||
partial.add(proof.position, zero, &proof.merkle_path).unwrap();
|
||||
|
||||
for i in 1..100 {
|
||||
let node = int_to_node(i);
|
||||
full.add(node);
|
||||
let delta = full.get_delta(partial.forest()).unwrap();
|
||||
let delta = full.get_delta(partial.forest(), full.forest()).unwrap();
|
||||
partial.apply(delta).unwrap();
|
||||
|
||||
assert_eq!(partial.forest(), full.forest());
|
||||
assert_eq!(partial.peaks(), full.accumulator().peaks());
|
||||
assert_eq!(partial.peaks(), full.peaks(full.forest()).unwrap());
|
||||
|
||||
let proof1 = full.open(i as usize).unwrap();
|
||||
let proof1 = full.open(i as usize, full.forest()).unwrap();
|
||||
partial.add(proof1.position, node, &proof1.merkle_path).unwrap();
|
||||
let proof2 = partial.open(proof1.position).unwrap().unwrap();
|
||||
assert_eq!(proof1.merkle_path, proof2.merkle_path);
|
||||
@@ -699,7 +820,7 @@ fn test_partial_mmr_update_single() {
|
||||
#[test]
|
||||
fn test_mmr_add_invalid_odd_leaf() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
let mut partial: PartialMmr = acc.clone().into();
|
||||
|
||||
let empty = MerklePath::new(Vec::new());
|
||||
|
||||
@@ -31,7 +31,7 @@ mod tiered_smt;
|
||||
pub use tiered_smt::{TieredSmt, TieredSmtProof, TieredSmtProofError};
|
||||
|
||||
mod mmr;
|
||||
pub use mmr::{InOrderIndex, Mmr, MmrError, MmrPeaks, MmrProof, PartialMmr};
|
||||
pub use mmr::{InOrderIndex, Mmr, MmrDelta, MmrError, MmrPeaks, MmrProof, PartialMmr};
|
||||
|
||||
mod store;
|
||||
pub use store::{DefaultMerkleStore, MerkleStore, RecordingMerkleStore, StoreNode};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::hash::rpo::RpoDigest;
|
||||
use super::RpoDigest;
|
||||
|
||||
/// Representation of a node with two children used for iterating over containers.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
|
||||
@@ -109,9 +109,9 @@ impl PartialMerkleTree {
|
||||
|
||||
// check if the number of leaves can be accommodated by the tree's depth; we use a min
|
||||
// depth of 63 because we consider passing in a vector of size 2^64 infeasible.
|
||||
let max = (1_u64 << 63) as usize;
|
||||
let max = 2usize.pow(63);
|
||||
if layers.len() > max {
|
||||
return Err(MerkleError::InvalidNumEntries(max, layers.len()));
|
||||
return Err(MerkleError::InvalidNumEntries(max));
|
||||
}
|
||||
|
||||
// Get maximum depth
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use super::{vec, InnerNodeInfo, MerkleError, NodeIndex, Rpo256, RpoDigest, Vec};
|
||||
use core::ops::{Deref, DerefMut};
|
||||
use winter_utils::{ByteReader, Deserializable, DeserializationError, Serializable};
|
||||
|
||||
// MERKLE PATH
|
||||
// ================================================================================================
|
||||
@@ -17,6 +18,7 @@ impl MerklePath {
|
||||
|
||||
/// Creates a new Merkle path from a list of nodes.
|
||||
pub fn new(nodes: Vec<RpoDigest>) -> Self {
|
||||
assert!(nodes.len() <= u8::MAX.into(), "MerklePath may have at most 256 items");
|
||||
Self { nodes }
|
||||
}
|
||||
|
||||
@@ -189,6 +191,55 @@ pub struct RootPath {
|
||||
pub path: MerklePath,
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for MerklePath {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
assert!(self.nodes.len() <= u8::MAX.into(), "Length enforced in the constructor");
|
||||
target.write_u8(self.nodes.len() as u8);
|
||||
self.nodes.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for MerklePath {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let count = source.read_u8()?.into();
|
||||
let nodes = RpoDigest::read_batch_from(source, count)?;
|
||||
Ok(Self { nodes })
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for ValuePath {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
self.value.write_into(target);
|
||||
self.path.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for ValuePath {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let value = RpoDigest::read_from(source)?;
|
||||
let path = MerklePath::read_from(source)?;
|
||||
Ok(Self { value, path })
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for RootPath {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
self.root.write_into(target);
|
||||
self.path.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RootPath {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let root = RpoDigest::read_from(source)?;
|
||||
let path = MerklePath::read_from(source)?;
|
||||
Ok(Self { root, path })
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ pub struct SimpleSmt {
|
||||
root: RpoDigest,
|
||||
leaves: BTreeMap<u64, Word>,
|
||||
branches: BTreeMap<NodeIndex, BranchNode>,
|
||||
empty_hashes: Vec<RpoDigest>,
|
||||
}
|
||||
|
||||
impl SimpleSmt {
|
||||
@@ -52,13 +51,11 @@ impl SimpleSmt {
|
||||
return Err(MerkleError::DepthTooBig(depth as u64));
|
||||
}
|
||||
|
||||
let empty_hashes = EmptySubtreeRoots::empty_hashes(depth).to_vec();
|
||||
let root = empty_hashes[0];
|
||||
let root = *EmptySubtreeRoots::entry(depth, 0);
|
||||
|
||||
Ok(Self {
|
||||
root,
|
||||
depth,
|
||||
empty_hashes,
|
||||
leaves: BTreeMap::new(),
|
||||
branches: BTreeMap::new(),
|
||||
})
|
||||
@@ -74,39 +71,54 @@ impl SimpleSmt {
|
||||
/// - If the depth is 0 or is greater than 64.
|
||||
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
|
||||
/// - The provided entries contain multiple values for the same key.
|
||||
pub fn with_leaves<R, I>(depth: u8, entries: R) -> Result<Self, MerkleError>
|
||||
where
|
||||
R: IntoIterator<IntoIter = I>,
|
||||
I: Iterator<Item = (u64, Word)> + ExactSizeIterator,
|
||||
{
|
||||
pub fn with_leaves(
|
||||
depth: u8,
|
||||
entries: impl IntoIterator<Item = (u64, Word)>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
// create an empty tree
|
||||
let mut tree = Self::new(depth)?;
|
||||
|
||||
// check if the number of leaves can be accommodated by the tree's depth; we use a min
|
||||
// depth of 63 because we consider passing in a vector of size 2^64 infeasible.
|
||||
let entries = entries.into_iter();
|
||||
let max = 1 << tree.depth.min(63);
|
||||
if entries.len() > max {
|
||||
return Err(MerkleError::InvalidNumEntries(max, entries.len()));
|
||||
}
|
||||
// compute the max number of entries. We use an upper bound of depth 63 because we consider
|
||||
// passing in a vector of size 2^64 infeasible.
|
||||
let max_num_entries = 2_usize.pow(tree.depth.min(63).into());
|
||||
|
||||
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
|
||||
// entries with the empty value need additional tracking.
|
||||
let mut key_set_to_zero = BTreeSet::new();
|
||||
|
||||
for (idx, (key, value)) in entries.into_iter().enumerate() {
|
||||
if idx >= max_num_entries {
|
||||
return Err(MerkleError::InvalidNumEntries(max_num_entries));
|
||||
}
|
||||
|
||||
// append leaves to the tree returning an error if a duplicate entry for the same key
|
||||
// is found
|
||||
let mut empty_entries = BTreeSet::new();
|
||||
for (key, value) in entries {
|
||||
let old_value = tree.update_leaf(key, value)?;
|
||||
if old_value != Self::EMPTY_VALUE || empty_entries.contains(&key) {
|
||||
return Err(MerkleError::DuplicateValuesForIndex(key));
|
||||
}
|
||||
// if we've processed an empty entry, add the key to the set of empty entry keys, and
|
||||
// if this key was already in the set, return an error
|
||||
if value == Self::EMPTY_VALUE && !empty_entries.insert(key) {
|
||||
|
||||
if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) {
|
||||
return Err(MerkleError::DuplicateValuesForIndex(key));
|
||||
}
|
||||
|
||||
if value == Self::EMPTY_VALUE {
|
||||
key_set_to_zero.insert(key);
|
||||
};
|
||||
}
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
|
||||
/// starting at index 0.
|
||||
pub fn with_contiguous_leaves(
|
||||
depth: u8,
|
||||
entries: impl IntoIterator<Item = Word>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
Self::with_leaves(
|
||||
depth,
|
||||
entries
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, word)| (idx.try_into().expect("tree max depth is 2^8"), word)),
|
||||
)
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
@@ -133,10 +145,12 @@ impl SimpleSmt {
|
||||
} else if index.depth() == self.depth() {
|
||||
// the lookup in empty_hashes could fail only if empty_hashes were not built correctly
|
||||
// by the constructor as we check the depth of the lookup above.
|
||||
Ok(RpoDigest::from(
|
||||
self.get_leaf_node(index.value())
|
||||
.unwrap_or_else(|| *self.empty_hashes[index.depth() as usize]),
|
||||
))
|
||||
let leaf_pos = index.value();
|
||||
let leaf = match self.get_leaf_node(leaf_pos) {
|
||||
Some(word) => word.into(),
|
||||
None => *EmptySubtreeRoots::entry(self.depth, index.depth()),
|
||||
};
|
||||
Ok(leaf)
|
||||
} else {
|
||||
Ok(self.get_branch_node(&index).parent())
|
||||
}
|
||||
@@ -214,6 +228,9 @@ impl SimpleSmt {
|
||||
/// # Errors
|
||||
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
|
||||
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<Word, MerkleError> {
|
||||
// validate the index before modifying the structure
|
||||
let idx = NodeIndex::new(self.depth(), index)?;
|
||||
|
||||
let old_value = self.insert_leaf_node(index, value).unwrap_or(Self::EMPTY_VALUE);
|
||||
|
||||
// if the old value and new value are the same, there is nothing to update
|
||||
@@ -221,8 +238,82 @@ impl SimpleSmt {
|
||||
return Ok(value);
|
||||
}
|
||||
|
||||
let mut index = NodeIndex::new(self.depth(), index)?;
|
||||
let mut value = RpoDigest::from(value);
|
||||
self.recompute_nodes_from_index_to_root(idx, RpoDigest::from(value));
|
||||
|
||||
Ok(old_value)
|
||||
}
|
||||
|
||||
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
|
||||
/// computed as `self.depth() - subtree.depth()`.
|
||||
///
|
||||
/// Returns the new root.
|
||||
pub fn set_subtree(
|
||||
&mut self,
|
||||
subtree_insertion_index: u64,
|
||||
subtree: SimpleSmt,
|
||||
) -> Result<RpoDigest, MerkleError> {
|
||||
if subtree.depth() > self.depth() {
|
||||
return Err(MerkleError::InvalidSubtreeDepth {
|
||||
subtree_depth: subtree.depth(),
|
||||
tree_depth: self.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
// Verify that `subtree_insertion_index` is valid.
|
||||
let subtree_root_insertion_depth = self.depth() - subtree.depth();
|
||||
let subtree_root_index =
|
||||
NodeIndex::new(subtree_root_insertion_depth, subtree_insertion_index)?;
|
||||
|
||||
// add leaves
|
||||
// --------------
|
||||
|
||||
// The subtree's leaf indices live in their own context - i.e. a subtree of depth `d`. If we
|
||||
// insert the subtree at `subtree_insertion_index = 0`, then the subtree leaf indices are
|
||||
// valid as they are. However, consider what happens when we insert at
|
||||
// `subtree_insertion_index = 1`. The first leaf of our subtree now will have index `2^d`;
|
||||
// you can see it as there's a full subtree sitting on its left. In general, for
|
||||
// `subtree_insertion_index = i`, there are `i` subtrees sitting before the subtree we want
|
||||
// to insert, so we need to adjust all its leaves by `i * 2^d`.
|
||||
let leaf_index_shift: u64 = subtree_insertion_index * 2_u64.pow(subtree.depth().into());
|
||||
for (subtree_leaf_idx, leaf_value) in subtree.leaves() {
|
||||
let new_leaf_idx = leaf_index_shift + subtree_leaf_idx;
|
||||
debug_assert!(new_leaf_idx < 2_u64.pow(self.depth().into()));
|
||||
|
||||
self.insert_leaf_node(new_leaf_idx, *leaf_value);
|
||||
}
|
||||
|
||||
// add subtree's branch nodes (which includes the root)
|
||||
// --------------
|
||||
for (branch_idx, branch_node) in subtree.branches {
|
||||
let new_branch_idx = {
|
||||
let new_depth = subtree_root_insertion_depth + branch_idx.depth();
|
||||
let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
|
||||
+ branch_idx.value();
|
||||
|
||||
NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
|
||||
};
|
||||
|
||||
self.branches.insert(new_branch_idx, branch_node);
|
||||
}
|
||||
|
||||
// recompute nodes starting from subtree root
|
||||
// --------------
|
||||
self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
|
||||
|
||||
Ok(self.root)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Recomputes the branch nodes (including the root) from `index` all the way to the root.
|
||||
/// `node_hash_at_index` is the hash of the node stored at index.
|
||||
fn recompute_nodes_from_index_to_root(
|
||||
&mut self,
|
||||
mut index: NodeIndex,
|
||||
node_hash_at_index: RpoDigest,
|
||||
) {
|
||||
let mut value = node_hash_at_index;
|
||||
for _ in 0..index.depth() {
|
||||
let is_right = index.is_value_odd();
|
||||
index.move_up();
|
||||
@@ -232,12 +323,8 @@ impl SimpleSmt {
|
||||
value = Rpo256::merge(&[left, right]);
|
||||
}
|
||||
self.root = value;
|
||||
Ok(old_value)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn get_leaf_node(&self, key: u64) -> Option<Word> {
|
||||
self.leaves.get(&key).copied()
|
||||
}
|
||||
@@ -248,8 +335,8 @@ impl SimpleSmt {
|
||||
|
||||
fn get_branch_node(&self, index: &NodeIndex) -> BranchNode {
|
||||
self.branches.get(index).cloned().unwrap_or_else(|| {
|
||||
let node = self.empty_hashes[index.depth() as usize + 1];
|
||||
BranchNode { left: node, right: node }
|
||||
let node = EmptySubtreeRoots::entry(self.depth, index.depth() + 1);
|
||||
BranchNode { left: *node, right: *node }
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use super::{
|
||||
NodeIndex, Rpo256, Vec,
|
||||
};
|
||||
use crate::{
|
||||
merkle::{digests_to_words, int_to_leaf, int_to_node},
|
||||
merkle::{digests_to_words, int_to_leaf, int_to_node, EmptySubtreeRoots},
|
||||
Word,
|
||||
};
|
||||
|
||||
@@ -71,6 +71,21 @@ fn build_sparse_tree() {
|
||||
assert_eq!(old_value, EMPTY_WORD);
|
||||
}
|
||||
|
||||
/// Tests that [`SimpleSmt::with_contiguous_leaves`] works as expected
|
||||
#[test]
|
||||
fn build_contiguous_tree() {
|
||||
let tree_with_leaves = SimpleSmt::with_leaves(
|
||||
2,
|
||||
[0, 1, 2, 3].into_iter().zip(digests_to_words(&VALUES4).into_iter()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let tree_with_contiguous_leaves =
|
||||
SimpleSmt::with_contiguous_leaves(2, digests_to_words(&VALUES4).into_iter()).unwrap();
|
||||
|
||||
assert_eq!(tree_with_leaves, tree_with_contiguous_leaves);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_depth2_tree() {
|
||||
let tree =
|
||||
@@ -214,22 +229,31 @@ fn small_tree_opening_is_consistent() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fail_on_duplicates() {
|
||||
let entries = [(1_u64, int_to_leaf(1)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(3))];
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert!(smt.is_err());
|
||||
fn test_simplesmt_fail_on_duplicates() {
|
||||
let values = [
|
||||
// same key, same value
|
||||
(int_to_leaf(1), int_to_leaf(1)),
|
||||
// same key, different values
|
||||
(int_to_leaf(1), int_to_leaf(2)),
|
||||
// same key, set to zero
|
||||
(EMPTY_WORD, int_to_leaf(1)),
|
||||
// same key, re-set to zero
|
||||
(int_to_leaf(1), EMPTY_WORD),
|
||||
// same key, set to zero twice
|
||||
(EMPTY_WORD, EMPTY_WORD),
|
||||
];
|
||||
|
||||
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(0))];
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert!(smt.is_err());
|
||||
for (first, second) in values.iter() {
|
||||
// consecutive
|
||||
let entries = [(1, *first), (1, *second)];
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||
|
||||
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(1))];
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert!(smt.is_err());
|
||||
|
||||
let entries = [(1_u64, int_to_leaf(1)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(0))];
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert!(smt.is_err());
|
||||
// not consecutive
|
||||
let entries = [(1, *first), (5, int_to_leaf(5)), (1, *second)];
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -239,6 +263,227 @@ fn with_no_duplicates_empty_node() {
|
||||
assert!(smt.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplesmt_update_nonexisting_leaf_with_zero() {
|
||||
// TESTING WITH EMPTY WORD
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
// Depth 1 has 2 leaf. Position is 0-indexed, position 2 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(1).unwrap();
|
||||
let result = smt.update_leaf(2, EMPTY_WORD);
|
||||
assert!(!smt.leaves.contains_key(&2));
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(2).unwrap();
|
||||
let result = smt.update_leaf(4, EMPTY_WORD);
|
||||
assert!(!smt.leaves.contains_key(&4));
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(3).unwrap();
|
||||
let result = smt.update_leaf(8, EMPTY_WORD);
|
||||
assert!(!smt.leaves.contains_key(&8));
|
||||
assert!(result.is_err());
|
||||
|
||||
// TESTING WITH A VALUE
|
||||
// --------------------------------------------------------------------------------------------
|
||||
let value = int_to_node(1);
|
||||
|
||||
// Depth 1 has 2 leaves. Position is 0-indexed, position 1 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(1).unwrap();
|
||||
let result = smt.update_leaf(2, *value);
|
||||
assert!(!smt.leaves.contains_key(&2));
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 2 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(2).unwrap();
|
||||
let result = smt.update_leaf(4, *value);
|
||||
assert!(!smt.leaves.contains_key(&4));
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(3).unwrap();
|
||||
let result = smt.update_leaf(8, *value);
|
||||
assert!(!smt.leaves.contains_key(&8));
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplesmt_with_leaves_nonexisting_leaf() {
|
||||
// TESTING WITH EMPTY WORD
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
// Depth 1 has 2 leaf. Position is 0-indexed, position 2 doesn't exist.
|
||||
let leaves = [(2, EMPTY_WORD)];
|
||||
let result = SimpleSmt::with_leaves(1, leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let leaves = [(4, EMPTY_WORD)];
|
||||
let result = SimpleSmt::with_leaves(2, leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||
let leaves = [(8, EMPTY_WORD)];
|
||||
let result = SimpleSmt::with_leaves(3, leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// TESTING WITH A VALUE
|
||||
// --------------------------------------------------------------------------------------------
|
||||
let value = int_to_node(1);
|
||||
|
||||
// Depth 1 has 2 leaves. Position is 0-indexed, position 2 doesn't exist.
|
||||
let leaves = [(2, *value)];
|
||||
let result = SimpleSmt::with_leaves(1, leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let leaves = [(4, *value)];
|
||||
let result = SimpleSmt::with_leaves(2, leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||
let leaves = [(8, *value)];
|
||||
let result = SimpleSmt::with_leaves(3, leaves);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplesmt_set_subtree() {
|
||||
// Final Tree:
|
||||
//
|
||||
// ____k____
|
||||
// / \
|
||||
// _i_ _j_
|
||||
// / \ / \
|
||||
// e f g h
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
let e = Rpo256::merge(&[a.into(), b.into()]);
|
||||
let f = Rpo256::merge(&[z.into(), z.into()]);
|
||||
let g = Rpo256::merge(&[c.into(), z.into()]);
|
||||
let h = Rpo256::merge(&[z.into(), d.into()]);
|
||||
|
||||
let i = Rpo256::merge(&[e, f]);
|
||||
let j = Rpo256::merge(&[g, h]);
|
||||
|
||||
let k = Rpo256::merge(&[i, j]);
|
||||
|
||||
// subtree:
|
||||
// g
|
||||
// / \
|
||||
// c 0
|
||||
let subtree = {
|
||||
let depth = 1;
|
||||
let entries = vec![(0, c)];
|
||||
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||
};
|
||||
|
||||
// insert subtree
|
||||
let tree = {
|
||||
let depth = 3;
|
||||
let entries = vec![(0, a), (1, b), (7, d)];
|
||||
let mut tree = SimpleSmt::with_leaves(depth, entries).unwrap();
|
||||
|
||||
tree.set_subtree(2, subtree).unwrap();
|
||||
|
||||
tree
|
||||
};
|
||||
|
||||
assert_eq!(tree.root(), k);
|
||||
assert_eq!(tree.get_leaf(4).unwrap(), c);
|
||||
assert_eq!(tree.get_branch_node(&NodeIndex::new_unchecked(2, 2)).parent(), g);
|
||||
}
|
||||
|
||||
/// Ensures that an invalid input node index into `set_subtree()` incurs no mutation of the tree
|
||||
#[test]
|
||||
fn test_simplesmt_set_subtree_unchanged_for_wrong_index() {
|
||||
// Final Tree:
|
||||
//
|
||||
// ____k____
|
||||
// / \
|
||||
// _i_ _j_
|
||||
// / \ / \
|
||||
// e f g h
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
// subtree:
|
||||
// g
|
||||
// / \
|
||||
// c 0
|
||||
let subtree = {
|
||||
let depth = 1;
|
||||
let entries = vec![(0, c)];
|
||||
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||
};
|
||||
|
||||
let mut tree = {
|
||||
let depth = 3;
|
||||
let entries = vec![(0, a), (1, b), (7, d)];
|
||||
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||
};
|
||||
let tree_root_before_insertion = tree.root();
|
||||
|
||||
// insert subtree
|
||||
assert!(tree.set_subtree(500, subtree).is_err());
|
||||
|
||||
assert_eq!(tree.root(), tree_root_before_insertion);
|
||||
}
|
||||
|
||||
/// We insert an empty subtree that has the same depth as the original tree
|
||||
#[test]
|
||||
fn test_simplesmt_set_subtree_entire_tree() {
|
||||
// Initial Tree:
|
||||
//
|
||||
// ____k____
|
||||
// / \
|
||||
// _i_ _j_
|
||||
// / \ / \
|
||||
// e f g h
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
let depth = 3;
|
||||
|
||||
// subtree: E3
|
||||
let subtree = { SimpleSmt::with_leaves(depth, Vec::new()).unwrap() };
|
||||
assert_eq!(subtree.root(), *EmptySubtreeRoots::entry(depth, 0));
|
||||
|
||||
// insert subtree
|
||||
let mut tree = {
|
||||
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
|
||||
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||
};
|
||||
|
||||
tree.set_subtree(0, subtree).unwrap();
|
||||
|
||||
assert_eq!(tree.root(), *EmptySubtreeRoots::entry(depth, 0));
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
use super::{
|
||||
DefaultMerkleStore as MerkleStore, EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex,
|
||||
PartialMerkleTree, RecordingMerkleStore, RpoDigest,
|
||||
PartialMerkleTree, RecordingMerkleStore, Rpo256, RpoDigest,
|
||||
};
|
||||
use crate::{
|
||||
hash::rpo::Rpo256,
|
||||
merkle::{digests_to_words, int_to_leaf, int_to_node, MerkleTree, SimpleSmt},
|
||||
Felt, Word, ONE, WORD_SIZE, ZERO,
|
||||
};
|
||||
|
||||
@@ -85,18 +85,26 @@ impl TieredSmtProof {
|
||||
/// Note: this method cannot be used to assert non-membership. That is, if false is returned,
|
||||
/// it does not mean that the provided key-value pair is not in the tree.
|
||||
pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool {
|
||||
if self.is_value_empty() {
|
||||
if value != &EMPTY_VALUE {
|
||||
return false;
|
||||
}
|
||||
// if the proof is for an empty value, we can verify it against any key which has a
|
||||
// common prefix with the key storied in entries, but the prefix must be greater than
|
||||
// the path length
|
||||
let common_prefix_tier = get_common_prefix_tier_depth(key, &self.entries[0].0);
|
||||
if common_prefix_tier < self.path.depth() {
|
||||
return false;
|
||||
}
|
||||
} else if !self.entries.contains(&(*key, *value)) {
|
||||
// Handles the following scenarios:
|
||||
// - the value is set
|
||||
// - empty leaf, there is an explicit entry for the key with the empty value
|
||||
// - shared 64-bit prefix, the target key is not included in the entries list, the value is implicitly the empty word
|
||||
let v = match self.entries.iter().find(|(k, _)| k == key) {
|
||||
Some((_, v)) => v,
|
||||
None => &EMPTY_VALUE,
|
||||
};
|
||||
|
||||
// The value must match for the proof to be valid
|
||||
if v != value {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If the proof is for an empty value, we can verify it against any key which has a common
|
||||
// prefix with the key storied in entries, but the prefix must be greater than the path
|
||||
// length
|
||||
if self.is_value_empty()
|
||||
&& get_common_prefix_tier_depth(key, &self.entries[0].0) < self.path.depth()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -715,6 +715,38 @@ fn tsmt_bottom_tier_two() {
|
||||
// GET PROOF TESTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Tests the membership and non-membership proof for a single at depth 64
|
||||
#[test]
|
||||
fn tsmt_get_proof_single_element_64() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
let raw_a = 0b_00000000_00000001_00000000_00000001_00000000_00000001_00000000_00000001_u64;
|
||||
let key_a = [ONE, ONE, ONE, raw_a.into()].into();
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// push element `a` to depth 64, by inserting another value that shares the 48-bit prefix
|
||||
let raw_b = 0b_00000000_00000001_00000000_00000001_00000000_00000001_00000000_00000000_u64;
|
||||
let key_b = [ONE, ONE, ONE, raw_b.into()].into();
|
||||
smt.insert(key_b, [ONE, ONE, ONE, ONE]);
|
||||
|
||||
// verify the proof for element `a`
|
||||
let proof = smt.prove(key_a);
|
||||
assert!(proof.verify_membership(&key_a, &value_a, &smt.root()));
|
||||
|
||||
// check that a value that is not inserted in the tree produces a valid membership proof for the
|
||||
// empty word
|
||||
let key = [ZERO, ZERO, ZERO, ZERO].into();
|
||||
let proof = smt.prove(key);
|
||||
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
|
||||
|
||||
// check that a key that shared the 64-bit prefix with `a`, but is not inserted, also has a
|
||||
// valid membership proof for the empty word
|
||||
let key = [ONE, ONE, ZERO, raw_a.into()].into();
|
||||
let proof = smt.prove(key);
|
||||
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_get_proof() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
//! Pseudo-random element generation.
|
||||
|
||||
pub use winter_crypto::{RandomCoin, RandomCoinError};
|
||||
pub use winter_crypto::{DefaultRandomCoin as WinterRandomCoin, RandomCoin, RandomCoinError};
|
||||
|
||||
use crate::{Felt, FieldElement, StarkField, Word, ZERO};
|
||||
|
||||
mod rpo;
|
||||
pub use rpo::RpoRandomCoin;
|
||||
|
||||
/// Pseudo-random element generator.
|
||||
///
|
||||
/// An instance can be used to draw, uniformly at random, base field elements as well as [Word]s.
|
||||
pub trait FeltRng {
|
||||
/// Draw, uniformly at random, a base field element.
|
||||
fn draw_element(&mut self) -> Felt;
|
||||
|
||||
/// Draw, uniformly at random, a [Word].
|
||||
fn draw_word(&mut self) -> Word;
|
||||
}
|
||||
|
||||
267
src/rand/rpo.rs
Normal file
267
src/rand/rpo.rs
Normal file
@@ -0,0 +1,267 @@
|
||||
use super::{Felt, FeltRng, FieldElement, StarkField, Word, ZERO};
|
||||
use crate::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
utils::{
|
||||
collections::Vec, string::ToString, vec, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, Serializable,
|
||||
},
|
||||
};
|
||||
pub use winter_crypto::{RandomCoin, RandomCoinError};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
const STATE_WIDTH: usize = Rpo256::STATE_WIDTH;
|
||||
const RATE_START: usize = Rpo256::RATE_RANGE.start;
|
||||
const RATE_END: usize = Rpo256::RATE_RANGE.end;
|
||||
const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start) / 2;
|
||||
|
||||
// RPO RANDOM COIN
|
||||
// ================================================================================================
|
||||
/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
|
||||
/// described in https://eprint.iacr.org/2011/499.pdf.
|
||||
///
|
||||
/// The simplification is related to the following facts:
|
||||
/// 1. A call to the reseed method implies one and only one call to the permutation function.
|
||||
/// This is possible because in our case we never reseed with more than 4 field elements.
|
||||
/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
|
||||
/// material.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct RpoRandomCoin {
|
||||
state: [Felt; STATE_WIDTH],
|
||||
current: usize,
|
||||
}
|
||||
|
||||
impl RpoRandomCoin {
|
||||
/// Returns a new [RpoRandomCoin] initialize with the specified seed.
|
||||
pub fn new(seed: Word) -> Self {
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
for i in 0..HALF_RATE_WIDTH {
|
||||
state[RATE_START + i] += seed[i];
|
||||
}
|
||||
|
||||
// Absorb
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
|
||||
RpoRandomCoin { state, current: RATE_START }
|
||||
}
|
||||
|
||||
/// Returns an [RpoRandomCoin] instantiated from the provided components.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if `current` is smaller than 4 or greater than or equal to 12.
|
||||
pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
|
||||
assert!(
|
||||
(RATE_START..RATE_END).contains(¤t),
|
||||
"current value outside of valid range"
|
||||
);
|
||||
Self { state, current }
|
||||
}
|
||||
|
||||
/// Returns components of this random coin.
|
||||
pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
|
||||
(self.state, self.current)
|
||||
}
|
||||
|
||||
fn draw_basefield(&mut self) -> Felt {
|
||||
if self.current == RATE_END {
|
||||
Rpo256::apply_permutation(&mut self.state);
|
||||
self.current = RATE_START;
|
||||
}
|
||||
|
||||
self.current += 1;
|
||||
self.state[self.current - 1]
|
||||
}
|
||||
}
|
||||
|
||||
// RANDOM COIN IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl RandomCoin for RpoRandomCoin {
|
||||
type BaseField = Felt;
|
||||
type Hasher = Rpo256;
|
||||
|
||||
fn new(seed: &[Self::BaseField]) -> Self {
|
||||
let digest: Word = Rpo256::hash_elements(seed).into();
|
||||
Self::new(digest)
|
||||
}
|
||||
|
||||
fn reseed(&mut self, data: RpoDigest) {
|
||||
// Reset buffer
|
||||
self.current = RATE_START;
|
||||
|
||||
// Add the new seed material to the first half of the rate portion of the RPO state
|
||||
let data: Word = data.into();
|
||||
|
||||
self.state[RATE_START] += data[0];
|
||||
self.state[RATE_START + 1] += data[1];
|
||||
self.state[RATE_START + 2] += data[2];
|
||||
self.state[RATE_START + 3] += data[3];
|
||||
|
||||
// Absorb
|
||||
Rpo256::apply_permutation(&mut self.state);
|
||||
}
|
||||
|
||||
fn check_leading_zeros(&self, value: u64) -> u32 {
|
||||
let value = Felt::new(value);
|
||||
let mut state_tmp = self.state;
|
||||
|
||||
state_tmp[RATE_START] += value;
|
||||
|
||||
Rpo256::apply_permutation(&mut state_tmp);
|
||||
|
||||
let first_rate_element = state_tmp[RATE_START].as_int();
|
||||
first_rate_element.trailing_zeros()
|
||||
}
|
||||
|
||||
fn draw<E: FieldElement<BaseField = Felt>>(&mut self) -> Result<E, RandomCoinError> {
|
||||
let ext_degree = E::EXTENSION_DEGREE;
|
||||
let mut result = vec![ZERO; ext_degree];
|
||||
for r in result.iter_mut().take(ext_degree) {
|
||||
*r = self.draw_basefield();
|
||||
}
|
||||
|
||||
let result = E::slice_from_base_elements(&result);
|
||||
Ok(result[0])
|
||||
}
|
||||
|
||||
fn draw_integers(
|
||||
&mut self,
|
||||
num_values: usize,
|
||||
domain_size: usize,
|
||||
nonce: u64,
|
||||
) -> Result<Vec<usize>, RandomCoinError> {
|
||||
assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
|
||||
assert!(num_values < domain_size, "number of values must be smaller than domain size");
|
||||
|
||||
// absorb the nonce
|
||||
let nonce = Felt::new(nonce);
|
||||
self.state[RATE_START] += nonce;
|
||||
Rpo256::apply_permutation(&mut self.state);
|
||||
|
||||
// reset the buffer
|
||||
self.current = RATE_START;
|
||||
|
||||
// determine how many bits are needed to represent valid values in the domain
|
||||
let v_mask = (domain_size - 1) as u64;
|
||||
|
||||
// draw values from PRNG until we get as many unique values as specified by num_queries
|
||||
let mut values = Vec::new();
|
||||
for _ in 0..1000 {
|
||||
// get the next pseudo-random field element
|
||||
let value = self.draw_basefield().as_int();
|
||||
|
||||
// use the mask to get a value within the range
|
||||
let value = (value & v_mask) as usize;
|
||||
|
||||
values.push(value);
|
||||
if values.len() == num_values {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if values.len() < num_values {
|
||||
return Err(RandomCoinError::FailedToDrawIntegers(num_values, values.len(), 1000));
|
||||
}
|
||||
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
|
||||
// FELT RNG IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl FeltRng for RpoRandomCoin {
|
||||
fn draw_element(&mut self) -> Felt {
|
||||
self.draw_basefield()
|
||||
}
|
||||
|
||||
fn draw_word(&mut self) -> Word {
|
||||
let mut output = [ZERO; 4];
|
||||
for o in output.iter_mut() {
|
||||
*o = self.draw_basefield();
|
||||
}
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl Serializable for RpoRandomCoin {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
self.state.iter().for_each(|v| v.write_into(target));
|
||||
// casting to u8 is OK because `current` is always between 4 and 12.
|
||||
target.write_u8(self.current as u8);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RpoRandomCoin {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let state = [
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
];
|
||||
let current = source.read_u8()? as usize;
|
||||
if !(RATE_START..RATE_END).contains(¤t) {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"current value outside of valid range".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(Self { state, current })
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Deserializable, FeltRng, RpoRandomCoin, Serializable, ZERO};
|
||||
use crate::ONE;
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_felt() {
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let output = rpocoin.draw_element();
|
||||
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let expected = rpocoin.draw_basefield();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_word() {
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let output = rpocoin.draw_word();
|
||||
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let mut expected = [ZERO; 4];
|
||||
for o in expected.iter_mut() {
|
||||
*o = rpocoin.draw_basefield();
|
||||
}
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_serialization() {
|
||||
let coin1 = RpoRandomCoin::from_parts([ONE; 12], 5);
|
||||
|
||||
let bytes = coin1.to_bytes();
|
||||
let coin2 = RpoRandomCoin::read_from_bytes(&bytes).unwrap();
|
||||
assert_eq!(coin1, coin2);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user