mirror of
https://github.com/arnaucube/miden-crypto.git
synced 2026-01-10 16:11:30 +01:00
Compare commits
40 Commits
v0.13.2
...
al-gkr-bas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f06fa30a9 | ||
|
|
0b074a795d | ||
|
|
862ccf54dd | ||
|
|
88bcdfd576 | ||
|
|
290894f497 | ||
|
|
4aac00884c | ||
|
|
2ef6f79656 | ||
|
|
5142e2fd31 | ||
|
|
9fb41337ec | ||
|
|
0296e05ccd | ||
|
|
499f97046d | ||
|
|
600feafe53 | ||
|
|
9d854f1fcb | ||
|
|
af76cb10d0 | ||
|
|
4758e0672f | ||
|
|
8bb080a91d | ||
|
|
e5f3b28645 | ||
|
|
29e0d07129 | ||
|
|
81a94ecbe7 | ||
|
|
223fbf887d | ||
|
|
9e77a7c9b7 | ||
|
|
894e20fe0c | ||
|
|
7ec7b06574 | ||
|
|
2499a8a2dd | ||
|
|
800994c69b | ||
|
|
26560605bf | ||
|
|
672340d0c2 | ||
|
|
8083b02aef | ||
|
|
ecb8719d45 | ||
|
|
4144f98560 | ||
|
|
c726050957 | ||
|
|
9239340888 | ||
|
|
97ee9298a4 | ||
|
|
bfae06e128 | ||
|
|
b4e2d63c10 | ||
|
|
9679329746 | ||
|
|
2bbea37dbe | ||
|
|
83000940da | ||
|
|
f44175e7a9 | ||
|
|
4cf8eebff5 |
20
.editorconfig
Normal file
20
.editorconfig
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# Documentation available at editorconfig.org
|
||||||
|
|
||||||
|
root=true
|
||||||
|
|
||||||
|
[*]
|
||||||
|
ident_style = space
|
||||||
|
ident_size = 4
|
||||||
|
end_of_line = lf
|
||||||
|
charset = utf-8
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
insert_final_newline = true
|
||||||
|
|
||||||
|
[*.rs]
|
||||||
|
max_line_length = 100
|
||||||
|
|
||||||
|
[*.md]
|
||||||
|
trim_trailing_whitespace = false
|
||||||
|
|
||||||
|
[*.yml]
|
||||||
|
ident_size = 2
|
||||||
@@ -1,3 +1,11 @@
|
|||||||
|
## 0.8.0 (TBD)
|
||||||
|
|
||||||
|
* Implemented the `PartialMmr` data structure (#195).
|
||||||
|
* Updated Winterfell dependency to v0.7 (#200)
|
||||||
|
* Implemented RPX hash function (#201).
|
||||||
|
* Added `FeltRng` and `RpoRandomCoin` (#237).
|
||||||
|
* Added `inner_nodes()` method to `PartialMmr` (#238).
|
||||||
|
|
||||||
## 0.7.1 (2023-10-10)
|
## 0.7.1 (2023-10-10)
|
||||||
|
|
||||||
* Fixed RPO Falcon signature build on Windows.
|
* Fixed RPO Falcon signature build on Windows.
|
||||||
@@ -12,7 +20,6 @@
|
|||||||
* Implemented benchmarking for `TieredSmt` (#182).
|
* Implemented benchmarking for `TieredSmt` (#182).
|
||||||
* Added more leaf traversal methods for `MerkleStore` (#185).
|
* Added more leaf traversal methods for `MerkleStore` (#185).
|
||||||
* Added SVE acceleration for RPO hash function (#189).
|
* Added SVE acceleration for RPO hash function (#189).
|
||||||
* Implemented the `PartialMmr` datastructure (#195).
|
|
||||||
|
|
||||||
## 0.6.0 (2023-06-25)
|
## 0.6.0 (2023-06-25)
|
||||||
|
|
||||||
|
|||||||
17
Cargo.toml
17
Cargo.toml
@@ -1,12 +1,12 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "miden-crypto"
|
name = "miden-crypto"
|
||||||
version = "0.7.1"
|
version = "0.8.0"
|
||||||
description = "Miden Cryptographic primitives"
|
description = "Miden Cryptographic primitives"
|
||||||
authors = ["miden contributors"]
|
authors = ["miden contributors"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
repository = "https://github.com/0xPolygonMiden/crypto"
|
repository = "https://github.com/0xPolygonMiden/crypto"
|
||||||
documentation = "https://docs.rs/miden-crypto/0.7.1"
|
documentation = "https://docs.rs/miden-crypto/0.8.0"
|
||||||
categories = ["cryptography", "no-std"]
|
categories = ["cryptography", "no-std"]
|
||||||
keywords = ["miden", "crypto", "hash", "merkle"]
|
keywords = ["miden", "crypto", "hash", "merkle"]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
@@ -42,16 +42,19 @@ sve = ["std"]
|
|||||||
blake3 = { version = "1.5", default-features = false }
|
blake3 = { version = "1.5", default-features = false }
|
||||||
clap = { version = "4.4", features = ["derive"], optional = true }
|
clap = { version = "4.4", features = ["derive"], optional = true }
|
||||||
libc = { version = "0.2", default-features = false, optional = true }
|
libc = { version = "0.2", default-features = false, optional = true }
|
||||||
rand_utils = { version = "0.6", package = "winter-rand-utils", optional = true }
|
rand_utils = { version = "0.7", package = "winter-rand-utils", optional = true }
|
||||||
serde = { version = "1.0", features = [ "derive" ], default-features = false, optional = true }
|
serde = { version = "1.0", features = [ "derive" ], default-features = false, optional = true }
|
||||||
winter_crypto = { version = "0.6", package = "winter-crypto", default-features = false }
|
winter_crypto = { version = "0.7", package = "winter-crypto", default-features = false }
|
||||||
winter_math = { version = "0.6", package = "winter-math", default-features = false }
|
winter_math = { version = "0.7", package = "winter-math", default-features = false }
|
||||||
winter_utils = { version = "0.6", package = "winter-utils", default-features = false }
|
winter_utils = { version = "0.7", package = "winter-utils", default-features = false }
|
||||||
|
rayon = "1.8.0"
|
||||||
|
rand = "0.8.4"
|
||||||
|
rand_core = { version = "0.5", default-features = false }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = { version = "0.5", features = ["html_reports"] }
|
criterion = { version = "0.5", features = ["html_reports"] }
|
||||||
proptest = "1.3"
|
proptest = "1.3"
|
||||||
rand_utils = { version = "0.6", package = "winter-rand-utils" }
|
rand_utils = { version = "0.7", package = "winter-rand-utils" }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
cc = { version = "1.0", features = ["parallel"], optional = true }
|
cc = { version = "1.0", features = ["parallel"], optional = true }
|
||||||
|
|||||||
10
README.md
10
README.md
@@ -6,6 +6,7 @@ This crate contains cryptographic primitives used in Polygon Miden.
|
|||||||
|
|
||||||
* [BLAKE3](https://github.com/BLAKE3-team/BLAKE3) hash function with 256-bit, 192-bit, or 160-bit output. The 192-bit and 160-bit outputs are obtained by truncating the 256-bit output of the standard BLAKE3.
|
* [BLAKE3](https://github.com/BLAKE3-team/BLAKE3) hash function with 256-bit, 192-bit, or 160-bit output. The 192-bit and 160-bit outputs are obtained by truncating the 256-bit output of the standard BLAKE3.
|
||||||
* [RPO](https://eprint.iacr.org/2022/1577) hash function with 256-bit output. This hash function is an algebraic hash function suitable for recursive STARKs.
|
* [RPO](https://eprint.iacr.org/2022/1577) hash function with 256-bit output. This hash function is an algebraic hash function suitable for recursive STARKs.
|
||||||
|
* [RPX](https://eprint.iacr.org/2023/1045) hash function with 256-bit output. Similar to RPO, this hash function is suitable for recursive STARKs but it is about 2x faster as compared to RPO.
|
||||||
|
|
||||||
For performance benchmarks of these hash functions and their comparison to other popular hash functions please see [here](./benches/).
|
For performance benchmarks of these hash functions and their comparison to other popular hash functions please see [here](./benches/).
|
||||||
|
|
||||||
@@ -16,18 +17,25 @@ For performance benchmarks of these hash functions and their comparison to other
|
|||||||
* `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64.
|
* `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64.
|
||||||
* `Mmr`: a Merkle mountain range structure designed to function as an append-only log.
|
* `Mmr`: a Merkle mountain range structure designed to function as an append-only log.
|
||||||
* `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64.
|
* `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64.
|
||||||
|
* `PartialMmr`: a partial view of a Merkle mountain range structure.
|
||||||
* `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values.
|
* `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values.
|
||||||
* `TieredSmt`: a Sparse Merkle tree (with compaction), mapping 4-element keys to 4-element values.
|
* `TieredSmt`: a Sparse Merkle tree (with compaction), mapping 4-element keys to 4-element values.
|
||||||
|
|
||||||
The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state.
|
The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state.
|
||||||
|
|
||||||
## Signatures
|
## Signatures
|
||||||
[DAS module](./src/dsa) provides a set of digital signature schemes supported by default in the Miden VM. Currently, these schemes are:
|
[DSA module](./src/dsa) provides a set of digital signature schemes supported by default in the Miden VM. Currently, these schemes are:
|
||||||
|
|
||||||
* `RPO Falcon512`: a variant of the [Falcon](https://falcon-sign.info/) signature scheme. This variant differs from the standard in that instead of using SHAKE256 hash function in the *hash-to-point* algorithm we use RPO256. This makes the signature more efficient to verify in Miden VM.
|
* `RPO Falcon512`: a variant of the [Falcon](https://falcon-sign.info/) signature scheme. This variant differs from the standard in that instead of using SHAKE256 hash function in the *hash-to-point* algorithm we use RPO256. This makes the signature more efficient to verify in Miden VM.
|
||||||
|
|
||||||
For the above signatures, key generation and signing is available only in the `std` context (see [crate features](#crate-features) below), while signature verification is available in `no_std` context as well.
|
For the above signatures, key generation and signing is available only in the `std` context (see [crate features](#crate-features) below), while signature verification is available in `no_std` context as well.
|
||||||
|
|
||||||
|
## Pseudo-Random Element Generator
|
||||||
|
[Pseudo random element generator module](./src/rand/) provides a set of traits and data structures that facilitate generating pseudo-random elements in the context of Miden VM and Miden rollup. The module currently includes:
|
||||||
|
|
||||||
|
* `FeltRng`: a trait for generating random field elements and random 4 field elements.
|
||||||
|
* `RpoRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait.
|
||||||
|
|
||||||
## Crate features
|
## Crate features
|
||||||
This crate can be compiled with the following features:
|
This crate can be compiled with the following features:
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ In the Miden VM, we make use of different hash functions. Some of these are "tra
|
|||||||
* **Poseidon** as specified [here](https://eprint.iacr.org/2019/458.pdf) and implemented [here](https://github.com/mir-protocol/plonky2/blob/806b88d7d6e69a30dc0b4775f7ba275c45e8b63b/plonky2/src/hash/poseidon_goldilocks.rs) (but in pure Rust, without vectorized instructions).
|
* **Poseidon** as specified [here](https://eprint.iacr.org/2019/458.pdf) and implemented [here](https://github.com/mir-protocol/plonky2/blob/806b88d7d6e69a30dc0b4775f7ba275c45e8b63b/plonky2/src/hash/poseidon_goldilocks.rs) (but in pure Rust, without vectorized instructions).
|
||||||
* **Rescue Prime (RP)** as specified [here](https://eprint.iacr.org/2020/1143) and implemented [here](https://github.com/novifinancial/winterfell/blob/46dce1adf0/crypto/src/hash/rescue/rp64_256/mod.rs).
|
* **Rescue Prime (RP)** as specified [here](https://eprint.iacr.org/2020/1143) and implemented [here](https://github.com/novifinancial/winterfell/blob/46dce1adf0/crypto/src/hash/rescue/rp64_256/mod.rs).
|
||||||
* **Rescue Prime Optimized (RPO)** as specified [here](https://eprint.iacr.org/2022/1577) and implemented in this crate.
|
* **Rescue Prime Optimized (RPO)** as specified [here](https://eprint.iacr.org/2022/1577) and implemented in this crate.
|
||||||
|
* **Rescue Prime Extended (RPX)** a variant of the [xHash](https://eprint.iacr.org/2023/1045) hash function as implemented in this crate.
|
||||||
|
|
||||||
## Comparison and Instructions
|
## Comparison and Instructions
|
||||||
|
|
||||||
@@ -15,28 +16,28 @@ The second scenario is that of sequential hashing where we take a sequence of le
|
|||||||
|
|
||||||
#### Scenario 1: 2-to-1 hashing `h(a,b)`
|
#### Scenario 1: 2-to-1 hashing `h(a,b)`
|
||||||
|
|
||||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 |
|
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
|
||||||
| ------------------- | ------ | --------| --------- | --------- | ------- |
|
| ------------------- | ------ | ------- | --------- | --------- | ------- | ------- |
|
||||||
| Apple M1 Pro | 80 ns | 245 ns | 1.5 us | 9.1 us | 5.4 us |
|
| Apple M1 Pro | 76 ns | 245 ns | 1.5 µs | 9.1 µs | 5.2 µs | 2.7 µs |
|
||||||
| Apple M2 | 76 ns | 233 ns | 1.3 us | 7.9 us | 5.0 us |
|
| Apple M2 Max | 71 ns | 233 ns | 1.3 µs | 7.9 µs | 4.6 µs | 2.4 µs |
|
||||||
| Amazon Graviton 3 | 108 ns | | | | 5.3 us |
|
| Amazon Graviton 3 | 108 ns | | | | 5.3 µs | 3.1 µs |
|
||||||
| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 us | 9.1 us | 5.5 us |
|
| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 µs | 9.1 µs | 5.5 µs | |
|
||||||
| Intel Core i5-8279U | 80 ns | | | | 8.7 us |
|
| Intel Core i5-8279U | 68 ns | 536 ns | 2.0 µs | 13.6 µs | 8.5 µs | 4.4 µs |
|
||||||
| Intel Xeon 8375C | 67 ns | | | | 8.2 us |
|
| Intel Xeon 8375C | 67 ns | | | | 8.2 µs | |
|
||||||
|
|
||||||
#### Scenario 2: Sequential hashing of 100 elements `h([a_0,...,a_99])`
|
#### Scenario 2: Sequential hashing of 100 elements `h([a_0,...,a_99])`
|
||||||
|
|
||||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 |
|
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
|
||||||
| ------------------- | -------| ------- | --------- | --------- | ------- |
|
| ------------------- | -------| ------- | --------- | --------- | ------- | ------- |
|
||||||
| Apple M1 Pro | 1.0 us | 1.5 us | 19.4 us | 118 us | 70 us |
|
| Apple M1 Pro | 1.0 µs | 1.5 µs | 19.4 µs | 118 µs | 69 µs | 35 µs |
|
||||||
| Apple M2 | 1.0 us | 1.5 us | 17.4 us | 103 us | 65 us |
|
| Apple M2 Max | 0.9 µs | 1.5 µs | 17.4 µs | 103 µs | 60 µs | 31 µs |
|
||||||
| Amazon Graviton 3 | 1.4 us | | | | 69 us |
|
| Amazon Graviton 3 | 1.4 µs | | | | 69 µs | 41 µs |
|
||||||
| AMD Ryzen 9 5950X | 0.8 us | 1.7 us | 15.7 us | 120 us | 72 us |
|
| AMD Ryzen 9 5950X | 0.8 µs | 1.7 µs | 15.7 µs | 120 µs | 72 µs | |
|
||||||
| Intel Core i5-8279U | 1.0 us | | | | 116 us |
|
| Intel Core i5-8279U | 0.9 µs | | | | 107 µs | 56 µs |
|
||||||
| Intel Xeon 8375C | 0.8 ns | | | | 110 us |
|
| Intel Xeon 8375C | 0.8 µs | | | | 110 µs | |
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- On Graviton 3, RPO256 is run with SVE acceleration enabled.
|
- On Graviton 3, RPO256 and RPX256 are run with SVE acceleration enabled.
|
||||||
|
|
||||||
### Instructions
|
### Instructions
|
||||||
Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks for RPO and BLAKE3, clone the current repository, and from the root directory of the repo run the following:
|
Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks for RPO and BLAKE3, clone the current repository, and from the root directory of the repo run the following:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ use miden_crypto::{
|
|||||||
hash::{
|
hash::{
|
||||||
blake::Blake3_256,
|
blake::Blake3_256,
|
||||||
rpo::{Rpo256, RpoDigest},
|
rpo::{Rpo256, RpoDigest},
|
||||||
|
rpx::{Rpx256, RpxDigest},
|
||||||
},
|
},
|
||||||
Felt,
|
Felt,
|
||||||
};
|
};
|
||||||
@@ -57,6 +58,54 @@ fn rpo256_sequential(c: &mut Criterion) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn rpx256_2to1(c: &mut Criterion) {
|
||||||
|
let v: [RpxDigest; 2] = [Rpx256::hash(&[1_u8]), Rpx256::hash(&[2_u8])];
|
||||||
|
c.bench_function("RPX256 2-to-1 hashing (cached)", |bench| {
|
||||||
|
bench.iter(|| Rpx256::merge(black_box(&v)))
|
||||||
|
});
|
||||||
|
|
||||||
|
c.bench_function("RPX256 2-to-1 hashing (random)", |bench| {
|
||||||
|
bench.iter_batched(
|
||||||
|
|| {
|
||||||
|
[
|
||||||
|
Rpx256::hash(&rand_value::<u64>().to_le_bytes()),
|
||||||
|
Rpx256::hash(&rand_value::<u64>().to_le_bytes()),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
|state| Rpx256::merge(&state),
|
||||||
|
BatchSize::SmallInput,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rpx256_sequential(c: &mut Criterion) {
|
||||||
|
let v: [Felt; 100] = (0..100)
|
||||||
|
.into_iter()
|
||||||
|
.map(Felt::new)
|
||||||
|
.collect::<Vec<Felt>>()
|
||||||
|
.try_into()
|
||||||
|
.expect("should not fail");
|
||||||
|
c.bench_function("RPX256 sequential hashing (cached)", |bench| {
|
||||||
|
bench.iter(|| Rpx256::hash_elements(black_box(&v)))
|
||||||
|
});
|
||||||
|
|
||||||
|
c.bench_function("RPX256 sequential hashing (random)", |bench| {
|
||||||
|
bench.iter_batched(
|
||||||
|
|| {
|
||||||
|
let v: [Felt; 100] = (0..100)
|
||||||
|
.into_iter()
|
||||||
|
.map(|_| Felt::new(rand_value()))
|
||||||
|
.collect::<Vec<Felt>>()
|
||||||
|
.try_into()
|
||||||
|
.expect("should not fail");
|
||||||
|
v
|
||||||
|
},
|
||||||
|
|state| Rpx256::hash_elements(&state),
|
||||||
|
BatchSize::SmallInput,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
fn blake3_2to1(c: &mut Criterion) {
|
fn blake3_2to1(c: &mut Criterion) {
|
||||||
let v: [<Blake3_256 as Hasher>::Digest; 2] =
|
let v: [<Blake3_256 as Hasher>::Digest; 2] =
|
||||||
[Blake3_256::hash(&[1_u8]), Blake3_256::hash(&[2_u8])];
|
[Blake3_256::hash(&[1_u8]), Blake3_256::hash(&[2_u8])];
|
||||||
@@ -106,5 +155,13 @@ fn blake3_sequential(c: &mut Criterion) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
criterion_group!(hash_group, rpo256_2to1, rpo256_sequential, blake3_2to1, blake3_sequential);
|
criterion_group!(
|
||||||
|
hash_group,
|
||||||
|
rpx256_2to1,
|
||||||
|
rpx256_sequential,
|
||||||
|
rpo256_2to1,
|
||||||
|
rpo256_sequential,
|
||||||
|
blake3_2to1,
|
||||||
|
blake3_sequential
|
||||||
|
);
|
||||||
criterion_main!(hash_group);
|
criterion_main!(hash_group);
|
||||||
|
|||||||
@@ -147,7 +147,12 @@ impl KeyPair {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if res == 0 {
|
if res == 0 {
|
||||||
Ok(Signature { sig, pk: self.public_key })
|
Ok(Signature {
|
||||||
|
sig,
|
||||||
|
pk: self.public_key,
|
||||||
|
pk_polynomial: Default::default(),
|
||||||
|
sig_polynomial: Default::default(),
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
Err(FalconError::SigGenerationFailed)
|
Err(FalconError::SigGenerationFailed)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ use super::{
|
|||||||
SIG_L2_BOUND, ZERO,
|
SIG_L2_BOUND, ZERO,
|
||||||
};
|
};
|
||||||
use crate::utils::string::ToString;
|
use crate::utils::string::ToString;
|
||||||
|
use core::cell::OnceCell;
|
||||||
|
|
||||||
// FALCON SIGNATURE
|
// FALCON SIGNATURE
|
||||||
// ================================================================================================
|
// ================================================================================================
|
||||||
@@ -43,6 +44,10 @@ use crate::utils::string::ToString;
|
|||||||
pub struct Signature {
|
pub struct Signature {
|
||||||
pub(super) pk: PublicKeyBytes,
|
pub(super) pk: PublicKeyBytes,
|
||||||
pub(super) sig: SignatureBytes,
|
pub(super) sig: SignatureBytes,
|
||||||
|
|
||||||
|
// Cached polynomial decoding for public key and signatures
|
||||||
|
pub(super) pk_polynomial: OnceCell<Polynomial>,
|
||||||
|
pub(super) sig_polynomial: OnceCell<Polynomial>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Signature {
|
impl Signature {
|
||||||
@@ -51,10 +56,11 @@ impl Signature {
|
|||||||
|
|
||||||
/// Returns the public key polynomial h.
|
/// Returns the public key polynomial h.
|
||||||
pub fn pub_key_poly(&self) -> Polynomial {
|
pub fn pub_key_poly(&self) -> Polynomial {
|
||||||
// TODO: memoize
|
*self.pk_polynomial.get_or_init(|| {
|
||||||
// we assume that the signature was constructed with a valid public key, and thus
|
// we assume that the signature was constructed with a valid public key, and thus
|
||||||
// expect() is OK here.
|
// expect() is OK here.
|
||||||
Polynomial::from_pub_key(&self.pk).expect("invalid public key")
|
Polynomial::from_pub_key(&self.pk).expect("invalid public key")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the nonce component of the signature represented as field elements.
|
/// Returns the nonce component of the signature represented as field elements.
|
||||||
@@ -70,10 +76,11 @@ impl Signature {
|
|||||||
|
|
||||||
// Returns the polynomial representation of the signature in Z_p[x]/(phi).
|
// Returns the polynomial representation of the signature in Z_p[x]/(phi).
|
||||||
pub fn sig_poly(&self) -> Polynomial {
|
pub fn sig_poly(&self) -> Polynomial {
|
||||||
// TODO: memoize
|
*self.sig_polynomial.get_or_init(|| {
|
||||||
// we assume that the signature was constructed with a valid signature, and thus
|
// we assume that the signature was constructed with a valid signature, and thus
|
||||||
// expect() is OK here.
|
// expect() is OK here.
|
||||||
Polynomial::from_signature(&self.sig).expect("invalid signature")
|
Polynomial::from_signature(&self.sig).expect("invalid signature")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// HASH-TO-POINT
|
// HASH-TO-POINT
|
||||||
@@ -123,12 +130,14 @@ impl Deserializable for Signature {
|
|||||||
let sig: SignatureBytes = source.read_array()?;
|
let sig: SignatureBytes = source.read_array()?;
|
||||||
|
|
||||||
// make sure public key and signature can be decoded correctly
|
// make sure public key and signature can be decoded correctly
|
||||||
Polynomial::from_pub_key(&pk)
|
let pk_polynomial = Polynomial::from_pub_key(&pk)
|
||||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?;
|
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?
|
||||||
Polynomial::from_signature(&sig[41..])
|
.into();
|
||||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?;
|
let sig_polynomial = Polynomial::from_signature(&sig[41..])
|
||||||
|
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?
|
||||||
|
.into();
|
||||||
|
|
||||||
Ok(Self { pk, sig })
|
Ok(Self { pk, sig, pk_polynomial, sig_polynomial })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
977
src/gkr/circuit/mod.rs
Normal file
977
src/gkr/circuit/mod.rs
Normal file
@@ -0,0 +1,977 @@
|
|||||||
|
use alloc::sync::Arc;
|
||||||
|
use winter_crypto::{ElementHasher, RandomCoin};
|
||||||
|
use winter_math::fields::f64::BaseElement;
|
||||||
|
use winter_math::FieldElement;
|
||||||
|
|
||||||
|
use crate::gkr::multivariate::{
|
||||||
|
ComposedMultiLinearsOracle, EqPolynomial, GkrCompositionVanilla, MultiLinearOracle,
|
||||||
|
};
|
||||||
|
use crate::gkr::sumcheck::{sum_check_verify, Claim};
|
||||||
|
|
||||||
|
use super::multivariate::{
|
||||||
|
gen_plain_gkr_oracle, gkr_composition_from_composition_polys, ComposedMultiLinears,
|
||||||
|
CompositionPolynomial, MultiLinear,
|
||||||
|
};
|
||||||
|
use super::sumcheck::{
|
||||||
|
sum_check_prove, sum_check_verify_and_reduce, FinalEvaluationClaim,
|
||||||
|
PartialProof as SumcheckInstanceProof, RoundProof as SumCheckRoundProof, Witness,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Layered circuit for computing a sum of fractions.
|
||||||
|
///
|
||||||
|
/// The circuit computes a sum of fractions based on the formula a / c + b / d = (a * d + b * c) / (c * d)
|
||||||
|
/// which defines a "gate" ((a, b), (c, d)) --> (a * d + b * c, c * d) upon which the `FractionalSumCircuit`
|
||||||
|
/// is built.
|
||||||
|
/// TODO: Swap 1 and 0
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct FractionalSumCircuit<E: FieldElement> {
|
||||||
|
p_1_vec: Vec<MultiLinear<E>>,
|
||||||
|
p_0_vec: Vec<MultiLinear<E>>,
|
||||||
|
q_1_vec: Vec<MultiLinear<E>>,
|
||||||
|
q_0_vec: Vec<MultiLinear<E>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: FieldElement> FractionalSumCircuit<E> {
|
||||||
|
/// Computes The values of the gates outputs for each of the layers of the fractional sum circuit.
|
||||||
|
pub fn new_(num_den: &Vec<MultiLinear<E>>) -> Self {
|
||||||
|
let mut p_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||||
|
let mut p_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||||
|
let mut q_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||||
|
let mut q_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||||
|
|
||||||
|
let num_layers = num_den[0].len().ilog2() as usize;
|
||||||
|
|
||||||
|
p_1_vec.push(num_den[0].to_owned());
|
||||||
|
p_0_vec.push(num_den[1].to_owned());
|
||||||
|
q_1_vec.push(num_den[2].to_owned());
|
||||||
|
q_0_vec.push(num_den[3].to_owned());
|
||||||
|
|
||||||
|
for i in 0..num_layers {
|
||||||
|
let (output_p_1, output_p_0, output_q_1, output_q_0) =
|
||||||
|
FractionalSumCircuit::compute_layer(
|
||||||
|
&p_1_vec[i],
|
||||||
|
&p_0_vec[i],
|
||||||
|
&q_1_vec[i],
|
||||||
|
&q_0_vec[i],
|
||||||
|
);
|
||||||
|
p_1_vec.push(output_p_1);
|
||||||
|
p_0_vec.push(output_p_0);
|
||||||
|
q_1_vec.push(output_q_1);
|
||||||
|
q_0_vec.push(output_q_0);
|
||||||
|
}
|
||||||
|
|
||||||
|
FractionalSumCircuit { p_1_vec, p_0_vec, q_1_vec, q_0_vec }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute the output values of the layer given a set of input values
|
||||||
|
fn compute_layer(
|
||||||
|
inp_p_1: &MultiLinear<E>,
|
||||||
|
inp_p_0: &MultiLinear<E>,
|
||||||
|
inp_q_1: &MultiLinear<E>,
|
||||||
|
inp_q_0: &MultiLinear<E>,
|
||||||
|
) -> (MultiLinear<E>, MultiLinear<E>, MultiLinear<E>, MultiLinear<E>) {
|
||||||
|
let len = inp_q_1.len();
|
||||||
|
let outp_p_1 = (0..len / 2)
|
||||||
|
.map(|i| inp_p_1[i] * inp_q_0[i] + inp_p_0[i] * inp_q_1[i])
|
||||||
|
.collect::<Vec<E>>();
|
||||||
|
let outp_p_0 = (len / 2..len)
|
||||||
|
.map(|i| inp_p_1[i] * inp_q_0[i] + inp_p_0[i] * inp_q_1[i])
|
||||||
|
.collect::<Vec<E>>();
|
||||||
|
let outp_q_1 = (0..len / 2).map(|i| inp_q_1[i] * inp_q_0[i]).collect::<Vec<E>>();
|
||||||
|
let outp_q_0 = (len / 2..len).map(|i| inp_q_1[i] * inp_q_0[i]).collect::<Vec<E>>();
|
||||||
|
|
||||||
|
(
|
||||||
|
MultiLinear::new(outp_p_1),
|
||||||
|
MultiLinear::new(outp_p_0),
|
||||||
|
MultiLinear::new(outp_q_1),
|
||||||
|
MultiLinear::new(outp_q_0),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes The values of the gates outputs for each of the layers of the fractional sum circuit.
|
||||||
|
pub fn new(poly: &MultiLinear<E>) -> Self {
|
||||||
|
let mut p_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||||
|
let mut p_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||||
|
let mut q_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||||
|
let mut q_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||||
|
|
||||||
|
let num_layers = poly.len().ilog2() as usize - 1;
|
||||||
|
let (output_p, output_q) = poly.split(poly.len() / 2);
|
||||||
|
let (output_p_1, output_p_0) = output_p.split(output_p.len() / 2);
|
||||||
|
let (output_q_1, output_q_0) = output_q.split(output_q.len() / 2);
|
||||||
|
|
||||||
|
p_1_vec.push(output_p_1);
|
||||||
|
p_0_vec.push(output_p_0);
|
||||||
|
q_1_vec.push(output_q_1);
|
||||||
|
q_0_vec.push(output_q_0);
|
||||||
|
|
||||||
|
for i in 0..num_layers - 1 {
|
||||||
|
let (output_p_1, output_p_0, output_q_1, output_q_0) =
|
||||||
|
FractionalSumCircuit::compute_layer(
|
||||||
|
&p_1_vec[i],
|
||||||
|
&p_0_vec[i],
|
||||||
|
&q_1_vec[i],
|
||||||
|
&q_0_vec[i],
|
||||||
|
);
|
||||||
|
p_1_vec.push(output_p_1);
|
||||||
|
p_0_vec.push(output_p_0);
|
||||||
|
q_1_vec.push(output_q_1);
|
||||||
|
q_0_vec.push(output_q_0);
|
||||||
|
}
|
||||||
|
|
||||||
|
FractionalSumCircuit { p_1_vec, p_0_vec, q_1_vec, q_0_vec }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Given a value r, computes the evaluation of the last layer at r when interpreted as (two)
|
||||||
|
/// multilinear polynomials.
|
||||||
|
pub fn evaluate(&self, r: E) -> (E, E) {
|
||||||
|
let len = self.p_1_vec.len();
|
||||||
|
assert_eq!(self.p_1_vec[len - 1].num_variables(), 0);
|
||||||
|
assert_eq!(self.p_0_vec[len - 1].num_variables(), 0);
|
||||||
|
assert_eq!(self.q_1_vec[len - 1].num_variables(), 0);
|
||||||
|
assert_eq!(self.q_0_vec[len - 1].num_variables(), 0);
|
||||||
|
|
||||||
|
let mut p_1 = self.p_1_vec[len - 1].clone();
|
||||||
|
p_1.extend(&self.p_0_vec[len - 1]);
|
||||||
|
let mut q_1 = self.q_1_vec[len - 1].clone();
|
||||||
|
q_1.extend(&self.q_0_vec[len - 1]);
|
||||||
|
|
||||||
|
(p_1.evaluate(&[r]), q_1.evaluate(&[r]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A proof for reducing a claim on the correctness of the output of a layer to that of:
|
||||||
|
///
|
||||||
|
/// 1. Correctness of a sumcheck proof on the claimed output.
|
||||||
|
/// 2. Correctness of the evaluation of the input (to the said layer) at a random point when
|
||||||
|
/// interpreted as multilinear polynomial.
|
||||||
|
///
|
||||||
|
/// The verifier will then have to work backward and:
|
||||||
|
///
|
||||||
|
/// 1. Verify that the sumcheck proof is valid.
|
||||||
|
/// 2. Recurse on the (claimed evaluations) using the same approach as above.
|
||||||
|
///
|
||||||
|
/// Note that the following struct batches proofs for many circuits of the same type that
|
||||||
|
/// are independent i.e., parallel.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct LayerProof<E: FieldElement> {
|
||||||
|
pub proof: SumcheckInstanceProof<E>,
|
||||||
|
pub claims_sum_p1: E,
|
||||||
|
pub claims_sum_p0: E,
|
||||||
|
pub claims_sum_q1: E,
|
||||||
|
pub claims_sum_q0: E,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
impl<E: FieldElement<BaseField = BaseElement> + 'static> LayerProof<E> {
|
||||||
|
/// Checks the validity of a `LayerProof`.
|
||||||
|
///
|
||||||
|
/// It first reduces the 2 claims to 1 claim using randomness and then checks that the sumcheck
|
||||||
|
/// protocol was correctly executed.
|
||||||
|
///
|
||||||
|
/// The method outputs:
|
||||||
|
///
|
||||||
|
/// 1. A vector containing the randomness sent by the verifier throughout the course of the
|
||||||
|
/// sum-check protocol.
|
||||||
|
/// 2. The (claimed) evaluation of the inner polynomial (i.e., the one being summed) at the this random vector.
|
||||||
|
/// 3. The random value used in the 2-to-1 reduction of the 2 sumchecks.
|
||||||
|
pub fn verify_sum_check_before_last<
|
||||||
|
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||||
|
H: ElementHasher<BaseField = BaseElement>,
|
||||||
|
>(
|
||||||
|
&self,
|
||||||
|
claim: (E, E),
|
||||||
|
num_rounds: usize,
|
||||||
|
transcript: &mut C,
|
||||||
|
) -> ((E, Vec<E>), E) {
|
||||||
|
// Absorb the claims
|
||||||
|
let data = vec![claim.0, claim.1];
|
||||||
|
transcript.reseed(H::hash_elements(&data));
|
||||||
|
|
||||||
|
// Squeeze challenge to reduce two sumchecks to one
|
||||||
|
let r_sum_check: E = transcript.draw().unwrap();
|
||||||
|
|
||||||
|
// Run the sumcheck protocol
|
||||||
|
|
||||||
|
// Given r_sum_check and claim, we create a Claim with the GKR composer and then call the generic sum-check verifier
|
||||||
|
let reduced_claim = claim.0 + claim.1 * r_sum_check;
|
||||||
|
|
||||||
|
// Create vanilla oracle
|
||||||
|
let oracle = gen_plain_gkr_oracle(num_rounds, r_sum_check);
|
||||||
|
|
||||||
|
// Create sum-check claim
|
||||||
|
let transformed_claim = Claim {
|
||||||
|
sum_value: reduced_claim,
|
||||||
|
polynomial: oracle,
|
||||||
|
};
|
||||||
|
|
||||||
|
let reduced_gkr_claim =
|
||||||
|
sum_check_verify_and_reduce(&transformed_claim, self.proof.clone(), transcript);
|
||||||
|
|
||||||
|
(reduced_gkr_claim, r_sum_check)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct GkrClaim<E: FieldElement + 'static> {
|
||||||
|
evaluation_point: Vec<E>,
|
||||||
|
claimed_evaluation: (E, E),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CircuitProof<E: FieldElement + 'static> {
|
||||||
|
pub proof: Vec<LayerProof<E>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: FieldElement<BaseField = BaseElement> + 'static> CircuitProof<E> {
|
||||||
|
pub fn prove<
|
||||||
|
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||||
|
H: ElementHasher<BaseField = BaseElement>,
|
||||||
|
>(
|
||||||
|
circuit: &mut FractionalSumCircuit<E>,
|
||||||
|
transcript: &mut C,
|
||||||
|
) -> (Self, Vec<E>, Vec<Vec<E>>) {
|
||||||
|
let mut proof_layers: Vec<LayerProof<E>> = Vec::new();
|
||||||
|
let num_layers = circuit.p_0_vec.len();
|
||||||
|
|
||||||
|
let data = vec![
|
||||||
|
circuit.p_1_vec[num_layers - 1][0],
|
||||||
|
circuit.p_0_vec[num_layers - 1][0],
|
||||||
|
circuit.q_1_vec[num_layers - 1][0],
|
||||||
|
circuit.q_0_vec[num_layers - 1][0],
|
||||||
|
];
|
||||||
|
transcript.reseed(H::hash_elements(&data));
|
||||||
|
|
||||||
|
// Challenge to reduce p1, p0, q1, q0 to pr, qr
|
||||||
|
let r_cord = transcript.draw().unwrap();
|
||||||
|
|
||||||
|
// Compute the (2-to-1 folded) claim
|
||||||
|
let mut claim = circuit.evaluate(r_cord);
|
||||||
|
let mut all_rand = Vec::new();
|
||||||
|
|
||||||
|
let mut rand = Vec::new();
|
||||||
|
rand.push(r_cord);
|
||||||
|
for layer_id in (0..num_layers - 1).rev() {
|
||||||
|
let len = circuit.p_0_vec[layer_id].len();
|
||||||
|
|
||||||
|
// Construct the Lagrange kernel evaluated at previous GKR round randomness.
|
||||||
|
// TODO: Treat the direction of doing sum-check more robustly.
|
||||||
|
let mut rand_reversed = rand.clone();
|
||||||
|
rand_reversed.reverse();
|
||||||
|
let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations();
|
||||||
|
let mut poly_x = MultiLinear::from_values(&eq_evals);
|
||||||
|
assert_eq!(poly_x.len(), len);
|
||||||
|
|
||||||
|
let num_rounds = poly_x.len().ilog2() as usize;
|
||||||
|
|
||||||
|
// 1. A is a polynomial containing the evaluations `p_1`.
|
||||||
|
// 2. B is a polynomial containing the evaluations `p_0`.
|
||||||
|
// 3. C is a polynomial containing the evaluations `q_1`.
|
||||||
|
// 4. D is a polynomial containing the evaluations `q_0`.
|
||||||
|
let poly_a: &mut MultiLinear<E>;
|
||||||
|
let poly_b: &mut MultiLinear<E>;
|
||||||
|
let poly_c: &mut MultiLinear<E>;
|
||||||
|
let poly_d: &mut MultiLinear<E>;
|
||||||
|
poly_a = &mut circuit.p_1_vec[layer_id];
|
||||||
|
poly_b = &mut circuit.p_0_vec[layer_id];
|
||||||
|
poly_c = &mut circuit.q_1_vec[layer_id];
|
||||||
|
poly_d = &mut circuit.q_0_vec[layer_id];
|
||||||
|
|
||||||
|
let poly_vec_par = (poly_a, poly_b, poly_c, poly_d, &mut poly_x);
|
||||||
|
|
||||||
|
// The (non-linear) polynomial combining the multilinear polynomials
|
||||||
|
let comb_func = |a: &E, b: &E, c: &E, d: &E, x: &E, rho: &E| -> E {
|
||||||
|
(*a * *d + *b * *c + *rho * *c * *d) * *x
|
||||||
|
};
|
||||||
|
|
||||||
|
// Run the sumcheck protocol
|
||||||
|
let (proof, rand_sumcheck, claims_sum) = sum_check_prover_gkr_before_last::<E, _, _>(
|
||||||
|
claim,
|
||||||
|
num_rounds,
|
||||||
|
poly_vec_par,
|
||||||
|
comb_func,
|
||||||
|
transcript,
|
||||||
|
);
|
||||||
|
|
||||||
|
let (claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0, _claims_eq) =
|
||||||
|
claims_sum;
|
||||||
|
|
||||||
|
let data = vec![claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0];
|
||||||
|
transcript.reseed(H::hash_elements(&data));
|
||||||
|
|
||||||
|
// Produce a random challenge to condense claims into a single claim
|
||||||
|
let r_layer = transcript.draw().unwrap();
|
||||||
|
|
||||||
|
claim = (
|
||||||
|
claims_sum_p1 + r_layer * (claims_sum_p0 - claims_sum_p1),
|
||||||
|
claims_sum_q1 + r_layer * (claims_sum_q0 - claims_sum_q1),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Collect the randomness used for the current layer in order to construct the random
|
||||||
|
// point where the input multilinear polynomials were evaluated.
|
||||||
|
let mut ext = rand_sumcheck;
|
||||||
|
ext.push(r_layer);
|
||||||
|
all_rand.push(rand);
|
||||||
|
rand = ext;
|
||||||
|
|
||||||
|
proof_layers.push(LayerProof {
|
||||||
|
proof,
|
||||||
|
claims_sum_p1,
|
||||||
|
claims_sum_p0,
|
||||||
|
claims_sum_q1,
|
||||||
|
claims_sum_q0,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
(CircuitProof { proof: proof_layers }, rand, all_rand)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn prove_virtual_bus<
|
||||||
|
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||||
|
H: ElementHasher<BaseField = BaseElement>,
|
||||||
|
>(
|
||||||
|
composition_polys: Vec<Vec<Arc<dyn CompositionPolynomial<E>>>>,
|
||||||
|
mls: &mut Vec<MultiLinear<E>>,
|
||||||
|
transcript: &mut C,
|
||||||
|
) -> (Vec<E>, Self, super::sumcheck::FullProof<E>) {
|
||||||
|
let num_evaluations = 1 << mls[0].num_variables();
|
||||||
|
|
||||||
|
// I) Evaluate the numerators and denominators over the boolean hyper-cube
|
||||||
|
let mut num_den: Vec<Vec<E>> = vec![vec![]; 4];
|
||||||
|
for i in 0..num_evaluations {
|
||||||
|
for j in 0..4 {
|
||||||
|
let query: Vec<E> = mls.iter().map(|ml| ml[i]).collect();
|
||||||
|
|
||||||
|
composition_polys[j].iter().for_each(|c| {
|
||||||
|
let evaluation = c.as_ref().evaluate(&query);
|
||||||
|
num_den[j].push(evaluation);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// II) Evaluate the GKR fractional sum circuit
|
||||||
|
let input: Vec<MultiLinear<E>> =
|
||||||
|
(0..4).map(|i| MultiLinear::from_values(&num_den[i])).collect();
|
||||||
|
let mut circuit = FractionalSumCircuit::new_(&input);
|
||||||
|
|
||||||
|
// III) Run the GKR prover for all layers except the last one
|
||||||
|
let (gkr_proofs, GkrClaim { evaluation_point, claimed_evaluation }) =
|
||||||
|
CircuitProof::prove_before_final(&mut circuit, transcript);
|
||||||
|
|
||||||
|
// IV) Run the sum-check prover for the last GKR layer counting backwards i.e., first layer
|
||||||
|
// in the circuit.
|
||||||
|
|
||||||
|
// 1) Build the EQ polynomial (Lagrange kernel) at the randomness sampled during the previous
|
||||||
|
// sum-check protocol run
|
||||||
|
let mut rand_reversed = evaluation_point.clone();
|
||||||
|
rand_reversed.reverse();
|
||||||
|
let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations();
|
||||||
|
let poly_x = MultiLinear::from_values(&eq_evals);
|
||||||
|
|
||||||
|
// 2) Add the Lagrange kernel to the list of MLs
|
||||||
|
mls.push(poly_x);
|
||||||
|
|
||||||
|
// 3) Absorb the final sum-check claims and generate randomness for 2-to-1 sum-check reduction
|
||||||
|
let data = vec![claimed_evaluation.0, claimed_evaluation.1];
|
||||||
|
transcript.reseed(H::hash_elements(&data));
|
||||||
|
// Squeeze challenge to reduce two sumchecks to one
|
||||||
|
let r_sum_check = transcript.draw().unwrap();
|
||||||
|
let reduced_claim = claimed_evaluation.0 + claimed_evaluation.1 * r_sum_check;
|
||||||
|
|
||||||
|
// 4) Create the composed ML representing the numerators and denominators of the topmost GKR layer
|
||||||
|
let gkr_final_composed_ml = gkr_composition_from_composition_polys(
|
||||||
|
&composition_polys,
|
||||||
|
r_sum_check,
|
||||||
|
1 << mls[0].num_variables,
|
||||||
|
);
|
||||||
|
let composed_ml =
|
||||||
|
ComposedMultiLinears::new(Arc::new(gkr_final_composed_ml.clone()), mls.to_vec());
|
||||||
|
|
||||||
|
// 5) Create the composed ML oracle. This will be used for verifying the FinalEvaluationClaim downstream
|
||||||
|
// TODO: This should be an input to the current function.
|
||||||
|
// TODO: Make MultiLinearOracle a variant in an enum so that it is possible to capture other types of oracles.
|
||||||
|
// For example, shifts of polynomials, Lagrange kernels at a random point or periodic (transparent) polynomials.
|
||||||
|
let left_num_oracle = MultiLinearOracle { id: 0 };
|
||||||
|
let right_num_oracle = MultiLinearOracle { id: 1 };
|
||||||
|
let left_denom_oracle = MultiLinearOracle { id: 2 };
|
||||||
|
let right_denom_oracle = MultiLinearOracle { id: 3 };
|
||||||
|
let eq_oracle = MultiLinearOracle { id: 4 };
|
||||||
|
let composed_ml_oracle = ComposedMultiLinearsOracle {
|
||||||
|
composer: (Arc::new(gkr_final_composed_ml.clone())),
|
||||||
|
multi_linears: vec![
|
||||||
|
eq_oracle,
|
||||||
|
left_num_oracle,
|
||||||
|
right_num_oracle,
|
||||||
|
left_denom_oracle,
|
||||||
|
right_denom_oracle,
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
// 6) Create the claim for the final sum-check protocol.
|
||||||
|
let claim = Claim {
|
||||||
|
sum_value: reduced_claim,
|
||||||
|
polynomial: composed_ml_oracle.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// 7) Create the witness for the sum-check claim.
|
||||||
|
let witness = Witness { polynomial: composed_ml };
|
||||||
|
let output = sum_check_prove(&claim, composed_ml_oracle, witness, transcript);
|
||||||
|
|
||||||
|
// 8) Create the claimed output of the circuit.
|
||||||
|
let circuit_outputs = vec![
|
||||||
|
circuit.p_1_vec.last().unwrap()[0],
|
||||||
|
circuit.p_0_vec.last().unwrap()[0],
|
||||||
|
circuit.q_1_vec.last().unwrap()[0],
|
||||||
|
circuit.q_0_vec.last().unwrap()[0],
|
||||||
|
];
|
||||||
|
|
||||||
|
// 9) Return:
|
||||||
|
// 1. The claimed circuit outputs.
|
||||||
|
// 2. GKR proofs of all circuit layers except the initial layer.
|
||||||
|
// 3. Output of the final sum-check protocol.
|
||||||
|
(circuit_outputs, gkr_proofs, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn prove_before_final<
|
||||||
|
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||||
|
H: ElementHasher<BaseField = BaseElement>,
|
||||||
|
>(
|
||||||
|
sum_circuits: &mut FractionalSumCircuit<E>,
|
||||||
|
transcript: &mut C,
|
||||||
|
) -> (Self, GkrClaim<E>) {
|
||||||
|
let mut proof_layers: Vec<LayerProof<E>> = Vec::new();
|
||||||
|
let num_layers = sum_circuits.p_0_vec.len();
|
||||||
|
|
||||||
|
let data = vec![
|
||||||
|
sum_circuits.p_1_vec[num_layers - 1][0],
|
||||||
|
sum_circuits.p_0_vec[num_layers - 1][0],
|
||||||
|
sum_circuits.q_1_vec[num_layers - 1][0],
|
||||||
|
sum_circuits.q_0_vec[num_layers - 1][0],
|
||||||
|
];
|
||||||
|
transcript.reseed(H::hash_elements(&data));
|
||||||
|
|
||||||
|
// Challenge to reduce p1, p0, q1, q0 to pr, qr
|
||||||
|
let r_cord = transcript.draw().unwrap();
|
||||||
|
|
||||||
|
// Compute the (2-to-1 folded) claim
|
||||||
|
let mut claims_to_verify = sum_circuits.evaluate(r_cord);
|
||||||
|
let mut all_rand = Vec::new();
|
||||||
|
|
||||||
|
let mut rand = Vec::new();
|
||||||
|
rand.push(r_cord);
|
||||||
|
for layer_id in (1..num_layers - 1).rev() {
|
||||||
|
let len = sum_circuits.p_0_vec[layer_id].len();
|
||||||
|
|
||||||
|
// Construct the Lagrange kernel evaluated at previous GKR round randomness.
|
||||||
|
// TODO: Treat the direction of doing sum-check more robustly.
|
||||||
|
let mut rand_reversed = rand.clone();
|
||||||
|
rand_reversed.reverse();
|
||||||
|
let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations();
|
||||||
|
let mut poly_x = MultiLinear::from_values(&eq_evals);
|
||||||
|
assert_eq!(poly_x.len(), len);
|
||||||
|
|
||||||
|
let num_rounds = poly_x.len().ilog2() as usize;
|
||||||
|
|
||||||
|
// 1. A is a polynomial containing the evaluations `p_1`.
|
||||||
|
// 2. B is a polynomial containing the evaluations `p_0`.
|
||||||
|
// 3. C is a polynomial containing the evaluations `q_1`.
|
||||||
|
// 4. D is a polynomial containing the evaluations `q_0`.
|
||||||
|
let poly_a: &mut MultiLinear<E>;
|
||||||
|
let poly_b: &mut MultiLinear<E>;
|
||||||
|
let poly_c: &mut MultiLinear<E>;
|
||||||
|
let poly_d: &mut MultiLinear<E>;
|
||||||
|
poly_a = &mut sum_circuits.p_1_vec[layer_id];
|
||||||
|
poly_b = &mut sum_circuits.p_0_vec[layer_id];
|
||||||
|
poly_c = &mut sum_circuits.q_1_vec[layer_id];
|
||||||
|
poly_d = &mut sum_circuits.q_0_vec[layer_id];
|
||||||
|
|
||||||
|
let poly_vec = (poly_a, poly_b, poly_c, poly_d, &mut poly_x);
|
||||||
|
|
||||||
|
let claim = claims_to_verify;
|
||||||
|
|
||||||
|
// The (non-linear) polynomial combining the multilinear polynomials
|
||||||
|
let comb_func = |a: &E, b: &E, c: &E, d: &E, x: &E, rho: &E| -> E {
|
||||||
|
(*a * *d + *b * *c + *rho * *c * *d) * *x
|
||||||
|
};
|
||||||
|
|
||||||
|
// Run the sumcheck protocol
|
||||||
|
let (proof, rand_sumcheck, claims_sum) = sum_check_prover_gkr_before_last::<E, _, _>(
|
||||||
|
claim, num_rounds, poly_vec, comb_func, transcript,
|
||||||
|
);
|
||||||
|
|
||||||
|
let (claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0, _claims_eq) =
|
||||||
|
claims_sum;
|
||||||
|
|
||||||
|
let data = vec![claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0];
|
||||||
|
transcript.reseed(H::hash_elements(&data));
|
||||||
|
|
||||||
|
// Produce a random challenge to condense claims into a single claim
|
||||||
|
let r_layer = transcript.draw().unwrap();
|
||||||
|
|
||||||
|
claims_to_verify = (
|
||||||
|
claims_sum_p1 + r_layer * (claims_sum_p0 - claims_sum_p1),
|
||||||
|
claims_sum_q1 + r_layer * (claims_sum_q0 - claims_sum_q1),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Collect the randomness used for the current layer in order to construct the random
|
||||||
|
// point where the input multilinear polynomials were evaluated.
|
||||||
|
let mut ext = rand_sumcheck;
|
||||||
|
ext.push(r_layer);
|
||||||
|
all_rand.push(rand);
|
||||||
|
rand = ext;
|
||||||
|
|
||||||
|
proof_layers.push(LayerProof {
|
||||||
|
proof,
|
||||||
|
claims_sum_p1,
|
||||||
|
claims_sum_p0,
|
||||||
|
claims_sum_q1,
|
||||||
|
claims_sum_q0,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let gkr_claim = GkrClaim {
|
||||||
|
evaluation_point: rand.clone(),
|
||||||
|
claimed_evaluation: claims_to_verify,
|
||||||
|
};
|
||||||
|
|
||||||
|
(CircuitProof { proof: proof_layers }, gkr_claim)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn verify<
|
||||||
|
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||||
|
H: ElementHasher<BaseField = BaseElement>,
|
||||||
|
>(
|
||||||
|
&self,
|
||||||
|
claims_sum_vec: &[E],
|
||||||
|
transcript: &mut C,
|
||||||
|
) -> ((E, E), Vec<E>) {
|
||||||
|
let num_layers = self.proof.len() as usize - 1;
|
||||||
|
let mut rand: Vec<E> = Vec::new();
|
||||||
|
|
||||||
|
let data = claims_sum_vec;
|
||||||
|
transcript.reseed(H::hash_elements(&data));
|
||||||
|
|
||||||
|
let r_cord = transcript.draw().unwrap();
|
||||||
|
|
||||||
|
let p_poly_coef = vec![claims_sum_vec[0], claims_sum_vec[1]];
|
||||||
|
let q_poly_coef = vec![claims_sum_vec[2], claims_sum_vec[3]];
|
||||||
|
|
||||||
|
let p_poly = MultiLinear::new(p_poly_coef);
|
||||||
|
let q_poly = MultiLinear::new(q_poly_coef);
|
||||||
|
let p_eval = p_poly.evaluate(&[r_cord]);
|
||||||
|
let q_eval = q_poly.evaluate(&[r_cord]);
|
||||||
|
|
||||||
|
let mut reduced_claim = (p_eval, q_eval);
|
||||||
|
|
||||||
|
rand.push(r_cord);
|
||||||
|
for (num_rounds, i) in (0..num_layers).enumerate() {
|
||||||
|
let ((claim_last, rand_sumcheck), r_two_sumchecks) = self.proof[i]
|
||||||
|
.verify_sum_check_before_last::<_, _>(reduced_claim, num_rounds + 1, transcript);
|
||||||
|
|
||||||
|
let claims_sum_p1 = &self.proof[i].claims_sum_p1;
|
||||||
|
let claims_sum_p0 = &self.proof[i].claims_sum_p0;
|
||||||
|
let claims_sum_q1 = &self.proof[i].claims_sum_q1;
|
||||||
|
let claims_sum_q0 = &self.proof[i].claims_sum_q0;
|
||||||
|
|
||||||
|
let data = vec![
|
||||||
|
claims_sum_p1.clone(),
|
||||||
|
claims_sum_p0.clone(),
|
||||||
|
claims_sum_q1.clone(),
|
||||||
|
claims_sum_q0.clone(),
|
||||||
|
];
|
||||||
|
transcript.reseed(H::hash_elements(&data));
|
||||||
|
|
||||||
|
assert_eq!(rand.len(), rand_sumcheck.len());
|
||||||
|
|
||||||
|
let eq: E = (0..rand.len())
|
||||||
|
.map(|i| {
|
||||||
|
rand[i] * rand_sumcheck[i] + (E::ONE - rand[i]) * (E::ONE - rand_sumcheck[i])
|
||||||
|
})
|
||||||
|
.fold(E::ONE, |acc, term| acc * term);
|
||||||
|
|
||||||
|
let claim_expected: E = (*claims_sum_p1 * *claims_sum_q0
|
||||||
|
+ *claims_sum_p0 * *claims_sum_q1
|
||||||
|
+ r_two_sumchecks * *claims_sum_q1 * *claims_sum_q0)
|
||||||
|
* eq;
|
||||||
|
|
||||||
|
assert_eq!(claim_expected, claim_last);
|
||||||
|
|
||||||
|
// Produce a random challenge to condense claims into a single claim
|
||||||
|
let r_layer = transcript.draw().unwrap();
|
||||||
|
|
||||||
|
reduced_claim = (
|
||||||
|
*claims_sum_p1 + r_layer * (*claims_sum_p0 - *claims_sum_p1),
|
||||||
|
*claims_sum_q1 + r_layer * (*claims_sum_q0 - *claims_sum_q1),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Collect the randomness' used for the current layer in order to construct the random
|
||||||
|
// point where the input multilinear polynomials were evaluated.
|
||||||
|
let mut ext = rand_sumcheck;
|
||||||
|
ext.push(r_layer);
|
||||||
|
rand = ext;
|
||||||
|
}
|
||||||
|
(reduced_claim, rand)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn verify_virtual_bus<
|
||||||
|
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||||
|
H: ElementHasher<BaseField = BaseElement>,
|
||||||
|
>(
|
||||||
|
&self,
|
||||||
|
composition_polys: Vec<Vec<Arc<dyn CompositionPolynomial<E>>>>,
|
||||||
|
final_layer_proof: super::sumcheck::FullProof<E>,
|
||||||
|
claims_sum_vec: &[E],
|
||||||
|
transcript: &mut C,
|
||||||
|
) -> (FinalEvaluationClaim<E>, Vec<E>) {
|
||||||
|
let num_layers = self.proof.len() as usize;
|
||||||
|
let mut rand: Vec<E> = Vec::new();
|
||||||
|
|
||||||
|
// Check that a/b + d/e is equal to 0
|
||||||
|
assert_ne!(claims_sum_vec[2], E::ZERO);
|
||||||
|
assert_ne!(claims_sum_vec[3], E::ZERO);
|
||||||
|
assert_eq!(
|
||||||
|
claims_sum_vec[0] * claims_sum_vec[3] + claims_sum_vec[1] * claims_sum_vec[2],
|
||||||
|
E::ZERO
|
||||||
|
);
|
||||||
|
|
||||||
|
let data = claims_sum_vec;
|
||||||
|
transcript.reseed(H::hash_elements(&data));
|
||||||
|
|
||||||
|
let r_cord = transcript.draw().unwrap();
|
||||||
|
|
||||||
|
let p_poly_coef = vec![claims_sum_vec[0], claims_sum_vec[1]];
|
||||||
|
let q_poly_coef = vec![claims_sum_vec[2], claims_sum_vec[3]];
|
||||||
|
|
||||||
|
let p_poly = MultiLinear::new(p_poly_coef);
|
||||||
|
let q_poly = MultiLinear::new(q_poly_coef);
|
||||||
|
let p_eval = p_poly.evaluate(&[r_cord]);
|
||||||
|
let q_eval = q_poly.evaluate(&[r_cord]);
|
||||||
|
|
||||||
|
let mut reduced_claim = (p_eval, q_eval);
|
||||||
|
|
||||||
|
// I) Verify all GKR layers but for the last one counting backwards.
|
||||||
|
rand.push(r_cord);
|
||||||
|
for (num_rounds, i) in (0..num_layers).enumerate() {
|
||||||
|
let ((claim_last, rand_sumcheck), r_two_sumchecks) = self.proof[i]
|
||||||
|
.verify_sum_check_before_last::<_, _>(reduced_claim, num_rounds + 1, transcript);
|
||||||
|
|
||||||
|
let claims_sum_p1 = &self.proof[i].claims_sum_p1;
|
||||||
|
let claims_sum_p0 = &self.proof[i].claims_sum_p0;
|
||||||
|
let claims_sum_q1 = &self.proof[i].claims_sum_q1;
|
||||||
|
let claims_sum_q0 = &self.proof[i].claims_sum_q0;
|
||||||
|
|
||||||
|
let data = vec![
|
||||||
|
claims_sum_p1.clone(),
|
||||||
|
claims_sum_p0.clone(),
|
||||||
|
claims_sum_q1.clone(),
|
||||||
|
claims_sum_q0.clone(),
|
||||||
|
];
|
||||||
|
transcript.reseed(H::hash_elements(&data));
|
||||||
|
|
||||||
|
assert_eq!(rand.len(), rand_sumcheck.len());
|
||||||
|
|
||||||
|
let eq: E = (0..rand.len())
|
||||||
|
.map(|i| {
|
||||||
|
rand[i] * rand_sumcheck[i] + (E::ONE - rand[i]) * (E::ONE - rand_sumcheck[i])
|
||||||
|
})
|
||||||
|
.fold(E::ONE, |acc, term| acc * term);
|
||||||
|
|
||||||
|
let claim_expected: E = (*claims_sum_p1 * *claims_sum_q0
|
||||||
|
+ *claims_sum_p0 * *claims_sum_q1
|
||||||
|
+ r_two_sumchecks * *claims_sum_q1 * *claims_sum_q0)
|
||||||
|
* eq;
|
||||||
|
|
||||||
|
assert_eq!(claim_expected, claim_last);
|
||||||
|
|
||||||
|
// Produce a random challenge to condense claims into a single claim
|
||||||
|
let r_layer = transcript.draw().unwrap();
|
||||||
|
|
||||||
|
reduced_claim = (
|
||||||
|
*claims_sum_p1 + r_layer * (*claims_sum_p0 - *claims_sum_p1),
|
||||||
|
*claims_sum_q1 + r_layer * (*claims_sum_q0 - *claims_sum_q1),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut ext = rand_sumcheck;
|
||||||
|
ext.push(r_layer);
|
||||||
|
rand = ext;
|
||||||
|
}
|
||||||
|
|
||||||
|
// II) Verify the final GKR layer counting backwards.
|
||||||
|
|
||||||
|
// Absorb the claims
|
||||||
|
let data = vec![reduced_claim.0, reduced_claim.1];
|
||||||
|
transcript.reseed(H::hash_elements(&data));
|
||||||
|
|
||||||
|
// Squeeze challenge to reduce two sumchecks to one
|
||||||
|
let r_sum_check = transcript.draw().unwrap();
|
||||||
|
let reduced_claim = reduced_claim.0 + reduced_claim.1 * r_sum_check;
|
||||||
|
|
||||||
|
let gkr_final_composed_ml = gkr_composition_from_composition_polys(
|
||||||
|
&composition_polys,
|
||||||
|
r_sum_check,
|
||||||
|
1 << (num_layers + 1),
|
||||||
|
);
|
||||||
|
|
||||||
|
// TODO: refactor
|
||||||
|
let composed_ml_oracle = {
|
||||||
|
let left_num_oracle = MultiLinearOracle { id: 0 };
|
||||||
|
let right_num_oracle = MultiLinearOracle { id: 1 };
|
||||||
|
let left_denom_oracle = MultiLinearOracle { id: 2 };
|
||||||
|
let right_denom_oracle = MultiLinearOracle { id: 3 };
|
||||||
|
let eq_oracle = MultiLinearOracle { id: 4 };
|
||||||
|
ComposedMultiLinearsOracle {
|
||||||
|
composer: (Arc::new(gkr_final_composed_ml.clone())),
|
||||||
|
multi_linears: vec![
|
||||||
|
eq_oracle,
|
||||||
|
left_num_oracle,
|
||||||
|
right_num_oracle,
|
||||||
|
left_denom_oracle,
|
||||||
|
right_denom_oracle,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let claim = Claim {
|
||||||
|
sum_value: reduced_claim,
|
||||||
|
polynomial: composed_ml_oracle.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let final_eval_claim = sum_check_verify(&claim, final_layer_proof, transcript);
|
||||||
|
|
||||||
|
(final_eval_claim, rand)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sum_check_prover_gkr_before_last<
|
||||||
|
E: FieldElement<BaseField = BaseElement>,
|
||||||
|
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||||
|
H: ElementHasher<BaseField = BaseElement>,
|
||||||
|
>(
|
||||||
|
claim: (E, E),
|
||||||
|
num_rounds: usize,
|
||||||
|
ml_polys: (
|
||||||
|
&mut MultiLinear<E>,
|
||||||
|
&mut MultiLinear<E>,
|
||||||
|
&mut MultiLinear<E>,
|
||||||
|
&mut MultiLinear<E>,
|
||||||
|
&mut MultiLinear<E>,
|
||||||
|
),
|
||||||
|
comb_func: impl Fn(&E, &E, &E, &E, &E, &E) -> E,
|
||||||
|
transcript: &mut C,
|
||||||
|
) -> (SumcheckInstanceProof<E>, Vec<E>, (E, E, E, E, E)) {
|
||||||
|
// Absorb the claims
|
||||||
|
let data = vec![claim.0, claim.1];
|
||||||
|
transcript.reseed(H::hash_elements(&data));
|
||||||
|
|
||||||
|
// Squeeze challenge to reduce two sumchecks to one
|
||||||
|
let r_sum_check = transcript.draw().unwrap();
|
||||||
|
|
||||||
|
let (poly_a, poly_b, poly_c, poly_d, poly_x) = ml_polys;
|
||||||
|
|
||||||
|
let mut e = claim.0 + claim.1 * r_sum_check;
|
||||||
|
|
||||||
|
let mut r: Vec<E> = Vec::new();
|
||||||
|
let mut round_proofs: Vec<SumCheckRoundProof<E>> = Vec::new();
|
||||||
|
|
||||||
|
for _j in 0..num_rounds {
|
||||||
|
let evals: (E, E, E) = {
|
||||||
|
let mut eval_point_0 = E::ZERO;
|
||||||
|
let mut eval_point_2 = E::ZERO;
|
||||||
|
let mut eval_point_3 = E::ZERO;
|
||||||
|
|
||||||
|
let len = poly_a.len() / 2;
|
||||||
|
for i in 0..len {
|
||||||
|
// The interpolation formula for a linear function is:
|
||||||
|
// z * A(x) + (1 - z) * A (y)
|
||||||
|
// z * A(1) + (1 - z) * A(0)
|
||||||
|
|
||||||
|
// eval at z = 0: A(1)
|
||||||
|
eval_point_0 += comb_func(
|
||||||
|
&poly_a[i << 1],
|
||||||
|
&poly_b[i << 1],
|
||||||
|
&poly_c[i << 1],
|
||||||
|
&poly_d[i << 1],
|
||||||
|
&poly_x[i << 1],
|
||||||
|
&r_sum_check,
|
||||||
|
);
|
||||||
|
|
||||||
|
let poly_a_u = poly_a[(i << 1) + 1];
|
||||||
|
let poly_a_v = poly_a[i << 1];
|
||||||
|
let poly_b_u = poly_b[(i << 1) + 1];
|
||||||
|
let poly_b_v = poly_b[i << 1];
|
||||||
|
let poly_c_u = poly_c[(i << 1) + 1];
|
||||||
|
let poly_c_v = poly_c[i << 1];
|
||||||
|
let poly_d_u = poly_d[(i << 1) + 1];
|
||||||
|
let poly_d_v = poly_d[i << 1];
|
||||||
|
let poly_x_u = poly_x[(i << 1) + 1];
|
||||||
|
let poly_x_v = poly_x[i << 1];
|
||||||
|
|
||||||
|
// eval at z = 2: 2 * A(1) - A(0)
|
||||||
|
let poly_a_extrapolated_point = poly_a_u + poly_a_u - poly_a_v;
|
||||||
|
let poly_b_extrapolated_point = poly_b_u + poly_b_u - poly_b_v;
|
||||||
|
let poly_c_extrapolated_point = poly_c_u + poly_c_u - poly_c_v;
|
||||||
|
let poly_d_extrapolated_point = poly_d_u + poly_d_u - poly_d_v;
|
||||||
|
let poly_x_extrapolated_point = poly_x_u + poly_x_u - poly_x_v;
|
||||||
|
eval_point_2 += comb_func(
|
||||||
|
&poly_a_extrapolated_point,
|
||||||
|
&poly_b_extrapolated_point,
|
||||||
|
&poly_c_extrapolated_point,
|
||||||
|
&poly_d_extrapolated_point,
|
||||||
|
&poly_x_extrapolated_point,
|
||||||
|
&r_sum_check,
|
||||||
|
);
|
||||||
|
|
||||||
|
// eval at z = 3: 3 * A(1) - 2 * A(0) = 2 * A(1) - A(0) + A(1) - A(0)
|
||||||
|
// hence we can compute the evaluation at z + 1 from that of z for z > 1
|
||||||
|
let poly_a_extrapolated_point = poly_a_extrapolated_point + poly_a_u - poly_a_v;
|
||||||
|
let poly_b_extrapolated_point = poly_b_extrapolated_point + poly_b_u - poly_b_v;
|
||||||
|
let poly_c_extrapolated_point = poly_c_extrapolated_point + poly_c_u - poly_c_v;
|
||||||
|
let poly_d_extrapolated_point = poly_d_extrapolated_point + poly_d_u - poly_d_v;
|
||||||
|
let poly_x_extrapolated_point = poly_x_extrapolated_point + poly_x_u - poly_x_v;
|
||||||
|
|
||||||
|
eval_point_3 += comb_func(
|
||||||
|
&poly_a_extrapolated_point,
|
||||||
|
&poly_b_extrapolated_point,
|
||||||
|
&poly_c_extrapolated_point,
|
||||||
|
&poly_d_extrapolated_point,
|
||||||
|
&poly_x_extrapolated_point,
|
||||||
|
&r_sum_check,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
(eval_point_0, eval_point_2, eval_point_3)
|
||||||
|
};
|
||||||
|
|
||||||
|
let eval_0 = evals.0;
|
||||||
|
let eval_2 = evals.1;
|
||||||
|
let eval_3 = evals.2;
|
||||||
|
|
||||||
|
let evals = vec![e - eval_0, eval_2, eval_3];
|
||||||
|
let compressed_poly = SumCheckRoundProof { poly_evals: evals };
|
||||||
|
|
||||||
|
// append the prover's message to the transcript
|
||||||
|
transcript.reseed(H::hash_elements(&compressed_poly.poly_evals));
|
||||||
|
|
||||||
|
// derive the verifier's challenge for the next round
|
||||||
|
let r_j = transcript.draw().unwrap();
|
||||||
|
r.push(r_j);
|
||||||
|
|
||||||
|
poly_a.bind_assign(r_j);
|
||||||
|
poly_b.bind_assign(r_j);
|
||||||
|
poly_c.bind_assign(r_j);
|
||||||
|
poly_d.bind_assign(r_j);
|
||||||
|
|
||||||
|
poly_x.bind_assign(r_j);
|
||||||
|
|
||||||
|
e = compressed_poly.evaluate(e, r_j);
|
||||||
|
|
||||||
|
round_proofs.push(compressed_poly);
|
||||||
|
}
|
||||||
|
let claims_sum = (poly_a[0], poly_b[0], poly_c[0], poly_d[0], poly_x[0]);
|
||||||
|
|
||||||
|
(SumcheckInstanceProof { round_proofs }, r, claims_sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod sum_circuit_tests {
|
||||||
|
use crate::rand::RpoRandomCoin;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use rand::Rng;
|
||||||
|
use rand_utils::rand_value;
|
||||||
|
use BaseElement as Felt;
|
||||||
|
|
||||||
|
/// The following tests the fractional sum circuit to check that \sum_{i = 0}^{log(m)-1} m / 2^{i} = 2 * (m - 1)
|
||||||
|
#[test]
|
||||||
|
fn sum_circuit_example() {
|
||||||
|
let n = 4; // n := log(m)
|
||||||
|
let mut inp: Vec<Felt> = (0..n).map(|_| Felt::from(1_u64 << n)).collect();
|
||||||
|
let inp_: Vec<Felt> = (0..n).map(|i| Felt::from(1_u64 << i)).collect();
|
||||||
|
inp.extend(inp_.iter());
|
||||||
|
|
||||||
|
let summation = MultiLinear::new(inp);
|
||||||
|
|
||||||
|
let expected_output = Felt::from(2 * ((1_u64 << n) - 1));
|
||||||
|
|
||||||
|
let mut circuit = FractionalSumCircuit::new(&summation);
|
||||||
|
|
||||||
|
let seed = [BaseElement::ZERO; 4];
|
||||||
|
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||||
|
|
||||||
|
let (proof, _evals, _) = CircuitProof::prove(&mut circuit, &mut transcript);
|
||||||
|
|
||||||
|
let (p1, q1) = circuit.evaluate(Felt::from(1_u8));
|
||||||
|
let (p0, q0) = circuit.evaluate(Felt::from(0_u8));
|
||||||
|
assert_eq!(expected_output, (p1 * q0 + q1 * p0) / (q1 * q0));
|
||||||
|
|
||||||
|
let seed = [BaseElement::ZERO; 4];
|
||||||
|
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||||
|
let claims = vec![p0, p1, q0, q1];
|
||||||
|
proof.verify(&claims, &mut transcript);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the fractional sum GKR in the context of LogUp.
|
||||||
|
#[test]
|
||||||
|
fn log_up() {
|
||||||
|
use rand::distributions::Slice;
|
||||||
|
|
||||||
|
let n: usize = 16;
|
||||||
|
let num_w: usize = 31; // This should be of the form 2^k - 1
|
||||||
|
let rng = rand::thread_rng();
|
||||||
|
|
||||||
|
let t_table: Vec<u32> = (0..(1 << n)).collect();
|
||||||
|
let mut m_table: Vec<u32> = (0..(1 << n)).map(|_| 0).collect();
|
||||||
|
|
||||||
|
let t_table_slice = Slice::new(&t_table).unwrap();
|
||||||
|
|
||||||
|
// Construct the witness columns. Uses sampling with replacement in order to have multiplicities
|
||||||
|
// different from 1.
|
||||||
|
let mut w_tables = Vec::new();
|
||||||
|
for _ in 0..num_w {
|
||||||
|
let wi_table: Vec<u32> =
|
||||||
|
rng.clone().sample_iter(&t_table_slice).cloned().take(1 << n).collect();
|
||||||
|
|
||||||
|
// Construct the multiplicities
|
||||||
|
wi_table.iter().for_each(|w| {
|
||||||
|
m_table[*w as usize] += 1;
|
||||||
|
});
|
||||||
|
w_tables.push(wi_table)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The numerators
|
||||||
|
let mut p: Vec<Felt> = m_table.iter().map(|m| Felt::from(*m as u32)).collect();
|
||||||
|
p.extend((0..(num_w * (1 << n))).map(|_| Felt::from(1_u32)).collect::<Vec<Felt>>());
|
||||||
|
|
||||||
|
// Sample the challenge alpha to construct the denominators.
|
||||||
|
let alpha = rand_value();
|
||||||
|
|
||||||
|
// Construct the denominators
|
||||||
|
let mut q: Vec<Felt> = t_table.iter().map(|t| Felt::from(*t) - alpha).collect();
|
||||||
|
for w_table in w_tables {
|
||||||
|
q.extend(w_table.iter().map(|w| alpha - Felt::from(*w)).collect::<Vec<Felt>>());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the input to the fractional sum GKR circuit
|
||||||
|
p.extend(q);
|
||||||
|
let input = p;
|
||||||
|
|
||||||
|
let summation = MultiLinear::new(input);
|
||||||
|
|
||||||
|
let expected_output = Felt::from(0_u8);
|
||||||
|
|
||||||
|
let mut circuit = FractionalSumCircuit::new(&summation);
|
||||||
|
|
||||||
|
let seed = [BaseElement::ZERO; 4];
|
||||||
|
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||||
|
|
||||||
|
let (proof, _evals, _) = CircuitProof::prove(&mut circuit, &mut transcript);
|
||||||
|
|
||||||
|
let (p1, q1) = circuit.evaluate(Felt::from(1_u8));
|
||||||
|
let (p0, q0) = circuit.evaluate(Felt::from(0_u8));
|
||||||
|
assert_eq!(expected_output, (p1 * q0 + q1 * p0) / (q1 * q0)); // This check should be part of verification
|
||||||
|
|
||||||
|
let seed = [BaseElement::ZERO; 4];
|
||||||
|
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||||
|
let claims = vec![p0, p1, q0, q1];
|
||||||
|
proof.verify(&claims, &mut transcript);
|
||||||
|
}
|
||||||
|
}
|
||||||
7
src/gkr/mod.rs
Normal file
7
src/gkr/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
#![allow(unused_imports)]
|
||||||
|
#![allow(dead_code)]
|
||||||
|
|
||||||
|
mod sumcheck;
|
||||||
|
mod multivariate;
|
||||||
|
mod utils;
|
||||||
|
mod circuit;
|
||||||
34
src/gkr/multivariate/eq_poly.rs
Normal file
34
src/gkr/multivariate/eq_poly.rs
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
use super::FieldElement;
|
||||||
|
|
||||||
|
pub struct EqPolynomial<E> {
|
||||||
|
r: Vec<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: FieldElement> EqPolynomial<E> {
|
||||||
|
pub fn new(r: Vec<E>) -> Self {
|
||||||
|
EqPolynomial { r }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn evaluate(&self, rho: &[E]) -> E {
|
||||||
|
assert_eq!(self.r.len(), rho.len());
|
||||||
|
(0..rho.len())
|
||||||
|
.map(|i| self.r[i] * rho[i] + (E::ONE - self.r[i]) * (E::ONE - rho[i]))
|
||||||
|
.fold(E::ONE, |acc, term| acc * term)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn evaluations(&self) -> Vec<E> {
|
||||||
|
let nu = self.r.len();
|
||||||
|
|
||||||
|
let mut evals: Vec<E> = vec![E::ONE; 1 << nu];
|
||||||
|
let mut size = 1;
|
||||||
|
for j in 0..nu {
|
||||||
|
size *= 2;
|
||||||
|
for i in (0..size).rev().step_by(2) {
|
||||||
|
let scalar = evals[i / 2];
|
||||||
|
evals[i] = scalar * self.r[j];
|
||||||
|
evals[i - 1] = scalar - evals[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
evals
|
||||||
|
}
|
||||||
|
}
|
||||||
543
src/gkr/multivariate/mod.rs
Normal file
543
src/gkr/multivariate/mod.rs
Normal file
@@ -0,0 +1,543 @@
|
|||||||
|
use core::ops::Index;
|
||||||
|
|
||||||
|
use alloc::sync::Arc;
|
||||||
|
use winter_math::{fields::f64::BaseElement, log2, FieldElement, StarkField};
|
||||||
|
|
||||||
|
mod eq_poly;
|
||||||
|
pub use eq_poly::EqPolynomial;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct MultiLinear<E: FieldElement> {
|
||||||
|
pub num_variables: usize,
|
||||||
|
pub evaluations: Vec<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: FieldElement> MultiLinear<E> {
|
||||||
|
pub fn new(values: Vec<E>) -> Self {
|
||||||
|
Self {
|
||||||
|
num_variables: log2(values.len()) as usize,
|
||||||
|
evaluations: values,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_values(values: &[E]) -> Self {
|
||||||
|
Self {
|
||||||
|
num_variables: log2(values.len()) as usize,
|
||||||
|
evaluations: values.to_owned(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn num_variables(&self) -> usize {
|
||||||
|
self.num_variables
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn evaluations(&self) -> &[E] {
|
||||||
|
&self.evaluations
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.evaluations.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn evaluate(&self, query: &[E]) -> E {
|
||||||
|
let tensored_query = tensorize(query);
|
||||||
|
inner_product(&self.evaluations, &tensored_query)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn bind(&self, round_challenge: E) -> Self {
|
||||||
|
let mut result = vec![E::ZERO; 1 << (self.num_variables() - 1)];
|
||||||
|
for i in 0..(1 << (self.num_variables() - 1)) {
|
||||||
|
result[i] = self.evaluations[i << 1]
|
||||||
|
+ round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]);
|
||||||
|
}
|
||||||
|
Self::from_values(&result)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn bind_assign(&mut self, round_challenge: E) {
|
||||||
|
let mut result = vec![E::ZERO; 1 << (self.num_variables() - 1)];
|
||||||
|
for i in 0..(1 << (self.num_variables() - 1)) {
|
||||||
|
result[i] = self.evaluations[i << 1]
|
||||||
|
+ round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]);
|
||||||
|
}
|
||||||
|
*self = Self::from_values(&result);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn split(&self, at: usize) -> (Self, Self) {
|
||||||
|
assert!(at < self.len());
|
||||||
|
(
|
||||||
|
Self::new(self.evaluations[..at].to_vec()),
|
||||||
|
Self::new(self.evaluations[at..2 * at].to_vec()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extend(&mut self, other: &MultiLinear<E>) {
|
||||||
|
let other_vec = other.evaluations.to_vec();
|
||||||
|
assert_eq!(other_vec.len(), self.len());
|
||||||
|
self.evaluations.extend(other_vec);
|
||||||
|
self.num_variables += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: FieldElement> Index<usize> for MultiLinear<E> {
|
||||||
|
type Output = E;
|
||||||
|
|
||||||
|
fn index(&self, index: usize) -> &E {
|
||||||
|
&(self.evaluations[index])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A multi-variate polynomial for composing individual multi-linear polynomials
|
||||||
|
pub trait CompositionPolynomial<E: FieldElement>: Sync + Send {
|
||||||
|
/// The number of variables when interpreted as a multi-variate polynomial.
|
||||||
|
fn num_variables(&self) -> usize;
|
||||||
|
|
||||||
|
/// Maximum degree in all variables.
|
||||||
|
fn max_degree(&self) -> usize;
|
||||||
|
|
||||||
|
/// Given a query, of length equal the number of variables, evaluate [Self] at this query.
|
||||||
|
fn evaluate(&self, query: &[E]) -> E;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ComposedMultiLinears<E: FieldElement> {
|
||||||
|
pub composer: Arc<dyn CompositionPolynomial<E>>,
|
||||||
|
pub multi_linears: Vec<MultiLinear<E>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: FieldElement> ComposedMultiLinears<E> {
|
||||||
|
pub fn new(
|
||||||
|
composer: Arc<dyn CompositionPolynomial<E>>,
|
||||||
|
multi_linears: Vec<MultiLinear<E>>,
|
||||||
|
) -> Self {
|
||||||
|
Self { composer, multi_linears }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn num_ml(&self) -> usize {
|
||||||
|
self.multi_linears.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn num_variables(&self) -> usize {
|
||||||
|
self.composer.num_variables()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn num_variables_ml(&self) -> usize {
|
||||||
|
self.multi_linears[0].num_variables
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn degree(&self) -> usize {
|
||||||
|
self.composer.max_degree()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn bind(&self, round_challenge: E) -> ComposedMultiLinears<E> {
|
||||||
|
let result: Vec<MultiLinear<E>> =
|
||||||
|
self.multi_linears.iter().map(|f| f.bind(round_challenge)).collect();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
composer: self.composer.clone(),
|
||||||
|
multi_linears: result,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ComposedMultiLinearsOracle<E: FieldElement> {
|
||||||
|
pub composer: Arc<dyn CompositionPolynomial<E>>,
|
||||||
|
pub multi_linears: Vec<MultiLinearOracle>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MultiLinearOracle {
|
||||||
|
pub id: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Composition polynomials
|
||||||
|
|
||||||
|
pub struct IdentityComposition {
|
||||||
|
num_variables: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IdentityComposition {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self { num_variables: 1 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> CompositionPolynomial<E> for IdentityComposition
|
||||||
|
where
|
||||||
|
E: FieldElement,
|
||||||
|
{
|
||||||
|
fn num_variables(&self) -> usize {
|
||||||
|
self.num_variables
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_degree(&self) -> usize {
|
||||||
|
self.num_variables
|
||||||
|
}
|
||||||
|
|
||||||
|
fn evaluate(&self, query: &[E]) -> E {
|
||||||
|
assert_eq!(query.len(), 1);
|
||||||
|
query[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ProjectionComposition {
|
||||||
|
coordinate: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProjectionComposition {
|
||||||
|
pub fn new(coordinate: usize) -> Self {
|
||||||
|
Self { coordinate }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> CompositionPolynomial<E> for ProjectionComposition
|
||||||
|
where
|
||||||
|
E: FieldElement,
|
||||||
|
{
|
||||||
|
fn num_variables(&self) -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_degree(&self) -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
fn evaluate(&self, query: &[E]) -> E {
|
||||||
|
query[self.coordinate]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct LogUpDenominatorTableComposition<E>
|
||||||
|
where
|
||||||
|
E: FieldElement,
|
||||||
|
{
|
||||||
|
projection_coordinate: usize,
|
||||||
|
alpha: E,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> LogUpDenominatorTableComposition<E>
|
||||||
|
where
|
||||||
|
E: FieldElement,
|
||||||
|
{
|
||||||
|
pub fn new(projection_coordinate: usize, alpha: E) -> Self {
|
||||||
|
Self { projection_coordinate, alpha }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> CompositionPolynomial<E> for LogUpDenominatorTableComposition<E>
|
||||||
|
where
|
||||||
|
E: FieldElement,
|
||||||
|
{
|
||||||
|
fn num_variables(&self) -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_degree(&self) -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
fn evaluate(&self, query: &[E]) -> E {
|
||||||
|
query[self.projection_coordinate] + self.alpha
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct LogUpDenominatorWitnessComposition<E>
|
||||||
|
where
|
||||||
|
E: FieldElement,
|
||||||
|
{
|
||||||
|
projection_coordinate: usize,
|
||||||
|
alpha: E,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> LogUpDenominatorWitnessComposition<E>
|
||||||
|
where
|
||||||
|
E: FieldElement,
|
||||||
|
{
|
||||||
|
pub fn new(projection_coordinate: usize, alpha: E) -> Self {
|
||||||
|
Self { projection_coordinate, alpha }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> CompositionPolynomial<E> for LogUpDenominatorWitnessComposition<E>
|
||||||
|
where
|
||||||
|
E: FieldElement,
|
||||||
|
{
|
||||||
|
fn num_variables(&self) -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_degree(&self) -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
fn evaluate(&self, query: &[E]) -> E {
|
||||||
|
-(query[self.projection_coordinate] + self.alpha)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ProductComposition {
|
||||||
|
num_variables: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProductComposition {
|
||||||
|
pub fn new(num_variables: usize) -> Self {
|
||||||
|
Self { num_variables }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> CompositionPolynomial<E> for ProductComposition
|
||||||
|
where
|
||||||
|
E: FieldElement,
|
||||||
|
{
|
||||||
|
fn num_variables(&self) -> usize {
|
||||||
|
self.num_variables
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_degree(&self) -> usize {
|
||||||
|
self.num_variables
|
||||||
|
}
|
||||||
|
|
||||||
|
fn evaluate(&self, query: &[E]) -> E {
|
||||||
|
query.iter().fold(E::ONE, |acc, x| acc * *x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct SumComposition {
|
||||||
|
num_variables: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SumComposition {
|
||||||
|
pub fn new(num_variables: usize) -> Self {
|
||||||
|
Self { num_variables }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> CompositionPolynomial<E> for SumComposition
|
||||||
|
where
|
||||||
|
E: FieldElement,
|
||||||
|
{
|
||||||
|
fn num_variables(&self) -> usize {
|
||||||
|
self.num_variables
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_degree(&self) -> usize {
|
||||||
|
self.num_variables
|
||||||
|
}
|
||||||
|
|
||||||
|
fn evaluate(&self, query: &[E]) -> E {
|
||||||
|
query.iter().fold(E::ZERO, |acc, x| acc + *x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct GkrCompositionVanilla<E: 'static>
|
||||||
|
where
|
||||||
|
E: FieldElement,
|
||||||
|
{
|
||||||
|
num_variables_ml: usize,
|
||||||
|
num_variables_merge: usize,
|
||||||
|
combining_randomness: E,
|
||||||
|
gkr_randomness: Vec<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> GkrCompositionVanilla<E>
|
||||||
|
where
|
||||||
|
E: FieldElement,
|
||||||
|
{
|
||||||
|
pub fn new(
|
||||||
|
num_variables_ml: usize,
|
||||||
|
num_variables_merge: usize,
|
||||||
|
combining_randomness: E,
|
||||||
|
gkr_randomness: Vec<E>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
num_variables_ml,
|
||||||
|
num_variables_merge,
|
||||||
|
combining_randomness,
|
||||||
|
gkr_randomness,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> CompositionPolynomial<E> for GkrCompositionVanilla<E>
|
||||||
|
where
|
||||||
|
E: FieldElement,
|
||||||
|
{
|
||||||
|
fn num_variables(&self) -> usize {
|
||||||
|
self.num_variables_ml // + TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_degree(&self) -> usize {
|
||||||
|
self.num_variables_ml //TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
fn evaluate(&self, query: &[E]) -> E {
|
||||||
|
let eval_left_numerator = query[0];
|
||||||
|
let eval_right_numerator = query[1];
|
||||||
|
let eval_left_denominator = query[2];
|
||||||
|
let eval_right_denominator = query[3];
|
||||||
|
let eq_eval = query[4];
|
||||||
|
|
||||||
|
eq_eval
|
||||||
|
* ((eval_left_numerator * eval_right_denominator
|
||||||
|
+ eval_right_numerator * eval_left_denominator)
|
||||||
|
+ eval_left_denominator * eval_right_denominator * self.combining_randomness)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct GkrComposition<E>
|
||||||
|
where
|
||||||
|
E: FieldElement<BaseField = BaseElement>,
|
||||||
|
{
|
||||||
|
pub num_variables_ml: usize,
|
||||||
|
pub combining_randomness: E,
|
||||||
|
|
||||||
|
eq_composer: Arc<dyn CompositionPolynomial<E>>,
|
||||||
|
right_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||||
|
left_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||||
|
right_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||||
|
left_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> GkrComposition<E>
|
||||||
|
where
|
||||||
|
E: FieldElement<BaseField = BaseElement>,
|
||||||
|
{
|
||||||
|
pub fn new(
|
||||||
|
num_variables_ml: usize,
|
||||||
|
combining_randomness: E,
|
||||||
|
eq_composer: Arc<dyn CompositionPolynomial<E>>,
|
||||||
|
right_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||||
|
left_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||||
|
right_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||||
|
left_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
num_variables_ml,
|
||||||
|
combining_randomness,
|
||||||
|
eq_composer,
|
||||||
|
right_numerator_composer,
|
||||||
|
left_numerator_composer,
|
||||||
|
right_denominator_composer,
|
||||||
|
left_denominator_composer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> CompositionPolynomial<E> for GkrComposition<E>
|
||||||
|
where
|
||||||
|
E: FieldElement<BaseField = BaseElement>,
|
||||||
|
{
|
||||||
|
fn num_variables(&self) -> usize {
|
||||||
|
self.num_variables_ml // + TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_degree(&self) -> usize {
|
||||||
|
3 // TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
fn evaluate(&self, query: &[E]) -> E {
|
||||||
|
let eval_right_numerator = self.right_numerator_composer[0].evaluate(query);
|
||||||
|
let eval_left_numerator = self.left_numerator_composer[0].evaluate(query);
|
||||||
|
let eval_right_denominator = self.right_denominator_composer[0].evaluate(query);
|
||||||
|
let eval_left_denominator = self.left_denominator_composer[0].evaluate(query);
|
||||||
|
let eq_eval = self.eq_composer.evaluate(query);
|
||||||
|
|
||||||
|
let res = eq_eval
|
||||||
|
* ((eval_left_numerator * eval_right_denominator
|
||||||
|
+ eval_right_numerator * eval_left_denominator)
|
||||||
|
+ eval_left_denominator * eval_right_denominator * self.combining_randomness);
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a composed ML polynomial for the initial GKR layer from a vector of composition
|
||||||
|
/// polynomials.
|
||||||
|
/// The composition polynomials are divided into LeftNumerator, RightNumerator, LeftDenominator
|
||||||
|
/// and RightDenominator.
|
||||||
|
/// TODO: Generalize this to the case where each numerator/denominator contains more than one
|
||||||
|
/// composition polynomial i.e., a merged composed ML polynomial.
|
||||||
|
pub fn gkr_composition_from_composition_polys<
|
||||||
|
E: FieldElement<BaseField = BaseElement> + 'static,
|
||||||
|
>(
|
||||||
|
composition_polys: &Vec<Vec<Arc<dyn CompositionPolynomial<E>>>>,
|
||||||
|
combining_randomness: E,
|
||||||
|
num_variables: usize,
|
||||||
|
) -> GkrComposition<E> {
|
||||||
|
let eq_composer = Arc::new(ProjectionComposition::new(4));
|
||||||
|
let left_numerator = composition_polys[0].to_owned();
|
||||||
|
let right_numerator = composition_polys[1].to_owned();
|
||||||
|
let left_denominator = composition_polys[2].to_owned();
|
||||||
|
let right_denominator = composition_polys[3].to_owned();
|
||||||
|
GkrComposition::new(
|
||||||
|
num_variables,
|
||||||
|
combining_randomness,
|
||||||
|
eq_composer,
|
||||||
|
right_numerator,
|
||||||
|
left_numerator,
|
||||||
|
right_denominator,
|
||||||
|
left_denominator,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a plain oracle for the sum-check protocol except the final one.
|
||||||
|
pub fn gen_plain_gkr_oracle<E: FieldElement<BaseField = BaseElement> + 'static>(
|
||||||
|
num_rounds: usize,
|
||||||
|
r_sum_check: E,
|
||||||
|
) -> ComposedMultiLinearsOracle<E> {
|
||||||
|
let gkr_composer = Arc::new(GkrCompositionVanilla::new(num_rounds, 0, r_sum_check, vec![]));
|
||||||
|
|
||||||
|
let ml_oracles = vec![
|
||||||
|
MultiLinearOracle { id: 0 },
|
||||||
|
MultiLinearOracle { id: 1 },
|
||||||
|
MultiLinearOracle { id: 2 },
|
||||||
|
MultiLinearOracle { id: 3 },
|
||||||
|
MultiLinearOracle { id: 4 },
|
||||||
|
];
|
||||||
|
|
||||||
|
let oracle = ComposedMultiLinearsOracle {
|
||||||
|
composer: gkr_composer,
|
||||||
|
multi_linears: ml_oracles,
|
||||||
|
};
|
||||||
|
oracle
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_index<E: FieldElement<BaseField = BaseElement>>(index: &[E]) -> usize {
|
||||||
|
let res = index.iter().fold(E::ZERO, |acc, term| acc * E::ONE.double() + (*term));
|
||||||
|
let res = res.base_element(0);
|
||||||
|
res.as_int() as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
fn inner_product<E: FieldElement>(evaluations: &[E], tensored_query: &[E]) -> E {
|
||||||
|
assert_eq!(evaluations.len(), tensored_query.len());
|
||||||
|
evaluations
|
||||||
|
.iter()
|
||||||
|
.zip(tensored_query.iter())
|
||||||
|
.fold(E::ZERO, |acc, (x_i, y_i)| acc + *x_i * *y_i)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tensorize<E: FieldElement>(query: &[E]) -> Vec<E> {
|
||||||
|
let nu = query.len();
|
||||||
|
let n = 1 << nu;
|
||||||
|
|
||||||
|
(0..n).map(|i| lagrange_basis_eval(query, i)).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn lagrange_basis_eval<E: FieldElement>(query: &[E], i: usize) -> E {
|
||||||
|
query
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(j, x_j)| if i & (1 << j) == 0 { E::ONE - *x_j } else { *x_j })
|
||||||
|
.fold(E::ONE, |acc, v| acc * v)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn compute_claim<E: FieldElement>(poly: &ComposedMultiLinears<E>) -> E {
|
||||||
|
let cube_size = 1 << poly.num_variables_ml();
|
||||||
|
let mut res = E::ZERO;
|
||||||
|
|
||||||
|
for i in 0..cube_size {
|
||||||
|
let eval_point: Vec<E> =
|
||||||
|
poly.multi_linears.iter().map(|poly| poly.evaluations[i]).collect();
|
||||||
|
res += poly.composer.evaluate(&eval_point);
|
||||||
|
}
|
||||||
|
res
|
||||||
|
}
|
||||||
108
src/gkr/sumcheck/mod.rs
Normal file
108
src/gkr/sumcheck/mod.rs
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
use super::{
|
||||||
|
multivariate::{ComposedMultiLinears, ComposedMultiLinearsOracle},
|
||||||
|
utils::{barycentric_weights, evaluate_barycentric},
|
||||||
|
};
|
||||||
|
use winter_math::FieldElement;
|
||||||
|
|
||||||
|
mod prover;
|
||||||
|
pub use prover::sum_check_prove;
|
||||||
|
mod verifier;
|
||||||
|
pub use verifier::{sum_check_verify, sum_check_verify_and_reduce};
|
||||||
|
mod tests;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RoundProof<E> {
|
||||||
|
pub poly_evals: Vec<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: FieldElement> RoundProof<E> {
|
||||||
|
pub fn to_evals(&self, claim: E) -> Vec<E> {
|
||||||
|
let mut result = vec![];
|
||||||
|
|
||||||
|
// s(0) + s(1) = claim
|
||||||
|
let c0 = claim - self.poly_evals[0];
|
||||||
|
|
||||||
|
result.push(c0);
|
||||||
|
result.extend_from_slice(&self.poly_evals);
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: refactor once we move to coefficient form
|
||||||
|
pub(crate) fn evaluate(&self, claim: E, r: E) -> E {
|
||||||
|
let poly_evals = self.to_evals(claim);
|
||||||
|
|
||||||
|
let points: Vec<E> = (0..poly_evals.len()).map(|i| E::from(i as u8)).collect();
|
||||||
|
let evalss: Vec<(E, E)> =
|
||||||
|
points.iter().zip(poly_evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||||
|
let weights = barycentric_weights(&evalss);
|
||||||
|
let new_claim = evaluate_barycentric(&evalss, r, &weights);
|
||||||
|
new_claim
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PartialProof<E> {
|
||||||
|
pub round_proofs: Vec<RoundProof<E>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct FinalEvaluationClaim<E: FieldElement> {
|
||||||
|
pub evaluation_point: Vec<E>,
|
||||||
|
pub claimed_evaluation: E,
|
||||||
|
pub polynomial: ComposedMultiLinearsOracle<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct FullProof<E: FieldElement> {
|
||||||
|
pub sum_check_proof: PartialProof<E>,
|
||||||
|
pub final_evaluation_claim: FinalEvaluationClaim<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Claim<E: FieldElement> {
|
||||||
|
pub sum_value: E,
|
||||||
|
pub polynomial: ComposedMultiLinearsOracle<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RoundClaim<E: FieldElement> {
|
||||||
|
pub partial_eval_point: Vec<E>,
|
||||||
|
pub current_claim: E,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RoundOutput<E: FieldElement> {
|
||||||
|
proof: PartialProof<E>,
|
||||||
|
witness: Witness<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: FieldElement> From<Claim<E>> for RoundClaim<E> {
|
||||||
|
fn from(value: Claim<E>) -> Self {
|
||||||
|
Self {
|
||||||
|
partial_eval_point: vec![],
|
||||||
|
current_claim: value.sum_value,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Witness<E: FieldElement> {
|
||||||
|
pub(crate) polynomial: ComposedMultiLinears<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reduce_claim<E: FieldElement>(
|
||||||
|
current_poly: RoundProof<E>,
|
||||||
|
current_round_claim: RoundClaim<E>,
|
||||||
|
round_challenge: E,
|
||||||
|
) -> RoundClaim<E> {
|
||||||
|
let poly_evals = current_poly.to_evals(current_round_claim.current_claim);
|
||||||
|
let points: Vec<E> = (0..poly_evals.len()).map(|i| E::from(i as u8)).collect();
|
||||||
|
let evalss: Vec<(E, E)> = points.iter().zip(poly_evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||||
|
let weights = barycentric_weights(&evalss);
|
||||||
|
let new_claim = evaluate_barycentric(&evalss, round_challenge, &weights);
|
||||||
|
|
||||||
|
let mut new_partial_eval_point = current_round_claim.partial_eval_point;
|
||||||
|
new_partial_eval_point.push(round_challenge);
|
||||||
|
|
||||||
|
RoundClaim {
|
||||||
|
partial_eval_point: new_partial_eval_point,
|
||||||
|
current_claim: new_claim,
|
||||||
|
}
|
||||||
|
}
|
||||||
109
src/gkr/sumcheck/prover.rs
Normal file
109
src/gkr/sumcheck/prover.rs
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
use super::{Claim, FullProof, RoundProof, Witness};
|
||||||
|
use crate::gkr::{
|
||||||
|
multivariate::{ComposedMultiLinears, ComposedMultiLinearsOracle},
|
||||||
|
sumcheck::{reduce_claim, FinalEvaluationClaim, PartialProof, RoundClaim, RoundOutput},
|
||||||
|
};
|
||||||
|
use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
||||||
|
use winter_crypto::{ElementHasher, RandomCoin};
|
||||||
|
use winter_math::{fields::f64::BaseElement, FieldElement};
|
||||||
|
|
||||||
|
pub fn sum_check_prove<
|
||||||
|
E: FieldElement<BaseField = BaseElement>,
|
||||||
|
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||||
|
H: ElementHasher<BaseField = BaseElement>,
|
||||||
|
>(
|
||||||
|
claim: &Claim<E>,
|
||||||
|
oracle: ComposedMultiLinearsOracle<E>,
|
||||||
|
witness: Witness<E>,
|
||||||
|
coin: &mut C,
|
||||||
|
) -> FullProof<E> {
|
||||||
|
// Setup first round
|
||||||
|
let mut prev_claim = RoundClaim {
|
||||||
|
partial_eval_point: vec![],
|
||||||
|
current_claim: claim.sum_value.clone(),
|
||||||
|
};
|
||||||
|
let prev_proof = PartialProof { round_proofs: vec![] };
|
||||||
|
let num_vars = witness.polynomial.num_variables_ml();
|
||||||
|
let prev_output = RoundOutput { proof: prev_proof, witness };
|
||||||
|
|
||||||
|
let mut output = sumcheck_round(prev_output);
|
||||||
|
let poly_evals = &output.proof.round_proofs[0].poly_evals;
|
||||||
|
coin.reseed(H::hash_elements(&poly_evals));
|
||||||
|
|
||||||
|
for i in 1..num_vars {
|
||||||
|
let round_challenge = coin.draw().unwrap();
|
||||||
|
let new_claim = reduce_claim(
|
||||||
|
output.proof.round_proofs.last().unwrap().clone(),
|
||||||
|
prev_claim,
|
||||||
|
round_challenge,
|
||||||
|
);
|
||||||
|
output.witness.polynomial = output.witness.polynomial.bind(round_challenge);
|
||||||
|
|
||||||
|
output = sumcheck_round(output);
|
||||||
|
prev_claim = new_claim;
|
||||||
|
|
||||||
|
let poly_evals = &output.proof.round_proofs[i].poly_evals;
|
||||||
|
coin.reseed(H::hash_elements(&poly_evals));
|
||||||
|
}
|
||||||
|
|
||||||
|
let round_challenge = coin.draw().unwrap();
|
||||||
|
let RoundClaim { partial_eval_point, current_claim } = reduce_claim(
|
||||||
|
output.proof.round_proofs.last().unwrap().clone(),
|
||||||
|
prev_claim,
|
||||||
|
round_challenge,
|
||||||
|
);
|
||||||
|
let final_eval_claim = FinalEvaluationClaim {
|
||||||
|
evaluation_point: partial_eval_point,
|
||||||
|
claimed_evaluation: current_claim,
|
||||||
|
polynomial: oracle,
|
||||||
|
};
|
||||||
|
|
||||||
|
FullProof {
|
||||||
|
sum_check_proof: output.proof,
|
||||||
|
final_evaluation_claim: final_eval_claim,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sumcheck_round<E: FieldElement>(prev_proof: RoundOutput<E>) -> RoundOutput<E> {
|
||||||
|
let RoundOutput { mut proof, witness } = prev_proof;
|
||||||
|
|
||||||
|
let polynomial = witness.polynomial;
|
||||||
|
let num_ml = polynomial.num_ml();
|
||||||
|
let num_vars = polynomial.num_variables_ml();
|
||||||
|
let num_rounds = num_vars - 1;
|
||||||
|
|
||||||
|
let mut evals_zero = vec![E::ZERO; num_ml];
|
||||||
|
let mut evals_one = vec![E::ZERO; num_ml];
|
||||||
|
let mut deltas = vec![E::ZERO; num_ml];
|
||||||
|
let mut evals_x = vec![E::ZERO; num_ml];
|
||||||
|
|
||||||
|
let total_evals = (0..1 << num_rounds).into_iter().map(|i| {
|
||||||
|
for (j, ml) in polynomial.multi_linears.iter().enumerate() {
|
||||||
|
evals_zero[j] = ml.evaluations[(i << 1) as usize];
|
||||||
|
evals_one[j] = ml.evaluations[(i << 1) + 1];
|
||||||
|
}
|
||||||
|
let mut total_evals = vec![E::ZERO; polynomial.degree()];
|
||||||
|
total_evals[0] = polynomial.composer.evaluate(&evals_one);
|
||||||
|
evals_zero
|
||||||
|
.iter()
|
||||||
|
.zip(evals_one.iter().zip(deltas.iter_mut().zip(evals_x.iter_mut())))
|
||||||
|
.for_each(|(a0, (a1, (delta, evx)))| {
|
||||||
|
*delta = *a1 - *a0;
|
||||||
|
*evx = *a1;
|
||||||
|
});
|
||||||
|
total_evals.iter_mut().skip(1).for_each(|e| {
|
||||||
|
evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| {
|
||||||
|
*evx += *delta;
|
||||||
|
});
|
||||||
|
*e = polynomial.composer.evaluate(&evals_x);
|
||||||
|
});
|
||||||
|
total_evals
|
||||||
|
});
|
||||||
|
let evaluations = total_evals.fold(vec![E::ZERO; polynomial.degree()], |mut acc, evals| {
|
||||||
|
acc.iter_mut().zip(evals.iter()).for_each(|(a, ev)| *a += *ev);
|
||||||
|
acc
|
||||||
|
});
|
||||||
|
let proof_update = RoundProof { poly_evals: evaluations };
|
||||||
|
proof.round_proofs.push(proof_update);
|
||||||
|
RoundOutput { proof, witness: Witness { polynomial } }
|
||||||
|
}
|
||||||
199
src/gkr/sumcheck/tests.rs
Normal file
199
src/gkr/sumcheck/tests.rs
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
use alloc::sync::Arc;
|
||||||
|
use rand::{distributions::Uniform, SeedableRng};
|
||||||
|
use winter_crypto::RandomCoin;
|
||||||
|
use winter_math::{fields::f64::BaseElement, FieldElement};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
gkr::{
|
||||||
|
circuit::{CircuitProof, FractionalSumCircuit},
|
||||||
|
multivariate::{
|
||||||
|
compute_claim, gkr_composition_from_composition_polys, ComposedMultiLinears,
|
||||||
|
ComposedMultiLinearsOracle, CompositionPolynomial, EqPolynomial, GkrComposition,
|
||||||
|
GkrCompositionVanilla, LogUpDenominatorTableComposition,
|
||||||
|
LogUpDenominatorWitnessComposition, MultiLinear, MultiLinearOracle,
|
||||||
|
ProjectionComposition, SumComposition,
|
||||||
|
},
|
||||||
|
sumcheck::{
|
||||||
|
prover::sum_check_prove, verifier::sum_check_verify, Claim, FinalEvaluationClaim,
|
||||||
|
FullProof, Witness,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
hash::rpo::Rpo256,
|
||||||
|
rand::RpoRandomCoin,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn gkr_workflow() {
|
||||||
|
// generate the data witness for the LogUp argument
|
||||||
|
let mut mls = generate_logup_witness::<BaseElement>(3);
|
||||||
|
|
||||||
|
// the is sampled after receiving the main trace commitment
|
||||||
|
let alpha = rand_utils::rand_value();
|
||||||
|
|
||||||
|
// the composition polynomials defining the numerators/denominators
|
||||||
|
let composition_polys: Vec<Vec<Arc<dyn CompositionPolynomial<BaseElement>>>> = vec![
|
||||||
|
// left num
|
||||||
|
vec![Arc::new(ProjectionComposition::new(0))],
|
||||||
|
// right num
|
||||||
|
vec![Arc::new(ProjectionComposition::new(1))],
|
||||||
|
// left den
|
||||||
|
vec![Arc::new(LogUpDenominatorTableComposition::new(2, alpha))],
|
||||||
|
// right den
|
||||||
|
vec![Arc::new(LogUpDenominatorWitnessComposition::new(3, alpha))],
|
||||||
|
];
|
||||||
|
|
||||||
|
// run the GKR prover to obtain:
|
||||||
|
// 1. The fractional sum circuit output.
|
||||||
|
// 2. GKR proofs up to the last circuit layer counting backwards.
|
||||||
|
// 3. GKR proof (i.e., a sum-check proof) for the last circuit layer counting backwards.
|
||||||
|
let seed = [BaseElement::ZERO; 4];
|
||||||
|
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||||
|
let (circuit_outputs, gkr_before_last_proof, final_layer_proof) =
|
||||||
|
CircuitProof::prove_virtual_bus(composition_polys.clone(), &mut mls, &mut transcript);
|
||||||
|
|
||||||
|
let seed = [BaseElement::ZERO; 4];
|
||||||
|
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||||
|
|
||||||
|
// run the GKR verifier to obtain:
|
||||||
|
// 1. A final evaluation claim.
|
||||||
|
// 2. Randomness defining the Lagrange kernel in the final sum-check protocol. Note that this
|
||||||
|
// Lagrange kernel is different from the one used by the STARK (outer) prover to open the MLs
|
||||||
|
// at the evaluation point.
|
||||||
|
let (final_eval_claim, gkr_lagrange_kernel_rand) = gkr_before_last_proof.verify_virtual_bus(
|
||||||
|
composition_polys.clone(),
|
||||||
|
final_layer_proof,
|
||||||
|
&circuit_outputs,
|
||||||
|
&mut transcript,
|
||||||
|
);
|
||||||
|
|
||||||
|
// the final verification step is composed of:
|
||||||
|
// 1. Querying the oracles for the openings at the evaluation point. This will be done by the
|
||||||
|
// (outer) STARK prover using:
|
||||||
|
// a. The Lagrange kernel (auxiliary) column at the evaluation point.
|
||||||
|
// b. An extra (auxiliary) column to compute an inner product between two vectors. The first
|
||||||
|
// being the Lagrange kernel and the second being (\sum_{j=0}^3 mls[j][i] * \lambda_i)_{i\in\{0,..,n\}}
|
||||||
|
// 2. Evaluating the composition polynomial at the previous openings and checking equality with
|
||||||
|
// the claimed evaluation.
|
||||||
|
|
||||||
|
// 1. Querying the oracles
|
||||||
|
|
||||||
|
let FinalEvaluationClaim {
|
||||||
|
evaluation_point,
|
||||||
|
claimed_evaluation,
|
||||||
|
polynomial,
|
||||||
|
} = final_eval_claim;
|
||||||
|
|
||||||
|
// The evaluation of the EQ polynomial can be done by the verifier directly
|
||||||
|
let eq = (0..gkr_lagrange_kernel_rand.len())
|
||||||
|
.map(|i| {
|
||||||
|
gkr_lagrange_kernel_rand[i] * evaluation_point[i]
|
||||||
|
+ (BaseElement::ONE - gkr_lagrange_kernel_rand[i])
|
||||||
|
* (BaseElement::ONE - evaluation_point[i])
|
||||||
|
})
|
||||||
|
.fold(BaseElement::ONE, |acc, term| acc * term);
|
||||||
|
|
||||||
|
// These are the queries to the oracles.
|
||||||
|
// They should be provided by the prover non-deterministically
|
||||||
|
let left_num_eval = mls[0].evaluate(&evaluation_point);
|
||||||
|
let right_num_eval = mls[1].evaluate(&evaluation_point);
|
||||||
|
let left_den_eval = mls[2].evaluate(&evaluation_point);
|
||||||
|
let right_den_eval = mls[3].evaluate(&evaluation_point);
|
||||||
|
|
||||||
|
// The verifier absorbs the claimed openings and generates batching randomness lambda
|
||||||
|
let mut query = vec![left_num_eval, right_num_eval, left_den_eval, right_den_eval];
|
||||||
|
transcript.reseed(Rpo256::hash_elements(&query));
|
||||||
|
let lambdas: Vec<BaseElement> = vec![
|
||||||
|
transcript.draw().unwrap(),
|
||||||
|
transcript.draw().unwrap(),
|
||||||
|
transcript.draw().unwrap(),
|
||||||
|
];
|
||||||
|
let batched_query =
|
||||||
|
query[0] + query[1] * lambdas[0] + query[2] * lambdas[1] + query[3] * lambdas[2];
|
||||||
|
|
||||||
|
// The prover generates the Lagrange kernel as an auxiliary column
|
||||||
|
let mut rev_evaluation_point = evaluation_point;
|
||||||
|
rev_evaluation_point.reverse();
|
||||||
|
let lagrange_kernel = EqPolynomial::new(rev_evaluation_point).evaluations();
|
||||||
|
|
||||||
|
// The prover generates the additional auxiliary column for the inner product
|
||||||
|
let tmp_col: Vec<BaseElement> = (0..mls[0].len())
|
||||||
|
.map(|i| {
|
||||||
|
mls[0][i] + mls[1][i] * lambdas[0] + mls[2][i] * lambdas[1] + mls[3][i] * lambdas[2]
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let mut running_sum_col = vec![BaseElement::ZERO; tmp_col.len() + 1];
|
||||||
|
running_sum_col[0] = BaseElement::ZERO;
|
||||||
|
for i in 1..(tmp_col.len() + 1) {
|
||||||
|
running_sum_col[i] = running_sum_col[i - 1] + tmp_col[i - 1] * lagrange_kernel[i - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Boundary constraint to check correctness of openings
|
||||||
|
assert_eq!(batched_query, *running_sum_col.last().unwrap());
|
||||||
|
|
||||||
|
// 2) Final evaluation and check
|
||||||
|
query.push(eq);
|
||||||
|
let verifier_computed = polynomial.composer.evaluate(&query);
|
||||||
|
|
||||||
|
assert_eq!(verifier_computed, claimed_evaluation);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_logup_witness<E: FieldElement>(trace_len: usize) -> Vec<MultiLinear<E>> {
|
||||||
|
let num_variables_ml = trace_len;
|
||||||
|
let num_evaluations = 1 << num_variables_ml;
|
||||||
|
let num_witnesses = 1;
|
||||||
|
let (p, q) = generate_logup_data::<E>(num_variables_ml, num_witnesses);
|
||||||
|
let numerators: Vec<Vec<E>> = p.chunks(num_evaluations).map(|x| x.into()).collect();
|
||||||
|
let denominators: Vec<Vec<E>> = q.chunks(num_evaluations).map(|x| x.into()).collect();
|
||||||
|
|
||||||
|
let mut mls = vec![];
|
||||||
|
for i in 0..2 {
|
||||||
|
let ml = MultiLinear::from_values(&numerators[i]);
|
||||||
|
mls.push(ml);
|
||||||
|
}
|
||||||
|
for i in 0..2 {
|
||||||
|
let ml = MultiLinear::from_values(&denominators[i]);
|
||||||
|
mls.push(ml);
|
||||||
|
}
|
||||||
|
mls
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_logup_data<E: FieldElement>(
|
||||||
|
trace_len: usize,
|
||||||
|
num_witnesses: usize,
|
||||||
|
) -> (Vec<E>, Vec<E>) {
|
||||||
|
use rand::distributions::Slice;
|
||||||
|
use rand::Rng;
|
||||||
|
let n: usize = trace_len;
|
||||||
|
let num_w: usize = num_witnesses; // This should be of the form 2^k - 1
|
||||||
|
let rng = rand::rngs::StdRng::seed_from_u64(0);
|
||||||
|
|
||||||
|
let t_table: Vec<u32> = (0..(1 << n)).collect();
|
||||||
|
let mut m_table: Vec<u32> = (0..(1 << n)).map(|_| 0).collect();
|
||||||
|
|
||||||
|
let t_table_slice = Slice::new(&t_table).unwrap();
|
||||||
|
|
||||||
|
// Construct the witness columns. Uses sampling with replacement in order to have multiplicities
|
||||||
|
// different from 1.
|
||||||
|
let mut w_tables = Vec::new();
|
||||||
|
for _ in 0..num_w {
|
||||||
|
let wi_table: Vec<u32> =
|
||||||
|
rng.clone().sample_iter(&t_table_slice).cloned().take(1 << n).collect();
|
||||||
|
|
||||||
|
// Construct the multiplicities
|
||||||
|
wi_table.iter().for_each(|w| {
|
||||||
|
m_table[*w as usize] += 1;
|
||||||
|
});
|
||||||
|
w_tables.push(wi_table)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The numerators
|
||||||
|
let mut p: Vec<E> = m_table.iter().map(|m| E::from(*m as u32)).collect();
|
||||||
|
p.extend((0..(num_w * (1 << n))).map(|_| E::from(1_u32)).collect::<Vec<E>>());
|
||||||
|
|
||||||
|
// Construct the denominators
|
||||||
|
let mut q: Vec<E> = t_table.iter().map(|t| E::from(*t)).collect();
|
||||||
|
for w_table in w_tables {
|
||||||
|
q.extend(w_table.iter().map(|w| E::from(*w)).collect::<Vec<E>>());
|
||||||
|
}
|
||||||
|
(p, q)
|
||||||
|
}
|
||||||
71
src/gkr/sumcheck/verifier.rs
Normal file
71
src/gkr/sumcheck/verifier.rs
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
use winter_crypto::{ElementHasher, RandomCoin};
|
||||||
|
use winter_math::{fields::f64::BaseElement, FieldElement};
|
||||||
|
|
||||||
|
use crate::gkr::utils::{barycentric_weights, evaluate_barycentric};
|
||||||
|
|
||||||
|
use super::{Claim, FinalEvaluationClaim, FullProof, PartialProof};
|
||||||
|
|
||||||
|
pub fn sum_check_verify_and_reduce<
|
||||||
|
E: FieldElement<BaseField = BaseElement>,
|
||||||
|
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||||
|
H: ElementHasher<BaseField = BaseElement>,
|
||||||
|
>(
|
||||||
|
claim: &Claim<E>,
|
||||||
|
proofs: PartialProof<E>,
|
||||||
|
coin: &mut C,
|
||||||
|
) -> (E, Vec<E>) {
|
||||||
|
let degree = 3;
|
||||||
|
let points: Vec<E> = (0..degree + 1).map(|x| E::from(x as u8)).collect();
|
||||||
|
let mut sum_value = claim.sum_value.clone();
|
||||||
|
let mut randomness = vec![];
|
||||||
|
|
||||||
|
for proof in proofs.round_proofs {
|
||||||
|
let partial_evals = proof.poly_evals.clone();
|
||||||
|
coin.reseed(H::hash_elements(&partial_evals));
|
||||||
|
|
||||||
|
// get r
|
||||||
|
let r: E = coin.draw().unwrap();
|
||||||
|
randomness.push(r);
|
||||||
|
let evals = proof.to_evals(sum_value);
|
||||||
|
|
||||||
|
let point_evals: Vec<_> = points.iter().zip(evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||||
|
let weights = barycentric_weights(&point_evals);
|
||||||
|
sum_value = evaluate_barycentric(&point_evals, r, &weights);
|
||||||
|
}
|
||||||
|
(sum_value, randomness)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sum_check_verify<
|
||||||
|
E: FieldElement<BaseField = BaseElement>,
|
||||||
|
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||||
|
H: ElementHasher<BaseField = BaseElement>,
|
||||||
|
>(
|
||||||
|
claim: &Claim<E>,
|
||||||
|
proofs: FullProof<E>,
|
||||||
|
coin: &mut C,
|
||||||
|
) -> FinalEvaluationClaim<E> {
|
||||||
|
let FullProof {
|
||||||
|
sum_check_proof: proofs,
|
||||||
|
final_evaluation_claim,
|
||||||
|
} = proofs;
|
||||||
|
let Claim { mut sum_value, polynomial } = claim;
|
||||||
|
let degree = polynomial.composer.max_degree();
|
||||||
|
let points: Vec<E> = (0..degree + 1).map(|x| E::from(x as u8)).collect();
|
||||||
|
|
||||||
|
for proof in proofs.round_proofs {
|
||||||
|
let partial_evals = proof.poly_evals.clone();
|
||||||
|
coin.reseed(H::hash_elements(&partial_evals));
|
||||||
|
|
||||||
|
// get r
|
||||||
|
let r: E = coin.draw().unwrap();
|
||||||
|
let evals = proof.to_evals(sum_value);
|
||||||
|
|
||||||
|
let point_evals: Vec<_> = points.iter().zip(evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||||
|
let weights = barycentric_weights(&point_evals);
|
||||||
|
sum_value = evaluate_barycentric(&point_evals, r, &weights);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(final_evaluation_claim.claimed_evaluation, sum_value);
|
||||||
|
|
||||||
|
final_evaluation_claim
|
||||||
|
}
|
||||||
33
src/gkr/utils/mod.rs
Normal file
33
src/gkr/utils/mod.rs
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
use winter_math::{FieldElement, batch_inversion};
|
||||||
|
|
||||||
|
|
||||||
|
pub fn barycentric_weights<E: FieldElement>(points: &[(E, E)]) -> Vec<E> {
|
||||||
|
let n = points.len();
|
||||||
|
let tmp = (0..n)
|
||||||
|
.map(|i| (0..n).filter(|&j| j != i).fold(E::ONE, |acc, j| acc * (points[i].0 - points[j].0)))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
batch_inversion(&tmp)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn evaluate_barycentric<E: FieldElement>(
|
||||||
|
points: &[(E, E)],
|
||||||
|
x: E,
|
||||||
|
barycentric_weights: &[E],
|
||||||
|
) -> E {
|
||||||
|
for &(x_i, y_i) in points {
|
||||||
|
if x_i == x {
|
||||||
|
return y_i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let l_x: E = points.iter().fold(E::ONE, |acc, &(x_i, _y_i)| acc * (x - x_i));
|
||||||
|
|
||||||
|
let sum = (0..points.len()).fold(E::ZERO, |acc, i| {
|
||||||
|
let x_i = points[i].0;
|
||||||
|
let y_i = points[i].1;
|
||||||
|
let w_i = barycentric_weights[i];
|
||||||
|
acc + (w_i / (x - x_i) * y_i)
|
||||||
|
});
|
||||||
|
|
||||||
|
l_x * sum
|
||||||
|
}
|
||||||
@@ -1,9 +1,17 @@
|
|||||||
//! Cryptographic hash functions used by the Miden VM and the Miden rollup.
|
//! Cryptographic hash functions used by the Miden VM and the Miden rollup.
|
||||||
|
|
||||||
use super::{Felt, FieldElement, StarkField, ONE, ZERO};
|
use super::{CubeExtension, Felt, FieldElement, StarkField, ONE, ZERO};
|
||||||
|
|
||||||
pub mod blake;
|
pub mod blake;
|
||||||
pub mod rpo;
|
|
||||||
|
mod rescue;
|
||||||
|
pub mod rpo {
|
||||||
|
pub use super::rescue::{Rpo256, RpoDigest};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub mod rpx {
|
||||||
|
pub use super::rescue::{Rpx256, RpxDigest};
|
||||||
|
}
|
||||||
|
|
||||||
// RE-EXPORTS
|
// RE-EXPORTS
|
||||||
// ================================================================================================
|
// ================================================================================================
|
||||||
|
|||||||
101
src/hash/rescue/arch/mod.rs
Normal file
101
src/hash/rescue/arch/mod.rs
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||||
|
pub mod optimized {
|
||||||
|
use crate::hash::rescue::STATE_WIDTH;
|
||||||
|
use crate::Felt;
|
||||||
|
|
||||||
|
mod ffi {
|
||||||
|
#[link(name = "rpo_sve", kind = "static")]
|
||||||
|
extern "C" {
|
||||||
|
pub fn add_constants_and_apply_sbox(
|
||||||
|
state: *mut std::ffi::c_ulong,
|
||||||
|
constants: *const std::ffi::c_ulong,
|
||||||
|
) -> bool;
|
||||||
|
pub fn add_constants_and_apply_inv_sbox(
|
||||||
|
state: *mut std::ffi::c_ulong,
|
||||||
|
constants: *const std::ffi::c_ulong,
|
||||||
|
) -> bool;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn add_constants_and_apply_sbox(
|
||||||
|
state: &mut [Felt; STATE_WIDTH],
|
||||||
|
ark: &[Felt; STATE_WIDTH],
|
||||||
|
) -> bool {
|
||||||
|
unsafe {
|
||||||
|
ffi::add_constants_and_apply_sbox(
|
||||||
|
state.as_mut_ptr() as *mut u64,
|
||||||
|
ark.as_ptr() as *const u64,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn add_constants_and_apply_inv_sbox(
|
||||||
|
state: &mut [Felt; STATE_WIDTH],
|
||||||
|
ark: &[Felt; STATE_WIDTH],
|
||||||
|
) -> bool {
|
||||||
|
unsafe {
|
||||||
|
ffi::add_constants_and_apply_inv_sbox(
|
||||||
|
state.as_mut_ptr() as *mut u64,
|
||||||
|
ark.as_ptr() as *const u64,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(target_feature = "avx2")]
|
||||||
|
mod x86_64_avx2;
|
||||||
|
|
||||||
|
#[cfg(target_feature = "avx2")]
|
||||||
|
pub mod optimized {
|
||||||
|
use super::x86_64_avx2::{apply_inv_sbox, apply_sbox};
|
||||||
|
use crate::hash::rescue::{add_constants, STATE_WIDTH};
|
||||||
|
use crate::Felt;
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn add_constants_and_apply_sbox(
|
||||||
|
state: &mut [Felt; STATE_WIDTH],
|
||||||
|
ark: &[Felt; STATE_WIDTH],
|
||||||
|
) -> bool {
|
||||||
|
add_constants(state, ark);
|
||||||
|
unsafe {
|
||||||
|
apply_sbox(std::mem::transmute(state));
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn add_constants_and_apply_inv_sbox(
|
||||||
|
state: &mut [Felt; STATE_WIDTH],
|
||||||
|
ark: &[Felt; STATE_WIDTH],
|
||||||
|
) -> bool {
|
||||||
|
add_constants(state, ark);
|
||||||
|
unsafe {
|
||||||
|
apply_inv_sbox(std::mem::transmute(state));
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(any(target_feature = "avx2", all(target_feature = "sve", feature = "sve"))))]
|
||||||
|
pub mod optimized {
|
||||||
|
use crate::hash::rescue::STATE_WIDTH;
|
||||||
|
use crate::Felt;
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn add_constants_and_apply_sbox(
|
||||||
|
_state: &mut [Felt; STATE_WIDTH],
|
||||||
|
_ark: &[Felt; STATE_WIDTH],
|
||||||
|
) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn add_constants_and_apply_inv_sbox(
|
||||||
|
_state: &mut [Felt; STATE_WIDTH],
|
||||||
|
_ark: &[Felt; STATE_WIDTH],
|
||||||
|
) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
325
src/hash/rescue/arch/x86_64_avx2.rs
Normal file
325
src/hash/rescue/arch/x86_64_avx2.rs
Normal file
@@ -0,0 +1,325 @@
|
|||||||
|
use core::arch::x86_64::*;
|
||||||
|
|
||||||
|
// The following AVX2 implementation has been copied from plonky2:
|
||||||
|
// https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs
|
||||||
|
|
||||||
|
// Preliminary notes:
|
||||||
|
// 1. AVX does not support addition with carry but 128-bit (2-word) addition can be easily
|
||||||
|
// emulated. The method recognizes that for a + b overflowed iff (a + b) < a:
|
||||||
|
// i. res_lo = a_lo + b_lo
|
||||||
|
// ii. carry_mask = res_lo < a_lo
|
||||||
|
// iii. res_hi = a_hi + b_hi - carry_mask
|
||||||
|
// Notice that carry_mask is subtracted, not added. This is because AVX comparison instructions
|
||||||
|
// return -1 (all bits 1) for true and 0 for false.
|
||||||
|
//
|
||||||
|
// 2. AVX does not have unsigned 64-bit comparisons. Those can be emulated with signed comparisons
|
||||||
|
// by recognizing that a <u b iff a + (1 << 63) <s b + (1 << 63), where the addition wraps around
|
||||||
|
// and the comparisons are unsigned and signed respectively. The shift function adds/subtracts
|
||||||
|
// 1 << 63 to enable this trick.
|
||||||
|
// Example: addition with carry.
|
||||||
|
// i. a_lo_s = shift(a_lo)
|
||||||
|
// ii. res_lo_s = a_lo_s + b_lo
|
||||||
|
// iii. carry_mask = res_lo_s <s a_lo_s
|
||||||
|
// iv. res_lo = shift(res_lo_s)
|
||||||
|
// v. res_hi = a_hi + b_hi - carry_mask
|
||||||
|
// The suffix _s denotes a value that has been shifted by 1 << 63. The result of addition is
|
||||||
|
// shifted if exactly one of the operands is shifted, as is the case on line ii. Line iii.
|
||||||
|
// performs a signed comparison res_lo_s <s a_lo_s on shifted values to emulate unsigned
|
||||||
|
// comparison res_lo <u a_lo on unshifted values. Finally, line iv. reverses the shift so the
|
||||||
|
// result can be returned.
|
||||||
|
// When performing a chain of calculations, we can often save instructions by letting the shift
|
||||||
|
// propagate through and only undoing it when necessary. For example, to compute the addition of
|
||||||
|
// three two-word (128-bit) numbers we can do:
|
||||||
|
// i. a_lo_s = shift(a_lo)
|
||||||
|
// ii. tmp_lo_s = a_lo_s + b_lo
|
||||||
|
// iii. tmp_carry_mask = tmp_lo_s <s a_lo_s
|
||||||
|
// iv. tmp_hi = a_hi + b_hi - tmp_carry_mask
|
||||||
|
// v. res_lo_s = tmp_lo_s + c_lo
|
||||||
|
// vi. res_carry_mask = res_lo_s <s tmp_lo_s
|
||||||
|
// vii. res_lo = shift(res_lo_s)
|
||||||
|
// viii. res_hi = tmp_hi + c_hi - res_carry_mask
|
||||||
|
// Notice that the above 3-value addition still only requires two calls to shift, just like our
|
||||||
|
// 2-value addition.
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn branch_hint() {
|
||||||
|
// NOTE: These are the currently supported assembly architectures. See the
|
||||||
|
// [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
|
||||||
|
// the most up-to-date list.
|
||||||
|
#[cfg(any(
|
||||||
|
target_arch = "aarch64",
|
||||||
|
target_arch = "arm",
|
||||||
|
target_arch = "riscv32",
|
||||||
|
target_arch = "riscv64",
|
||||||
|
target_arch = "x86",
|
||||||
|
target_arch = "x86_64",
|
||||||
|
))]
|
||||||
|
unsafe {
|
||||||
|
core::arch::asm!("", options(nomem, nostack, preserves_flags));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! map3 {
|
||||||
|
($f:ident::<$l:literal>, $v:ident) => {
|
||||||
|
($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2))
|
||||||
|
};
|
||||||
|
($f:ident::<$l:literal>, $v1:ident, $v2:ident) => {
|
||||||
|
($f::<$l>($v1.0, $v2.0), $f::<$l>($v1.1, $v2.1), $f::<$l>($v1.2, $v2.2))
|
||||||
|
};
|
||||||
|
($f:ident, $v:ident) => {
|
||||||
|
($f($v.0), $f($v.1), $f($v.2))
|
||||||
|
};
|
||||||
|
($f:ident, $v0:ident, $v1:ident) => {
|
||||||
|
($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2))
|
||||||
|
};
|
||||||
|
($f:ident, rep $v0:ident, $v1:ident) => {
|
||||||
|
($f($v0, $v1.0), $f($v0, $v1.1), $f($v0, $v1.2))
|
||||||
|
};
|
||||||
|
|
||||||
|
($f:ident, $v0:ident, rep $v1:ident) => {
|
||||||
|
($f($v0.0, $v1), $f($v0.1, $v1), $f($v0.2, $v1))
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn square3(
|
||||||
|
x: (__m256i, __m256i, __m256i),
|
||||||
|
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
|
||||||
|
let x_hi = {
|
||||||
|
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
|
||||||
|
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
|
||||||
|
// This is safe and free.
|
||||||
|
let x_ps = map3!(_mm256_castsi256_ps, x);
|
||||||
|
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
|
||||||
|
map3!(_mm256_castps_si256, x_hi_ps)
|
||||||
|
};
|
||||||
|
|
||||||
|
// All pairwise multiplications.
|
||||||
|
let mul_ll = map3!(_mm256_mul_epu32, x, x);
|
||||||
|
let mul_lh = map3!(_mm256_mul_epu32, x, x_hi);
|
||||||
|
let mul_hh = map3!(_mm256_mul_epu32, x_hi, x_hi);
|
||||||
|
|
||||||
|
// Bignum addition, but mul_lh is shifted by 33 bits (not 32).
|
||||||
|
let mul_ll_hi = map3!(_mm256_srli_epi64::<33>, mul_ll);
|
||||||
|
let t0 = map3!(_mm256_add_epi64, mul_lh, mul_ll_hi);
|
||||||
|
let t0_hi = map3!(_mm256_srli_epi64::<31>, t0);
|
||||||
|
let res_hi = map3!(_mm256_add_epi64, mul_hh, t0_hi);
|
||||||
|
|
||||||
|
// Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high
|
||||||
|
// position).
|
||||||
|
let mul_lh_lo = map3!(_mm256_slli_epi64::<33>, mul_lh);
|
||||||
|
let res_lo = map3!(_mm256_add_epi64, mul_ll, mul_lh_lo);
|
||||||
|
|
||||||
|
(res_lo, res_hi)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn mul3(
|
||||||
|
x: (__m256i, __m256i, __m256i),
|
||||||
|
y: (__m256i, __m256i, __m256i),
|
||||||
|
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
|
||||||
|
let epsilon = _mm256_set1_epi64x(0xffffffff);
|
||||||
|
let x_hi = {
|
||||||
|
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
|
||||||
|
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
|
||||||
|
// This is safe and free.
|
||||||
|
let x_ps = map3!(_mm256_castsi256_ps, x);
|
||||||
|
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
|
||||||
|
map3!(_mm256_castps_si256, x_hi_ps)
|
||||||
|
};
|
||||||
|
let y_hi = {
|
||||||
|
let y_ps = map3!(_mm256_castsi256_ps, y);
|
||||||
|
let y_hi_ps = map3!(_mm256_movehdup_ps, y_ps);
|
||||||
|
map3!(_mm256_castps_si256, y_hi_ps)
|
||||||
|
};
|
||||||
|
|
||||||
|
// All four pairwise multiplications
|
||||||
|
let mul_ll = map3!(_mm256_mul_epu32, x, y);
|
||||||
|
let mul_lh = map3!(_mm256_mul_epu32, x, y_hi);
|
||||||
|
let mul_hl = map3!(_mm256_mul_epu32, x_hi, y);
|
||||||
|
let mul_hh = map3!(_mm256_mul_epu32, x_hi, y_hi);
|
||||||
|
|
||||||
|
// Bignum addition
|
||||||
|
// Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow.
|
||||||
|
let mul_ll_hi = map3!(_mm256_srli_epi64::<32>, mul_ll);
|
||||||
|
let t0 = map3!(_mm256_add_epi64, mul_hl, mul_ll_hi);
|
||||||
|
// Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow.
|
||||||
|
// Also, extract high 32 bits of t0 and add to mul_hh.
|
||||||
|
let t0_lo = map3!(_mm256_and_si256, t0, rep epsilon);
|
||||||
|
let t0_hi = map3!(_mm256_srli_epi64::<32>, t0);
|
||||||
|
let t1 = map3!(_mm256_add_epi64, mul_lh, t0_lo);
|
||||||
|
let t2 = map3!(_mm256_add_epi64, mul_hh, t0_hi);
|
||||||
|
// Lastly, extract the high 32 bits of t1 and add to t2.
|
||||||
|
let t1_hi = map3!(_mm256_srli_epi64::<32>, t1);
|
||||||
|
let res_hi = map3!(_mm256_add_epi64, t2, t1_hi);
|
||||||
|
|
||||||
|
// Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high
|
||||||
|
// position).
|
||||||
|
let t1_lo = {
|
||||||
|
let t1_ps = map3!(_mm256_castsi256_ps, t1);
|
||||||
|
let t1_lo_ps = map3!(_mm256_moveldup_ps, t1_ps);
|
||||||
|
map3!(_mm256_castps_si256, t1_lo_ps)
|
||||||
|
};
|
||||||
|
let res_lo = map3!(_mm256_blend_epi32::<0xaa>, mul_ll, t1_lo);
|
||||||
|
|
||||||
|
(res_lo, res_hi)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Addition, where the second operand is `0 <= y < 0xffffffff00000001`.
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn add_small(
|
||||||
|
x_s: (__m256i, __m256i, __m256i),
|
||||||
|
y: (__m256i, __m256i, __m256i),
|
||||||
|
) -> (__m256i, __m256i, __m256i) {
|
||||||
|
let res_wrapped_s = map3!(_mm256_add_epi64, x_s, y);
|
||||||
|
let mask = map3!(_mm256_cmpgt_epi32, x_s, res_wrapped_s);
|
||||||
|
let wrapback_amt = map3!(_mm256_srli_epi64::<32>, mask); // EPSILON if overflowed else 0.
|
||||||
|
let res_s = map3!(_mm256_add_epi64, res_wrapped_s, wrapback_amt);
|
||||||
|
res_s
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn maybe_adj_sub(res_wrapped_s: __m256i, mask: __m256i) -> __m256i {
|
||||||
|
// The subtraction is very unlikely to overflow so we're best off branching.
|
||||||
|
// The even u32s in `mask` are meaningless, so we want to ignore them. `_mm256_testz_pd`
|
||||||
|
// branches depending on the sign bit of double-precision (64-bit) floats. Bit cast `mask` to
|
||||||
|
// floating-point (this is free).
|
||||||
|
let mask_pd = _mm256_castsi256_pd(mask);
|
||||||
|
// `_mm256_testz_pd(mask_pd, mask_pd) == 1` iff all sign bits are 0, meaning that underflow
|
||||||
|
// did not occur for any of the vector elements.
|
||||||
|
if _mm256_testz_pd(mask_pd, mask_pd) == 1 {
|
||||||
|
res_wrapped_s
|
||||||
|
} else {
|
||||||
|
branch_hint();
|
||||||
|
// Highly unlikely: underflow did occur. Find adjustment per element and apply it.
|
||||||
|
let adj_amount = _mm256_srli_epi64::<32>(mask); // EPSILON if underflow.
|
||||||
|
_mm256_sub_epi64(res_wrapped_s, adj_amount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Addition, where the second operand is much smaller than `0xffffffff00000001`.
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn sub_tiny(
|
||||||
|
x_s: (__m256i, __m256i, __m256i),
|
||||||
|
y: (__m256i, __m256i, __m256i),
|
||||||
|
) -> (__m256i, __m256i, __m256i) {
|
||||||
|
let res_wrapped_s = map3!(_mm256_sub_epi64, x_s, y);
|
||||||
|
let mask = map3!(_mm256_cmpgt_epi32, res_wrapped_s, x_s);
|
||||||
|
let res_s = map3!(maybe_adj_sub, res_wrapped_s, mask);
|
||||||
|
res_s
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn reduce3(
|
||||||
|
(lo0, hi0): ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)),
|
||||||
|
) -> (__m256i, __m256i, __m256i) {
|
||||||
|
let sign_bit = _mm256_set1_epi64x(i64::MIN);
|
||||||
|
let epsilon = _mm256_set1_epi64x(0xffffffff);
|
||||||
|
let lo0_s = map3!(_mm256_xor_si256, lo0, rep sign_bit);
|
||||||
|
let hi_hi0 = map3!(_mm256_srli_epi64::<32>, hi0);
|
||||||
|
let lo1_s = sub_tiny(lo0_s, hi_hi0);
|
||||||
|
let t1 = map3!(_mm256_mul_epu32, hi0, rep epsilon);
|
||||||
|
let lo2_s = add_small(lo1_s, t1);
|
||||||
|
let lo2 = map3!(_mm256_xor_si256, lo2_s, rep sign_bit);
|
||||||
|
lo2
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn mul_reduce(
|
||||||
|
a: (__m256i, __m256i, __m256i),
|
||||||
|
b: (__m256i, __m256i, __m256i),
|
||||||
|
) -> (__m256i, __m256i, __m256i) {
|
||||||
|
reduce3(mul3(a, b))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn square_reduce(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||||
|
reduce3(square3(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn exp_acc(
|
||||||
|
high: (__m256i, __m256i, __m256i),
|
||||||
|
low: (__m256i, __m256i, __m256i),
|
||||||
|
exp: usize,
|
||||||
|
) -> (__m256i, __m256i, __m256i) {
|
||||||
|
let mut result = high;
|
||||||
|
for _ in 0..exp {
|
||||||
|
result = square_reduce(result);
|
||||||
|
}
|
||||||
|
mul_reduce(result, low)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn do_apply_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||||
|
let state2 = square_reduce(state);
|
||||||
|
let state4_unreduced = square3(state2);
|
||||||
|
let state3_unreduced = mul3(state2, state);
|
||||||
|
let state4 = reduce3(state4_unreduced);
|
||||||
|
let state3 = reduce3(state3_unreduced);
|
||||||
|
let state7_unreduced = mul3(state3, state4);
|
||||||
|
let state7 = reduce3(state7_unreduced);
|
||||||
|
state7
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn do_apply_inv_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||||
|
// compute base^10540996611094048183 using 72 multiplications per array element
|
||||||
|
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
|
||||||
|
|
||||||
|
// compute base^10
|
||||||
|
let t1 = square_reduce(state);
|
||||||
|
|
||||||
|
// compute base^100
|
||||||
|
let t2 = square_reduce(t1);
|
||||||
|
|
||||||
|
// compute base^100100
|
||||||
|
let t3 = exp_acc(t2, t2, 3);
|
||||||
|
|
||||||
|
// compute base^100100100100
|
||||||
|
let t4 = exp_acc(t3, t3, 6);
|
||||||
|
|
||||||
|
// compute base^100100100100100100100100
|
||||||
|
let t5 = exp_acc(t4, t4, 12);
|
||||||
|
|
||||||
|
// compute base^100100100100100100100100100100
|
||||||
|
let t6 = exp_acc(t5, t3, 6);
|
||||||
|
|
||||||
|
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||||
|
let t7 = exp_acc(t6, t6, 31);
|
||||||
|
|
||||||
|
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||||
|
let a = square_reduce(square_reduce(mul_reduce(square_reduce(t7), t6)));
|
||||||
|
let b = mul_reduce(t1, mul_reduce(t2, state));
|
||||||
|
mul_reduce(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn avx2_load(state: &[u64; 12]) -> (__m256i, __m256i, __m256i) {
|
||||||
|
(
|
||||||
|
_mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()),
|
||||||
|
_mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()),
|
||||||
|
_mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn avx2_store(buf: &mut [u64; 12], state: (__m256i, __m256i, __m256i)) {
|
||||||
|
_mm256_storeu_si256((&mut buf[0..4]).as_mut_ptr().cast::<__m256i>(), state.0);
|
||||||
|
_mm256_storeu_si256((&mut buf[4..8]).as_mut_ptr().cast::<__m256i>(), state.1);
|
||||||
|
_mm256_storeu_si256((&mut buf[8..12]).as_mut_ptr().cast::<__m256i>(), state.2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub unsafe fn apply_sbox(buffer: &mut [u64; 12]) {
|
||||||
|
let mut state = avx2_load(&buffer);
|
||||||
|
state = do_apply_sbox(state);
|
||||||
|
avx2_store(buffer, state);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub unsafe fn apply_inv_sbox(buffer: &mut [u64; 12]) {
|
||||||
|
let mut state = avx2_load(&buffer);
|
||||||
|
state = do_apply_inv_sbox(state);
|
||||||
|
avx2_store(buffer, state);
|
||||||
|
}
|
||||||
@@ -11,7 +11,8 @@
|
|||||||
/// divisions by 2 and repeated modular reductions. This is because of our explicit choice of
|
/// divisions by 2 and repeated modular reductions. This is because of our explicit choice of
|
||||||
/// an MDS matrix that has small powers of 2 entries in frequency domain.
|
/// an MDS matrix that has small powers of 2 entries in frequency domain.
|
||||||
/// The following implementation has benefited greatly from the discussions and insights of
|
/// The following implementation has benefited greatly from the discussions and insights of
|
||||||
/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero.
|
/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is base on Nabaglo's Plonky2
|
||||||
|
/// implementation.
|
||||||
|
|
||||||
// Rescue MDS matrix in frequency domain.
|
// Rescue MDS matrix in frequency domain.
|
||||||
// More precisely, this is the output of the three 4-point (real) FFTs of the first column of
|
// More precisely, this is the output of the three 4-point (real) FFTs of the first column of
|
||||||
@@ -26,7 +27,7 @@ const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1];
|
|||||||
|
|
||||||
// We use split 3 x 4 FFT transform in order to transform our vectors into the frequency domain.
|
// We use split 3 x 4 FFT transform in order to transform our vectors into the frequency domain.
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
|
pub const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
|
||||||
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = state;
|
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = state;
|
||||||
|
|
||||||
let (u0, u1, u2) = fft4_real([s0, s3, s6, s9]);
|
let (u0, u1, u2) = fft4_real([s0, s3, s6, s9]);
|
||||||
@@ -156,7 +157,7 @@ const fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::super::{Felt, Rpo256, MDS, ZERO};
|
use super::super::{apply_mds, Felt, MDS, ZERO};
|
||||||
use proptest::prelude::*;
|
use proptest::prelude::*;
|
||||||
|
|
||||||
const STATE_WIDTH: usize = 12;
|
const STATE_WIDTH: usize = 12;
|
||||||
@@ -185,7 +186,7 @@ mod tests {
|
|||||||
v2 = v1;
|
v2 = v1;
|
||||||
|
|
||||||
apply_mds_naive(&mut v1);
|
apply_mds_naive(&mut v1);
|
||||||
Rpo256::apply_mds(&mut v2);
|
apply_mds(&mut v2);
|
||||||
|
|
||||||
prop_assert_eq!(v1, v2);
|
prop_assert_eq!(v1, v2);
|
||||||
}
|
}
|
||||||
214
src/hash/rescue/mds/mod.rs
Normal file
214
src/hash/rescue/mds/mod.rs
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
use super::{Felt, STATE_WIDTH, ZERO};
|
||||||
|
|
||||||
|
mod freq;
|
||||||
|
pub use freq::mds_multiply_freq;
|
||||||
|
|
||||||
|
// MDS MULTIPLICATION
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn apply_mds(state: &mut [Felt; STATE_WIDTH]) {
|
||||||
|
let mut result = [ZERO; STATE_WIDTH];
|
||||||
|
|
||||||
|
// Using the linearity of the operations we can split the state into a low||high decomposition
|
||||||
|
// and operate on each with no overflow and then combine/reduce the result to a field element.
|
||||||
|
// The no overflow is guaranteed by the fact that the MDS matrix is a small powers of two in
|
||||||
|
// frequency domain.
|
||||||
|
let mut state_l = [0u64; STATE_WIDTH];
|
||||||
|
let mut state_h = [0u64; STATE_WIDTH];
|
||||||
|
|
||||||
|
for r in 0..STATE_WIDTH {
|
||||||
|
let s = state[r].inner();
|
||||||
|
state_h[r] = s >> 32;
|
||||||
|
state_l[r] = (s as u32) as u64;
|
||||||
|
}
|
||||||
|
|
||||||
|
let state_h = mds_multiply_freq(state_h);
|
||||||
|
let state_l = mds_multiply_freq(state_l);
|
||||||
|
|
||||||
|
for r in 0..STATE_WIDTH {
|
||||||
|
let s = state_l[r] as u128 + ((state_h[r] as u128) << 32);
|
||||||
|
let s_hi = (s >> 64) as u64;
|
||||||
|
let s_lo = s as u64;
|
||||||
|
let z = (s_hi << 32) - s_hi;
|
||||||
|
let (res, over) = s_lo.overflowing_add(z);
|
||||||
|
|
||||||
|
result[r] = Felt::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64));
|
||||||
|
}
|
||||||
|
*state = result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// MDS MATRIX
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
/// RPO MDS matrix
|
||||||
|
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = [
|
||||||
|
[
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(23),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(26),
|
||||||
|
Felt::new(13),
|
||||||
|
Felt::new(10),
|
||||||
|
Felt::new(9),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(6),
|
||||||
|
Felt::new(22),
|
||||||
|
Felt::new(21),
|
||||||
|
Felt::new(8),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(23),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(26),
|
||||||
|
Felt::new(13),
|
||||||
|
Felt::new(10),
|
||||||
|
Felt::new(9),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(6),
|
||||||
|
Felt::new(22),
|
||||||
|
Felt::new(21),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(21),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(23),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(26),
|
||||||
|
Felt::new(13),
|
||||||
|
Felt::new(10),
|
||||||
|
Felt::new(9),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(6),
|
||||||
|
Felt::new(22),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(22),
|
||||||
|
Felt::new(21),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(23),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(26),
|
||||||
|
Felt::new(13),
|
||||||
|
Felt::new(10),
|
||||||
|
Felt::new(9),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(6),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(6),
|
||||||
|
Felt::new(22),
|
||||||
|
Felt::new(21),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(23),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(26),
|
||||||
|
Felt::new(13),
|
||||||
|
Felt::new(10),
|
||||||
|
Felt::new(9),
|
||||||
|
Felt::new(7),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(6),
|
||||||
|
Felt::new(22),
|
||||||
|
Felt::new(21),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(23),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(26),
|
||||||
|
Felt::new(13),
|
||||||
|
Felt::new(10),
|
||||||
|
Felt::new(9),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(9),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(6),
|
||||||
|
Felt::new(22),
|
||||||
|
Felt::new(21),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(23),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(26),
|
||||||
|
Felt::new(13),
|
||||||
|
Felt::new(10),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(10),
|
||||||
|
Felt::new(9),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(6),
|
||||||
|
Felt::new(22),
|
||||||
|
Felt::new(21),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(23),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(26),
|
||||||
|
Felt::new(13),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(13),
|
||||||
|
Felt::new(10),
|
||||||
|
Felt::new(9),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(6),
|
||||||
|
Felt::new(22),
|
||||||
|
Felt::new(21),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(23),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(26),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(26),
|
||||||
|
Felt::new(13),
|
||||||
|
Felt::new(10),
|
||||||
|
Felt::new(9),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(6),
|
||||||
|
Felt::new(22),
|
||||||
|
Felt::new(21),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(23),
|
||||||
|
Felt::new(8),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(26),
|
||||||
|
Felt::new(13),
|
||||||
|
Felt::new(10),
|
||||||
|
Felt::new(9),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(6),
|
||||||
|
Felt::new(22),
|
||||||
|
Felt::new(21),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(23),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(23),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(26),
|
||||||
|
Felt::new(13),
|
||||||
|
Felt::new(10),
|
||||||
|
Felt::new(9),
|
||||||
|
Felt::new(7),
|
||||||
|
Felt::new(6),
|
||||||
|
Felt::new(22),
|
||||||
|
Felt::new(21),
|
||||||
|
Felt::new(8),
|
||||||
|
Felt::new(7),
|
||||||
|
],
|
||||||
|
];
|
||||||
348
src/hash/rescue/mod.rs
Normal file
348
src/hash/rescue/mod.rs
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
use super::{
|
||||||
|
CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO,
|
||||||
|
};
|
||||||
|
use core::ops::Range;
|
||||||
|
|
||||||
|
mod arch;
|
||||||
|
pub use arch::optimized::{add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox};
|
||||||
|
|
||||||
|
mod mds;
|
||||||
|
use mds::{apply_mds, MDS};
|
||||||
|
|
||||||
|
mod rpo;
|
||||||
|
pub use rpo::{Rpo256, RpoDigest};
|
||||||
|
|
||||||
|
mod rpx;
|
||||||
|
pub use rpx::{Rpx256, RpxDigest};
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests;
|
||||||
|
|
||||||
|
// CONSTANTS
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
/// The number of rounds is set to 7. For the RPO hash functions all rounds are uniform. For the
|
||||||
|
/// RPX hash function, there are 3 different types of rounds.
|
||||||
|
const NUM_ROUNDS: usize = 7;
|
||||||
|
|
||||||
|
/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and
|
||||||
|
/// the remaining 4 elements are reserved for capacity.
|
||||||
|
const STATE_WIDTH: usize = 12;
|
||||||
|
|
||||||
|
/// The rate portion of the state is located in elements 4 through 11.
|
||||||
|
const RATE_RANGE: Range<usize> = 4..12;
|
||||||
|
const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
|
||||||
|
|
||||||
|
const INPUT1_RANGE: Range<usize> = 4..8;
|
||||||
|
const INPUT2_RANGE: Range<usize> = 8..12;
|
||||||
|
|
||||||
|
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||||
|
const CAPACITY_RANGE: Range<usize> = 0..4;
|
||||||
|
|
||||||
|
/// The output of the hash function is a digest which consists of 4 field elements or 32 bytes.
|
||||||
|
///
|
||||||
|
/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the
|
||||||
|
/// rate portion).
|
||||||
|
const DIGEST_RANGE: Range<usize> = 4..8;
|
||||||
|
const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start;
|
||||||
|
|
||||||
|
/// The number of bytes needed to encoded a digest
|
||||||
|
const DIGEST_BYTES: usize = 32;
|
||||||
|
|
||||||
|
/// The number of byte chunks defining a field element when hashing a sequence of bytes
|
||||||
|
const BINARY_CHUNK_SIZE: usize = 7;
|
||||||
|
|
||||||
|
/// S-Box and Inverse S-Box powers;
|
||||||
|
///
|
||||||
|
/// The constants are defined for tests only because the exponentiations in the code are unrolled
|
||||||
|
/// for efficiency reasons.
|
||||||
|
#[cfg(test)]
|
||||||
|
const ALPHA: u64 = 7;
|
||||||
|
#[cfg(test)]
|
||||||
|
const INV_ALPHA: u64 = 10540996611094048183;
|
||||||
|
|
||||||
|
// SBOX FUNCTION
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn apply_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||||
|
state[0] = state[0].exp7();
|
||||||
|
state[1] = state[1].exp7();
|
||||||
|
state[2] = state[2].exp7();
|
||||||
|
state[3] = state[3].exp7();
|
||||||
|
state[4] = state[4].exp7();
|
||||||
|
state[5] = state[5].exp7();
|
||||||
|
state[6] = state[6].exp7();
|
||||||
|
state[7] = state[7].exp7();
|
||||||
|
state[8] = state[8].exp7();
|
||||||
|
state[9] = state[9].exp7();
|
||||||
|
state[10] = state[10].exp7();
|
||||||
|
state[11] = state[11].exp7();
|
||||||
|
}
|
||||||
|
|
||||||
|
// INVERSE SBOX FUNCTION
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||||
|
// compute base^10540996611094048183 using 72 multiplications per array element
|
||||||
|
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
|
||||||
|
|
||||||
|
// compute base^10
|
||||||
|
let mut t1 = *state;
|
||||||
|
t1.iter_mut().for_each(|t| *t = t.square());
|
||||||
|
|
||||||
|
// compute base^100
|
||||||
|
let mut t2 = t1;
|
||||||
|
t2.iter_mut().for_each(|t| *t = t.square());
|
||||||
|
|
||||||
|
// compute base^100100
|
||||||
|
let t3 = exp_acc::<Felt, STATE_WIDTH, 3>(t2, t2);
|
||||||
|
|
||||||
|
// compute base^100100100100
|
||||||
|
let t4 = exp_acc::<Felt, STATE_WIDTH, 6>(t3, t3);
|
||||||
|
|
||||||
|
// compute base^100100100100100100100100
|
||||||
|
let t5 = exp_acc::<Felt, STATE_WIDTH, 12>(t4, t4);
|
||||||
|
|
||||||
|
// compute base^100100100100100100100100100100
|
||||||
|
let t6 = exp_acc::<Felt, STATE_WIDTH, 6>(t5, t3);
|
||||||
|
|
||||||
|
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||||
|
let t7 = exp_acc::<Felt, STATE_WIDTH, 31>(t6, t6);
|
||||||
|
|
||||||
|
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||||
|
for (i, s) in state.iter_mut().enumerate() {
|
||||||
|
let a = (t7[i].square() * t6[i]).square().square();
|
||||||
|
let b = t1[i] * t2[i] * *s;
|
||||||
|
*s = a * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn exp_acc<B: StarkField, const N: usize, const M: usize>(
|
||||||
|
base: [B; N],
|
||||||
|
tail: [B; N],
|
||||||
|
) -> [B; N] {
|
||||||
|
let mut result = base;
|
||||||
|
for _ in 0..M {
|
||||||
|
result.iter_mut().for_each(|r| *r = r.square());
|
||||||
|
}
|
||||||
|
result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t);
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) {
|
||||||
|
state.iter_mut().zip(ark).for_each(|(s, &k)| *s += k);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ROUND CONSTANTS
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
/// Rescue round constants;
|
||||||
|
/// computed as in [specifications](https://github.com/ASDiscreteMathematics/rpo)
|
||||||
|
///
|
||||||
|
/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the
|
||||||
|
/// first half of RPO round, and ARK2 contains constants for the second half of RPO round.
|
||||||
|
const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||||
|
[
|
||||||
|
Felt::new(5789762306288267392),
|
||||||
|
Felt::new(6522564764413701783),
|
||||||
|
Felt::new(17809893479458208203),
|
||||||
|
Felt::new(107145243989736508),
|
||||||
|
Felt::new(6388978042437517382),
|
||||||
|
Felt::new(15844067734406016715),
|
||||||
|
Felt::new(9975000513555218239),
|
||||||
|
Felt::new(3344984123768313364),
|
||||||
|
Felt::new(9959189626657347191),
|
||||||
|
Felt::new(12960773468763563665),
|
||||||
|
Felt::new(9602914297752488475),
|
||||||
|
Felt::new(16657542370200465908),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(12987190162843096997),
|
||||||
|
Felt::new(653957632802705281),
|
||||||
|
Felt::new(4441654670647621225),
|
||||||
|
Felt::new(4038207883745915761),
|
||||||
|
Felt::new(5613464648874830118),
|
||||||
|
Felt::new(13222989726778338773),
|
||||||
|
Felt::new(3037761201230264149),
|
||||||
|
Felt::new(16683759727265180203),
|
||||||
|
Felt::new(8337364536491240715),
|
||||||
|
Felt::new(3227397518293416448),
|
||||||
|
Felt::new(8110510111539674682),
|
||||||
|
Felt::new(2872078294163232137),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(18072785500942327487),
|
||||||
|
Felt::new(6200974112677013481),
|
||||||
|
Felt::new(17682092219085884187),
|
||||||
|
Felt::new(10599526828986756440),
|
||||||
|
Felt::new(975003873302957338),
|
||||||
|
Felt::new(8264241093196931281),
|
||||||
|
Felt::new(10065763900435475170),
|
||||||
|
Felt::new(2181131744534710197),
|
||||||
|
Felt::new(6317303992309418647),
|
||||||
|
Felt::new(1401440938888741532),
|
||||||
|
Felt::new(8884468225181997494),
|
||||||
|
Felt::new(13066900325715521532),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(5674685213610121970),
|
||||||
|
Felt::new(5759084860419474071),
|
||||||
|
Felt::new(13943282657648897737),
|
||||||
|
Felt::new(1352748651966375394),
|
||||||
|
Felt::new(17110913224029905221),
|
||||||
|
Felt::new(1003883795902368422),
|
||||||
|
Felt::new(4141870621881018291),
|
||||||
|
Felt::new(8121410972417424656),
|
||||||
|
Felt::new(14300518605864919529),
|
||||||
|
Felt::new(13712227150607670181),
|
||||||
|
Felt::new(17021852944633065291),
|
||||||
|
Felt::new(6252096473787587650),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(4887609836208846458),
|
||||||
|
Felt::new(3027115137917284492),
|
||||||
|
Felt::new(9595098600469470675),
|
||||||
|
Felt::new(10528569829048484079),
|
||||||
|
Felt::new(7864689113198939815),
|
||||||
|
Felt::new(17533723827845969040),
|
||||||
|
Felt::new(5781638039037710951),
|
||||||
|
Felt::new(17024078752430719006),
|
||||||
|
Felt::new(109659393484013511),
|
||||||
|
Felt::new(7158933660534805869),
|
||||||
|
Felt::new(2955076958026921730),
|
||||||
|
Felt::new(7433723648458773977),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(16308865189192447297),
|
||||||
|
Felt::new(11977192855656444890),
|
||||||
|
Felt::new(12532242556065780287),
|
||||||
|
Felt::new(14594890931430968898),
|
||||||
|
Felt::new(7291784239689209784),
|
||||||
|
Felt::new(5514718540551361949),
|
||||||
|
Felt::new(10025733853830934803),
|
||||||
|
Felt::new(7293794580341021693),
|
||||||
|
Felt::new(6728552937464861756),
|
||||||
|
Felt::new(6332385040983343262),
|
||||||
|
Felt::new(13277683694236792804),
|
||||||
|
Felt::new(2600778905124452676),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(7123075680859040534),
|
||||||
|
Felt::new(1034205548717903090),
|
||||||
|
Felt::new(7717824418247931797),
|
||||||
|
Felt::new(3019070937878604058),
|
||||||
|
Felt::new(11403792746066867460),
|
||||||
|
Felt::new(10280580802233112374),
|
||||||
|
Felt::new(337153209462421218),
|
||||||
|
Felt::new(13333398568519923717),
|
||||||
|
Felt::new(3596153696935337464),
|
||||||
|
Felt::new(8104208463525993784),
|
||||||
|
Felt::new(14345062289456085693),
|
||||||
|
Felt::new(17036731477169661256),
|
||||||
|
],
|
||||||
|
];
|
||||||
|
|
||||||
|
const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||||
|
[
|
||||||
|
Felt::new(6077062762357204287),
|
||||||
|
Felt::new(15277620170502011191),
|
||||||
|
Felt::new(5358738125714196705),
|
||||||
|
Felt::new(14233283787297595718),
|
||||||
|
Felt::new(13792579614346651365),
|
||||||
|
Felt::new(11614812331536767105),
|
||||||
|
Felt::new(14871063686742261166),
|
||||||
|
Felt::new(10148237148793043499),
|
||||||
|
Felt::new(4457428952329675767),
|
||||||
|
Felt::new(15590786458219172475),
|
||||||
|
Felt::new(10063319113072092615),
|
||||||
|
Felt::new(14200078843431360086),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(6202948458916099932),
|
||||||
|
Felt::new(17690140365333231091),
|
||||||
|
Felt::new(3595001575307484651),
|
||||||
|
Felt::new(373995945117666487),
|
||||||
|
Felt::new(1235734395091296013),
|
||||||
|
Felt::new(14172757457833931602),
|
||||||
|
Felt::new(707573103686350224),
|
||||||
|
Felt::new(15453217512188187135),
|
||||||
|
Felt::new(219777875004506018),
|
||||||
|
Felt::new(17876696346199469008),
|
||||||
|
Felt::new(17731621626449383378),
|
||||||
|
Felt::new(2897136237748376248),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(8023374565629191455),
|
||||||
|
Felt::new(15013690343205953430),
|
||||||
|
Felt::new(4485500052507912973),
|
||||||
|
Felt::new(12489737547229155153),
|
||||||
|
Felt::new(9500452585969030576),
|
||||||
|
Felt::new(2054001340201038870),
|
||||||
|
Felt::new(12420704059284934186),
|
||||||
|
Felt::new(355990932618543755),
|
||||||
|
Felt::new(9071225051243523860),
|
||||||
|
Felt::new(12766199826003448536),
|
||||||
|
Felt::new(9045979173463556963),
|
||||||
|
Felt::new(12934431667190679898),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(18389244934624494276),
|
||||||
|
Felt::new(16731736864863925227),
|
||||||
|
Felt::new(4440209734760478192),
|
||||||
|
Felt::new(17208448209698888938),
|
||||||
|
Felt::new(8739495587021565984),
|
||||||
|
Felt::new(17000774922218161967),
|
||||||
|
Felt::new(13533282547195532087),
|
||||||
|
Felt::new(525402848358706231),
|
||||||
|
Felt::new(16987541523062161972),
|
||||||
|
Felt::new(5466806524462797102),
|
||||||
|
Felt::new(14512769585918244983),
|
||||||
|
Felt::new(10973956031244051118),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(6982293561042362913),
|
||||||
|
Felt::new(14065426295947720331),
|
||||||
|
Felt::new(16451845770444974180),
|
||||||
|
Felt::new(7139138592091306727),
|
||||||
|
Felt::new(9012006439959783127),
|
||||||
|
Felt::new(14619614108529063361),
|
||||||
|
Felt::new(1394813199588124371),
|
||||||
|
Felt::new(4635111139507788575),
|
||||||
|
Felt::new(16217473952264203365),
|
||||||
|
Felt::new(10782018226466330683),
|
||||||
|
Felt::new(6844229992533662050),
|
||||||
|
Felt::new(7446486531695178711),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(3736792340494631448),
|
||||||
|
Felt::new(577852220195055341),
|
||||||
|
Felt::new(6689998335515779805),
|
||||||
|
Felt::new(13886063479078013492),
|
||||||
|
Felt::new(14358505101923202168),
|
||||||
|
Felt::new(7744142531772274164),
|
||||||
|
Felt::new(16135070735728404443),
|
||||||
|
Felt::new(12290902521256031137),
|
||||||
|
Felt::new(12059913662657709804),
|
||||||
|
Felt::new(16456018495793751911),
|
||||||
|
Felt::new(4571485474751953524),
|
||||||
|
Felt::new(17200392109565783176),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
Felt::new(17130398059294018733),
|
||||||
|
Felt::new(519782857322261988),
|
||||||
|
Felt::new(9625384390925085478),
|
||||||
|
Felt::new(1664893052631119222),
|
||||||
|
Felt::new(7629576092524553570),
|
||||||
|
Felt::new(3485239601103661425),
|
||||||
|
Felt::new(9755891797164033838),
|
||||||
|
Felt::new(15218148195153269027),
|
||||||
|
Felt::new(16460604813734957368),
|
||||||
|
Felt::new(9643968136937729763),
|
||||||
|
Felt::new(3611348709641382851),
|
||||||
|
Felt::new(18256379591337759196),
|
||||||
|
],
|
||||||
|
];
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
use super::{Digest, Felt, StarkField, DIGEST_SIZE, ZERO};
|
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
|
||||||
use crate::utils::{
|
use crate::utils::{
|
||||||
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
|
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
|
||||||
DeserializationError, HexParseError, Serializable,
|
DeserializationError, HexParseError, Serializable,
|
||||||
@@ -6,9 +6,6 @@ use crate::utils::{
|
|||||||
use core::{cmp::Ordering, fmt::Display, ops::Deref};
|
use core::{cmp::Ordering, fmt::Display, ops::Deref};
|
||||||
use winter_utils::Randomizable;
|
use winter_utils::Randomizable;
|
||||||
|
|
||||||
/// The number of bytes needed to encoded a digest
|
|
||||||
pub const DIGEST_BYTES: usize = 32;
|
|
||||||
|
|
||||||
// DIGEST TRAIT IMPLEMENTATIONS
|
// DIGEST TRAIT IMPLEMENTATIONS
|
||||||
// ================================================================================================
|
// ================================================================================================
|
||||||
|
|
||||||
@@ -172,9 +169,21 @@ impl From<&RpoDigest> for String {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CONVERSIONS: TO DIGEST
|
// CONVERSIONS: TO RPO DIGEST
|
||||||
// ================================================================================================
|
// ================================================================================================
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug)]
|
||||||
|
pub enum RpoDigestError {
|
||||||
|
/// The provided u64 integer does not fit in the field's moduli.
|
||||||
|
InvalidInteger,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&[Felt; DIGEST_SIZE]> for RpoDigest {
|
||||||
|
fn from(value: &[Felt; DIGEST_SIZE]) -> Self {
|
||||||
|
Self(*value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<[Felt; DIGEST_SIZE]> for RpoDigest {
|
impl From<[Felt; DIGEST_SIZE]> for RpoDigest {
|
||||||
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
|
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||||
Self(value)
|
Self(value)
|
||||||
@@ -200,6 +209,46 @@ impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&[u8; DIGEST_BYTES]> for RpoDigest {
|
||||||
|
type Error = HexParseError;
|
||||||
|
|
||||||
|
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||||
|
(*value).try_into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&[u8]> for RpoDigest {
|
||||||
|
type Error = HexParseError;
|
||||||
|
|
||||||
|
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
|
||||||
|
(*value).try_into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<[u64; DIGEST_SIZE]> for RpoDigest {
|
||||||
|
type Error = RpoDigestError;
|
||||||
|
|
||||||
|
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
|
||||||
|
if value[0] >= Felt::MODULUS
|
||||||
|
|| value[1] >= Felt::MODULUS
|
||||||
|
|| value[2] >= Felt::MODULUS
|
||||||
|
|| value[3] >= Felt::MODULUS
|
||||||
|
{
|
||||||
|
return Err(RpoDigestError::InvalidInteger);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&[u64; DIGEST_SIZE]> for RpoDigest {
|
||||||
|
type Error = RpoDigestError;
|
||||||
|
|
||||||
|
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
|
||||||
|
(*value).try_into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl TryFrom<&str> for RpoDigest {
|
impl TryFrom<&str> for RpoDigest {
|
||||||
type Error = HexParseError;
|
type Error = HexParseError;
|
||||||
|
|
||||||
@@ -253,13 +302,24 @@ impl Deserializable for RpoDigest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ITERATORS
|
||||||
|
// ================================================================================================
|
||||||
|
impl IntoIterator for RpoDigest {
|
||||||
|
type Item = Felt;
|
||||||
|
type IntoIter = <[Felt; 4] as IntoIterator>::IntoIter;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
self.0.into_iter()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TESTS
|
// TESTS
|
||||||
// ================================================================================================
|
// ================================================================================================
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES};
|
use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
|
||||||
use crate::utils::SliceReader;
|
use crate::utils::{string::String, SliceReader};
|
||||||
use rand_utils::rand_value;
|
use rand_utils::rand_value;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -281,7 +341,6 @@ mod tests {
|
|||||||
assert_eq!(d1, d2);
|
assert_eq!(d1, d2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "std")]
|
|
||||||
#[test]
|
#[test]
|
||||||
fn digest_encoding() {
|
fn digest_encoding() {
|
||||||
let digest = RpoDigest([
|
let digest = RpoDigest([
|
||||||
@@ -296,4 +355,54 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(digest, round_trip);
|
assert_eq!(digest, round_trip);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_conversions() {
|
||||||
|
let digest = RpoDigest([
|
||||||
|
Felt::new(rand_value()),
|
||||||
|
Felt::new(rand_value()),
|
||||||
|
Felt::new(rand_value()),
|
||||||
|
Felt::new(rand_value()),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let v: [Felt; DIGEST_SIZE] = digest.into();
|
||||||
|
let v2: RpoDigest = v.into();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: [Felt; DIGEST_SIZE] = (&digest).into();
|
||||||
|
let v2: RpoDigest = v.into();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: [u64; DIGEST_SIZE] = digest.into();
|
||||||
|
let v2: RpoDigest = v.try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: [u64; DIGEST_SIZE] = (&digest).into();
|
||||||
|
let v2: RpoDigest = v.try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||||
|
let v2: RpoDigest = v.try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||||
|
let v2: RpoDigest = v.try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: String = digest.into();
|
||||||
|
let v2: RpoDigest = v.try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: String = (&digest).into();
|
||||||
|
let v2: RpoDigest = v.try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||||
|
let v2: RpoDigest = (&v).try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||||
|
let v2: RpoDigest = (&v).try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
323
src/hash/rescue/rpo/mod.rs
Normal file
323
src/hash/rescue/rpo/mod.rs
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
use super::{
|
||||||
|
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
|
||||||
|
apply_mds, apply_sbox, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ARK1,
|
||||||
|
ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, DIGEST_SIZE, INPUT1_RANGE,
|
||||||
|
INPUT2_RANGE, MDS, NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO,
|
||||||
|
};
|
||||||
|
use core::{convert::TryInto, ops::Range};
|
||||||
|
|
||||||
|
mod digest;
|
||||||
|
pub use digest::RpoDigest;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests;
|
||||||
|
|
||||||
|
// HASHER IMPLEMENTATION
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
/// Implementation of the Rescue Prime Optimized hash function with 256-bit output.
|
||||||
|
///
|
||||||
|
/// The hash function is implemented according to the Rescue Prime Optimized
|
||||||
|
/// [specifications](https://eprint.iacr.org/2022/1577)
|
||||||
|
///
|
||||||
|
/// The parameters used to instantiate the function are:
|
||||||
|
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
|
||||||
|
/// * State width: 12 field elements.
|
||||||
|
/// * Capacity size: 4 field elements.
|
||||||
|
/// * Number of founds: 7.
|
||||||
|
/// * S-Box degree: 7.
|
||||||
|
///
|
||||||
|
/// The above parameters target 128-bit security level. The digest consists of four field elements
|
||||||
|
/// and it can be serialized into 32 bytes (256 bits).
|
||||||
|
///
|
||||||
|
/// ## Hash output consistency
|
||||||
|
/// Functions [hash_elements()](Rpo256::hash_elements), [merge()](Rpo256::merge), and
|
||||||
|
/// [merge_with_int()](Rpo256::merge_with_int) are internally consistent. That is, computing
|
||||||
|
/// a hash for the same set of elements using these functions will always produce the same
|
||||||
|
/// result. For example, merging two digests using [merge()](Rpo256::merge) will produce the
|
||||||
|
/// same result as hashing 8 elements which make up these digests using
|
||||||
|
/// [hash_elements()](Rpo256::hash_elements) function.
|
||||||
|
///
|
||||||
|
/// However, [hash()](Rpo256::hash) function is not consistent with functions mentioned above.
|
||||||
|
/// For example, if we take two field elements, serialize them to bytes and hash them using
|
||||||
|
/// [hash()](Rpo256::hash), the result will differ from the result obtained by hashing these
|
||||||
|
/// elements directly using [hash_elements()](Rpo256::hash_elements) function. The reason for
|
||||||
|
/// this difference is that [hash()](Rpo256::hash) function needs to be able to handle
|
||||||
|
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
|
||||||
|
/// deserialization procedure used by this function is different from the procedure used to
|
||||||
|
/// deserialize valid field elements.
|
||||||
|
///
|
||||||
|
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
||||||
|
/// to deserialize them into field elements and then hash them using
|
||||||
|
/// [hash_elements()](Rpo256::hash_elements) function rather then hashing the serialized bytes
|
||||||
|
/// using [hash()](Rpo256::hash) function.
|
||||||
|
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||||
|
pub struct Rpo256();
|
||||||
|
|
||||||
|
impl Hasher for Rpo256 {
|
||||||
|
/// Rpo256 collision resistance is the same as the security level, that is 128-bits.
|
||||||
|
///
|
||||||
|
/// #### Collision resistance
|
||||||
|
///
|
||||||
|
/// However, our setup of the capacity registers might drop it to 126.
|
||||||
|
///
|
||||||
|
/// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69)
|
||||||
|
const COLLISION_RESISTANCE: u32 = 128;
|
||||||
|
|
||||||
|
type Digest = RpoDigest;
|
||||||
|
|
||||||
|
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||||
|
// initialize the state with zeroes
|
||||||
|
let mut state = [ZERO; STATE_WIDTH];
|
||||||
|
|
||||||
|
// set the capacity (first element) to a flag on whether or not the input length is evenly
|
||||||
|
// divided by the rate. this will prevent collisions between padded and non-padded inputs,
|
||||||
|
// and will rule out the need to perform an extra permutation in case of evenly divided
|
||||||
|
// inputs.
|
||||||
|
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0;
|
||||||
|
if !is_rate_multiple {
|
||||||
|
state[CAPACITY_RANGE.start] = ONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize a buffer to receive the little-endian elements.
|
||||||
|
let mut buf = [0_u8; 8];
|
||||||
|
|
||||||
|
// iterate the chunks of bytes, creating a field element from each chunk and copying it
|
||||||
|
// into the state.
|
||||||
|
//
|
||||||
|
// every time the rate range is filled, a permutation is performed. if the final value of
|
||||||
|
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
|
||||||
|
// additional permutation must be performed.
|
||||||
|
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
|
||||||
|
// the last element of the iteration may or may not be a full chunk. if it's not, then
|
||||||
|
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`.
|
||||||
|
// this will avoid collisions.
|
||||||
|
if chunk.len() == BINARY_CHUNK_SIZE {
|
||||||
|
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||||
|
} else {
|
||||||
|
buf.fill(0);
|
||||||
|
buf[..chunk.len()].copy_from_slice(chunk);
|
||||||
|
buf[chunk.len()] = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the current rate element to the input. since we take at most 7 bytes, we are
|
||||||
|
// guaranteed that the inputs data will fit into a single field element.
|
||||||
|
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf));
|
||||||
|
|
||||||
|
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
||||||
|
// counter to the beginning of the range.
|
||||||
|
if i == RATE_WIDTH - 1 {
|
||||||
|
Self::apply_permutation(&mut state);
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
i + 1
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||||
|
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we
|
||||||
|
// don't need to apply any extra padding because the first capacity element contains a
|
||||||
|
// flag indicating whether the input is evenly divisible by the rate.
|
||||||
|
if i != 0 {
|
||||||
|
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
|
||||||
|
state[RATE_RANGE.start + i] = ONE;
|
||||||
|
Self::apply_permutation(&mut state);
|
||||||
|
}
|
||||||
|
|
||||||
|
// return the first 4 elements of the rate as hash result.
|
||||||
|
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||||
|
// initialize the state by copying the digest elements into the rate portion of the state
|
||||||
|
// (8 total elements), and set the capacity elements to 0.
|
||||||
|
let mut state = [ZERO; STATE_WIDTH];
|
||||||
|
let it = Self::Digest::digests_as_elements(values.iter());
|
||||||
|
for (i, v) in it.enumerate() {
|
||||||
|
state[RATE_RANGE.start + i] = *v;
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply the RPO permutation and return the first four elements of the state
|
||||||
|
Self::apply_permutation(&mut state);
|
||||||
|
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||||
|
// initialize the state as follows:
|
||||||
|
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||||
|
// - if the value fits into a single field element, copy it into the fifth rate element
|
||||||
|
// and set the sixth rate element to 1.
|
||||||
|
// - if the value doesn't fit into a single field element, split it into two field
|
||||||
|
// elements, copy them into rate elements 5 and 6, and set the seventh rate element
|
||||||
|
// to 1.
|
||||||
|
// - set the first capacity element to 1
|
||||||
|
let mut state = [ZERO; STATE_WIDTH];
|
||||||
|
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||||
|
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||||
|
if value < Felt::MODULUS {
|
||||||
|
state[INPUT2_RANGE.start + 1] = ONE;
|
||||||
|
} else {
|
||||||
|
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
||||||
|
state[INPUT2_RANGE.start + 2] = ONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// common padding for both cases
|
||||||
|
state[CAPACITY_RANGE.start] = ONE;
|
||||||
|
|
||||||
|
// apply the RPO permutation and return the first four elements of the state
|
||||||
|
Self::apply_permutation(&mut state);
|
||||||
|
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ElementHasher for Rpo256 {
|
||||||
|
type BaseField = Felt;
|
||||||
|
|
||||||
|
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
|
||||||
|
// convert the elements into a list of base field elements
|
||||||
|
let elements = E::slice_as_base_elements(elements);
|
||||||
|
|
||||||
|
// initialize state to all zeros, except for the first element of the capacity part, which
|
||||||
|
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH.
|
||||||
|
let mut state = [ZERO; STATE_WIDTH];
|
||||||
|
if elements.len() % RATE_WIDTH != 0 {
|
||||||
|
state[CAPACITY_RANGE.start] = ONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// absorb elements into the state one by one until the rate portion of the state is filled
|
||||||
|
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
||||||
|
// elements have been absorbed
|
||||||
|
let mut i = 0;
|
||||||
|
for &element in elements.iter() {
|
||||||
|
state[RATE_RANGE.start + i] = element;
|
||||||
|
i += 1;
|
||||||
|
if i % RATE_WIDTH == 0 {
|
||||||
|
Self::apply_permutation(&mut state);
|
||||||
|
i = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||||
|
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after
|
||||||
|
// padding by appending a 1 followed by as many 0 as necessary to make the input length a
|
||||||
|
// multiple of the RATE_WIDTH.
|
||||||
|
if i > 0 {
|
||||||
|
state[RATE_RANGE.start + i] = ONE;
|
||||||
|
i += 1;
|
||||||
|
while i != RATE_WIDTH {
|
||||||
|
state[RATE_RANGE.start + i] = ZERO;
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
Self::apply_permutation(&mut state);
|
||||||
|
}
|
||||||
|
|
||||||
|
// return the first 4 elements of the state as hash result
|
||||||
|
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HASH FUNCTION IMPLEMENTATION
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
impl Rpo256 {
|
||||||
|
// CONSTANTS
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// The number of rounds is set to 7 to target 128-bit security level.
|
||||||
|
pub const NUM_ROUNDS: usize = NUM_ROUNDS;
|
||||||
|
|
||||||
|
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
|
||||||
|
/// the remaining 4 elements are reserved for capacity.
|
||||||
|
pub const STATE_WIDTH: usize = STATE_WIDTH;
|
||||||
|
|
||||||
|
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
|
||||||
|
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
|
||||||
|
|
||||||
|
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||||
|
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
|
||||||
|
|
||||||
|
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
|
||||||
|
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
|
||||||
|
|
||||||
|
/// MDS matrix used for computing the linear layer in a RPO round.
|
||||||
|
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
|
||||||
|
|
||||||
|
/// Round constants added to the hasher state in the first half of the RPO round.
|
||||||
|
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
|
||||||
|
|
||||||
|
/// Round constants added to the hasher state in the second half of the RPO round.
|
||||||
|
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
|
||||||
|
|
||||||
|
// TRAIT PASS-THROUGH FUNCTIONS
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Returns a hash of the provided sequence of bytes.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn hash(bytes: &[u8]) -> RpoDigest {
|
||||||
|
<Self as Hasher>::hash(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||||
|
/// Merkle trees and verification of Merkle paths.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn merge(values: &[RpoDigest; 2]) -> RpoDigest {
|
||||||
|
<Self as Hasher>::merge(values)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a hash of the provided field elements.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpoDigest {
|
||||||
|
<Self as ElementHasher>::hash_elements(elements)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DOMAIN IDENTIFIER
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Returns a hash of two digests and a domain identifier.
|
||||||
|
pub fn merge_in_domain(values: &[RpoDigest; 2], domain: Felt) -> RpoDigest {
|
||||||
|
// initialize the state by copying the digest elements into the rate portion of the state
|
||||||
|
// (8 total elements), and set the capacity elements to 0.
|
||||||
|
let mut state = [ZERO; STATE_WIDTH];
|
||||||
|
let it = RpoDigest::digests_as_elements(values.iter());
|
||||||
|
for (i, v) in it.enumerate() {
|
||||||
|
state[RATE_RANGE.start + i] = *v;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the second capacity element to the domain value. The first capacity element is used
|
||||||
|
// for padding purposes.
|
||||||
|
state[CAPACITY_RANGE.start + 1] = domain;
|
||||||
|
|
||||||
|
// apply the RPO permutation and return the first four elements of the state
|
||||||
|
Self::apply_permutation(&mut state);
|
||||||
|
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
// RESCUE PERMUTATION
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Applies RPO permutation to the provided state.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
|
||||||
|
for i in 0..NUM_ROUNDS {
|
||||||
|
Self::apply_round(state, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// RPO round function.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||||
|
// apply first half of RPO round
|
||||||
|
apply_mds(state);
|
||||||
|
if !add_constants_and_apply_sbox(state, &ARK1[round]) {
|
||||||
|
add_constants(state, &ARK1[round]);
|
||||||
|
apply_sbox(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply second half of RPO round
|
||||||
|
apply_mds(state);
|
||||||
|
if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
|
||||||
|
add_constants(state, &ARK2[round]);
|
||||||
|
apply_inv_sbox(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
use super::{
|
use super::{
|
||||||
Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, ALPHA, INV_ALPHA, ONE, STATE_WIDTH,
|
super::{apply_inv_sbox, apply_sbox, ALPHA, INV_ALPHA},
|
||||||
ZERO,
|
Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, ONE, STATE_WIDTH, ZERO,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
utils::collections::{BTreeSet, Vec},
|
utils::collections::{BTreeSet, Vec},
|
||||||
@@ -10,13 +10,6 @@ use core::convert::TryInto;
|
|||||||
use proptest::prelude::*;
|
use proptest::prelude::*;
|
||||||
use rand_utils::rand_value;
|
use rand_utils::rand_value;
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_alphas() {
|
|
||||||
let e: Felt = Felt::new(rand_value());
|
|
||||||
let e_exp = e.exp(ALPHA);
|
|
||||||
assert_eq!(e, e_exp.exp(INV_ALPHA));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_sbox() {
|
fn test_sbox() {
|
||||||
let state = [Felt::new(rand_value()); STATE_WIDTH];
|
let state = [Felt::new(rand_value()); STATE_WIDTH];
|
||||||
@@ -25,7 +18,7 @@ fn test_sbox() {
|
|||||||
expected.iter_mut().for_each(|v| *v = v.exp(ALPHA));
|
expected.iter_mut().for_each(|v| *v = v.exp(ALPHA));
|
||||||
|
|
||||||
let mut actual = state;
|
let mut actual = state;
|
||||||
Rpo256::apply_sbox(&mut actual);
|
apply_sbox(&mut actual);
|
||||||
|
|
||||||
assert_eq!(expected, actual);
|
assert_eq!(expected, actual);
|
||||||
}
|
}
|
||||||
@@ -38,7 +31,7 @@ fn test_inv_sbox() {
|
|||||||
expected.iter_mut().for_each(|v| *v = v.exp(INV_ALPHA));
|
expected.iter_mut().for_each(|v| *v = v.exp(INV_ALPHA));
|
||||||
|
|
||||||
let mut actual = state;
|
let mut actual = state;
|
||||||
Rpo256::apply_inv_sbox(&mut actual);
|
apply_inv_sbox(&mut actual);
|
||||||
|
|
||||||
assert_eq!(expected, actual);
|
assert_eq!(expected, actual);
|
||||||
}
|
}
|
||||||
398
src/hash/rescue/rpx/digest.rs
Normal file
398
src/hash/rescue/rpx/digest.rs
Normal file
@@ -0,0 +1,398 @@
|
|||||||
|
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
|
||||||
|
use crate::utils::{
|
||||||
|
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
|
||||||
|
DeserializationError, HexParseError, Serializable,
|
||||||
|
};
|
||||||
|
use core::{cmp::Ordering, fmt::Display, ops::Deref};
|
||||||
|
use winter_utils::Randomizable;
|
||||||
|
|
||||||
|
// DIGEST TRAIT IMPLEMENTATIONS
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
|
||||||
|
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||||
|
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
|
||||||
|
pub struct RpxDigest([Felt; DIGEST_SIZE]);
|
||||||
|
|
||||||
|
impl RpxDigest {
|
||||||
|
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||||
|
Self(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_elements(&self) -> &[Felt] {
|
||||||
|
self.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
|
||||||
|
<Self as Digest>::as_bytes(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn digests_as_elements<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = &'a Self>,
|
||||||
|
{
|
||||||
|
digests.flat_map(|d| d.0.iter())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Digest for RpxDigest {
|
||||||
|
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
|
||||||
|
let mut result = [0; DIGEST_BYTES];
|
||||||
|
|
||||||
|
result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes());
|
||||||
|
result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes());
|
||||||
|
result[16..24].copy_from_slice(&self.0[2].as_int().to_le_bytes());
|
||||||
|
result[24..].copy_from_slice(&self.0[3].as_int().to_le_bytes());
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deref for RpxDigest {
|
||||||
|
type Target = [Felt; DIGEST_SIZE];
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Ord for RpxDigest {
|
||||||
|
fn cmp(&self, other: &Self) -> Ordering {
|
||||||
|
// compare the inner u64 of both elements.
|
||||||
|
//
|
||||||
|
// it will iterate the elements and will return the first computation different than
|
||||||
|
// `Equal`. Otherwise, the ordering is equal.
|
||||||
|
//
|
||||||
|
// the endianness is irrelevant here because since, this being a cryptographically secure
|
||||||
|
// hash computation, the digest shouldn't have any ordered property of its input.
|
||||||
|
//
|
||||||
|
// finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a
|
||||||
|
// montgomery reduction for every limb. that is safe because every inner element of the
|
||||||
|
// digest is guaranteed to be in its canonical form (that is, `x in [0,p)`).
|
||||||
|
self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold(
|
||||||
|
Ordering::Equal,
|
||||||
|
|ord, (a, b)| match ord {
|
||||||
|
Ordering::Equal => a.cmp(&b),
|
||||||
|
_ => ord,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialOrd for RpxDigest {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||||
|
Some(self.cmp(other))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for RpxDigest {
|
||||||
|
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||||
|
let encoded: String = self.into();
|
||||||
|
write!(f, "{}", encoded)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Randomizable for RpxDigest {
|
||||||
|
const VALUE_SIZE: usize = DIGEST_BYTES;
|
||||||
|
|
||||||
|
fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
|
||||||
|
let bytes_array: Option<[u8; 32]> = bytes.try_into().ok();
|
||||||
|
if let Some(bytes_array) = bytes_array {
|
||||||
|
Self::try_from(bytes_array).ok()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CONVERSIONS: FROM RPX DIGEST
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
impl From<&RpxDigest> for [Felt; DIGEST_SIZE] {
|
||||||
|
fn from(value: &RpxDigest) -> Self {
|
||||||
|
value.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<RpxDigest> for [Felt; DIGEST_SIZE] {
|
||||||
|
fn from(value: RpxDigest) -> Self {
|
||||||
|
value.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&RpxDigest> for [u64; DIGEST_SIZE] {
|
||||||
|
fn from(value: &RpxDigest) -> Self {
|
||||||
|
[
|
||||||
|
value.0[0].as_int(),
|
||||||
|
value.0[1].as_int(),
|
||||||
|
value.0[2].as_int(),
|
||||||
|
value.0[3].as_int(),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<RpxDigest> for [u64; DIGEST_SIZE] {
|
||||||
|
fn from(value: RpxDigest) -> Self {
|
||||||
|
[
|
||||||
|
value.0[0].as_int(),
|
||||||
|
value.0[1].as_int(),
|
||||||
|
value.0[2].as_int(),
|
||||||
|
value.0[3].as_int(),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&RpxDigest> for [u8; DIGEST_BYTES] {
|
||||||
|
fn from(value: &RpxDigest) -> Self {
|
||||||
|
value.as_bytes()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<RpxDigest> for [u8; DIGEST_BYTES] {
|
||||||
|
fn from(value: RpxDigest) -> Self {
|
||||||
|
value.as_bytes()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<RpxDigest> for String {
|
||||||
|
/// The returned string starts with `0x`.
|
||||||
|
fn from(value: RpxDigest) -> Self {
|
||||||
|
bytes_to_hex_string(value.as_bytes())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&RpxDigest> for String {
|
||||||
|
/// The returned string starts with `0x`.
|
||||||
|
fn from(value: &RpxDigest) -> Self {
|
||||||
|
(*value).into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CONVERSIONS: TO RPX DIGEST
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug)]
|
||||||
|
pub enum RpxDigestError {
|
||||||
|
/// The provided u64 integer does not fit in the field's moduli.
|
||||||
|
InvalidInteger,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&[Felt; DIGEST_SIZE]> for RpxDigest {
|
||||||
|
fn from(value: &[Felt; DIGEST_SIZE]) -> Self {
|
||||||
|
Self(*value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<[Felt; DIGEST_SIZE]> for RpxDigest {
|
||||||
|
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||||
|
Self(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<[u8; DIGEST_BYTES]> for RpxDigest {
|
||||||
|
type Error = HexParseError;
|
||||||
|
|
||||||
|
fn try_from(value: [u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||||
|
// Note: the input length is known, the conversion from slice to array must succeed so the
|
||||||
|
// `unwrap`s below are safe
|
||||||
|
let a = u64::from_le_bytes(value[0..8].try_into().unwrap());
|
||||||
|
let b = u64::from_le_bytes(value[8..16].try_into().unwrap());
|
||||||
|
let c = u64::from_le_bytes(value[16..24].try_into().unwrap());
|
||||||
|
let d = u64::from_le_bytes(value[24..32].try_into().unwrap());
|
||||||
|
|
||||||
|
if [a, b, c, d].iter().any(|v| *v >= Felt::MODULUS) {
|
||||||
|
return Err(HexParseError::OutOfRange);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(RpxDigest([Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&[u8; DIGEST_BYTES]> for RpxDigest {
|
||||||
|
type Error = HexParseError;
|
||||||
|
|
||||||
|
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||||
|
(*value).try_into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&[u8]> for RpxDigest {
|
||||||
|
type Error = HexParseError;
|
||||||
|
|
||||||
|
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
|
||||||
|
(*value).try_into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<[u64; DIGEST_SIZE]> for RpxDigest {
|
||||||
|
type Error = RpxDigestError;
|
||||||
|
|
||||||
|
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
|
||||||
|
if value[0] >= Felt::MODULUS
|
||||||
|
|| value[1] >= Felt::MODULUS
|
||||||
|
|| value[2] >= Felt::MODULUS
|
||||||
|
|| value[3] >= Felt::MODULUS
|
||||||
|
{
|
||||||
|
return Err(RpxDigestError::InvalidInteger);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&[u64; DIGEST_SIZE]> for RpxDigest {
|
||||||
|
type Error = RpxDigestError;
|
||||||
|
|
||||||
|
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
|
||||||
|
(*value).try_into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&str> for RpxDigest {
|
||||||
|
type Error = HexParseError;
|
||||||
|
|
||||||
|
/// Expects the string to start with `0x`.
|
||||||
|
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||||
|
hex_to_bytes(value).and_then(|v| v.try_into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<String> for RpxDigest {
|
||||||
|
type Error = HexParseError;
|
||||||
|
|
||||||
|
/// Expects the string to start with `0x`.
|
||||||
|
fn try_from(value: String) -> Result<Self, Self::Error> {
|
||||||
|
value.as_str().try_into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&String> for RpxDigest {
|
||||||
|
type Error = HexParseError;
|
||||||
|
|
||||||
|
/// Expects the string to start with `0x`.
|
||||||
|
fn try_from(value: &String) -> Result<Self, Self::Error> {
|
||||||
|
value.as_str().try_into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SERIALIZATION / DESERIALIZATION
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
impl Serializable for RpxDigest {
|
||||||
|
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||||
|
target.write_bytes(&self.as_bytes());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deserializable for RpxDigest {
|
||||||
|
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||||
|
let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE];
|
||||||
|
for inner in inner.iter_mut() {
|
||||||
|
let e = source.read_u64()?;
|
||||||
|
if e >= Felt::MODULUS {
|
||||||
|
return Err(DeserializationError::InvalidValue(String::from(
|
||||||
|
"Value not in the appropriate range",
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
*inner = Felt::new(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self(inner))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TESTS
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{Deserializable, Felt, RpxDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
|
||||||
|
use crate::utils::{string::String, SliceReader};
|
||||||
|
use rand_utils::rand_value;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn digest_serialization() {
|
||||||
|
let e1 = Felt::new(rand_value());
|
||||||
|
let e2 = Felt::new(rand_value());
|
||||||
|
let e3 = Felt::new(rand_value());
|
||||||
|
let e4 = Felt::new(rand_value());
|
||||||
|
|
||||||
|
let d1 = RpxDigest([e1, e2, e3, e4]);
|
||||||
|
|
||||||
|
let mut bytes = vec![];
|
||||||
|
d1.write_into(&mut bytes);
|
||||||
|
assert_eq!(DIGEST_BYTES, bytes.len());
|
||||||
|
|
||||||
|
let mut reader = SliceReader::new(&bytes);
|
||||||
|
let d2 = RpxDigest::read_from(&mut reader).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(d1, d2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
#[test]
|
||||||
|
fn digest_encoding() {
|
||||||
|
let digest = RpxDigest([
|
||||||
|
Felt::new(rand_value()),
|
||||||
|
Felt::new(rand_value()),
|
||||||
|
Felt::new(rand_value()),
|
||||||
|
Felt::new(rand_value()),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let string: String = digest.into();
|
||||||
|
let round_trip: RpxDigest = string.try_into().expect("decoding failed");
|
||||||
|
|
||||||
|
assert_eq!(digest, round_trip);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_conversions() {
|
||||||
|
let digest = RpxDigest([
|
||||||
|
Felt::new(rand_value()),
|
||||||
|
Felt::new(rand_value()),
|
||||||
|
Felt::new(rand_value()),
|
||||||
|
Felt::new(rand_value()),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let v: [Felt; DIGEST_SIZE] = digest.into();
|
||||||
|
let v2: RpxDigest = v.into();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: [Felt; DIGEST_SIZE] = (&digest).into();
|
||||||
|
let v2: RpxDigest = v.into();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: [u64; DIGEST_SIZE] = digest.into();
|
||||||
|
let v2: RpxDigest = v.try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: [u64; DIGEST_SIZE] = (&digest).into();
|
||||||
|
let v2: RpxDigest = v.try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||||
|
let v2: RpxDigest = v.try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||||
|
let v2: RpxDigest = v.try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: String = digest.into();
|
||||||
|
let v2: RpxDigest = v.try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: String = (&digest).into();
|
||||||
|
let v2: RpxDigest = v.try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||||
|
let v2: RpxDigest = (&v).try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
|
||||||
|
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||||
|
let v2: RpxDigest = (&v).try_into().unwrap();
|
||||||
|
assert_eq!(digest, v2);
|
||||||
|
}
|
||||||
|
}
|
||||||
366
src/hash/rescue/rpx/mod.rs
Normal file
366
src/hash/rescue/rpx/mod.rs
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
use super::{
|
||||||
|
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
|
||||||
|
apply_mds, apply_sbox, CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher,
|
||||||
|
StarkField, ARK1, ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE,
|
||||||
|
DIGEST_SIZE, INPUT1_RANGE, INPUT2_RANGE, MDS, NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH,
|
||||||
|
STATE_WIDTH, ZERO,
|
||||||
|
};
|
||||||
|
use core::{convert::TryInto, ops::Range};
|
||||||
|
|
||||||
|
mod digest;
|
||||||
|
pub use digest::RpxDigest;
|
||||||
|
|
||||||
|
pub type CubicExtElement = CubeExtension<Felt>;
|
||||||
|
|
||||||
|
// HASHER IMPLEMENTATION
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
/// Implementation of the Rescue Prime eXtension hash function with 256-bit output.
|
||||||
|
///
|
||||||
|
/// The hash function is based on the XHash12 construction in [specifications](https://eprint.iacr.org/2023/1045)
|
||||||
|
///
|
||||||
|
/// The parameters used to instantiate the function are:
|
||||||
|
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
|
||||||
|
/// * State width: 12 field elements.
|
||||||
|
/// * Capacity size: 4 field elements.
|
||||||
|
/// * S-Box degree: 7.
|
||||||
|
/// * Rounds: There are 3 different types of rounds:
|
||||||
|
/// - (FB): `apply_mds` → `add_constants` → `apply_sbox` → `apply_mds` → `add_constants` → `apply_inv_sbox`.
|
||||||
|
/// - (E): `add_constants` → `ext_sbox` (which is raising to power 7 in the degree 3 extension field).
|
||||||
|
/// - (M): `apply_mds` → `add_constants`.
|
||||||
|
/// * Permutation: (FB) (E) (FB) (E) (FB) (E) (M).
|
||||||
|
///
|
||||||
|
/// The above parameters target 128-bit security level. The digest consists of four field elements
|
||||||
|
/// and it can be serialized into 32 bytes (256 bits).
|
||||||
|
///
|
||||||
|
/// ## Hash output consistency
|
||||||
|
/// Functions [hash_elements()](Rpx256::hash_elements), [merge()](Rpx256::merge), and
|
||||||
|
/// [merge_with_int()](Rpx256::merge_with_int) are internally consistent. That is, computing
|
||||||
|
/// a hash for the same set of elements using these functions will always produce the same
|
||||||
|
/// result. For example, merging two digests using [merge()](Rpx256::merge) will produce the
|
||||||
|
/// same result as hashing 8 elements which make up these digests using
|
||||||
|
/// [hash_elements()](Rpx256::hash_elements) function.
|
||||||
|
///
|
||||||
|
/// However, [hash()](Rpx256::hash) function is not consistent with functions mentioned above.
|
||||||
|
/// For example, if we take two field elements, serialize them to bytes and hash them using
|
||||||
|
/// [hash()](Rpx256::hash), the result will differ from the result obtained by hashing these
|
||||||
|
/// elements directly using [hash_elements()](Rpx256::hash_elements) function. The reason for
|
||||||
|
/// this difference is that [hash()](Rpx256::hash) function needs to be able to handle
|
||||||
|
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
|
||||||
|
/// deserialization procedure used by this function is different from the procedure used to
|
||||||
|
/// deserialize valid field elements.
|
||||||
|
///
|
||||||
|
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
||||||
|
/// to deserialize them into field elements and then hash them using
|
||||||
|
/// [hash_elements()](Rpx256::hash_elements) function rather then hashing the serialized bytes
|
||||||
|
/// using [hash()](Rpx256::hash) function.
|
||||||
|
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||||
|
pub struct Rpx256();
|
||||||
|
|
||||||
|
impl Hasher for Rpx256 {
|
||||||
|
/// Rpx256 collision resistance is the same as the security level, that is 128-bits.
|
||||||
|
///
|
||||||
|
/// #### Collision resistance
|
||||||
|
///
|
||||||
|
/// However, our setup of the capacity registers might drop it to 126.
|
||||||
|
///
|
||||||
|
/// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69)
|
||||||
|
const COLLISION_RESISTANCE: u32 = 128;
|
||||||
|
|
||||||
|
type Digest = RpxDigest;
|
||||||
|
|
||||||
|
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||||
|
// initialize the state with zeroes
|
||||||
|
let mut state = [ZERO; STATE_WIDTH];
|
||||||
|
|
||||||
|
// set the capacity (first element) to a flag on whether or not the input length is evenly
|
||||||
|
// divided by the rate. this will prevent collisions between padded and non-padded inputs,
|
||||||
|
// and will rule out the need to perform an extra permutation in case of evenly divided
|
||||||
|
// inputs.
|
||||||
|
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0;
|
||||||
|
if !is_rate_multiple {
|
||||||
|
state[CAPACITY_RANGE.start] = ONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize a buffer to receive the little-endian elements.
|
||||||
|
let mut buf = [0_u8; 8];
|
||||||
|
|
||||||
|
// iterate the chunks of bytes, creating a field element from each chunk and copying it
|
||||||
|
// into the state.
|
||||||
|
//
|
||||||
|
// every time the rate range is filled, a permutation is performed. if the final value of
|
||||||
|
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
|
||||||
|
// additional permutation must be performed.
|
||||||
|
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
|
||||||
|
// the last element of the iteration may or may not be a full chunk. if it's not, then
|
||||||
|
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`.
|
||||||
|
// this will avoid collisions.
|
||||||
|
if chunk.len() == BINARY_CHUNK_SIZE {
|
||||||
|
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||||
|
} else {
|
||||||
|
buf.fill(0);
|
||||||
|
buf[..chunk.len()].copy_from_slice(chunk);
|
||||||
|
buf[chunk.len()] = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the current rate element to the input. since we take at most 7 bytes, we are
|
||||||
|
// guaranteed that the inputs data will fit into a single field element.
|
||||||
|
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf));
|
||||||
|
|
||||||
|
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
||||||
|
// counter to the beginning of the range.
|
||||||
|
if i == RATE_WIDTH - 1 {
|
||||||
|
Self::apply_permutation(&mut state);
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
i + 1
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||||
|
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation. we
|
||||||
|
// don't need to apply any extra padding because the first capacity element contains a
|
||||||
|
// flag indicating whether the input is evenly divisible by the rate.
|
||||||
|
if i != 0 {
|
||||||
|
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
|
||||||
|
state[RATE_RANGE.start + i] = ONE;
|
||||||
|
Self::apply_permutation(&mut state);
|
||||||
|
}
|
||||||
|
|
||||||
|
// return the first 4 elements of the rate as hash result.
|
||||||
|
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||||
|
// initialize the state by copying the digest elements into the rate portion of the state
|
||||||
|
// (8 total elements), and set the capacity elements to 0.
|
||||||
|
let mut state = [ZERO; STATE_WIDTH];
|
||||||
|
let it = Self::Digest::digests_as_elements(values.iter());
|
||||||
|
for (i, v) in it.enumerate() {
|
||||||
|
state[RATE_RANGE.start + i] = *v;
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply the RPX permutation and return the first four elements of the state
|
||||||
|
Self::apply_permutation(&mut state);
|
||||||
|
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||||
|
// initialize the state as follows:
|
||||||
|
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||||
|
// - if the value fits into a single field element, copy it into the fifth rate element
|
||||||
|
// and set the sixth rate element to 1.
|
||||||
|
// - if the value doesn't fit into a single field element, split it into two field
|
||||||
|
// elements, copy them into rate elements 5 and 6, and set the seventh rate element
|
||||||
|
// to 1.
|
||||||
|
// - set the first capacity element to 1
|
||||||
|
let mut state = [ZERO; STATE_WIDTH];
|
||||||
|
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||||
|
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||||
|
if value < Felt::MODULUS {
|
||||||
|
state[INPUT2_RANGE.start + 1] = ONE;
|
||||||
|
} else {
|
||||||
|
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
||||||
|
state[INPUT2_RANGE.start + 2] = ONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// common padding for both cases
|
||||||
|
state[CAPACITY_RANGE.start] = ONE;
|
||||||
|
|
||||||
|
// apply the RPX permutation and return the first four elements of the state
|
||||||
|
Self::apply_permutation(&mut state);
|
||||||
|
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ElementHasher for Rpx256 {
|
||||||
|
type BaseField = Felt;
|
||||||
|
|
||||||
|
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
|
||||||
|
// convert the elements into a list of base field elements
|
||||||
|
let elements = E::slice_as_base_elements(elements);
|
||||||
|
|
||||||
|
// initialize state to all zeros, except for the first element of the capacity part, which
|
||||||
|
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH.
|
||||||
|
let mut state = [ZERO; STATE_WIDTH];
|
||||||
|
if elements.len() % RATE_WIDTH != 0 {
|
||||||
|
state[CAPACITY_RANGE.start] = ONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// absorb elements into the state one by one until the rate portion of the state is filled
|
||||||
|
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
||||||
|
// elements have been absorbed
|
||||||
|
let mut i = 0;
|
||||||
|
for &element in elements.iter() {
|
||||||
|
state[RATE_RANGE.start + i] = element;
|
||||||
|
i += 1;
|
||||||
|
if i % RATE_WIDTH == 0 {
|
||||||
|
Self::apply_permutation(&mut state);
|
||||||
|
i = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||||
|
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation after
|
||||||
|
// padding by appending a 1 followed by as many 0 as necessary to make the input length a
|
||||||
|
// multiple of the RATE_WIDTH.
|
||||||
|
if i > 0 {
|
||||||
|
state[RATE_RANGE.start + i] = ONE;
|
||||||
|
i += 1;
|
||||||
|
while i != RATE_WIDTH {
|
||||||
|
state[RATE_RANGE.start + i] = ZERO;
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
Self::apply_permutation(&mut state);
|
||||||
|
}
|
||||||
|
|
||||||
|
// return the first 4 elements of the state as hash result
|
||||||
|
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HASH FUNCTION IMPLEMENTATION
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
impl Rpx256 {
|
||||||
|
// CONSTANTS
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
|
||||||
|
/// the remaining 4 elements are reserved for capacity.
|
||||||
|
pub const STATE_WIDTH: usize = STATE_WIDTH;
|
||||||
|
|
||||||
|
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
|
||||||
|
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
|
||||||
|
|
||||||
|
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||||
|
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
|
||||||
|
|
||||||
|
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
|
||||||
|
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
|
||||||
|
|
||||||
|
/// MDS matrix used for computing the linear layer in the (FB) and (E) rounds.
|
||||||
|
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
|
||||||
|
|
||||||
|
/// Round constants added to the hasher state in the first half of the round.
|
||||||
|
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
|
||||||
|
|
||||||
|
/// Round constants added to the hasher state in the second half of the round.
|
||||||
|
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
|
||||||
|
|
||||||
|
// TRAIT PASS-THROUGH FUNCTIONS
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Returns a hash of the provided sequence of bytes.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn hash(bytes: &[u8]) -> RpxDigest {
|
||||||
|
<Self as Hasher>::hash(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||||
|
/// Merkle trees and verification of Merkle paths.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn merge(values: &[RpxDigest; 2]) -> RpxDigest {
|
||||||
|
<Self as Hasher>::merge(values)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a hash of the provided field elements.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpxDigest {
|
||||||
|
<Self as ElementHasher>::hash_elements(elements)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DOMAIN IDENTIFIER
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Returns a hash of two digests and a domain identifier.
|
||||||
|
pub fn merge_in_domain(values: &[RpxDigest; 2], domain: Felt) -> RpxDigest {
|
||||||
|
// initialize the state by copying the digest elements into the rate portion of the state
|
||||||
|
// (8 total elements), and set the capacity elements to 0.
|
||||||
|
let mut state = [ZERO; STATE_WIDTH];
|
||||||
|
let it = RpxDigest::digests_as_elements(values.iter());
|
||||||
|
for (i, v) in it.enumerate() {
|
||||||
|
state[RATE_RANGE.start + i] = *v;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the second capacity element to the domain value. The first capacity element is used
|
||||||
|
// for padding purposes.
|
||||||
|
state[CAPACITY_RANGE.start + 1] = domain;
|
||||||
|
|
||||||
|
// apply the RPX permutation and return the first four elements of the state
|
||||||
|
Self::apply_permutation(&mut state);
|
||||||
|
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
// RPX PERMUTATION
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Applies RPX permutation to the provided state.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
|
||||||
|
Self::apply_fb_round(state, 0);
|
||||||
|
Self::apply_ext_round(state, 1);
|
||||||
|
Self::apply_fb_round(state, 2);
|
||||||
|
Self::apply_ext_round(state, 3);
|
||||||
|
Self::apply_fb_round(state, 4);
|
||||||
|
Self::apply_ext_round(state, 5);
|
||||||
|
Self::apply_final_round(state, 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
// RPX PERMUTATION ROUND FUNCTIONS
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// (FB) round function.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn apply_fb_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||||
|
apply_mds(state);
|
||||||
|
if !add_constants_and_apply_sbox(state, &ARK1[round]) {
|
||||||
|
add_constants(state, &ARK1[round]);
|
||||||
|
apply_sbox(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
apply_mds(state);
|
||||||
|
if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
|
||||||
|
add_constants(state, &ARK2[round]);
|
||||||
|
apply_inv_sbox(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// (E) round function.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn apply_ext_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||||
|
// add constants
|
||||||
|
add_constants(state, &ARK1[round]);
|
||||||
|
|
||||||
|
// decompose the state into 4 elements in the cubic extension field and apply the power 7
|
||||||
|
// map to each of the elements
|
||||||
|
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = *state;
|
||||||
|
let ext0 = Self::exp7(CubicExtElement::new(s0, s1, s2));
|
||||||
|
let ext1 = Self::exp7(CubicExtElement::new(s3, s4, s5));
|
||||||
|
let ext2 = Self::exp7(CubicExtElement::new(s6, s7, s8));
|
||||||
|
let ext3 = Self::exp7(CubicExtElement::new(s9, s10, s11));
|
||||||
|
|
||||||
|
// decompose the state back into 12 base field elements
|
||||||
|
let arr_ext = [ext0, ext1, ext2, ext3];
|
||||||
|
*state = CubicExtElement::slice_as_base_elements(&arr_ext)
|
||||||
|
.try_into()
|
||||||
|
.expect("shouldn't fail");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// (M) round function.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn apply_final_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||||
|
apply_mds(state);
|
||||||
|
add_constants(state, &ARK1[round]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes an exponentiation to the power 7 in cubic extension field
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn exp7(x: CubeExtension<Felt>) -> CubeExtension<Felt> {
|
||||||
|
let x2 = x.square();
|
||||||
|
let x4 = x2.square();
|
||||||
|
|
||||||
|
let x3 = x2 * x;
|
||||||
|
x3 * x4
|
||||||
|
}
|
||||||
|
}
|
||||||
9
src/hash/rescue/tests.rs
Normal file
9
src/hash/rescue/tests.rs
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
use super::{Felt, FieldElement, ALPHA, INV_ALPHA};
|
||||||
|
use rand_utils::rand_value;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_alphas() {
|
||||||
|
let e: Felt = Felt::new(rand_value());
|
||||||
|
let e_exp = e.exp(ALPHA);
|
||||||
|
assert_eq!(e, e_exp.exp(INV_ALPHA));
|
||||||
|
}
|
||||||
@@ -1,905 +0,0 @@
|
|||||||
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO};
|
|
||||||
use core::{convert::TryInto, ops::Range};
|
|
||||||
|
|
||||||
mod digest;
|
|
||||||
pub use digest::RpoDigest;
|
|
||||||
|
|
||||||
mod mds_freq;
|
|
||||||
use mds_freq::mds_multiply_freq;
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests;
|
|
||||||
|
|
||||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
|
||||||
#[link(name = "rpo_sve", kind = "static")]
|
|
||||||
extern "C" {
|
|
||||||
fn add_constants_and_apply_sbox(
|
|
||||||
state: *mut std::ffi::c_ulong,
|
|
||||||
constants: *const std::ffi::c_ulong,
|
|
||||||
) -> bool;
|
|
||||||
fn add_constants_and_apply_inv_sbox(
|
|
||||||
state: *mut std::ffi::c_ulong,
|
|
||||||
constants: *const std::ffi::c_ulong,
|
|
||||||
) -> bool;
|
|
||||||
}
|
|
||||||
|
|
||||||
// CONSTANTS
|
|
||||||
// ================================================================================================
|
|
||||||
|
|
||||||
/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and
|
|
||||||
/// the remaining 4 elements are reserved for capacity.
|
|
||||||
const STATE_WIDTH: usize = 12;
|
|
||||||
|
|
||||||
/// The rate portion of the state is located in elements 4 through 11.
|
|
||||||
const RATE_RANGE: Range<usize> = 4..12;
|
|
||||||
const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
|
|
||||||
|
|
||||||
const INPUT1_RANGE: Range<usize> = 4..8;
|
|
||||||
const INPUT2_RANGE: Range<usize> = 8..12;
|
|
||||||
|
|
||||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
|
||||||
const CAPACITY_RANGE: Range<usize> = 0..4;
|
|
||||||
|
|
||||||
/// The output of the hash function is a digest which consists of 4 field elements or 32 bytes.
|
|
||||||
///
|
|
||||||
/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the
|
|
||||||
/// rate portion).
|
|
||||||
const DIGEST_RANGE: Range<usize> = 4..8;
|
|
||||||
const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start;
|
|
||||||
|
|
||||||
/// The number of rounds is set to 7 to target 128-bit security level
|
|
||||||
const NUM_ROUNDS: usize = 7;
|
|
||||||
|
|
||||||
/// The number of byte chunks defining a field element when hashing a sequence of bytes
|
|
||||||
const BINARY_CHUNK_SIZE: usize = 7;
|
|
||||||
|
|
||||||
/// S-Box and Inverse S-Box powers;
|
|
||||||
///
|
|
||||||
/// The constants are defined for tests only because the exponentiations in the code are unrolled
|
|
||||||
/// for efficiency reasons.
|
|
||||||
#[cfg(test)]
|
|
||||||
const ALPHA: u64 = 7;
|
|
||||||
#[cfg(test)]
|
|
||||||
const INV_ALPHA: u64 = 10540996611094048183;
|
|
||||||
|
|
||||||
// HASHER IMPLEMENTATION
|
|
||||||
// ================================================================================================
|
|
||||||
|
|
||||||
/// Implementation of the Rescue Prime Optimized hash function with 256-bit output.
|
|
||||||
///
|
|
||||||
/// The hash function is implemented according to the Rescue Prime Optimized
|
|
||||||
/// [specifications](https://eprint.iacr.org/2022/1577)
|
|
||||||
///
|
|
||||||
/// The parameters used to instantiate the function are:
|
|
||||||
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
|
|
||||||
/// * State width: 12 field elements.
|
|
||||||
/// * Capacity size: 4 field elements.
|
|
||||||
/// * Number of founds: 7.
|
|
||||||
/// * S-Box degree: 7.
|
|
||||||
///
|
|
||||||
/// The above parameters target 128-bit security level. The digest consists of four field elements
|
|
||||||
/// and it can be serialized into 32 bytes (256 bits).
|
|
||||||
///
|
|
||||||
/// ## Hash output consistency
|
|
||||||
/// Functions [hash_elements()](Rpo256::hash_elements), [merge()](Rpo256::merge), and
|
|
||||||
/// [merge_with_int()](Rpo256::merge_with_int) are internally consistent. That is, computing
|
|
||||||
/// a hash for the same set of elements using these functions will always produce the same
|
|
||||||
/// result. For example, merging two digests using [merge()](Rpo256::merge) will produce the
|
|
||||||
/// same result as hashing 8 elements which make up these digests using
|
|
||||||
/// [hash_elements()](Rpo256::hash_elements) function.
|
|
||||||
///
|
|
||||||
/// However, [hash()](Rpo256::hash) function is not consistent with functions mentioned above.
|
|
||||||
/// For example, if we take two field elements, serialize them to bytes and hash them using
|
|
||||||
/// [hash()](Rpo256::hash), the result will differ from the result obtained by hashing these
|
|
||||||
/// elements directly using [hash_elements()](Rpo256::hash_elements) function. The reason for
|
|
||||||
/// this difference is that [hash()](Rpo256::hash) function needs to be able to handle
|
|
||||||
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
|
|
||||||
/// deserialization procedure used by this function is different from the procedure used to
|
|
||||||
/// deserialize valid field elements.
|
|
||||||
///
|
|
||||||
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
|
||||||
/// to deserialize them into field elements and then hash them using
|
|
||||||
/// [hash_elements()](Rpo256::hash_elements) function rather then hashing the serialized bytes
|
|
||||||
/// using [hash()](Rpo256::hash) function.
|
|
||||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
|
||||||
pub struct Rpo256();
|
|
||||||
|
|
||||||
impl Hasher for Rpo256 {
|
|
||||||
/// Rpo256 collision resistance is the same as the security level, that is 128-bits.
|
|
||||||
///
|
|
||||||
/// #### Collision resistance
|
|
||||||
///
|
|
||||||
/// However, our setup of the capacity registers might drop it to 126.
|
|
||||||
///
|
|
||||||
/// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69)
|
|
||||||
const COLLISION_RESISTANCE: u32 = 128;
|
|
||||||
|
|
||||||
type Digest = RpoDigest;
|
|
||||||
|
|
||||||
fn hash(bytes: &[u8]) -> Self::Digest {
|
|
||||||
// initialize the state with zeroes
|
|
||||||
let mut state = [ZERO; STATE_WIDTH];
|
|
||||||
|
|
||||||
// set the capacity (first element) to a flag on whether or not the input length is evenly
|
|
||||||
// divided by the rate. this will prevent collisions between padded and non-padded inputs,
|
|
||||||
// and will rule out the need to perform an extra permutation in case of evenly divided
|
|
||||||
// inputs.
|
|
||||||
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0;
|
|
||||||
if !is_rate_multiple {
|
|
||||||
state[CAPACITY_RANGE.start] = ONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
// initialize a buffer to receive the little-endian elements.
|
|
||||||
let mut buf = [0_u8; 8];
|
|
||||||
|
|
||||||
// iterate the chunks of bytes, creating a field element from each chunk and copying it
|
|
||||||
// into the state.
|
|
||||||
//
|
|
||||||
// every time the rate range is filled, a permutation is performed. if the final value of
|
|
||||||
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
|
|
||||||
// additional permutation must be performed.
|
|
||||||
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
|
|
||||||
// the last element of the iteration may or may not be a full chunk. if it's not, then
|
|
||||||
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`.
|
|
||||||
// this will avoid collisions.
|
|
||||||
if chunk.len() == BINARY_CHUNK_SIZE {
|
|
||||||
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
|
||||||
} else {
|
|
||||||
buf.fill(0);
|
|
||||||
buf[..chunk.len()].copy_from_slice(chunk);
|
|
||||||
buf[chunk.len()] = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// set the current rate element to the input. since we take at most 7 bytes, we are
|
|
||||||
// guaranteed that the inputs data will fit into a single field element.
|
|
||||||
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf));
|
|
||||||
|
|
||||||
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
|
||||||
// counter to the beginning of the range.
|
|
||||||
if i == RATE_WIDTH - 1 {
|
|
||||||
Self::apply_permutation(&mut state);
|
|
||||||
0
|
|
||||||
} else {
|
|
||||||
i + 1
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
|
||||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we
|
|
||||||
// don't need to apply any extra padding because the first capacity element contains a
|
|
||||||
// flag indicating whether the input is evenly divisible by the rate.
|
|
||||||
if i != 0 {
|
|
||||||
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
|
|
||||||
state[RATE_RANGE.start + i] = ONE;
|
|
||||||
Self::apply_permutation(&mut state);
|
|
||||||
}
|
|
||||||
|
|
||||||
// return the first 4 elements of the rate as hash result.
|
|
||||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
|
||||||
// initialize the state by copying the digest elements into the rate portion of the state
|
|
||||||
// (8 total elements), and set the capacity elements to 0.
|
|
||||||
let mut state = [ZERO; STATE_WIDTH];
|
|
||||||
let it = Self::Digest::digests_as_elements(values.iter());
|
|
||||||
for (i, v) in it.enumerate() {
|
|
||||||
state[RATE_RANGE.start + i] = *v;
|
|
||||||
}
|
|
||||||
|
|
||||||
// apply the RPO permutation and return the first four elements of the state
|
|
||||||
Self::apply_permutation(&mut state);
|
|
||||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
|
||||||
// initialize the state as follows:
|
|
||||||
// - seed is copied into the first 4 elements of the rate portion of the state.
|
|
||||||
// - if the value fits into a single field element, copy it into the fifth rate element
|
|
||||||
// and set the sixth rate element to 1.
|
|
||||||
// - if the value doesn't fit into a single field element, split it into two field
|
|
||||||
// elements, copy them into rate elements 5 and 6, and set the seventh rate element
|
|
||||||
// to 1.
|
|
||||||
// - set the first capacity element to 1
|
|
||||||
let mut state = [ZERO; STATE_WIDTH];
|
|
||||||
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
|
||||||
state[INPUT2_RANGE.start] = Felt::new(value);
|
|
||||||
if value < Felt::MODULUS {
|
|
||||||
state[INPUT2_RANGE.start + 1] = ONE;
|
|
||||||
} else {
|
|
||||||
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
|
||||||
state[INPUT2_RANGE.start + 2] = ONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
// common padding for both cases
|
|
||||||
state[CAPACITY_RANGE.start] = ONE;
|
|
||||||
|
|
||||||
// apply the RPO permutation and return the first four elements of the state
|
|
||||||
Self::apply_permutation(&mut state);
|
|
||||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ElementHasher for Rpo256 {
|
|
||||||
type BaseField = Felt;
|
|
||||||
|
|
||||||
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
|
|
||||||
// convert the elements into a list of base field elements
|
|
||||||
let elements = E::slice_as_base_elements(elements);
|
|
||||||
|
|
||||||
// initialize state to all zeros, except for the first element of the capacity part, which
|
|
||||||
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH.
|
|
||||||
let mut state = [ZERO; STATE_WIDTH];
|
|
||||||
if elements.len() % RATE_WIDTH != 0 {
|
|
||||||
state[CAPACITY_RANGE.start] = ONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
// absorb elements into the state one by one until the rate portion of the state is filled
|
|
||||||
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
|
||||||
// elements have been absorbed
|
|
||||||
let mut i = 0;
|
|
||||||
for &element in elements.iter() {
|
|
||||||
state[RATE_RANGE.start + i] = element;
|
|
||||||
i += 1;
|
|
||||||
if i % RATE_WIDTH == 0 {
|
|
||||||
Self::apply_permutation(&mut state);
|
|
||||||
i = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
|
||||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after
|
|
||||||
// padding by appending a 1 followed by as many 0 as necessary to make the input length a
|
|
||||||
// multiple of the RATE_WIDTH.
|
|
||||||
if i > 0 {
|
|
||||||
state[RATE_RANGE.start + i] = ONE;
|
|
||||||
i += 1;
|
|
||||||
while i != RATE_WIDTH {
|
|
||||||
state[RATE_RANGE.start + i] = ZERO;
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
Self::apply_permutation(&mut state);
|
|
||||||
}
|
|
||||||
|
|
||||||
// return the first 4 elements of the state as hash result
|
|
||||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// HASH FUNCTION IMPLEMENTATION
|
|
||||||
// ================================================================================================
|
|
||||||
|
|
||||||
impl Rpo256 {
|
|
||||||
// CONSTANTS
|
|
||||||
// --------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
/// The number of rounds is set to 7 to target 128-bit security level.
|
|
||||||
pub const NUM_ROUNDS: usize = NUM_ROUNDS;
|
|
||||||
|
|
||||||
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
|
|
||||||
/// the remaining 4 elements are reserved for capacity.
|
|
||||||
pub const STATE_WIDTH: usize = STATE_WIDTH;
|
|
||||||
|
|
||||||
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
|
|
||||||
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
|
|
||||||
|
|
||||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
|
||||||
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
|
|
||||||
|
|
||||||
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
|
|
||||||
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
|
|
||||||
|
|
||||||
/// MDS matrix used for computing the linear layer in a RPO round.
|
|
||||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
|
|
||||||
|
|
||||||
/// Round constants added to the hasher state in the first half of the RPO round.
|
|
||||||
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
|
|
||||||
|
|
||||||
/// Round constants added to the hasher state in the second half of the RPO round.
|
|
||||||
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
|
|
||||||
|
|
||||||
// TRAIT PASS-THROUGH FUNCTIONS
|
|
||||||
// --------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
/// Returns a hash of the provided sequence of bytes.
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn hash(bytes: &[u8]) -> RpoDigest {
|
|
||||||
<Self as Hasher>::hash(bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a hash of two digests. This method is intended for use in construction of
|
|
||||||
/// Merkle trees and verification of Merkle paths.
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn merge(values: &[RpoDigest; 2]) -> RpoDigest {
|
|
||||||
<Self as Hasher>::merge(values)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a hash of the provided field elements.
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpoDigest {
|
|
||||||
<Self as ElementHasher>::hash_elements(elements)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DOMAIN IDENTIFIER
|
|
||||||
// --------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
/// Returns a hash of two digests and a domain identifier.
|
|
||||||
pub fn merge_in_domain(values: &[RpoDigest; 2], domain: Felt) -> RpoDigest {
|
|
||||||
// initialize the state by copying the digest elements into the rate portion of the state
|
|
||||||
// (8 total elements), and set the capacity elements to 0.
|
|
||||||
let mut state = [ZERO; STATE_WIDTH];
|
|
||||||
let it = RpoDigest::digests_as_elements(values.iter());
|
|
||||||
for (i, v) in it.enumerate() {
|
|
||||||
state[RATE_RANGE.start + i] = *v;
|
|
||||||
}
|
|
||||||
|
|
||||||
// set the second capacity element to the domain value. The first capacity element is used
|
|
||||||
// for padding purposes.
|
|
||||||
state[CAPACITY_RANGE.start + 1] = domain;
|
|
||||||
|
|
||||||
// apply the RPO permutation and return the first four elements of the state
|
|
||||||
Self::apply_permutation(&mut state);
|
|
||||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
|
||||||
}
|
|
||||||
|
|
||||||
// RESCUE PERMUTATION
|
|
||||||
// --------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
/// Applies RPO permutation to the provided state.
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
|
|
||||||
for i in 0..NUM_ROUNDS {
|
|
||||||
Self::apply_round(state, i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// RPO round function.
|
|
||||||
#[inline(always)]
|
|
||||||
pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
|
||||||
// apply first half of RPO round
|
|
||||||
Self::apply_mds(state);
|
|
||||||
if !Self::optimized_add_constants_and_apply_sbox(state, &ARK1[round]) {
|
|
||||||
Self::add_constants(state, &ARK1[round]);
|
|
||||||
Self::apply_sbox(state);
|
|
||||||
}
|
|
||||||
|
|
||||||
// apply second half of RPO round
|
|
||||||
Self::apply_mds(state);
|
|
||||||
if !Self::optimized_add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
|
|
||||||
Self::add_constants(state, &ARK2[round]);
|
|
||||||
Self::apply_inv_sbox(state);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// HELPER FUNCTIONS
|
|
||||||
// --------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
|
||||||
fn optimized_add_constants_and_apply_sbox(
|
|
||||||
state: &mut [Felt; STATE_WIDTH],
|
|
||||||
ark: &[Felt; STATE_WIDTH],
|
|
||||||
) -> bool {
|
|
||||||
unsafe {
|
|
||||||
add_constants_and_apply_sbox(state.as_mut_ptr() as *mut u64, ark.as_ptr() as *const u64)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
#[cfg(not(all(target_feature = "sve", feature = "sve")))]
|
|
||||||
fn optimized_add_constants_and_apply_sbox(
|
|
||||||
_state: &mut [Felt; STATE_WIDTH],
|
|
||||||
_ark: &[Felt; STATE_WIDTH],
|
|
||||||
) -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
|
||||||
fn optimized_add_constants_and_apply_inv_sbox(
|
|
||||||
state: &mut [Felt; STATE_WIDTH],
|
|
||||||
ark: &[Felt; STATE_WIDTH],
|
|
||||||
) -> bool {
|
|
||||||
unsafe {
|
|
||||||
add_constants_and_apply_inv_sbox(
|
|
||||||
state.as_mut_ptr() as *mut u64,
|
|
||||||
ark.as_ptr() as *const u64,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
#[cfg(not(all(target_feature = "sve", feature = "sve")))]
|
|
||||||
fn optimized_add_constants_and_apply_inv_sbox(
|
|
||||||
_state: &mut [Felt; STATE_WIDTH],
|
|
||||||
_ark: &[Felt; STATE_WIDTH],
|
|
||||||
) -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
fn apply_mds(state: &mut [Felt; STATE_WIDTH]) {
|
|
||||||
let mut result = [ZERO; STATE_WIDTH];
|
|
||||||
|
|
||||||
// Using the linearity of the operations we can split the state into a low||high decomposition
|
|
||||||
// and operate on each with no overflow and then combine/reduce the result to a field element.
|
|
||||||
// The no overflow is guaranteed by the fact that the MDS matrix is a small powers of two in
|
|
||||||
// frequency domain.
|
|
||||||
let mut state_l = [0u64; STATE_WIDTH];
|
|
||||||
let mut state_h = [0u64; STATE_WIDTH];
|
|
||||||
|
|
||||||
for r in 0..STATE_WIDTH {
|
|
||||||
let s = state[r].inner();
|
|
||||||
state_h[r] = s >> 32;
|
|
||||||
state_l[r] = (s as u32) as u64;
|
|
||||||
}
|
|
||||||
|
|
||||||
let state_h = mds_multiply_freq(state_h);
|
|
||||||
let state_l = mds_multiply_freq(state_l);
|
|
||||||
|
|
||||||
for r in 0..STATE_WIDTH {
|
|
||||||
let s = state_l[r] as u128 + ((state_h[r] as u128) << 32);
|
|
||||||
let s_hi = (s >> 64) as u64;
|
|
||||||
let s_lo = s as u64;
|
|
||||||
let z = (s_hi << 32) - s_hi;
|
|
||||||
let (res, over) = s_lo.overflowing_add(z);
|
|
||||||
|
|
||||||
result[r] = Felt::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64));
|
|
||||||
}
|
|
||||||
*state = result;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) {
|
|
||||||
state.iter_mut().zip(ark).for_each(|(s, &k)| *s += k);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
fn apply_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
|
||||||
state[0] = state[0].exp7();
|
|
||||||
state[1] = state[1].exp7();
|
|
||||||
state[2] = state[2].exp7();
|
|
||||||
state[3] = state[3].exp7();
|
|
||||||
state[4] = state[4].exp7();
|
|
||||||
state[5] = state[5].exp7();
|
|
||||||
state[6] = state[6].exp7();
|
|
||||||
state[7] = state[7].exp7();
|
|
||||||
state[8] = state[8].exp7();
|
|
||||||
state[9] = state[9].exp7();
|
|
||||||
state[10] = state[10].exp7();
|
|
||||||
state[11] = state[11].exp7();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
|
||||||
// compute base^10540996611094048183 using 72 multiplications per array element
|
|
||||||
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
|
|
||||||
|
|
||||||
// compute base^10
|
|
||||||
let mut t1 = *state;
|
|
||||||
t1.iter_mut().for_each(|t| *t = t.square());
|
|
||||||
|
|
||||||
// compute base^100
|
|
||||||
let mut t2 = t1;
|
|
||||||
t2.iter_mut().for_each(|t| *t = t.square());
|
|
||||||
|
|
||||||
// compute base^100100
|
|
||||||
let t3 = Self::exp_acc::<Felt, STATE_WIDTH, 3>(t2, t2);
|
|
||||||
|
|
||||||
// compute base^100100100100
|
|
||||||
let t4 = Self::exp_acc::<Felt, STATE_WIDTH, 6>(t3, t3);
|
|
||||||
|
|
||||||
// compute base^100100100100100100100100
|
|
||||||
let t5 = Self::exp_acc::<Felt, STATE_WIDTH, 12>(t4, t4);
|
|
||||||
|
|
||||||
// compute base^100100100100100100100100100100
|
|
||||||
let t6 = Self::exp_acc::<Felt, STATE_WIDTH, 6>(t5, t3);
|
|
||||||
|
|
||||||
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
|
||||||
let t7 = Self::exp_acc::<Felt, STATE_WIDTH, 31>(t6, t6);
|
|
||||||
|
|
||||||
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
|
||||||
for (i, s) in state.iter_mut().enumerate() {
|
|
||||||
let a = (t7[i].square() * t6[i]).square().square();
|
|
||||||
let b = t1[i] * t2[i] * *s;
|
|
||||||
*s = a * b;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline(always)]
|
|
||||||
fn exp_acc<B: StarkField, const N: usize, const M: usize>(
|
|
||||||
base: [B; N],
|
|
||||||
tail: [B; N],
|
|
||||||
) -> [B; N] {
|
|
||||||
let mut result = base;
|
|
||||||
for _ in 0..M {
|
|
||||||
result.iter_mut().for_each(|r| *r = r.square());
|
|
||||||
}
|
|
||||||
result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t);
|
|
||||||
result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MDS
|
|
||||||
// ================================================================================================
|
|
||||||
/// RPO MDS matrix
|
|
||||||
const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = [
|
|
||||||
[
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(23),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(26),
|
|
||||||
Felt::new(13),
|
|
||||||
Felt::new(10),
|
|
||||||
Felt::new(9),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(6),
|
|
||||||
Felt::new(22),
|
|
||||||
Felt::new(21),
|
|
||||||
Felt::new(8),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(23),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(26),
|
|
||||||
Felt::new(13),
|
|
||||||
Felt::new(10),
|
|
||||||
Felt::new(9),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(6),
|
|
||||||
Felt::new(22),
|
|
||||||
Felt::new(21),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(21),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(23),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(26),
|
|
||||||
Felt::new(13),
|
|
||||||
Felt::new(10),
|
|
||||||
Felt::new(9),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(6),
|
|
||||||
Felt::new(22),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(22),
|
|
||||||
Felt::new(21),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(23),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(26),
|
|
||||||
Felt::new(13),
|
|
||||||
Felt::new(10),
|
|
||||||
Felt::new(9),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(6),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(6),
|
|
||||||
Felt::new(22),
|
|
||||||
Felt::new(21),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(23),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(26),
|
|
||||||
Felt::new(13),
|
|
||||||
Felt::new(10),
|
|
||||||
Felt::new(9),
|
|
||||||
Felt::new(7),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(6),
|
|
||||||
Felt::new(22),
|
|
||||||
Felt::new(21),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(23),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(26),
|
|
||||||
Felt::new(13),
|
|
||||||
Felt::new(10),
|
|
||||||
Felt::new(9),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(9),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(6),
|
|
||||||
Felt::new(22),
|
|
||||||
Felt::new(21),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(23),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(26),
|
|
||||||
Felt::new(13),
|
|
||||||
Felt::new(10),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(10),
|
|
||||||
Felt::new(9),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(6),
|
|
||||||
Felt::new(22),
|
|
||||||
Felt::new(21),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(23),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(26),
|
|
||||||
Felt::new(13),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(13),
|
|
||||||
Felt::new(10),
|
|
||||||
Felt::new(9),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(6),
|
|
||||||
Felt::new(22),
|
|
||||||
Felt::new(21),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(23),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(26),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(26),
|
|
||||||
Felt::new(13),
|
|
||||||
Felt::new(10),
|
|
||||||
Felt::new(9),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(6),
|
|
||||||
Felt::new(22),
|
|
||||||
Felt::new(21),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(23),
|
|
||||||
Felt::new(8),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(26),
|
|
||||||
Felt::new(13),
|
|
||||||
Felt::new(10),
|
|
||||||
Felt::new(9),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(6),
|
|
||||||
Felt::new(22),
|
|
||||||
Felt::new(21),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(23),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(23),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(26),
|
|
||||||
Felt::new(13),
|
|
||||||
Felt::new(10),
|
|
||||||
Felt::new(9),
|
|
||||||
Felt::new(7),
|
|
||||||
Felt::new(6),
|
|
||||||
Felt::new(22),
|
|
||||||
Felt::new(21),
|
|
||||||
Felt::new(8),
|
|
||||||
Felt::new(7),
|
|
||||||
],
|
|
||||||
];
|
|
||||||
|
|
||||||
// ROUND CONSTANTS
|
|
||||||
// ================================================================================================
|
|
||||||
|
|
||||||
/// Rescue round constants;
|
|
||||||
/// computed as in [specifications](https://github.com/ASDiscreteMathematics/rpo)
|
|
||||||
///
|
|
||||||
/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the
|
|
||||||
/// first half of RPO round, and ARK2 contains constants for the second half of RPO round.
|
|
||||||
const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
|
||||||
[
|
|
||||||
Felt::new(5789762306288267392),
|
|
||||||
Felt::new(6522564764413701783),
|
|
||||||
Felt::new(17809893479458208203),
|
|
||||||
Felt::new(107145243989736508),
|
|
||||||
Felt::new(6388978042437517382),
|
|
||||||
Felt::new(15844067734406016715),
|
|
||||||
Felt::new(9975000513555218239),
|
|
||||||
Felt::new(3344984123768313364),
|
|
||||||
Felt::new(9959189626657347191),
|
|
||||||
Felt::new(12960773468763563665),
|
|
||||||
Felt::new(9602914297752488475),
|
|
||||||
Felt::new(16657542370200465908),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(12987190162843096997),
|
|
||||||
Felt::new(653957632802705281),
|
|
||||||
Felt::new(4441654670647621225),
|
|
||||||
Felt::new(4038207883745915761),
|
|
||||||
Felt::new(5613464648874830118),
|
|
||||||
Felt::new(13222989726778338773),
|
|
||||||
Felt::new(3037761201230264149),
|
|
||||||
Felt::new(16683759727265180203),
|
|
||||||
Felt::new(8337364536491240715),
|
|
||||||
Felt::new(3227397518293416448),
|
|
||||||
Felt::new(8110510111539674682),
|
|
||||||
Felt::new(2872078294163232137),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(18072785500942327487),
|
|
||||||
Felt::new(6200974112677013481),
|
|
||||||
Felt::new(17682092219085884187),
|
|
||||||
Felt::new(10599526828986756440),
|
|
||||||
Felt::new(975003873302957338),
|
|
||||||
Felt::new(8264241093196931281),
|
|
||||||
Felt::new(10065763900435475170),
|
|
||||||
Felt::new(2181131744534710197),
|
|
||||||
Felt::new(6317303992309418647),
|
|
||||||
Felt::new(1401440938888741532),
|
|
||||||
Felt::new(8884468225181997494),
|
|
||||||
Felt::new(13066900325715521532),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(5674685213610121970),
|
|
||||||
Felt::new(5759084860419474071),
|
|
||||||
Felt::new(13943282657648897737),
|
|
||||||
Felt::new(1352748651966375394),
|
|
||||||
Felt::new(17110913224029905221),
|
|
||||||
Felt::new(1003883795902368422),
|
|
||||||
Felt::new(4141870621881018291),
|
|
||||||
Felt::new(8121410972417424656),
|
|
||||||
Felt::new(14300518605864919529),
|
|
||||||
Felt::new(13712227150607670181),
|
|
||||||
Felt::new(17021852944633065291),
|
|
||||||
Felt::new(6252096473787587650),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(4887609836208846458),
|
|
||||||
Felt::new(3027115137917284492),
|
|
||||||
Felt::new(9595098600469470675),
|
|
||||||
Felt::new(10528569829048484079),
|
|
||||||
Felt::new(7864689113198939815),
|
|
||||||
Felt::new(17533723827845969040),
|
|
||||||
Felt::new(5781638039037710951),
|
|
||||||
Felt::new(17024078752430719006),
|
|
||||||
Felt::new(109659393484013511),
|
|
||||||
Felt::new(7158933660534805869),
|
|
||||||
Felt::new(2955076958026921730),
|
|
||||||
Felt::new(7433723648458773977),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(16308865189192447297),
|
|
||||||
Felt::new(11977192855656444890),
|
|
||||||
Felt::new(12532242556065780287),
|
|
||||||
Felt::new(14594890931430968898),
|
|
||||||
Felt::new(7291784239689209784),
|
|
||||||
Felt::new(5514718540551361949),
|
|
||||||
Felt::new(10025733853830934803),
|
|
||||||
Felt::new(7293794580341021693),
|
|
||||||
Felt::new(6728552937464861756),
|
|
||||||
Felt::new(6332385040983343262),
|
|
||||||
Felt::new(13277683694236792804),
|
|
||||||
Felt::new(2600778905124452676),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(7123075680859040534),
|
|
||||||
Felt::new(1034205548717903090),
|
|
||||||
Felt::new(7717824418247931797),
|
|
||||||
Felt::new(3019070937878604058),
|
|
||||||
Felt::new(11403792746066867460),
|
|
||||||
Felt::new(10280580802233112374),
|
|
||||||
Felt::new(337153209462421218),
|
|
||||||
Felt::new(13333398568519923717),
|
|
||||||
Felt::new(3596153696935337464),
|
|
||||||
Felt::new(8104208463525993784),
|
|
||||||
Felt::new(14345062289456085693),
|
|
||||||
Felt::new(17036731477169661256),
|
|
||||||
],
|
|
||||||
];
|
|
||||||
|
|
||||||
const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
|
||||||
[
|
|
||||||
Felt::new(6077062762357204287),
|
|
||||||
Felt::new(15277620170502011191),
|
|
||||||
Felt::new(5358738125714196705),
|
|
||||||
Felt::new(14233283787297595718),
|
|
||||||
Felt::new(13792579614346651365),
|
|
||||||
Felt::new(11614812331536767105),
|
|
||||||
Felt::new(14871063686742261166),
|
|
||||||
Felt::new(10148237148793043499),
|
|
||||||
Felt::new(4457428952329675767),
|
|
||||||
Felt::new(15590786458219172475),
|
|
||||||
Felt::new(10063319113072092615),
|
|
||||||
Felt::new(14200078843431360086),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(6202948458916099932),
|
|
||||||
Felt::new(17690140365333231091),
|
|
||||||
Felt::new(3595001575307484651),
|
|
||||||
Felt::new(373995945117666487),
|
|
||||||
Felt::new(1235734395091296013),
|
|
||||||
Felt::new(14172757457833931602),
|
|
||||||
Felt::new(707573103686350224),
|
|
||||||
Felt::new(15453217512188187135),
|
|
||||||
Felt::new(219777875004506018),
|
|
||||||
Felt::new(17876696346199469008),
|
|
||||||
Felt::new(17731621626449383378),
|
|
||||||
Felt::new(2897136237748376248),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(8023374565629191455),
|
|
||||||
Felt::new(15013690343205953430),
|
|
||||||
Felt::new(4485500052507912973),
|
|
||||||
Felt::new(12489737547229155153),
|
|
||||||
Felt::new(9500452585969030576),
|
|
||||||
Felt::new(2054001340201038870),
|
|
||||||
Felt::new(12420704059284934186),
|
|
||||||
Felt::new(355990932618543755),
|
|
||||||
Felt::new(9071225051243523860),
|
|
||||||
Felt::new(12766199826003448536),
|
|
||||||
Felt::new(9045979173463556963),
|
|
||||||
Felt::new(12934431667190679898),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(18389244934624494276),
|
|
||||||
Felt::new(16731736864863925227),
|
|
||||||
Felt::new(4440209734760478192),
|
|
||||||
Felt::new(17208448209698888938),
|
|
||||||
Felt::new(8739495587021565984),
|
|
||||||
Felt::new(17000774922218161967),
|
|
||||||
Felt::new(13533282547195532087),
|
|
||||||
Felt::new(525402848358706231),
|
|
||||||
Felt::new(16987541523062161972),
|
|
||||||
Felt::new(5466806524462797102),
|
|
||||||
Felt::new(14512769585918244983),
|
|
||||||
Felt::new(10973956031244051118),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(6982293561042362913),
|
|
||||||
Felt::new(14065426295947720331),
|
|
||||||
Felt::new(16451845770444974180),
|
|
||||||
Felt::new(7139138592091306727),
|
|
||||||
Felt::new(9012006439959783127),
|
|
||||||
Felt::new(14619614108529063361),
|
|
||||||
Felt::new(1394813199588124371),
|
|
||||||
Felt::new(4635111139507788575),
|
|
||||||
Felt::new(16217473952264203365),
|
|
||||||
Felt::new(10782018226466330683),
|
|
||||||
Felt::new(6844229992533662050),
|
|
||||||
Felt::new(7446486531695178711),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(3736792340494631448),
|
|
||||||
Felt::new(577852220195055341),
|
|
||||||
Felt::new(6689998335515779805),
|
|
||||||
Felt::new(13886063479078013492),
|
|
||||||
Felt::new(14358505101923202168),
|
|
||||||
Felt::new(7744142531772274164),
|
|
||||||
Felt::new(16135070735728404443),
|
|
||||||
Felt::new(12290902521256031137),
|
|
||||||
Felt::new(12059913662657709804),
|
|
||||||
Felt::new(16456018495793751911),
|
|
||||||
Felt::new(4571485474751953524),
|
|
||||||
Felt::new(17200392109565783176),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Felt::new(17130398059294018733),
|
|
||||||
Felt::new(519782857322261988),
|
|
||||||
Felt::new(9625384390925085478),
|
|
||||||
Felt::new(1664893052631119222),
|
|
||||||
Felt::new(7629576092524553570),
|
|
||||||
Felt::new(3485239601103661425),
|
|
||||||
Felt::new(9755891797164033838),
|
|
||||||
Felt::new(15218148195153269027),
|
|
||||||
Felt::new(16460604813734957368),
|
|
||||||
Felt::new(9643968136937729763),
|
|
||||||
Felt::new(3611348709641382851),
|
|
||||||
Felt::new(18256379591337759196),
|
|
||||||
],
|
|
||||||
];
|
|
||||||
10
src/lib.rs
10
src/lib.rs
@@ -1,7 +1,7 @@
|
|||||||
#![cfg_attr(not(feature = "std"), no_std)]
|
#![cfg_attr(not(feature = "std"), no_std)]
|
||||||
|
|
||||||
#[cfg(not(feature = "std"))]
|
//#[cfg(not(feature = "std"))]
|
||||||
#[cfg_attr(test, macro_use)]
|
//#[cfg_attr(test, macro_use)]
|
||||||
extern crate alloc;
|
extern crate alloc;
|
||||||
|
|
||||||
pub mod dsa;
|
pub mod dsa;
|
||||||
@@ -9,11 +9,15 @@ pub mod hash;
|
|||||||
pub mod merkle;
|
pub mod merkle;
|
||||||
pub mod rand;
|
pub mod rand;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
pub mod gkr;
|
||||||
|
|
||||||
// RE-EXPORTS
|
// RE-EXPORTS
|
||||||
// ================================================================================================
|
// ================================================================================================
|
||||||
|
|
||||||
pub use winter_math::{fields::f64::BaseElement as Felt, FieldElement, StarkField};
|
pub use winter_math::{
|
||||||
|
fields::{f64::BaseElement as Felt, CubeExtension, QuadExtension},
|
||||||
|
FieldElement, StarkField,
|
||||||
|
};
|
||||||
|
|
||||||
// TYPE ALIASES
|
// TYPE ALIASES
|
||||||
// ================================================================================================
|
// ================================================================================================
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use miden_crypto::{
|
use miden_crypto::{
|
||||||
hash::rpo::RpoDigest,
|
hash::rpo::{Rpo256, RpoDigest},
|
||||||
merkle::MerkleError,
|
merkle::{MerkleError, TieredSmt},
|
||||||
Felt, Word, ONE,
|
Felt, Word, ONE,
|
||||||
{hash::rpo::Rpo256, merkle::TieredSmt},
|
|
||||||
};
|
};
|
||||||
use rand_utils::rand_value;
|
use rand_utils::rand_value;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|||||||
@@ -10,12 +10,19 @@ pub struct EmptySubtreeRoots;
|
|||||||
impl EmptySubtreeRoots {
|
impl EmptySubtreeRoots {
|
||||||
/// Returns a static slice with roots of empty subtrees of a Merkle tree starting at the
|
/// Returns a static slice with roots of empty subtrees of a Merkle tree starting at the
|
||||||
/// specified depth.
|
/// specified depth.
|
||||||
pub const fn empty_hashes(depth: u8) -> &'static [RpoDigest] {
|
pub const fn empty_hashes(tree_depth: u8) -> &'static [RpoDigest] {
|
||||||
let ptr = &EMPTY_SUBTREES[255 - depth as usize] as *const RpoDigest;
|
let ptr = &EMPTY_SUBTREES[255 - tree_depth as usize] as *const RpoDigest;
|
||||||
// Safety: this is a static/constant array, so it will never be outlived. If we attempt to
|
// Safety: this is a static/constant array, so it will never be outlived. If we attempt to
|
||||||
// use regular slices, this wouldn't be a `const` function, meaning we won't be able to use
|
// use regular slices, this wouldn't be a `const` function, meaning we won't be able to use
|
||||||
// the returned value for static/constant definitions.
|
// the returned value for static/constant definitions.
|
||||||
unsafe { slice::from_raw_parts(ptr, depth as usize + 1) }
|
unsafe { slice::from_raw_parts(ptr, tree_depth as usize + 1) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the node's digest for a sub-tree with all its leaves set to the empty word.
|
||||||
|
pub const fn entry(tree_depth: u8, node_depth: u8) -> &'static RpoDigest {
|
||||||
|
assert!(node_depth <= tree_depth);
|
||||||
|
let pos = 255 - tree_depth + node_depth;
|
||||||
|
&EMPTY_SUBTREES[pos as usize]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1583,3 +1590,16 @@ fn all_depths_opens_to_zero() {
|
|||||||
.for_each(|(x, computed)| assert_eq!(x, computed));
|
.for_each(|(x, computed)| assert_eq!(x, computed));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_entry() {
|
||||||
|
// check the leaf is always the empty work
|
||||||
|
for depth in 0..255 {
|
||||||
|
assert_eq!(EmptySubtreeRoots::entry(depth, depth), &RpoDigest::new(EMPTY_WORD));
|
||||||
|
}
|
||||||
|
|
||||||
|
// check the root matches the first element of empty_hashes
|
||||||
|
for depth in 0..255 {
|
||||||
|
assert_eq!(EmptySubtreeRoots::entry(depth, 0), &EmptySubtreeRoots::empty_hashes(depth)[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,8 +13,9 @@ pub enum MerkleError {
|
|||||||
DuplicateValuesForKey(RpoDigest),
|
DuplicateValuesForKey(RpoDigest),
|
||||||
InvalidIndex { depth: u8, value: u64 },
|
InvalidIndex { depth: u8, value: u64 },
|
||||||
InvalidDepth { expected: u8, provided: u8 },
|
InvalidDepth { expected: u8, provided: u8 },
|
||||||
|
InvalidSubtreeDepth { subtree_depth: u8, tree_depth: u8 },
|
||||||
InvalidPath(MerklePath),
|
InvalidPath(MerklePath),
|
||||||
InvalidNumEntries(usize, usize),
|
InvalidNumEntries(usize),
|
||||||
NodeNotInSet(NodeIndex),
|
NodeNotInSet(NodeIndex),
|
||||||
NodeNotInStore(RpoDigest, NodeIndex),
|
NodeNotInStore(RpoDigest, NodeIndex),
|
||||||
NumLeavesNotPowerOfTwo(usize),
|
NumLeavesNotPowerOfTwo(usize),
|
||||||
@@ -30,18 +31,21 @@ impl fmt::Display for MerkleError {
|
|||||||
DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"),
|
DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"),
|
||||||
DuplicateValuesForIndex(key) => write!(f, "multiple values provided for key {key}"),
|
DuplicateValuesForIndex(key) => write!(f, "multiple values provided for key {key}"),
|
||||||
DuplicateValuesForKey(key) => write!(f, "multiple values provided for key {key}"),
|
DuplicateValuesForKey(key) => write!(f, "multiple values provided for key {key}"),
|
||||||
InvalidIndex{ depth, value} => write!(
|
InvalidIndex { depth, value } => {
|
||||||
f,
|
write!(f, "the index value {value} is not valid for the depth {depth}")
|
||||||
"the index value {value} is not valid for the depth {depth}"
|
}
|
||||||
),
|
InvalidDepth { expected, provided } => {
|
||||||
InvalidDepth { expected, provided } => write!(
|
write!(f, "the provided depth {provided} is not valid for {expected}")
|
||||||
f,
|
}
|
||||||
"the provided depth {provided} is not valid for {expected}"
|
InvalidSubtreeDepth { subtree_depth, tree_depth } => {
|
||||||
),
|
write!(f, "tried inserting a subtree of depth {subtree_depth} into a tree of depth {tree_depth}")
|
||||||
|
}
|
||||||
InvalidPath(_path) => write!(f, "the provided path is not valid"),
|
InvalidPath(_path) => write!(f, "the provided path is not valid"),
|
||||||
InvalidNumEntries(max, provided) => write!(f, "the provided number of entries is {provided}, but the maximum for the given depth is {max}"),
|
InvalidNumEntries(max) => write!(f, "number of entries exceeded the maximum: {max}"),
|
||||||
NodeNotInSet(index) => write!(f, "the node with index ({index}) is not in the set"),
|
NodeNotInSet(index) => write!(f, "the node with index ({index}) is not in the set"),
|
||||||
NodeNotInStore(hash, index) => write!(f, "the node {hash:?} with index ({index}) is not in the store"),
|
NodeNotInStore(hash, index) => {
|
||||||
|
write!(f, "the node {hash:?} with index ({index}) is not in the store")
|
||||||
|
}
|
||||||
NumLeavesNotPowerOfTwo(leaves) => {
|
NumLeavesNotPowerOfTwo(leaves) => {
|
||||||
write!(f, "the leaves count {leaves} is not a power of 2")
|
write!(f, "the leaves count {leaves} is not a power of 2")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -187,13 +187,20 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_node_index_value_too_high() {
|
fn test_node_index_value_too_high() {
|
||||||
assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
|
assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
|
||||||
match NodeIndex::new(0, 1) {
|
let err = NodeIndex::new(0, 1).unwrap_err();
|
||||||
Err(MerkleError::InvalidIndex { depth, value }) => {
|
assert_eq!(err, MerkleError::InvalidIndex { depth: 0, value: 1 });
|
||||||
assert_eq!(depth, 0);
|
|
||||||
assert_eq!(value, 1);
|
assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
|
||||||
}
|
let err = NodeIndex::new(1, 2).unwrap_err();
|
||||||
_ => unreachable!(),
|
assert_eq!(err, MerkleError::InvalidIndex { depth: 1, value: 2 });
|
||||||
}
|
|
||||||
|
assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
|
||||||
|
let err = NodeIndex::new(2, 4).unwrap_err();
|
||||||
|
assert_eq!(err, MerkleError::InvalidIndex { depth: 2, value: 4 });
|
||||||
|
|
||||||
|
assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
|
||||||
|
let err = NodeIndex::new(3, 8).unwrap_err();
|
||||||
|
assert_eq!(err, MerkleError::InvalidIndex { depth: 3, value: 8 });
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -9,11 +9,12 @@
|
|||||||
//! least number of leaves. The structure preserves the invariant that each tree has different
|
//! least number of leaves. The structure preserves the invariant that each tree has different
|
||||||
//! depths, i.e. as part of adding adding a new element to the forest the trees with same depth are
|
//! depths, i.e. as part of adding adding a new element to the forest the trees with same depth are
|
||||||
//! merged, creating a new tree with depth d+1, this process is continued until the property is
|
//! merged, creating a new tree with depth d+1, this process is continued until the property is
|
||||||
//! restabilished.
|
//! reestablished.
|
||||||
use super::{
|
use super::{
|
||||||
super::{InnerNodeInfo, MerklePath, RpoDigest, Vec},
|
super::{InnerNodeInfo, MerklePath, Vec},
|
||||||
bit::TrueBitPositionIterator,
|
bit::TrueBitPositionIterator,
|
||||||
leaf_to_corresponding_tree, nodes_in_forest, MmrDelta, MmrError, MmrPeaks, MmrProof, Rpo256,
|
leaf_to_corresponding_tree, nodes_in_forest, MmrDelta, MmrError, MmrPeaks, MmrProof, Rpo256,
|
||||||
|
RpoDigest,
|
||||||
};
|
};
|
||||||
|
|
||||||
// MMR
|
// MMR
|
||||||
@@ -76,13 +77,13 @@ impl Mmr {
|
|||||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||||
/// has position 0, the second position 1, and so on.
|
/// has position 0, the second position 1, and so on.
|
||||||
pub fn open(&self, pos: usize) -> Result<MmrProof, MmrError> {
|
pub fn open(&self, pos: usize, target_forest: usize) -> Result<MmrProof, MmrError> {
|
||||||
// find the target tree responsible for the MMR position
|
// find the target tree responsible for the MMR position
|
||||||
let tree_bit =
|
let tree_bit =
|
||||||
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?;
|
leaf_to_corresponding_tree(pos, target_forest).ok_or(MmrError::InvalidPosition(pos))?;
|
||||||
|
|
||||||
// isolate the trees before the target
|
// isolate the trees before the target
|
||||||
let forest_before = self.forest & high_bitmask(tree_bit + 1);
|
let forest_before = target_forest & high_bitmask(tree_bit + 1);
|
||||||
let index_offset = nodes_in_forest(forest_before);
|
let index_offset = nodes_in_forest(forest_before);
|
||||||
|
|
||||||
// update the value position from global to the target tree
|
// update the value position from global to the target tree
|
||||||
@@ -92,7 +93,7 @@ impl Mmr {
|
|||||||
let (_, path) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);
|
let (_, path) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);
|
||||||
|
|
||||||
Ok(MmrProof {
|
Ok(MmrProof {
|
||||||
forest: self.forest,
|
forest: target_forest,
|
||||||
position: pos,
|
position: pos,
|
||||||
merkle_path: MerklePath::new(path),
|
merkle_path: MerklePath::new(path),
|
||||||
})
|
})
|
||||||
@@ -143,9 +144,13 @@ impl Mmr {
|
|||||||
self.forest += 1;
|
self.forest += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an accumulator representing the current state of the MMR.
|
/// Returns an peaks of the MMR for the version specified by `forest`.
|
||||||
pub fn accumulator(&self) -> MmrPeaks {
|
pub fn peaks(&self, forest: usize) -> Result<MmrPeaks, MmrError> {
|
||||||
let peaks: Vec<RpoDigest> = TrueBitPositionIterator::new(self.forest)
|
if forest > self.forest {
|
||||||
|
return Err(MmrError::InvalidPeaks);
|
||||||
|
}
|
||||||
|
|
||||||
|
let peaks: Vec<RpoDigest> = TrueBitPositionIterator::new(forest)
|
||||||
.rev()
|
.rev()
|
||||||
.map(|bit| nodes_in_forest(1 << bit))
|
.map(|bit| nodes_in_forest(1 << bit))
|
||||||
.scan(0, |offset, el| {
|
.scan(0, |offset, el| {
|
||||||
@@ -156,39 +161,41 @@ impl Mmr {
|
|||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Safety: the invariant is maintained by the [Mmr]
|
// Safety: the invariant is maintained by the [Mmr]
|
||||||
MmrPeaks::new(self.forest, peaks).unwrap()
|
let peaks = MmrPeaks::new(forest, peaks).unwrap();
|
||||||
|
|
||||||
|
Ok(peaks)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compute the required update to `original_forest`.
|
/// Compute the required update to `original_forest`.
|
||||||
///
|
///
|
||||||
/// The result is a packed sequence of the authentication elements required to update the trees
|
/// The result is a packed sequence of the authentication elements required to update the trees
|
||||||
/// that have been merged together, followed by the new peaks of the [Mmr].
|
/// that have been merged together, followed by the new peaks of the [Mmr].
|
||||||
pub fn get_delta(&self, original_forest: usize) -> Result<MmrDelta, MmrError> {
|
pub fn get_delta(&self, from_forest: usize, to_forest: usize) -> Result<MmrDelta, MmrError> {
|
||||||
if original_forest > self.forest {
|
if to_forest > self.forest || from_forest > to_forest {
|
||||||
return Err(MmrError::InvalidPeaks);
|
return Err(MmrError::InvalidPeaks);
|
||||||
}
|
}
|
||||||
|
|
||||||
if original_forest == self.forest {
|
if from_forest == to_forest {
|
||||||
return Ok(MmrDelta { forest: self.forest, data: Vec::new() });
|
return Ok(MmrDelta { forest: to_forest, data: Vec::new() });
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
|
|
||||||
// Find the largest tree in this [Mmr] which is new to `original_forest`.
|
// Find the largest tree in this [Mmr] which is new to `from_forest`.
|
||||||
let candidate_trees = self.forest ^ original_forest;
|
let candidate_trees = to_forest ^ from_forest;
|
||||||
let mut new_high = 1 << candidate_trees.ilog2();
|
let mut new_high = 1 << candidate_trees.ilog2();
|
||||||
|
|
||||||
// Collect authentication nodes used for tree merges
|
// Collect authentication nodes used for tree merges
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
// Find the trees from `original_forest` that have been merged into `new_high`.
|
// Find the trees from `from_forest` that have been merged into `new_high`.
|
||||||
let mut merges = original_forest & (new_high - 1);
|
let mut merges = from_forest & (new_high - 1);
|
||||||
|
|
||||||
// Find the peaks that are common to `original_forest` and this [Mmr]
|
// Find the peaks that are common to `from_forest` and this [Mmr]
|
||||||
let common_trees = original_forest ^ merges;
|
let common_trees = from_forest ^ merges;
|
||||||
|
|
||||||
if merges != 0 {
|
if merges != 0 {
|
||||||
// Skip the smallest trees unknown to `original_forest`.
|
// Skip the smallest trees unknown to `from_forest`.
|
||||||
let mut target = 1 << merges.trailing_zeros();
|
let mut target = 1 << merges.trailing_zeros();
|
||||||
|
|
||||||
// Collect siblings required to computed the merged tree's peak
|
// Collect siblings required to computed the merged tree's peak
|
||||||
@@ -213,15 +220,15 @@ impl Mmr {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// The new high tree may not be the result of any merges, if it is smaller than all the
|
// The new high tree may not be the result of any merges, if it is smaller than all the
|
||||||
// trees of `original_forest`.
|
// trees of `from_forest`.
|
||||||
new_high = 0;
|
new_high = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collect the new [Mmr] peaks
|
// Collect the new [Mmr] peaks
|
||||||
// ----------------------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------------------
|
||||||
|
|
||||||
let mut new_peaks = self.forest ^ common_trees ^ new_high;
|
let mut new_peaks = to_forest ^ common_trees ^ new_high;
|
||||||
let old_peaks = self.forest ^ new_peaks;
|
let old_peaks = to_forest ^ new_peaks;
|
||||||
let mut offset = nodes_in_forest(old_peaks);
|
let mut offset = nodes_in_forest(old_peaks);
|
||||||
while new_peaks != 0 {
|
while new_peaks != 0 {
|
||||||
let target = 1 << new_peaks.ilog2();
|
let target = 1 << new_peaks.ilog2();
|
||||||
@@ -230,7 +237,7 @@ impl Mmr {
|
|||||||
new_peaks ^= target;
|
new_peaks ^= target;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(MmrDelta { forest: self.forest, data: result })
|
Ok(MmrDelta { forest: to_forest, data: result })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An iterator over inner nodes in the MMR. The order of iteration is unspecified.
|
/// An iterator over inner nodes in the MMR. The order of iteration is unspecified.
|
||||||
|
|||||||
@@ -6,6 +6,9 @@
|
|||||||
//! leaves count.
|
//! leaves count.
|
||||||
use core::num::NonZeroUsize;
|
use core::num::NonZeroUsize;
|
||||||
|
|
||||||
|
// IN-ORDER INDEX
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
/// Index of nodes in a perfectly balanced binary tree based on an in-order tree walk.
|
/// Index of nodes in a perfectly balanced binary tree based on an in-order tree walk.
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||||
pub struct InOrderIndex {
|
pub struct InOrderIndex {
|
||||||
@@ -13,15 +16,17 @@ pub struct InOrderIndex {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl InOrderIndex {
|
impl InOrderIndex {
|
||||||
/// Constructor for a new [InOrderIndex].
|
// CONSTRUCTORS
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Returns a new [InOrderIndex] instantiated from the provided value.
|
||||||
pub fn new(idx: NonZeroUsize) -> InOrderIndex {
|
pub fn new(idx: NonZeroUsize) -> InOrderIndex {
|
||||||
InOrderIndex { idx: idx.get() }
|
InOrderIndex { idx: idx.get() }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Constructs an index from a leaf position.
|
/// Return a new [InOrderIndex] instantiated from the specified leaf position.
|
||||||
///
|
|
||||||
/// Panics:
|
|
||||||
///
|
///
|
||||||
|
/// # Panics:
|
||||||
/// If `leaf` is higher than or equal to `usize::MAX / 2`.
|
/// If `leaf` is higher than or equal to `usize::MAX / 2`.
|
||||||
pub fn from_leaf_pos(leaf: usize) -> InOrderIndex {
|
pub fn from_leaf_pos(leaf: usize) -> InOrderIndex {
|
||||||
// Convert the position from 0-indexed to 1-indexed, since the bit manipulation in this
|
// Convert the position from 0-indexed to 1-indexed, since the bit manipulation in this
|
||||||
@@ -30,6 +35,9 @@ impl InOrderIndex {
|
|||||||
InOrderIndex { idx: pos * 2 - 1 }
|
InOrderIndex { idx: pos * 2 - 1 }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PUBLIC ACCESSORS
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
/// True if the index is pointing at a leaf.
|
/// True if the index is pointing at a leaf.
|
||||||
///
|
///
|
||||||
/// Every odd number represents a leaf.
|
/// Every odd number represents a leaf.
|
||||||
@@ -37,6 +45,11 @@ impl InOrderIndex {
|
|||||||
self.idx & 1 == 1
|
self.idx & 1 == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns true if this note is a left child of its parent.
|
||||||
|
pub fn is_left_child(&self) -> bool {
|
||||||
|
self.parent().left_child() == *self
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the level of the index.
|
/// Returns the level of the index.
|
||||||
///
|
///
|
||||||
/// Starts at level zero for leaves and increases by one for each parent.
|
/// Starts at level zero for leaves and increases by one for each parent.
|
||||||
@@ -46,8 +59,7 @@ impl InOrderIndex {
|
|||||||
|
|
||||||
/// Returns the index of the left child.
|
/// Returns the index of the left child.
|
||||||
///
|
///
|
||||||
/// Panics:
|
/// # Panics:
|
||||||
///
|
|
||||||
/// If the index corresponds to a leaf.
|
/// If the index corresponds to a leaf.
|
||||||
pub fn left_child(&self) -> InOrderIndex {
|
pub fn left_child(&self) -> InOrderIndex {
|
||||||
// The left child is itself a parent, with an index that splits its left/right subtrees. To
|
// The left child is itself a parent, with an index that splits its left/right subtrees. To
|
||||||
@@ -59,8 +71,7 @@ impl InOrderIndex {
|
|||||||
|
|
||||||
/// Returns the index of the right child.
|
/// Returns the index of the right child.
|
||||||
///
|
///
|
||||||
/// Panics:
|
/// # Panics:
|
||||||
///
|
|
||||||
/// If the index corresponds to a leaf.
|
/// If the index corresponds to a leaf.
|
||||||
pub fn right_child(&self) -> InOrderIndex {
|
pub fn right_child(&self) -> InOrderIndex {
|
||||||
// To compute the index of the parent of the right subtree it is sufficient to add the size
|
// To compute the index of the parent of the right subtree it is sufficient to add the size
|
||||||
@@ -94,8 +105,25 @@ impl InOrderIndex {
|
|||||||
parent.right_child()
|
parent.right_child()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the inner value of this [InOrderIndex].
|
||||||
|
pub fn inner(&self) -> u64 {
|
||||||
|
self.idx as u64
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CONVERSIONS FROM IN-ORDER INDEX
|
||||||
|
// ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
impl From<InOrderIndex> for u64 {
|
||||||
|
fn from(index: InOrderIndex) -> Self {
|
||||||
|
index.idx as u64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TESTS
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use super::InOrderIndex;
|
use super::InOrderIndex;
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ mod proof;
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests;
|
mod tests;
|
||||||
|
|
||||||
use super::{Felt, Rpo256, Word};
|
use super::{Felt, Rpo256, RpoDigest, Word};
|
||||||
|
|
||||||
// REEXPORTS
|
// REEXPORTS
|
||||||
// ================================================================================================
|
// ================================================================================================
|
||||||
@@ -40,10 +40,10 @@ const fn leaf_to_corresponding_tree(pos: usize, forest: usize) -> Option<u32> {
|
|||||||
// - each bit in the forest is a unique tree and the bit position its power-of-two size
|
// - each bit in the forest is a unique tree and the bit position its power-of-two size
|
||||||
// - each tree owns a consecutive range of positions equal to its size from left-to-right
|
// - each tree owns a consecutive range of positions equal to its size from left-to-right
|
||||||
// - this means the first tree owns from `0` up to the `2^k_0` first positions, where `k_0`
|
// - this means the first tree owns from `0` up to the `2^k_0` first positions, where `k_0`
|
||||||
// is the highest true bit position, the second tree from `2^k_0 + 1` up to `2^k_1` where
|
// is the highest true bit position, the second tree from `2^k_0 + 1` up to `2^k_1` where
|
||||||
// `k_1` is the second highest bit, so on.
|
// `k_1` is the second highest bit, so on.
|
||||||
// - this means the highest bits work as a category marker, and the position is owned by
|
// - this means the highest bits work as a category marker, and the position is owned by
|
||||||
// the first tree which doesn't share a high bit with the position
|
// the first tree which doesn't share a high bit with the position
|
||||||
let before = forest & pos;
|
let before = forest & pos;
|
||||||
let after = forest ^ before;
|
let after = forest ^ before;
|
||||||
let tree = after.ilog2();
|
let tree = after.ilog2();
|
||||||
|
|||||||
@@ -1,57 +1,60 @@
|
|||||||
|
use super::{MmrDelta, MmrProof, Rpo256, RpoDigest};
|
||||||
use crate::{
|
use crate::{
|
||||||
hash::rpo::{Rpo256, RpoDigest},
|
|
||||||
merkle::{
|
merkle::{
|
||||||
mmr::{leaf_to_corresponding_tree, nodes_in_forest},
|
mmr::{leaf_to_corresponding_tree, nodes_in_forest},
|
||||||
InOrderIndex, MerklePath, MmrError, MmrPeaks,
|
InOrderIndex, InnerNodeInfo, MerklePath, MmrError, MmrPeaks,
|
||||||
|
},
|
||||||
|
utils::{
|
||||||
|
collections::{BTreeMap, BTreeSet, Vec},
|
||||||
|
vec,
|
||||||
},
|
},
|
||||||
utils::collections::{BTreeMap, Vec},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{MmrDelta, MmrProof};
|
// PARTIAL MERKLE MOUNTAIN RANGE
|
||||||
|
// ================================================================================================
|
||||||
/// Partially materialized [Mmr], used to efficiently store and update the authentication paths for
|
/// Partially materialized Merkle Mountain Range (MMR), used to efficiently store and update the
|
||||||
/// a subset of the elements in a full [Mmr].
|
/// authentication paths for a subset of the elements in a full MMR.
|
||||||
///
|
///
|
||||||
/// This structure store only the authentication path for a value, the value itself is stored
|
/// This structure store only the authentication path for a value, the value itself is stored
|
||||||
/// separately.
|
/// separately.
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct PartialMmr {
|
pub struct PartialMmr {
|
||||||
/// The version of the [Mmr].
|
/// The version of the MMR.
|
||||||
///
|
///
|
||||||
/// This value serves the following purposes:
|
/// This value serves the following purposes:
|
||||||
///
|
///
|
||||||
/// - The forest is a counter for the total number of elements in the [Mmr].
|
/// - The forest is a counter for the total number of elements in the MMR.
|
||||||
/// - Since the [Mmr] is an append-only structure, every change to it causes a change to the
|
/// - Since the MMR is an append-only structure, every change to it causes a change to the
|
||||||
/// `forest`, so this value has a dual purpose as a version tag.
|
/// `forest`, so this value has a dual purpose as a version tag.
|
||||||
/// - The bits in the forest also corresponds to the count and size of every perfect binary
|
/// - The bits in the forest also corresponds to the count and size of every perfect binary
|
||||||
/// tree that composes the [Mmr] structure, which server to compute indexes and perform
|
/// tree that composes the MMR structure, which server to compute indexes and perform
|
||||||
/// validation.
|
/// validation.
|
||||||
pub(crate) forest: usize,
|
pub(crate) forest: usize,
|
||||||
|
|
||||||
/// The [Mmr] peaks.
|
/// The MMR peaks.
|
||||||
///
|
///
|
||||||
/// The peaks are used for two reasons:
|
/// The peaks are used for two reasons:
|
||||||
///
|
///
|
||||||
/// 1. It authenticates the addition of an element to the [PartialMmr], ensuring only valid
|
/// 1. It authenticates the addition of an element to the [PartialMmr], ensuring only valid
|
||||||
/// elements are tracked.
|
/// elements are tracked.
|
||||||
/// 2. During a [Mmr] update peaks can be merged by hashing the left and right hand sides. The
|
/// 2. During a MMR update peaks can be merged by hashing the left and right hand sides. The
|
||||||
/// peaks are used as the left hand.
|
/// peaks are used as the left hand.
|
||||||
///
|
///
|
||||||
/// All the peaks of every tree in the [Mmr] forest. The peaks are always ordered by number of
|
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
|
||||||
/// leaves, starting from the peak with most children, to the one with least.
|
/// leaves, starting from the peak with most children, to the one with least.
|
||||||
pub(crate) peaks: Vec<RpoDigest>,
|
pub(crate) peaks: Vec<RpoDigest>,
|
||||||
|
|
||||||
/// Authentication nodes used to construct merkle paths for a subset of the [Mmr]'s leaves.
|
/// Authentication nodes used to construct merkle paths for a subset of the MMR's leaves.
|
||||||
///
|
///
|
||||||
/// This does not include the [Mmr]'s peaks nor the tracked nodes, only the elements required
|
/// This does not include the MMR's peaks nor the tracked nodes, only the elements required
|
||||||
/// to construct their authentication paths. This property is used to detect when elements can
|
/// to construct their authentication paths. This property is used to detect when elements can
|
||||||
/// be safely removed from, because they are no longer required to authenticate any element in
|
/// be safely removed from, because they are no longer required to authenticate any element in
|
||||||
/// the [PartialMmr].
|
/// the [PartialMmr].
|
||||||
///
|
///
|
||||||
/// The elements in the [Mmr] are referenced using a in-order tree index. This indexing scheme
|
/// The elements in the MMR are referenced using a in-order tree index. This indexing scheme
|
||||||
/// permits for easy computation of the relative nodes (left/right children, sibling, parent),
|
/// permits for easy computation of the relative nodes (left/right children, sibling, parent),
|
||||||
/// which is useful for traversal. The indexing is also stable, meaning that merges to the
|
/// which is useful for traversal. The indexing is also stable, meaning that merges to the
|
||||||
/// trees in the [Mmr] can be represented without rewrites of the indexes.
|
/// trees in the MMR can be represented without rewrites of the indexes.
|
||||||
pub(crate) nodes: BTreeMap<InOrderIndex, RpoDigest>,
|
pub(crate) nodes: BTreeMap<InOrderIndex, RpoDigest>,
|
||||||
|
|
||||||
/// Flag indicating if the odd element should be tracked.
|
/// Flag indicating if the odd element should be tracked.
|
||||||
@@ -66,33 +69,42 @@ impl PartialMmr {
|
|||||||
// --------------------------------------------------------------------------------------------
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
/// Constructs a [PartialMmr] from the given [MmrPeaks].
|
/// Constructs a [PartialMmr] from the given [MmrPeaks].
|
||||||
pub fn from_peaks(accumulator: MmrPeaks) -> Self {
|
pub fn from_peaks(peaks: MmrPeaks) -> Self {
|
||||||
let forest = accumulator.num_leaves();
|
let forest = peaks.num_leaves();
|
||||||
let peaks = accumulator.peaks().to_vec();
|
let peaks = peaks.peaks().to_vec();
|
||||||
let nodes = BTreeMap::new();
|
let nodes = BTreeMap::new();
|
||||||
let track_latest = false;
|
let track_latest = false;
|
||||||
|
|
||||||
Self { forest, peaks, nodes, track_latest }
|
Self { forest, peaks, nodes, track_latest }
|
||||||
}
|
}
|
||||||
|
|
||||||
// ACCESSORS
|
// PUBLIC ACCESSORS
|
||||||
// --------------------------------------------------------------------------------------------
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
// Gets the current `forest`.
|
/// Returns the current `forest` of this [PartialMmr].
|
||||||
//
|
///
|
||||||
// This value corresponds to the version of the [PartialMmr] and the number of leaves in it.
|
/// This value corresponds to the version of the [PartialMmr] and the number of leaves in the
|
||||||
|
/// underlying MMR.
|
||||||
pub fn forest(&self) -> usize {
|
pub fn forest(&self) -> usize {
|
||||||
self.forest
|
self.forest
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a reference to the current peaks in the [PartialMmr]
|
/// Returns the number of leaves in the underlying MMR for this [PartialMmr].
|
||||||
pub fn peaks(&self) -> &[RpoDigest] {
|
pub fn num_leaves(&self) -> usize {
|
||||||
&self.peaks
|
self.forest
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given a leaf position, returns the Merkle path to its corresponding peak. If the position
|
/// Returns the peaks of the MMR for this [PartialMmr].
|
||||||
/// is greater-or-equal than the tree size an error is returned. If the requested value is not
|
pub fn peaks(&self) -> MmrPeaks {
|
||||||
/// tracked returns `None`.
|
// expect() is OK here because the constructor ensures that MMR peaks can be constructed
|
||||||
|
// correctly
|
||||||
|
MmrPeaks::new(self.forest, self.peaks.clone()).expect("invalid MMR peaks")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Given a leaf position, returns the Merkle path to its corresponding peak.
|
||||||
|
///
|
||||||
|
/// If the position is greater-or-equal than the tree size an error is returned. If the
|
||||||
|
/// requested value is not tracked returns `None`.
|
||||||
///
|
///
|
||||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||||
@@ -125,14 +137,45 @@ impl PartialMmr {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MODIFIERS
|
// ITERATORS
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Returns an iterator nodes of all authentication paths of this [PartialMmr].
|
||||||
|
pub fn nodes(&self) -> impl Iterator<Item = (&InOrderIndex, &RpoDigest)> {
|
||||||
|
self.nodes.iter()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an iterator over inner nodes of this [PartialMmr] for the specified leaves.
|
||||||
|
///
|
||||||
|
/// The order of iteration is not defined. If a leaf is not presented in this partial MMR it
|
||||||
|
/// is silently ignored.
|
||||||
|
pub fn inner_nodes<'a, I: Iterator<Item = &'a (usize, RpoDigest)> + 'a>(
|
||||||
|
&'a self,
|
||||||
|
mut leaves: I,
|
||||||
|
) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||||
|
let stack = if let Some((pos, leaf)) = leaves.next() {
|
||||||
|
let idx = InOrderIndex::from_leaf_pos(*pos);
|
||||||
|
vec![(idx, *leaf)]
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
InnerNodeIterator {
|
||||||
|
nodes: &self.nodes,
|
||||||
|
leaves,
|
||||||
|
stack,
|
||||||
|
seen_nodes: BTreeSet::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// STATE MUTATORS
|
||||||
// --------------------------------------------------------------------------------------------
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
/// Add the authentication path represented by [MerklePath] if it is valid.
|
/// Add the authentication path represented by [MerklePath] if it is valid.
|
||||||
///
|
///
|
||||||
/// The `index` refers to the global position of the leaf in the [Mmr], these are 0-indexed
|
/// The `index` refers to the global position of the leaf in the MMR, these are 0-indexed
|
||||||
/// values assigned in a strictly monotonic fashion as elements are inserted into the [Mmr],
|
/// values assigned in a strictly monotonic fashion as elements are inserted into the MMR,
|
||||||
/// this value corresponds to the values used in the [Mmr] structure.
|
/// this value corresponds to the values used in the MMR structure.
|
||||||
///
|
///
|
||||||
/// The `node` corresponds to the value at `index`, and `path` is the authentication path for
|
/// The `node` corresponds to the value at `index`, and `path` is the authentication path for
|
||||||
/// that element up to its corresponding Mmr peak. The `node` is only used to compute the root
|
/// that element up to its corresponding Mmr peak. The `node` is only used to compute the root
|
||||||
@@ -185,7 +228,7 @@ impl PartialMmr {
|
|||||||
|
|
||||||
/// Remove a leaf of the [PartialMmr] and the unused nodes from the authentication path.
|
/// Remove a leaf of the [PartialMmr] and the unused nodes from the authentication path.
|
||||||
///
|
///
|
||||||
/// Note: `leaf_pos` corresponds to the position the [Mmr] and not on an individual tree.
|
/// Note: `leaf_pos` corresponds to the position in the MMR and not on an individual tree.
|
||||||
pub fn remove(&mut self, leaf_pos: usize) {
|
pub fn remove(&mut self, leaf_pos: usize) {
|
||||||
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
|
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
|
||||||
|
|
||||||
@@ -202,18 +245,21 @@ impl PartialMmr {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies updates to the [PartialMmr].
|
/// Applies updates to this [PartialMmr] and returns a vector of new authentication nodes
|
||||||
pub fn apply(&mut self, delta: MmrDelta) -> Result<(), MmrError> {
|
/// inserted into the partial MMR.
|
||||||
|
pub fn apply(&mut self, delta: MmrDelta) -> Result<Vec<(InOrderIndex, RpoDigest)>, MmrError> {
|
||||||
if delta.forest < self.forest {
|
if delta.forest < self.forest {
|
||||||
return Err(MmrError::InvalidPeaks);
|
return Err(MmrError::InvalidPeaks);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut inserted_nodes = Vec::new();
|
||||||
|
|
||||||
if delta.forest == self.forest {
|
if delta.forest == self.forest {
|
||||||
if !delta.data.is_empty() {
|
if !delta.data.is_empty() {
|
||||||
return Err(MmrError::InvalidUpdate);
|
return Err(MmrError::InvalidUpdate);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Ok(());
|
return Ok(inserted_nodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
// find the tree merges
|
// find the tree merges
|
||||||
@@ -268,16 +314,21 @@ impl PartialMmr {
|
|||||||
// check if either the left or right subtrees have saved for authentication paths.
|
// check if either the left or right subtrees have saved for authentication paths.
|
||||||
// If so, turn tracking on to update those paths.
|
// If so, turn tracking on to update those paths.
|
||||||
if target != 1 && !track {
|
if target != 1 && !track {
|
||||||
let left_child = peak_idx.left_child();
|
track = self.is_tracked_node(&peak_idx);
|
||||||
let right_child = peak_idx.right_child();
|
|
||||||
track = self.nodes.contains_key(&left_child)
|
|
||||||
| self.nodes.contains_key(&right_child);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// update data only contains the nodes from the right subtrees, left nodes are
|
// update data only contains the nodes from the right subtrees, left nodes are
|
||||||
// either previously known peaks or computed values
|
// either previously known peaks or computed values
|
||||||
let (left, right) = if target & merges != 0 {
|
let (left, right) = if target & merges != 0 {
|
||||||
let peak = self.peaks[peak_count];
|
let peak = self.peaks[peak_count];
|
||||||
|
let sibling_idx = peak_idx.sibling();
|
||||||
|
|
||||||
|
// if the sibling peak is tracked, add this peaks to the set of
|
||||||
|
// authentication nodes
|
||||||
|
if self.is_tracked_node(&sibling_idx) {
|
||||||
|
self.nodes.insert(peak_idx, new);
|
||||||
|
inserted_nodes.push((peak_idx, new));
|
||||||
|
}
|
||||||
peak_count += 1;
|
peak_count += 1;
|
||||||
(peak, new)
|
(peak, new)
|
||||||
} else {
|
} else {
|
||||||
@@ -287,7 +338,14 @@ impl PartialMmr {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if track {
|
if track {
|
||||||
self.nodes.insert(peak_idx.sibling(), right);
|
let sibling_idx = peak_idx.sibling();
|
||||||
|
if peak_idx.is_left_child() {
|
||||||
|
self.nodes.insert(sibling_idx, right);
|
||||||
|
inserted_nodes.push((sibling_idx, right));
|
||||||
|
} else {
|
||||||
|
self.nodes.insert(sibling_idx, left);
|
||||||
|
inserted_nodes.push((sibling_idx, left));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
peak_idx = peak_idx.parent();
|
peak_idx = peak_idx.parent();
|
||||||
@@ -313,7 +371,22 @@ impl PartialMmr {
|
|||||||
|
|
||||||
debug_assert!(self.peaks.len() == (self.forest.count_ones() as usize));
|
debug_assert!(self.peaks.len() == (self.forest.count_ones() as usize));
|
||||||
|
|
||||||
Ok(())
|
Ok(inserted_nodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HELPER METHODS
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Returns true if this [PartialMmr] tracks authentication path for the node at the specified
|
||||||
|
/// index.
|
||||||
|
fn is_tracked_node(&self, node_index: &InOrderIndex) -> bool {
|
||||||
|
if node_index.is_leaf() {
|
||||||
|
self.nodes.contains_key(&node_index.sibling())
|
||||||
|
} else {
|
||||||
|
let left_child = node_index.left_child();
|
||||||
|
let right_child = node_index.right_child();
|
||||||
|
self.nodes.contains_key(&left_child) | self.nodes.contains_key(&right_child)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -348,12 +421,59 @@ impl From<&PartialMmr> for MmrPeaks {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ITERATORS
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
/// An iterator over every inner node of the [PartialMmr].
|
||||||
|
pub struct InnerNodeIterator<'a, I: Iterator<Item = &'a (usize, RpoDigest)>> {
|
||||||
|
nodes: &'a BTreeMap<InOrderIndex, RpoDigest>,
|
||||||
|
leaves: I,
|
||||||
|
stack: Vec<(InOrderIndex, RpoDigest)>,
|
||||||
|
seen_nodes: BTreeSet<InOrderIndex>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, I: Iterator<Item = &'a (usize, RpoDigest)>> Iterator for InnerNodeIterator<'a, I> {
|
||||||
|
type Item = InnerNodeInfo;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
while let Some((idx, node)) = self.stack.pop() {
|
||||||
|
let parent_idx = idx.parent();
|
||||||
|
let new_node = self.seen_nodes.insert(parent_idx);
|
||||||
|
|
||||||
|
// if we haven't seen this node's parent before, and the node has a sibling, return
|
||||||
|
// the inner node defined by the parent of this node, and move up the branch
|
||||||
|
if new_node {
|
||||||
|
if let Some(sibling) = self.nodes.get(&idx.sibling()) {
|
||||||
|
let (left, right) = if parent_idx.left_child() == idx {
|
||||||
|
(node, *sibling)
|
||||||
|
} else {
|
||||||
|
(*sibling, node)
|
||||||
|
};
|
||||||
|
let parent = Rpo256::merge(&[left, right]);
|
||||||
|
let inner_node = InnerNodeInfo { value: parent, left, right };
|
||||||
|
|
||||||
|
self.stack.push((parent_idx, parent));
|
||||||
|
return Some(inner_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// the previous leaf has been processed, try to process the next leaf
|
||||||
|
if let Some((pos, leaf)) = self.leaves.next() {
|
||||||
|
let idx = InOrderIndex::from_leaf_pos(*pos);
|
||||||
|
self.stack.push((idx, *leaf));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// UTILS
|
// UTILS
|
||||||
// ================================================================================================
|
// ================================================================================================
|
||||||
|
|
||||||
/// Given the description of a `forest`, returns the index of the root element of the smallest tree
|
/// Given the description of a `forest`, returns the index of the root element of the smallest tree
|
||||||
/// in it.
|
/// in it.
|
||||||
pub fn forest_to_root_index(forest: usize) -> InOrderIndex {
|
fn forest_to_root_index(forest: usize) -> InOrderIndex {
|
||||||
// Count total size of all trees in the forest.
|
// Count total size of all trees in the forest.
|
||||||
let nodes = nodes_in_forest(forest);
|
let nodes = nodes_in_forest(forest);
|
||||||
|
|
||||||
@@ -370,10 +490,23 @@ pub fn forest_to_root_index(forest: usize) -> InOrderIndex {
|
|||||||
InOrderIndex::new(idx.try_into().unwrap())
|
InOrderIndex::new(idx.try_into().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TESTS
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod tests {
|
||||||
use super::forest_to_root_index;
|
use super::{forest_to_root_index, BTreeSet, InOrderIndex, PartialMmr, RpoDigest, Vec};
|
||||||
use crate::merkle::InOrderIndex;
|
use crate::merkle::{int_to_node, MerkleStore, Mmr, NodeIndex};
|
||||||
|
|
||||||
|
const LEAVES: [RpoDigest; 7] = [
|
||||||
|
int_to_node(0),
|
||||||
|
int_to_node(1),
|
||||||
|
int_to_node(2),
|
||||||
|
int_to_node(3),
|
||||||
|
int_to_node(4),
|
||||||
|
int_to_node(5),
|
||||||
|
int_to_node(6),
|
||||||
|
];
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_forest_to_root_index() {
|
fn test_forest_to_root_index() {
|
||||||
@@ -400,4 +533,171 @@ mod test {
|
|||||||
assert_eq!(forest_to_root_index(0b1100), idx(20));
|
assert_eq!(forest_to_root_index(0b1100), idx(20));
|
||||||
assert_eq!(forest_to_root_index(0b1110), idx(26));
|
assert_eq!(forest_to_root_index(0b1110), idx(26));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_partial_mmr_apply_delta() {
|
||||||
|
// build an MMR with 10 nodes (2 peaks) and a partial MMR based on it
|
||||||
|
let mut mmr = Mmr::default();
|
||||||
|
(0..10).for_each(|i| mmr.add(int_to_node(i)));
|
||||||
|
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||||
|
|
||||||
|
// add authentication path for position 1 and 8
|
||||||
|
{
|
||||||
|
let node = mmr.get(1).unwrap();
|
||||||
|
let proof = mmr.open(1, mmr.forest()).unwrap();
|
||||||
|
partial_mmr.add(1, node, &proof.merkle_path).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let node = mmr.get(8).unwrap();
|
||||||
|
let proof = mmr.open(8, mmr.forest()).unwrap();
|
||||||
|
partial_mmr.add(8, node, &proof.merkle_path).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// add 2 more nodes into the MMR and validate apply_delta()
|
||||||
|
(10..12).for_each(|i| mmr.add(int_to_node(i)));
|
||||||
|
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||||
|
|
||||||
|
// add 1 more node to the MMR, validate apply_delta() and start tracking the node
|
||||||
|
mmr.add(int_to_node(12));
|
||||||
|
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||||
|
{
|
||||||
|
let node = mmr.get(12).unwrap();
|
||||||
|
let proof = mmr.open(12, mmr.forest()).unwrap();
|
||||||
|
partial_mmr.add(12, node, &proof.merkle_path).unwrap();
|
||||||
|
assert!(partial_mmr.track_latest);
|
||||||
|
}
|
||||||
|
|
||||||
|
// by this point we are tracking authentication paths for positions: 1, 8, and 12
|
||||||
|
|
||||||
|
// add 3 more nodes to the MMR (collapses to 1 peak) and validate apply_delta()
|
||||||
|
(13..16).for_each(|i| mmr.add(int_to_node(i)));
|
||||||
|
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_apply_delta(mmr: &Mmr, partial: &mut PartialMmr) {
|
||||||
|
let tracked_leaves = partial
|
||||||
|
.nodes
|
||||||
|
.iter()
|
||||||
|
.filter_map(|(index, _)| if index.is_leaf() { Some(index.sibling()) } else { None })
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let nodes_before = partial.nodes.clone();
|
||||||
|
|
||||||
|
// compute and apply delta
|
||||||
|
let delta = mmr.get_delta(partial.forest(), mmr.forest()).unwrap();
|
||||||
|
let nodes_delta = partial.apply(delta).unwrap();
|
||||||
|
|
||||||
|
// new peaks were computed correctly
|
||||||
|
assert_eq!(mmr.peaks(mmr.forest()).unwrap(), partial.peaks());
|
||||||
|
|
||||||
|
let mut expected_nodes = nodes_before;
|
||||||
|
for (key, value) in nodes_delta {
|
||||||
|
// nodes should not be duplicated
|
||||||
|
assert!(expected_nodes.insert(key, value).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
// new nodes should be a combination of original nodes and delta
|
||||||
|
assert_eq!(expected_nodes, partial.nodes);
|
||||||
|
|
||||||
|
// make sure tracked leaves open to the same proofs as in the underlying MMR
|
||||||
|
for index in tracked_leaves {
|
||||||
|
let index_value: u64 = index.into();
|
||||||
|
let pos = index_value / 2;
|
||||||
|
let proof1 = partial.open(pos as usize).unwrap().unwrap();
|
||||||
|
let proof2 = mmr.open(pos as usize, mmr.forest()).unwrap();
|
||||||
|
assert_eq!(proof1, proof2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_partial_mmr_inner_nodes_iterator() {
|
||||||
|
// build the MMR
|
||||||
|
let mmr: Mmr = LEAVES.into();
|
||||||
|
let first_peak = mmr.peaks(mmr.forest).unwrap().peaks()[0];
|
||||||
|
|
||||||
|
// -- test single tree ----------------------------
|
||||||
|
|
||||||
|
// get path and node for position 1
|
||||||
|
let node1 = mmr.get(1).unwrap();
|
||||||
|
let proof1 = mmr.open(1, mmr.forest()).unwrap();
|
||||||
|
|
||||||
|
// create partial MMR and add authentication path to node at position 1
|
||||||
|
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||||
|
partial_mmr.add(1, node1, &proof1.merkle_path).unwrap();
|
||||||
|
|
||||||
|
// empty iterator should have no nodes
|
||||||
|
assert_eq!(partial_mmr.inner_nodes([].iter()).next(), None);
|
||||||
|
|
||||||
|
// build Merkle store from authentication paths in partial MMR
|
||||||
|
let mut store: MerkleStore = MerkleStore::new();
|
||||||
|
store.extend(partial_mmr.inner_nodes([(1, node1)].iter()));
|
||||||
|
|
||||||
|
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||||
|
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||||
|
|
||||||
|
assert_eq!(path1, proof1.merkle_path);
|
||||||
|
|
||||||
|
// -- test no duplicates --------------------------
|
||||||
|
|
||||||
|
// build the partial MMR
|
||||||
|
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||||
|
|
||||||
|
let node0 = mmr.get(0).unwrap();
|
||||||
|
let proof0 = mmr.open(0, mmr.forest()).unwrap();
|
||||||
|
|
||||||
|
let node2 = mmr.get(2).unwrap();
|
||||||
|
let proof2 = mmr.open(2, mmr.forest()).unwrap();
|
||||||
|
|
||||||
|
partial_mmr.add(0, node0, &proof0.merkle_path).unwrap();
|
||||||
|
partial_mmr.add(1, node1, &proof1.merkle_path).unwrap();
|
||||||
|
partial_mmr.add(2, node2, &proof2.merkle_path).unwrap();
|
||||||
|
|
||||||
|
// make sure there are no duplicates
|
||||||
|
let leaves = [(0, node0), (1, node1), (2, node2)];
|
||||||
|
let mut nodes = BTreeSet::new();
|
||||||
|
for node in partial_mmr.inner_nodes(leaves.iter()) {
|
||||||
|
assert!(nodes.insert(node.value));
|
||||||
|
}
|
||||||
|
|
||||||
|
// and also that the store is still be built correctly
|
||||||
|
store.extend(partial_mmr.inner_nodes(leaves.iter()));
|
||||||
|
|
||||||
|
let index0 = NodeIndex::new(2, 0).unwrap();
|
||||||
|
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||||
|
let index2 = NodeIndex::new(2, 2).unwrap();
|
||||||
|
|
||||||
|
let path0 = store.get_path(first_peak, index0).unwrap().path;
|
||||||
|
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||||
|
let path2 = store.get_path(first_peak, index2).unwrap().path;
|
||||||
|
|
||||||
|
assert_eq!(path0, proof0.merkle_path);
|
||||||
|
assert_eq!(path1, proof1.merkle_path);
|
||||||
|
assert_eq!(path2, proof2.merkle_path);
|
||||||
|
|
||||||
|
// -- test multiple trees -------------------------
|
||||||
|
|
||||||
|
// build the partial MMR
|
||||||
|
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||||
|
|
||||||
|
let node5 = mmr.get(5).unwrap();
|
||||||
|
let proof5 = mmr.open(5, mmr.forest()).unwrap();
|
||||||
|
|
||||||
|
partial_mmr.add(1, node1, &proof1.merkle_path).unwrap();
|
||||||
|
partial_mmr.add(5, node5, &proof5.merkle_path).unwrap();
|
||||||
|
|
||||||
|
// build Merkle store from authentication paths in partial MMR
|
||||||
|
let mut store: MerkleStore = MerkleStore::new();
|
||||||
|
store.extend(partial_mmr.inner_nodes([(1, node1), (5, node5)].iter()));
|
||||||
|
|
||||||
|
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||||
|
let index5 = NodeIndex::new(1, 1).unwrap();
|
||||||
|
|
||||||
|
let second_peak = mmr.peaks(mmr.forest).unwrap().peaks()[1];
|
||||||
|
|
||||||
|
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||||
|
let path5 = store.get_path(second_peak, index5).unwrap().path;
|
||||||
|
|
||||||
|
assert_eq!(path1, proof1.merkle_path);
|
||||||
|
assert_eq!(path5, proof5.merkle_path);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,17 +3,20 @@ use super::{
|
|||||||
Felt, MmrError, MmrProof, Rpo256, Word,
|
Felt, MmrError, MmrProof, Rpo256, Word,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
// MMR PEAKS
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||||
pub struct MmrPeaks {
|
pub struct MmrPeaks {
|
||||||
/// The number of leaves is used to differentiate accumulators that have the same number of
|
/// The number of leaves is used to differentiate MMRs that have the same number of peaks. This
|
||||||
/// peaks. This happens because the number of peaks goes up-and-down as the structure is used
|
/// happens because the number of peaks goes up-and-down as the structure is used causing
|
||||||
/// causing existing trees to be merged and new ones to be created. As an example, every time
|
/// existing trees to be merged and new ones to be created. As an example, every time the MMR
|
||||||
/// the [Mmr] has a power-of-two number of leaves there is a single peak.
|
/// has a power-of-two number of leaves there is a single peak.
|
||||||
///
|
///
|
||||||
/// Every tree in the [Mmr] forest has a distinct power-of-two size, this means only the right
|
/// Every tree in the MMR forest has a distinct power-of-two size, this means only the right-
|
||||||
/// most tree can have an odd number of elements (e.g. `1`). Additionally this means that the bits in
|
/// most tree can have an odd number of elements (e.g. `1`). Additionally this means that the
|
||||||
/// `num_leaves` conveniently encode the size of each individual tree.
|
/// bits in `num_leaves` conveniently encode the size of each individual tree.
|
||||||
///
|
///
|
||||||
/// Examples:
|
/// Examples:
|
||||||
///
|
///
|
||||||
@@ -25,7 +28,7 @@ pub struct MmrPeaks {
|
|||||||
/// leftmost tree has `2**3=8` elements, and the right most has `2**2=4` elements.
|
/// leftmost tree has `2**3=8` elements, and the right most has `2**2=4` elements.
|
||||||
num_leaves: usize,
|
num_leaves: usize,
|
||||||
|
|
||||||
/// All the peaks of every tree in the [Mmr] forest. The peaks are always ordered by number of
|
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
|
||||||
/// leaves, starting from the peak with most children, to the one with least.
|
/// leaves, starting from the peak with most children, to the one with least.
|
||||||
///
|
///
|
||||||
/// Invariant: The length of `peaks` must be equal to the number of true bits in `num_leaves`.
|
/// Invariant: The length of `peaks` must be equal to the number of true bits in `num_leaves`.
|
||||||
@@ -33,6 +36,14 @@ pub struct MmrPeaks {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MmrPeaks {
|
impl MmrPeaks {
|
||||||
|
// CONSTRUCTOR
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Returns new [MmrPeaks] instantiated from the provided vector of peaks and the number of
|
||||||
|
/// leaves in the underlying MMR.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
/// Returns an error if the number of leaves and the number of peaks are inconsistent.
|
||||||
pub fn new(num_leaves: usize, peaks: Vec<RpoDigest>) -> Result<Self, MmrError> {
|
pub fn new(num_leaves: usize, peaks: Vec<RpoDigest>) -> Result<Self, MmrError> {
|
||||||
if num_leaves.count_ones() as usize != peaks.len() {
|
if num_leaves.count_ones() as usize != peaks.len() {
|
||||||
return Err(MmrError::InvalidPeaks);
|
return Err(MmrError::InvalidPeaks);
|
||||||
@@ -44,23 +55,34 @@ impl MmrPeaks {
|
|||||||
// ACCESSORS
|
// ACCESSORS
|
||||||
// --------------------------------------------------------------------------------------------
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
/// Returns a count of the [Mmr]'s leaves.
|
/// Returns a count of leaves in the underlying MMR.
|
||||||
pub fn num_leaves(&self) -> usize {
|
pub fn num_leaves(&self) -> usize {
|
||||||
self.num_leaves
|
self.num_leaves
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the current peaks of the [Mmr].
|
/// Returns the number of peaks of the underlying MMR.
|
||||||
|
pub fn num_peaks(&self) -> usize {
|
||||||
|
self.peaks.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the list of peaks of the underlying MMR.
|
||||||
pub fn peaks(&self) -> &[RpoDigest] {
|
pub fn peaks(&self) -> &[RpoDigest] {
|
||||||
&self.peaks
|
&self.peaks
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Converts this [MmrPeaks] into its components: number of leaves and a vector of peaks of
|
||||||
|
/// the underlying MMR.
|
||||||
|
pub fn into_parts(self) -> (usize, Vec<RpoDigest>) {
|
||||||
|
(self.num_leaves, self.peaks)
|
||||||
|
}
|
||||||
|
|
||||||
/// Hashes the peaks.
|
/// Hashes the peaks.
|
||||||
///
|
///
|
||||||
/// The procedure will:
|
/// The procedure will:
|
||||||
/// - Flatten and pad the peaks to a vector of Felts.
|
/// - Flatten and pad the peaks to a vector of Felts.
|
||||||
/// - Hash the vector of Felts.
|
/// - Hash the vector of Felts.
|
||||||
pub fn hash_peaks(&self) -> Word {
|
pub fn hash_peaks(&self) -> RpoDigest {
|
||||||
Rpo256::hash_elements(&self.flatten_and_pad_peaks()).into()
|
Rpo256::hash_elements(&self.flatten_and_pad_peaks())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn verify(&self, value: RpoDigest, opening: MmrProof) -> bool {
|
pub fn verify(&self, value: RpoDigest, opening: MmrProof) -> bool {
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
use super::{
|
use super::{
|
||||||
super::{InnerNodeInfo, Vec},
|
super::{InnerNodeInfo, Rpo256, RpoDigest, Vec},
|
||||||
bit::TrueBitPositionIterator,
|
bit::TrueBitPositionIterator,
|
||||||
full::high_bitmask,
|
full::high_bitmask,
|
||||||
leaf_to_corresponding_tree, nodes_in_forest, Mmr, MmrPeaks, PartialMmr, Rpo256,
|
leaf_to_corresponding_tree, nodes_in_forest, Mmr, MmrPeaks, PartialMmr,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
hash::rpo::RpoDigest,
|
|
||||||
merkle::{int_to_node, InOrderIndex, MerklePath, MerkleTree, MmrProof, NodeIndex},
|
merkle::{int_to_node, InOrderIndex, MerklePath, MerkleTree, MmrProof, NodeIndex},
|
||||||
Felt, Word,
|
Felt, Word,
|
||||||
};
|
};
|
||||||
@@ -137,7 +136,7 @@ fn test_mmr_simple() {
|
|||||||
assert_eq!(mmr.nodes.len(), 1);
|
assert_eq!(mmr.nodes.len(), 1);
|
||||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||||
|
|
||||||
let acc = mmr.accumulator();
|
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||||
assert_eq!(acc.num_leaves(), 1);
|
assert_eq!(acc.num_leaves(), 1);
|
||||||
assert_eq!(acc.peaks(), &[postorder[0]]);
|
assert_eq!(acc.peaks(), &[postorder[0]]);
|
||||||
|
|
||||||
@@ -146,7 +145,7 @@ fn test_mmr_simple() {
|
|||||||
assert_eq!(mmr.nodes.len(), 3);
|
assert_eq!(mmr.nodes.len(), 3);
|
||||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||||
|
|
||||||
let acc = mmr.accumulator();
|
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||||
assert_eq!(acc.num_leaves(), 2);
|
assert_eq!(acc.num_leaves(), 2);
|
||||||
assert_eq!(acc.peaks(), &[postorder[2]]);
|
assert_eq!(acc.peaks(), &[postorder[2]]);
|
||||||
|
|
||||||
@@ -155,7 +154,7 @@ fn test_mmr_simple() {
|
|||||||
assert_eq!(mmr.nodes.len(), 4);
|
assert_eq!(mmr.nodes.len(), 4);
|
||||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||||
|
|
||||||
let acc = mmr.accumulator();
|
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||||
assert_eq!(acc.num_leaves(), 3);
|
assert_eq!(acc.num_leaves(), 3);
|
||||||
assert_eq!(acc.peaks(), &[postorder[2], postorder[3]]);
|
assert_eq!(acc.peaks(), &[postorder[2], postorder[3]]);
|
||||||
|
|
||||||
@@ -164,7 +163,7 @@ fn test_mmr_simple() {
|
|||||||
assert_eq!(mmr.nodes.len(), 7);
|
assert_eq!(mmr.nodes.len(), 7);
|
||||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||||
|
|
||||||
let acc = mmr.accumulator();
|
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||||
assert_eq!(acc.num_leaves(), 4);
|
assert_eq!(acc.num_leaves(), 4);
|
||||||
assert_eq!(acc.peaks(), &[postorder[6]]);
|
assert_eq!(acc.peaks(), &[postorder[6]]);
|
||||||
|
|
||||||
@@ -173,7 +172,7 @@ fn test_mmr_simple() {
|
|||||||
assert_eq!(mmr.nodes.len(), 8);
|
assert_eq!(mmr.nodes.len(), 8);
|
||||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||||
|
|
||||||
let acc = mmr.accumulator();
|
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||||
assert_eq!(acc.num_leaves(), 5);
|
assert_eq!(acc.num_leaves(), 5);
|
||||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[7]]);
|
assert_eq!(acc.peaks(), &[postorder[6], postorder[7]]);
|
||||||
|
|
||||||
@@ -182,7 +181,7 @@ fn test_mmr_simple() {
|
|||||||
assert_eq!(mmr.nodes.len(), 10);
|
assert_eq!(mmr.nodes.len(), 10);
|
||||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||||
|
|
||||||
let acc = mmr.accumulator();
|
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||||
assert_eq!(acc.num_leaves(), 6);
|
assert_eq!(acc.num_leaves(), 6);
|
||||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[9]]);
|
assert_eq!(acc.peaks(), &[postorder[6], postorder[9]]);
|
||||||
|
|
||||||
@@ -191,7 +190,7 @@ fn test_mmr_simple() {
|
|||||||
assert_eq!(mmr.nodes.len(), 11);
|
assert_eq!(mmr.nodes.len(), 11);
|
||||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||||
|
|
||||||
let acc = mmr.accumulator();
|
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||||
assert_eq!(acc.num_leaves(), 7);
|
assert_eq!(acc.num_leaves(), 7);
|
||||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[9], postorder[10]]);
|
assert_eq!(acc.peaks(), &[postorder[6], postorder[9], postorder[10]]);
|
||||||
}
|
}
|
||||||
@@ -203,96 +202,139 @@ fn test_mmr_open() {
|
|||||||
let h23 = merge(LEAVES[2], LEAVES[3]);
|
let h23 = merge(LEAVES[2], LEAVES[3]);
|
||||||
|
|
||||||
// node at pos 7 is the root
|
// node at pos 7 is the root
|
||||||
assert!(mmr.open(7).is_err(), "Element 7 is not in the tree, result should be None");
|
assert!(
|
||||||
|
mmr.open(7, mmr.forest()).is_err(),
|
||||||
|
"Element 7 is not in the tree, result should be None"
|
||||||
|
);
|
||||||
|
|
||||||
// node at pos 6 is the root
|
// node at pos 6 is the root
|
||||||
let empty: MerklePath = MerklePath::new(vec![]);
|
let empty: MerklePath = MerklePath::new(vec![]);
|
||||||
let opening = mmr
|
let opening = mmr
|
||||||
.open(6)
|
.open(6, mmr.forest())
|
||||||
.expect("Element 6 is contained in the tree, expected an opening result.");
|
.expect("Element 6 is contained in the tree, expected an opening result.");
|
||||||
assert_eq!(opening.merkle_path, empty);
|
assert_eq!(opening.merkle_path, empty);
|
||||||
assert_eq!(opening.forest, mmr.forest);
|
assert_eq!(opening.forest, mmr.forest);
|
||||||
assert_eq!(opening.position, 6);
|
assert_eq!(opening.position, 6);
|
||||||
assert!(
|
assert!(
|
||||||
mmr.accumulator().verify(LEAVES[6], opening),
|
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[6], opening),
|
||||||
"MmrProof should be valid for the current accumulator."
|
"MmrProof should be valid for the current accumulator."
|
||||||
);
|
);
|
||||||
|
|
||||||
// nodes 4,5 are depth 1
|
// nodes 4,5 are depth 1
|
||||||
let root_to_path = MerklePath::new(vec![LEAVES[4]]);
|
let root_to_path = MerklePath::new(vec![LEAVES[4]]);
|
||||||
let opening = mmr
|
let opening = mmr
|
||||||
.open(5)
|
.open(5, mmr.forest())
|
||||||
.expect("Element 5 is contained in the tree, expected an opening result.");
|
.expect("Element 5 is contained in the tree, expected an opening result.");
|
||||||
assert_eq!(opening.merkle_path, root_to_path);
|
assert_eq!(opening.merkle_path, root_to_path);
|
||||||
assert_eq!(opening.forest, mmr.forest);
|
assert_eq!(opening.forest, mmr.forest);
|
||||||
assert_eq!(opening.position, 5);
|
assert_eq!(opening.position, 5);
|
||||||
assert!(
|
assert!(
|
||||||
mmr.accumulator().verify(LEAVES[5], opening),
|
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[5], opening),
|
||||||
"MmrProof should be valid for the current accumulator."
|
"MmrProof should be valid for the current accumulator."
|
||||||
);
|
);
|
||||||
|
|
||||||
let root_to_path = MerklePath::new(vec![LEAVES[5]]);
|
let root_to_path = MerklePath::new(vec![LEAVES[5]]);
|
||||||
let opening = mmr
|
let opening = mmr
|
||||||
.open(4)
|
.open(4, mmr.forest())
|
||||||
.expect("Element 4 is contained in the tree, expected an opening result.");
|
.expect("Element 4 is contained in the tree, expected an opening result.");
|
||||||
assert_eq!(opening.merkle_path, root_to_path);
|
assert_eq!(opening.merkle_path, root_to_path);
|
||||||
assert_eq!(opening.forest, mmr.forest);
|
assert_eq!(opening.forest, mmr.forest);
|
||||||
assert_eq!(opening.position, 4);
|
assert_eq!(opening.position, 4);
|
||||||
assert!(
|
assert!(
|
||||||
mmr.accumulator().verify(LEAVES[4], opening),
|
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[4], opening),
|
||||||
"MmrProof should be valid for the current accumulator."
|
"MmrProof should be valid for the current accumulator."
|
||||||
);
|
);
|
||||||
|
|
||||||
// nodes 0,1,2,3 are detph 2
|
// nodes 0,1,2,3 are detph 2
|
||||||
let root_to_path = MerklePath::new(vec![LEAVES[2], h01]);
|
let root_to_path = MerklePath::new(vec![LEAVES[2], h01]);
|
||||||
let opening = mmr
|
let opening = mmr
|
||||||
.open(3)
|
.open(3, mmr.forest())
|
||||||
.expect("Element 3 is contained in the tree, expected an opening result.");
|
.expect("Element 3 is contained in the tree, expected an opening result.");
|
||||||
assert_eq!(opening.merkle_path, root_to_path);
|
assert_eq!(opening.merkle_path, root_to_path);
|
||||||
assert_eq!(opening.forest, mmr.forest);
|
assert_eq!(opening.forest, mmr.forest);
|
||||||
assert_eq!(opening.position, 3);
|
assert_eq!(opening.position, 3);
|
||||||
assert!(
|
assert!(
|
||||||
mmr.accumulator().verify(LEAVES[3], opening),
|
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[3], opening),
|
||||||
"MmrProof should be valid for the current accumulator."
|
"MmrProof should be valid for the current accumulator."
|
||||||
);
|
);
|
||||||
|
|
||||||
let root_to_path = MerklePath::new(vec![LEAVES[3], h01]);
|
let root_to_path = MerklePath::new(vec![LEAVES[3], h01]);
|
||||||
let opening = mmr
|
let opening = mmr
|
||||||
.open(2)
|
.open(2, mmr.forest())
|
||||||
.expect("Element 2 is contained in the tree, expected an opening result.");
|
.expect("Element 2 is contained in the tree, expected an opening result.");
|
||||||
assert_eq!(opening.merkle_path, root_to_path);
|
assert_eq!(opening.merkle_path, root_to_path);
|
||||||
assert_eq!(opening.forest, mmr.forest);
|
assert_eq!(opening.forest, mmr.forest);
|
||||||
assert_eq!(opening.position, 2);
|
assert_eq!(opening.position, 2);
|
||||||
assert!(
|
assert!(
|
||||||
mmr.accumulator().verify(LEAVES[2], opening),
|
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[2], opening),
|
||||||
"MmrProof should be valid for the current accumulator."
|
"MmrProof should be valid for the current accumulator."
|
||||||
);
|
);
|
||||||
|
|
||||||
let root_to_path = MerklePath::new(vec![LEAVES[0], h23]);
|
let root_to_path = MerklePath::new(vec![LEAVES[0], h23]);
|
||||||
let opening = mmr
|
let opening = mmr
|
||||||
.open(1)
|
.open(1, mmr.forest())
|
||||||
.expect("Element 1 is contained in the tree, expected an opening result.");
|
.expect("Element 1 is contained in the tree, expected an opening result.");
|
||||||
assert_eq!(opening.merkle_path, root_to_path);
|
assert_eq!(opening.merkle_path, root_to_path);
|
||||||
assert_eq!(opening.forest, mmr.forest);
|
assert_eq!(opening.forest, mmr.forest);
|
||||||
assert_eq!(opening.position, 1);
|
assert_eq!(opening.position, 1);
|
||||||
assert!(
|
assert!(
|
||||||
mmr.accumulator().verify(LEAVES[1], opening),
|
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[1], opening),
|
||||||
"MmrProof should be valid for the current accumulator."
|
"MmrProof should be valid for the current accumulator."
|
||||||
);
|
);
|
||||||
|
|
||||||
let root_to_path = MerklePath::new(vec![LEAVES[1], h23]);
|
let root_to_path = MerklePath::new(vec![LEAVES[1], h23]);
|
||||||
let opening = mmr
|
let opening = mmr
|
||||||
.open(0)
|
.open(0, mmr.forest())
|
||||||
.expect("Element 0 is contained in the tree, expected an opening result.");
|
.expect("Element 0 is contained in the tree, expected an opening result.");
|
||||||
assert_eq!(opening.merkle_path, root_to_path);
|
assert_eq!(opening.merkle_path, root_to_path);
|
||||||
assert_eq!(opening.forest, mmr.forest);
|
assert_eq!(opening.forest, mmr.forest);
|
||||||
assert_eq!(opening.position, 0);
|
assert_eq!(opening.position, 0);
|
||||||
assert!(
|
assert!(
|
||||||
mmr.accumulator().verify(LEAVES[0], opening),
|
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[0], opening),
|
||||||
"MmrProof should be valid for the current accumulator."
|
"MmrProof should be valid for the current accumulator."
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mmr_open_older_version() {
|
||||||
|
let mmr: Mmr = LEAVES.into();
|
||||||
|
|
||||||
|
fn is_even(v: &usize) -> bool {
|
||||||
|
v & 1 == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// merkle path of a node is empty if there are no elements to pair with it
|
||||||
|
for pos in (0..mmr.forest()).filter(is_even) {
|
||||||
|
let forest = pos + 1;
|
||||||
|
let proof = mmr.open(pos, forest).unwrap();
|
||||||
|
assert_eq!(proof.forest, forest);
|
||||||
|
assert_eq!(proof.merkle_path.nodes(), []);
|
||||||
|
assert_eq!(proof.position, pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
// openings match that of a merkle tree
|
||||||
|
let mtree: MerkleTree = LEAVES[..4].try_into().unwrap();
|
||||||
|
for forest in 4..=LEAVES.len() {
|
||||||
|
for pos in 0..4 {
|
||||||
|
let idx = NodeIndex::new(2, pos).unwrap();
|
||||||
|
let path = mtree.get_path(idx).unwrap();
|
||||||
|
let proof = mmr.open(pos as usize, forest).unwrap();
|
||||||
|
assert_eq!(path, proof.merkle_path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mtree: MerkleTree = LEAVES[4..6].try_into().unwrap();
|
||||||
|
for forest in 6..=LEAVES.len() {
|
||||||
|
for pos in 0..2 {
|
||||||
|
let idx = NodeIndex::new(1, pos).unwrap();
|
||||||
|
let path = mtree.get_path(idx).unwrap();
|
||||||
|
// account for the bigger tree with 4 elements
|
||||||
|
let mmr_pos = (pos + 4) as usize;
|
||||||
|
let proof = mmr.open(mmr_pos, forest).unwrap();
|
||||||
|
assert_eq!(path, proof.merkle_path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Tests the openings of a simple Mmr with a single tree of depth 8.
|
/// Tests the openings of a simple Mmr with a single tree of depth 8.
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mmr_open_eight() {
|
fn test_mmr_open_eight() {
|
||||||
@@ -313,49 +355,49 @@ fn test_mmr_open_eight() {
|
|||||||
let root = mtree.root();
|
let root = mtree.root();
|
||||||
|
|
||||||
let position = 0;
|
let position = 0;
|
||||||
let proof = mmr.open(position).unwrap();
|
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||||
|
|
||||||
let position = 1;
|
let position = 1;
|
||||||
let proof = mmr.open(position).unwrap();
|
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||||
|
|
||||||
let position = 2;
|
let position = 2;
|
||||||
let proof = mmr.open(position).unwrap();
|
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||||
|
|
||||||
let position = 3;
|
let position = 3;
|
||||||
let proof = mmr.open(position).unwrap();
|
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||||
|
|
||||||
let position = 4;
|
let position = 4;
|
||||||
let proof = mmr.open(position).unwrap();
|
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||||
|
|
||||||
let position = 5;
|
let position = 5;
|
||||||
let proof = mmr.open(position).unwrap();
|
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||||
|
|
||||||
let position = 6;
|
let position = 6;
|
||||||
let proof = mmr.open(position).unwrap();
|
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||||
|
|
||||||
let position = 7;
|
let position = 7;
|
||||||
let proof = mmr.open(position).unwrap();
|
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||||
@@ -371,47 +413,47 @@ fn test_mmr_open_seven() {
|
|||||||
let mmr: Mmr = LEAVES.into();
|
let mmr: Mmr = LEAVES.into();
|
||||||
|
|
||||||
let position = 0;
|
let position = 0;
|
||||||
let proof = mmr.open(position).unwrap();
|
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||||
let merkle_path: MerklePath =
|
let merkle_path: MerklePath =
|
||||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[0]).unwrap(), mtree1.root());
|
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[0]).unwrap(), mtree1.root());
|
||||||
|
|
||||||
let position = 1;
|
let position = 1;
|
||||||
let proof = mmr.open(position).unwrap();
|
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||||
let merkle_path: MerklePath =
|
let merkle_path: MerklePath =
|
||||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||||
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[1]).unwrap(), mtree1.root());
|
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[1]).unwrap(), mtree1.root());
|
||||||
|
|
||||||
let position = 2;
|
let position = 2;
|
||||||
let proof = mmr.open(position).unwrap();
|
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||||
let merkle_path: MerklePath =
|
let merkle_path: MerklePath =
|
||||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||||
assert_eq!(proof.merkle_path.compute_root(2, LEAVES[2]).unwrap(), mtree1.root());
|
assert_eq!(proof.merkle_path.compute_root(2, LEAVES[2]).unwrap(), mtree1.root());
|
||||||
|
|
||||||
let position = 3;
|
let position = 3;
|
||||||
let proof = mmr.open(position).unwrap();
|
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||||
let merkle_path: MerklePath =
|
let merkle_path: MerklePath =
|
||||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||||
assert_eq!(proof.merkle_path.compute_root(3, LEAVES[3]).unwrap(), mtree1.root());
|
assert_eq!(proof.merkle_path.compute_root(3, LEAVES[3]).unwrap(), mtree1.root());
|
||||||
|
|
||||||
let position = 4;
|
let position = 4;
|
||||||
let proof = mmr.open(position).unwrap();
|
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||||
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 0u64).unwrap()).unwrap();
|
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 0u64).unwrap()).unwrap();
|
||||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[4]).unwrap(), mtree2.root());
|
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[4]).unwrap(), mtree2.root());
|
||||||
|
|
||||||
let position = 5;
|
let position = 5;
|
||||||
let proof = mmr.open(position).unwrap();
|
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||||
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 1u64).unwrap()).unwrap();
|
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 1u64).unwrap()).unwrap();
|
||||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||||
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[5]).unwrap(), mtree2.root());
|
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[5]).unwrap(), mtree2.root());
|
||||||
|
|
||||||
let position = 6;
|
let position = 6;
|
||||||
let proof = mmr.open(position).unwrap();
|
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||||
let merkle_path: MerklePath = [].as_ref().into();
|
let merkle_path: MerklePath = [].as_ref().into();
|
||||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[6]).unwrap(), LEAVES[6]);
|
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[6]).unwrap(), LEAVES[6]);
|
||||||
@@ -435,7 +477,7 @@ fn test_mmr_invariants() {
|
|||||||
let mut mmr = Mmr::new();
|
let mut mmr = Mmr::new();
|
||||||
for v in 1..=1028 {
|
for v in 1..=1028 {
|
||||||
mmr.add(int_to_node(v));
|
mmr.add(int_to_node(v));
|
||||||
let accumulator = mmr.accumulator();
|
let accumulator = mmr.peaks(mmr.forest()).unwrap();
|
||||||
assert_eq!(v as usize, mmr.forest(), "MMR leaf count must increase by one on every add");
|
assert_eq!(v as usize, mmr.forest(), "MMR leaf count must increase by one on every add");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
v as usize,
|
v as usize,
|
||||||
@@ -516,10 +558,50 @@ fn test_mmr_inner_nodes() {
|
|||||||
assert_eq!(postorder, nodes);
|
assert_eq!(postorder, nodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mmr_peaks() {
|
||||||
|
let mmr: Mmr = LEAVES.into();
|
||||||
|
|
||||||
|
let forest = 0b0001;
|
||||||
|
let acc = mmr.peaks(forest).unwrap();
|
||||||
|
assert_eq!(acc.num_leaves(), forest);
|
||||||
|
assert_eq!(acc.peaks(), &[mmr.nodes[0]]);
|
||||||
|
|
||||||
|
let forest = 0b0010;
|
||||||
|
let acc = mmr.peaks(forest).unwrap();
|
||||||
|
assert_eq!(acc.num_leaves(), forest);
|
||||||
|
assert_eq!(acc.peaks(), &[mmr.nodes[2]]);
|
||||||
|
|
||||||
|
let forest = 0b0011;
|
||||||
|
let acc = mmr.peaks(forest).unwrap();
|
||||||
|
assert_eq!(acc.num_leaves(), forest);
|
||||||
|
assert_eq!(acc.peaks(), &[mmr.nodes[2], mmr.nodes[3]]);
|
||||||
|
|
||||||
|
let forest = 0b0100;
|
||||||
|
let acc = mmr.peaks(forest).unwrap();
|
||||||
|
assert_eq!(acc.num_leaves(), forest);
|
||||||
|
assert_eq!(acc.peaks(), &[mmr.nodes[6]]);
|
||||||
|
|
||||||
|
let forest = 0b0101;
|
||||||
|
let acc = mmr.peaks(forest).unwrap();
|
||||||
|
assert_eq!(acc.num_leaves(), forest);
|
||||||
|
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[7]]);
|
||||||
|
|
||||||
|
let forest = 0b0110;
|
||||||
|
let acc = mmr.peaks(forest).unwrap();
|
||||||
|
assert_eq!(acc.num_leaves(), forest);
|
||||||
|
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9]]);
|
||||||
|
|
||||||
|
let forest = 0b0111;
|
||||||
|
let acc = mmr.peaks(forest).unwrap();
|
||||||
|
assert_eq!(acc.num_leaves(), forest);
|
||||||
|
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9], mmr.nodes[10]]);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mmr_hash_peaks() {
|
fn test_mmr_hash_peaks() {
|
||||||
let mmr: Mmr = LEAVES.into();
|
let mmr: Mmr = LEAVES.into();
|
||||||
let peaks = mmr.accumulator();
|
let peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||||
|
|
||||||
let first_peak = Rpo256::merge(&[
|
let first_peak = Rpo256::merge(&[
|
||||||
Rpo256::merge(&[LEAVES[0], LEAVES[1]]),
|
Rpo256::merge(&[LEAVES[0], LEAVES[1]]),
|
||||||
@@ -531,10 +613,7 @@ fn test_mmr_hash_peaks() {
|
|||||||
// minimum length is 16
|
// minimum length is 16
|
||||||
let mut expected_peaks = [first_peak, second_peak, third_peak].to_vec();
|
let mut expected_peaks = [first_peak, second_peak, third_peak].to_vec();
|
||||||
expected_peaks.resize(16, RpoDigest::default());
|
expected_peaks.resize(16, RpoDigest::default());
|
||||||
assert_eq!(
|
assert_eq!(peaks.hash_peaks(), Rpo256::hash_elements(&digests_to_elements(&expected_peaks)));
|
||||||
peaks.hash_peaks(),
|
|
||||||
*Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -552,7 +631,7 @@ fn test_mmr_peaks_hash_less_than_16() {
|
|||||||
expected_peaks.resize(16, RpoDigest::default());
|
expected_peaks.resize(16, RpoDigest::default());
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
accumulator.hash_peaks(),
|
accumulator.hash_peaks(),
|
||||||
*Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -569,47 +648,47 @@ fn test_mmr_peaks_hash_odd() {
|
|||||||
expected_peaks.resize(18, RpoDigest::default());
|
expected_peaks.resize(18, RpoDigest::default());
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
accumulator.hash_peaks(),
|
accumulator.hash_peaks(),
|
||||||
*Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mmr_updates() {
|
fn test_mmr_delta() {
|
||||||
let mmr: Mmr = LEAVES.into();
|
let mmr: Mmr = LEAVES.into();
|
||||||
let acc = mmr.accumulator();
|
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||||
|
|
||||||
// original_forest can't have more elements
|
// original_forest can't have more elements
|
||||||
assert!(
|
assert!(
|
||||||
mmr.get_delta(LEAVES.len() + 1).is_err(),
|
mmr.get_delta(LEAVES.len() + 1, mmr.forest()).is_err(),
|
||||||
"Can not provide updates for a newer Mmr"
|
"Can not provide updates for a newer Mmr"
|
||||||
);
|
);
|
||||||
|
|
||||||
// if the number of elements is the same there is no change
|
// if the number of elements is the same there is no change
|
||||||
assert!(
|
assert!(
|
||||||
mmr.get_delta(LEAVES.len()).unwrap().data.is_empty(),
|
mmr.get_delta(LEAVES.len(), mmr.forest()).unwrap().data.is_empty(),
|
||||||
"There are no updates for the same Mmr version"
|
"There are no updates for the same Mmr version"
|
||||||
);
|
);
|
||||||
|
|
||||||
// missing the last element added, which is itself a tree peak
|
// missing the last element added, which is itself a tree peak
|
||||||
assert_eq!(mmr.get_delta(6).unwrap().data, vec![acc.peaks()[2]], "one peak");
|
assert_eq!(mmr.get_delta(6, mmr.forest()).unwrap().data, vec![acc.peaks()[2]], "one peak");
|
||||||
|
|
||||||
// missing the sibling to complete the tree of depth 2, and the last element
|
// missing the sibling to complete the tree of depth 2, and the last element
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
mmr.get_delta(5).unwrap().data,
|
mmr.get_delta(5, mmr.forest()).unwrap().data,
|
||||||
vec![LEAVES[5], acc.peaks()[2]],
|
vec![LEAVES[5], acc.peaks()[2]],
|
||||||
"one sibling, one peak"
|
"one sibling, one peak"
|
||||||
);
|
);
|
||||||
|
|
||||||
// missing the whole last two trees, only send the peaks
|
// missing the whole last two trees, only send the peaks
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
mmr.get_delta(4).unwrap().data,
|
mmr.get_delta(4, mmr.forest()).unwrap().data,
|
||||||
vec![acc.peaks()[1], acc.peaks()[2]],
|
vec![acc.peaks()[1], acc.peaks()[2]],
|
||||||
"two peaks"
|
"two peaks"
|
||||||
);
|
);
|
||||||
|
|
||||||
// missing the sibling to complete the first tree, and the two last trees
|
// missing the sibling to complete the first tree, and the two last trees
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
mmr.get_delta(3).unwrap().data,
|
mmr.get_delta(3, mmr.forest()).unwrap().data,
|
||||||
vec![LEAVES[3], acc.peaks()[1], acc.peaks()[2]],
|
vec![LEAVES[3], acc.peaks()[1], acc.peaks()[2]],
|
||||||
"one sibling, two peaks"
|
"one sibling, two peaks"
|
||||||
);
|
);
|
||||||
@@ -617,35 +696,77 @@ fn test_mmr_updates() {
|
|||||||
// missing half of the first tree, only send the computed element (not the leaves), and the new
|
// missing half of the first tree, only send the computed element (not the leaves), and the new
|
||||||
// peaks
|
// peaks
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
mmr.get_delta(2).unwrap().data,
|
mmr.get_delta(2, mmr.forest()).unwrap().data,
|
||||||
vec![mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
|
vec![mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
|
||||||
"one sibling, two peaks"
|
"one sibling, two peaks"
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
mmr.get_delta(1).unwrap().data,
|
mmr.get_delta(1, mmr.forest()).unwrap().data,
|
||||||
vec![LEAVES[1], mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
|
vec![LEAVES[1], mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
|
||||||
"one sibling, two peaks"
|
"one sibling, two peaks"
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(&mmr.get_delta(0).unwrap().data, acc.peaks(), "all peaks");
|
assert_eq!(&mmr.get_delta(0, mmr.forest()).unwrap().data, acc.peaks(), "all peaks");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mmr_delta_old_forest() {
|
||||||
|
let mmr: Mmr = LEAVES.into();
|
||||||
|
|
||||||
|
// from_forest must be smaller-or-equal to to_forest
|
||||||
|
for version in 1..=mmr.forest() {
|
||||||
|
assert!(mmr.get_delta(version + 1, version).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
// when from_forest and to_forest are equal, there are no updates
|
||||||
|
for version in 1..=mmr.forest() {
|
||||||
|
let delta = mmr.get_delta(version, version).unwrap();
|
||||||
|
assert!(delta.data.is_empty());
|
||||||
|
assert_eq!(delta.forest, version);
|
||||||
|
}
|
||||||
|
|
||||||
|
// test update which merges the odd peak to the right
|
||||||
|
for count in 0..(mmr.forest() / 2) {
|
||||||
|
// *2 because every iteration tests a pair
|
||||||
|
// +1 because the Mmr is 1-indexed
|
||||||
|
let from_forest = (count * 2) + 1;
|
||||||
|
let to_forest = (count * 2) + 2;
|
||||||
|
let delta = mmr.get_delta(from_forest, to_forest).unwrap();
|
||||||
|
|
||||||
|
// *2 because every iteration tests a pair
|
||||||
|
// +1 because sibling is the odd element
|
||||||
|
let sibling = (count * 2) + 1;
|
||||||
|
assert_eq!(delta.data, [LEAVES[sibling]]);
|
||||||
|
assert_eq!(delta.forest, to_forest);
|
||||||
|
}
|
||||||
|
|
||||||
|
let version = 4;
|
||||||
|
let delta = mmr.get_delta(1, version).unwrap();
|
||||||
|
assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5]]);
|
||||||
|
assert_eq!(delta.forest, version);
|
||||||
|
|
||||||
|
let version = 5;
|
||||||
|
let delta = mmr.get_delta(1, version).unwrap();
|
||||||
|
assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5], mmr.nodes[7]]);
|
||||||
|
assert_eq!(delta.forest, version);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_partial_mmr_simple() {
|
fn test_partial_mmr_simple() {
|
||||||
let mmr: Mmr = LEAVES.into();
|
let mmr: Mmr = LEAVES.into();
|
||||||
let acc = mmr.accumulator();
|
let peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||||
let mut partial: PartialMmr = acc.clone().into();
|
let mut partial: PartialMmr = peaks.clone().into();
|
||||||
|
|
||||||
// check initial state of the partial mmr
|
// check initial state of the partial mmr
|
||||||
assert_eq!(partial.peaks(), acc.peaks());
|
assert_eq!(partial.peaks(), peaks);
|
||||||
assert_eq!(partial.forest(), acc.num_leaves());
|
assert_eq!(partial.forest(), peaks.num_leaves());
|
||||||
assert_eq!(partial.forest(), LEAVES.len());
|
assert_eq!(partial.forest(), LEAVES.len());
|
||||||
assert_eq!(partial.peaks().len(), 3);
|
assert_eq!(partial.peaks().num_peaks(), 3);
|
||||||
assert_eq!(partial.nodes.len(), 0);
|
assert_eq!(partial.nodes.len(), 0);
|
||||||
|
|
||||||
// check state after adding tracking one element
|
// check state after adding tracking one element
|
||||||
let proof1 = mmr.open(0).unwrap();
|
let proof1 = mmr.open(0, mmr.forest()).unwrap();
|
||||||
let el1 = mmr.get(proof1.position).unwrap();
|
let el1 = mmr.get(proof1.position).unwrap();
|
||||||
partial.add(proof1.position, el1, &proof1.merkle_path).unwrap();
|
partial.add(proof1.position, el1, &proof1.merkle_path).unwrap();
|
||||||
|
|
||||||
@@ -657,7 +778,7 @@ fn test_partial_mmr_simple() {
|
|||||||
let idx = idx.parent();
|
let idx = idx.parent();
|
||||||
assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[1]);
|
assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[1]);
|
||||||
|
|
||||||
let proof2 = mmr.open(1).unwrap();
|
let proof2 = mmr.open(1, mmr.forest()).unwrap();
|
||||||
let el2 = mmr.get(proof2.position).unwrap();
|
let el2 = mmr.get(proof2.position).unwrap();
|
||||||
partial.add(proof2.position, el2, &proof2.merkle_path).unwrap();
|
partial.add(proof2.position, el2, &proof2.merkle_path).unwrap();
|
||||||
|
|
||||||
@@ -675,21 +796,21 @@ fn test_partial_mmr_update_single() {
|
|||||||
let mut full = Mmr::new();
|
let mut full = Mmr::new();
|
||||||
let zero = int_to_node(0);
|
let zero = int_to_node(0);
|
||||||
full.add(zero);
|
full.add(zero);
|
||||||
let mut partial: PartialMmr = full.accumulator().into();
|
let mut partial: PartialMmr = full.peaks(full.forest()).unwrap().into();
|
||||||
|
|
||||||
let proof = full.open(0).unwrap();
|
let proof = full.open(0, full.forest()).unwrap();
|
||||||
partial.add(proof.position, zero, &proof.merkle_path).unwrap();
|
partial.add(proof.position, zero, &proof.merkle_path).unwrap();
|
||||||
|
|
||||||
for i in 1..100 {
|
for i in 1..100 {
|
||||||
let node = int_to_node(i);
|
let node = int_to_node(i);
|
||||||
full.add(node);
|
full.add(node);
|
||||||
let delta = full.get_delta(partial.forest()).unwrap();
|
let delta = full.get_delta(partial.forest(), full.forest()).unwrap();
|
||||||
partial.apply(delta).unwrap();
|
partial.apply(delta).unwrap();
|
||||||
|
|
||||||
assert_eq!(partial.forest(), full.forest());
|
assert_eq!(partial.forest(), full.forest());
|
||||||
assert_eq!(partial.peaks(), full.accumulator().peaks());
|
assert_eq!(partial.peaks(), full.peaks(full.forest()).unwrap());
|
||||||
|
|
||||||
let proof1 = full.open(i as usize).unwrap();
|
let proof1 = full.open(i as usize, full.forest()).unwrap();
|
||||||
partial.add(proof1.position, node, &proof1.merkle_path).unwrap();
|
partial.add(proof1.position, node, &proof1.merkle_path).unwrap();
|
||||||
let proof2 = partial.open(proof1.position).unwrap().unwrap();
|
let proof2 = partial.open(proof1.position).unwrap().unwrap();
|
||||||
assert_eq!(proof1.merkle_path, proof2.merkle_path);
|
assert_eq!(proof1.merkle_path, proof2.merkle_path);
|
||||||
@@ -699,7 +820,7 @@ fn test_partial_mmr_update_single() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_mmr_add_invalid_odd_leaf() {
|
fn test_mmr_add_invalid_odd_leaf() {
|
||||||
let mmr: Mmr = LEAVES.into();
|
let mmr: Mmr = LEAVES.into();
|
||||||
let acc = mmr.accumulator();
|
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||||
let mut partial: PartialMmr = acc.clone().into();
|
let mut partial: PartialMmr = acc.clone().into();
|
||||||
|
|
||||||
let empty = MerklePath::new(Vec::new());
|
let empty = MerklePath::new(Vec::new());
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ mod tiered_smt;
|
|||||||
pub use tiered_smt::{TieredSmt, TieredSmtProof, TieredSmtProofError};
|
pub use tiered_smt::{TieredSmt, TieredSmtProof, TieredSmtProofError};
|
||||||
|
|
||||||
mod mmr;
|
mod mmr;
|
||||||
pub use mmr::{InOrderIndex, Mmr, MmrError, MmrPeaks, MmrProof, PartialMmr};
|
pub use mmr::{InOrderIndex, Mmr, MmrDelta, MmrError, MmrPeaks, MmrProof, PartialMmr};
|
||||||
|
|
||||||
mod store;
|
mod store;
|
||||||
pub use store::{DefaultMerkleStore, MerkleStore, RecordingMerkleStore, StoreNode};
|
pub use store::{DefaultMerkleStore, MerkleStore, RecordingMerkleStore, StoreNode};
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use crate::hash::rpo::RpoDigest;
|
use super::RpoDigest;
|
||||||
|
|
||||||
/// Representation of a node with two children used for iterating over containers.
|
/// Representation of a node with two children used for iterating over containers.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
|||||||
@@ -109,9 +109,9 @@ impl PartialMerkleTree {
|
|||||||
|
|
||||||
// check if the number of leaves can be accommodated by the tree's depth; we use a min
|
// check if the number of leaves can be accommodated by the tree's depth; we use a min
|
||||||
// depth of 63 because we consider passing in a vector of size 2^64 infeasible.
|
// depth of 63 because we consider passing in a vector of size 2^64 infeasible.
|
||||||
let max = (1_u64 << 63) as usize;
|
let max = 2usize.pow(63);
|
||||||
if layers.len() > max {
|
if layers.len() > max {
|
||||||
return Err(MerkleError::InvalidNumEntries(max, layers.len()));
|
return Err(MerkleError::InvalidNumEntries(max));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get maximum depth
|
// Get maximum depth
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use super::{vec, InnerNodeInfo, MerkleError, NodeIndex, Rpo256, RpoDigest, Vec};
|
use super::{vec, InnerNodeInfo, MerkleError, NodeIndex, Rpo256, RpoDigest, Vec};
|
||||||
use core::ops::{Deref, DerefMut};
|
use core::ops::{Deref, DerefMut};
|
||||||
|
use winter_utils::{ByteReader, Deserializable, DeserializationError, Serializable};
|
||||||
|
|
||||||
// MERKLE PATH
|
// MERKLE PATH
|
||||||
// ================================================================================================
|
// ================================================================================================
|
||||||
@@ -17,6 +18,7 @@ impl MerklePath {
|
|||||||
|
|
||||||
/// Creates a new Merkle path from a list of nodes.
|
/// Creates a new Merkle path from a list of nodes.
|
||||||
pub fn new(nodes: Vec<RpoDigest>) -> Self {
|
pub fn new(nodes: Vec<RpoDigest>) -> Self {
|
||||||
|
assert!(nodes.len() <= u8::MAX.into(), "MerklePath may have at most 256 items");
|
||||||
Self { nodes }
|
Self { nodes }
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,6 +191,55 @@ pub struct RootPath {
|
|||||||
pub path: MerklePath,
|
pub path: MerklePath,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SERIALIZATION
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
impl Serializable for MerklePath {
|
||||||
|
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||||
|
assert!(self.nodes.len() <= u8::MAX.into(), "Length enforced in the constructor");
|
||||||
|
target.write_u8(self.nodes.len() as u8);
|
||||||
|
self.nodes.write_into(target);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deserializable for MerklePath {
|
||||||
|
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||||
|
let count = source.read_u8()?.into();
|
||||||
|
let nodes = RpoDigest::read_batch_from(source, count)?;
|
||||||
|
Ok(Self { nodes })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serializable for ValuePath {
|
||||||
|
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||||
|
self.value.write_into(target);
|
||||||
|
self.path.write_into(target);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deserializable for ValuePath {
|
||||||
|
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||||
|
let value = RpoDigest::read_from(source)?;
|
||||||
|
let path = MerklePath::read_from(source)?;
|
||||||
|
Ok(Self { value, path })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serializable for RootPath {
|
||||||
|
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||||
|
self.root.write_into(target);
|
||||||
|
self.path.write_into(target);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deserializable for RootPath {
|
||||||
|
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||||
|
let root = RpoDigest::read_from(source)?;
|
||||||
|
let path = MerklePath::read_from(source)?;
|
||||||
|
Ok(Self { root, path })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TESTS
|
// TESTS
|
||||||
// ================================================================================================
|
// ================================================================================================
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ pub struct SimpleSmt {
|
|||||||
root: RpoDigest,
|
root: RpoDigest,
|
||||||
leaves: BTreeMap<u64, Word>,
|
leaves: BTreeMap<u64, Word>,
|
||||||
branches: BTreeMap<NodeIndex, BranchNode>,
|
branches: BTreeMap<NodeIndex, BranchNode>,
|
||||||
empty_hashes: Vec<RpoDigest>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SimpleSmt {
|
impl SimpleSmt {
|
||||||
@@ -52,13 +51,11 @@ impl SimpleSmt {
|
|||||||
return Err(MerkleError::DepthTooBig(depth as u64));
|
return Err(MerkleError::DepthTooBig(depth as u64));
|
||||||
}
|
}
|
||||||
|
|
||||||
let empty_hashes = EmptySubtreeRoots::empty_hashes(depth).to_vec();
|
let root = *EmptySubtreeRoots::entry(depth, 0);
|
||||||
let root = empty_hashes[0];
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
root,
|
root,
|
||||||
depth,
|
depth,
|
||||||
empty_hashes,
|
|
||||||
leaves: BTreeMap::new(),
|
leaves: BTreeMap::new(),
|
||||||
branches: BTreeMap::new(),
|
branches: BTreeMap::new(),
|
||||||
})
|
})
|
||||||
@@ -74,39 +71,54 @@ impl SimpleSmt {
|
|||||||
/// - If the depth is 0 or is greater than 64.
|
/// - If the depth is 0 or is greater than 64.
|
||||||
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
|
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
|
||||||
/// - The provided entries contain multiple values for the same key.
|
/// - The provided entries contain multiple values for the same key.
|
||||||
pub fn with_leaves<R, I>(depth: u8, entries: R) -> Result<Self, MerkleError>
|
pub fn with_leaves(
|
||||||
where
|
depth: u8,
|
||||||
R: IntoIterator<IntoIter = I>,
|
entries: impl IntoIterator<Item = (u64, Word)>,
|
||||||
I: Iterator<Item = (u64, Word)> + ExactSizeIterator,
|
) -> Result<Self, MerkleError> {
|
||||||
{
|
|
||||||
// create an empty tree
|
// create an empty tree
|
||||||
let mut tree = Self::new(depth)?;
|
let mut tree = Self::new(depth)?;
|
||||||
|
|
||||||
// check if the number of leaves can be accommodated by the tree's depth; we use a min
|
// compute the max number of entries. We use an upper bound of depth 63 because we consider
|
||||||
// depth of 63 because we consider passing in a vector of size 2^64 infeasible.
|
// passing in a vector of size 2^64 infeasible.
|
||||||
let entries = entries.into_iter();
|
let max_num_entries = 2_usize.pow(tree.depth.min(63).into());
|
||||||
let max = 1 << tree.depth.min(63);
|
|
||||||
if entries.len() > max {
|
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
|
||||||
return Err(MerkleError::InvalidNumEntries(max, entries.len()));
|
// entries with the empty value need additional tracking.
|
||||||
}
|
let mut key_set_to_zero = BTreeSet::new();
|
||||||
|
|
||||||
|
for (idx, (key, value)) in entries.into_iter().enumerate() {
|
||||||
|
if idx >= max_num_entries {
|
||||||
|
return Err(MerkleError::InvalidNumEntries(max_num_entries));
|
||||||
|
}
|
||||||
|
|
||||||
// append leaves to the tree returning an error if a duplicate entry for the same key
|
|
||||||
// is found
|
|
||||||
let mut empty_entries = BTreeSet::new();
|
|
||||||
for (key, value) in entries {
|
|
||||||
let old_value = tree.update_leaf(key, value)?;
|
let old_value = tree.update_leaf(key, value)?;
|
||||||
if old_value != Self::EMPTY_VALUE || empty_entries.contains(&key) {
|
|
||||||
return Err(MerkleError::DuplicateValuesForIndex(key));
|
if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) {
|
||||||
}
|
|
||||||
// if we've processed an empty entry, add the key to the set of empty entry keys, and
|
|
||||||
// if this key was already in the set, return an error
|
|
||||||
if value == Self::EMPTY_VALUE && !empty_entries.insert(key) {
|
|
||||||
return Err(MerkleError::DuplicateValuesForIndex(key));
|
return Err(MerkleError::DuplicateValuesForIndex(key));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if value == Self::EMPTY_VALUE {
|
||||||
|
key_set_to_zero.insert(key);
|
||||||
|
};
|
||||||
}
|
}
|
||||||
Ok(tree)
|
Ok(tree)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
|
||||||
|
/// starting at index 0.
|
||||||
|
pub fn with_contiguous_leaves(
|
||||||
|
depth: u8,
|
||||||
|
entries: impl IntoIterator<Item = Word>,
|
||||||
|
) -> Result<Self, MerkleError> {
|
||||||
|
Self::with_leaves(
|
||||||
|
depth,
|
||||||
|
entries
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(idx, word)| (idx.try_into().expect("tree max depth is 2^8"), word)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// PUBLIC ACCESSORS
|
// PUBLIC ACCESSORS
|
||||||
// --------------------------------------------------------------------------------------------
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
@@ -133,10 +145,12 @@ impl SimpleSmt {
|
|||||||
} else if index.depth() == self.depth() {
|
} else if index.depth() == self.depth() {
|
||||||
// the lookup in empty_hashes could fail only if empty_hashes were not built correctly
|
// the lookup in empty_hashes could fail only if empty_hashes were not built correctly
|
||||||
// by the constructor as we check the depth of the lookup above.
|
// by the constructor as we check the depth of the lookup above.
|
||||||
Ok(RpoDigest::from(
|
let leaf_pos = index.value();
|
||||||
self.get_leaf_node(index.value())
|
let leaf = match self.get_leaf_node(leaf_pos) {
|
||||||
.unwrap_or_else(|| *self.empty_hashes[index.depth() as usize]),
|
Some(word) => word.into(),
|
||||||
))
|
None => *EmptySubtreeRoots::entry(self.depth, index.depth()),
|
||||||
|
};
|
||||||
|
Ok(leaf)
|
||||||
} else {
|
} else {
|
||||||
Ok(self.get_branch_node(&index).parent())
|
Ok(self.get_branch_node(&index).parent())
|
||||||
}
|
}
|
||||||
@@ -214,6 +228,9 @@ impl SimpleSmt {
|
|||||||
/// # Errors
|
/// # Errors
|
||||||
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
|
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
|
||||||
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<Word, MerkleError> {
|
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<Word, MerkleError> {
|
||||||
|
// validate the index before modifying the structure
|
||||||
|
let idx = NodeIndex::new(self.depth(), index)?;
|
||||||
|
|
||||||
let old_value = self.insert_leaf_node(index, value).unwrap_or(Self::EMPTY_VALUE);
|
let old_value = self.insert_leaf_node(index, value).unwrap_or(Self::EMPTY_VALUE);
|
||||||
|
|
||||||
// if the old value and new value are the same, there is nothing to update
|
// if the old value and new value are the same, there is nothing to update
|
||||||
@@ -221,8 +238,82 @@ impl SimpleSmt {
|
|||||||
return Ok(value);
|
return Ok(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut index = NodeIndex::new(self.depth(), index)?;
|
self.recompute_nodes_from_index_to_root(idx, RpoDigest::from(value));
|
||||||
let mut value = RpoDigest::from(value);
|
|
||||||
|
Ok(old_value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
|
||||||
|
/// computed as `self.depth() - subtree.depth()`.
|
||||||
|
///
|
||||||
|
/// Returns the new root.
|
||||||
|
pub fn set_subtree(
|
||||||
|
&mut self,
|
||||||
|
subtree_insertion_index: u64,
|
||||||
|
subtree: SimpleSmt,
|
||||||
|
) -> Result<RpoDigest, MerkleError> {
|
||||||
|
if subtree.depth() > self.depth() {
|
||||||
|
return Err(MerkleError::InvalidSubtreeDepth {
|
||||||
|
subtree_depth: subtree.depth(),
|
||||||
|
tree_depth: self.depth(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that `subtree_insertion_index` is valid.
|
||||||
|
let subtree_root_insertion_depth = self.depth() - subtree.depth();
|
||||||
|
let subtree_root_index =
|
||||||
|
NodeIndex::new(subtree_root_insertion_depth, subtree_insertion_index)?;
|
||||||
|
|
||||||
|
// add leaves
|
||||||
|
// --------------
|
||||||
|
|
||||||
|
// The subtree's leaf indices live in their own context - i.e. a subtree of depth `d`. If we
|
||||||
|
// insert the subtree at `subtree_insertion_index = 0`, then the subtree leaf indices are
|
||||||
|
// valid as they are. However, consider what happens when we insert at
|
||||||
|
// `subtree_insertion_index = 1`. The first leaf of our subtree now will have index `2^d`;
|
||||||
|
// you can see it as there's a full subtree sitting on its left. In general, for
|
||||||
|
// `subtree_insertion_index = i`, there are `i` subtrees sitting before the subtree we want
|
||||||
|
// to insert, so we need to adjust all its leaves by `i * 2^d`.
|
||||||
|
let leaf_index_shift: u64 = subtree_insertion_index * 2_u64.pow(subtree.depth().into());
|
||||||
|
for (subtree_leaf_idx, leaf_value) in subtree.leaves() {
|
||||||
|
let new_leaf_idx = leaf_index_shift + subtree_leaf_idx;
|
||||||
|
debug_assert!(new_leaf_idx < 2_u64.pow(self.depth().into()));
|
||||||
|
|
||||||
|
self.insert_leaf_node(new_leaf_idx, *leaf_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add subtree's branch nodes (which includes the root)
|
||||||
|
// --------------
|
||||||
|
for (branch_idx, branch_node) in subtree.branches {
|
||||||
|
let new_branch_idx = {
|
||||||
|
let new_depth = subtree_root_insertion_depth + branch_idx.depth();
|
||||||
|
let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
|
||||||
|
+ branch_idx.value();
|
||||||
|
|
||||||
|
NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
|
||||||
|
};
|
||||||
|
|
||||||
|
self.branches.insert(new_branch_idx, branch_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// recompute nodes starting from subtree root
|
||||||
|
// --------------
|
||||||
|
self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
|
||||||
|
|
||||||
|
Ok(self.root)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HELPER METHODS
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Recomputes the branch nodes (including the root) from `index` all the way to the root.
|
||||||
|
/// `node_hash_at_index` is the hash of the node stored at index.
|
||||||
|
fn recompute_nodes_from_index_to_root(
|
||||||
|
&mut self,
|
||||||
|
mut index: NodeIndex,
|
||||||
|
node_hash_at_index: RpoDigest,
|
||||||
|
) {
|
||||||
|
let mut value = node_hash_at_index;
|
||||||
for _ in 0..index.depth() {
|
for _ in 0..index.depth() {
|
||||||
let is_right = index.is_value_odd();
|
let is_right = index.is_value_odd();
|
||||||
index.move_up();
|
index.move_up();
|
||||||
@@ -232,12 +323,8 @@ impl SimpleSmt {
|
|||||||
value = Rpo256::merge(&[left, right]);
|
value = Rpo256::merge(&[left, right]);
|
||||||
}
|
}
|
||||||
self.root = value;
|
self.root = value;
|
||||||
Ok(old_value)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HELPER METHODS
|
|
||||||
// --------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
fn get_leaf_node(&self, key: u64) -> Option<Word> {
|
fn get_leaf_node(&self, key: u64) -> Option<Word> {
|
||||||
self.leaves.get(&key).copied()
|
self.leaves.get(&key).copied()
|
||||||
}
|
}
|
||||||
@@ -248,8 +335,8 @@ impl SimpleSmt {
|
|||||||
|
|
||||||
fn get_branch_node(&self, index: &NodeIndex) -> BranchNode {
|
fn get_branch_node(&self, index: &NodeIndex) -> BranchNode {
|
||||||
self.branches.get(index).cloned().unwrap_or_else(|| {
|
self.branches.get(index).cloned().unwrap_or_else(|| {
|
||||||
let node = self.empty_hashes[index.depth() as usize + 1];
|
let node = EmptySubtreeRoots::entry(self.depth, index.depth() + 1);
|
||||||
BranchNode { left: node, right: node }
|
BranchNode { left: *node, right: *node }
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use super::{
|
|||||||
NodeIndex, Rpo256, Vec,
|
NodeIndex, Rpo256, Vec,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
merkle::{digests_to_words, int_to_leaf, int_to_node},
|
merkle::{digests_to_words, int_to_leaf, int_to_node, EmptySubtreeRoots},
|
||||||
Word,
|
Word,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -71,6 +71,21 @@ fn build_sparse_tree() {
|
|||||||
assert_eq!(old_value, EMPTY_WORD);
|
assert_eq!(old_value, EMPTY_WORD);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Tests that [`SimpleSmt::with_contiguous_leaves`] works as expected
|
||||||
|
#[test]
|
||||||
|
fn build_contiguous_tree() {
|
||||||
|
let tree_with_leaves = SimpleSmt::with_leaves(
|
||||||
|
2,
|
||||||
|
[0, 1, 2, 3].into_iter().zip(digests_to_words(&VALUES4).into_iter()),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tree_with_contiguous_leaves =
|
||||||
|
SimpleSmt::with_contiguous_leaves(2, digests_to_words(&VALUES4).into_iter()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(tree_with_leaves, tree_with_contiguous_leaves);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_depth2_tree() {
|
fn test_depth2_tree() {
|
||||||
let tree =
|
let tree =
|
||||||
@@ -214,22 +229,31 @@ fn small_tree_opening_is_consistent() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn fail_on_duplicates() {
|
fn test_simplesmt_fail_on_duplicates() {
|
||||||
let entries = [(1_u64, int_to_leaf(1)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(3))];
|
let values = [
|
||||||
let smt = SimpleSmt::with_leaves(64, entries);
|
// same key, same value
|
||||||
assert!(smt.is_err());
|
(int_to_leaf(1), int_to_leaf(1)),
|
||||||
|
// same key, different values
|
||||||
|
(int_to_leaf(1), int_to_leaf(2)),
|
||||||
|
// same key, set to zero
|
||||||
|
(EMPTY_WORD, int_to_leaf(1)),
|
||||||
|
// same key, re-set to zero
|
||||||
|
(int_to_leaf(1), EMPTY_WORD),
|
||||||
|
// same key, set to zero twice
|
||||||
|
(EMPTY_WORD, EMPTY_WORD),
|
||||||
|
];
|
||||||
|
|
||||||
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(0))];
|
for (first, second) in values.iter() {
|
||||||
let smt = SimpleSmt::with_leaves(64, entries);
|
// consecutive
|
||||||
assert!(smt.is_err());
|
let entries = [(1, *first), (1, *second)];
|
||||||
|
let smt = SimpleSmt::with_leaves(64, entries);
|
||||||
|
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||||
|
|
||||||
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(1))];
|
// not consecutive
|
||||||
let smt = SimpleSmt::with_leaves(64, entries);
|
let entries = [(1, *first), (5, int_to_leaf(5)), (1, *second)];
|
||||||
assert!(smt.is_err());
|
let smt = SimpleSmt::with_leaves(64, entries);
|
||||||
|
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||||
let entries = [(1_u64, int_to_leaf(1)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(0))];
|
}
|
||||||
let smt = SimpleSmt::with_leaves(64, entries);
|
|
||||||
assert!(smt.is_err());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -239,6 +263,227 @@ fn with_no_duplicates_empty_node() {
|
|||||||
assert!(smt.is_ok());
|
assert!(smt.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_simplesmt_update_nonexisting_leaf_with_zero() {
|
||||||
|
// TESTING WITH EMPTY WORD
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Depth 1 has 2 leaf. Position is 0-indexed, position 2 doesn't exist.
|
||||||
|
let mut smt = SimpleSmt::new(1).unwrap();
|
||||||
|
let result = smt.update_leaf(2, EMPTY_WORD);
|
||||||
|
assert!(!smt.leaves.contains_key(&2));
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||||
|
let mut smt = SimpleSmt::new(2).unwrap();
|
||||||
|
let result = smt.update_leaf(4, EMPTY_WORD);
|
||||||
|
assert!(!smt.leaves.contains_key(&4));
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||||
|
let mut smt = SimpleSmt::new(3).unwrap();
|
||||||
|
let result = smt.update_leaf(8, EMPTY_WORD);
|
||||||
|
assert!(!smt.leaves.contains_key(&8));
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
// TESTING WITH A VALUE
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
let value = int_to_node(1);
|
||||||
|
|
||||||
|
// Depth 1 has 2 leaves. Position is 0-indexed, position 1 doesn't exist.
|
||||||
|
let mut smt = SimpleSmt::new(1).unwrap();
|
||||||
|
let result = smt.update_leaf(2, *value);
|
||||||
|
assert!(!smt.leaves.contains_key(&2));
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
// Depth 2 has 4 leaves. Position is 0-indexed, position 2 doesn't exist.
|
||||||
|
let mut smt = SimpleSmt::new(2).unwrap();
|
||||||
|
let result = smt.update_leaf(4, *value);
|
||||||
|
assert!(!smt.leaves.contains_key(&4));
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
// Depth 3 has 8 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||||
|
let mut smt = SimpleSmt::new(3).unwrap();
|
||||||
|
let result = smt.update_leaf(8, *value);
|
||||||
|
assert!(!smt.leaves.contains_key(&8));
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_simplesmt_with_leaves_nonexisting_leaf() {
|
||||||
|
// TESTING WITH EMPTY WORD
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Depth 1 has 2 leaf. Position is 0-indexed, position 2 doesn't exist.
|
||||||
|
let leaves = [(2, EMPTY_WORD)];
|
||||||
|
let result = SimpleSmt::with_leaves(1, leaves);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||||
|
let leaves = [(4, EMPTY_WORD)];
|
||||||
|
let result = SimpleSmt::with_leaves(2, leaves);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||||
|
let leaves = [(8, EMPTY_WORD)];
|
||||||
|
let result = SimpleSmt::with_leaves(3, leaves);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
// TESTING WITH A VALUE
|
||||||
|
// --------------------------------------------------------------------------------------------
|
||||||
|
let value = int_to_node(1);
|
||||||
|
|
||||||
|
// Depth 1 has 2 leaves. Position is 0-indexed, position 2 doesn't exist.
|
||||||
|
let leaves = [(2, *value)];
|
||||||
|
let result = SimpleSmt::with_leaves(1, leaves);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||||
|
let leaves = [(4, *value)];
|
||||||
|
let result = SimpleSmt::with_leaves(2, leaves);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||||
|
let leaves = [(8, *value)];
|
||||||
|
let result = SimpleSmt::with_leaves(3, leaves);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_simplesmt_set_subtree() {
|
||||||
|
// Final Tree:
|
||||||
|
//
|
||||||
|
// ____k____
|
||||||
|
// / \
|
||||||
|
// _i_ _j_
|
||||||
|
// / \ / \
|
||||||
|
// e f g h
|
||||||
|
// / \ / \ / \ / \
|
||||||
|
// a b 0 0 c 0 0 d
|
||||||
|
|
||||||
|
let z = EMPTY_WORD;
|
||||||
|
|
||||||
|
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||||
|
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||||
|
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||||
|
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||||
|
|
||||||
|
let e = Rpo256::merge(&[a.into(), b.into()]);
|
||||||
|
let f = Rpo256::merge(&[z.into(), z.into()]);
|
||||||
|
let g = Rpo256::merge(&[c.into(), z.into()]);
|
||||||
|
let h = Rpo256::merge(&[z.into(), d.into()]);
|
||||||
|
|
||||||
|
let i = Rpo256::merge(&[e, f]);
|
||||||
|
let j = Rpo256::merge(&[g, h]);
|
||||||
|
|
||||||
|
let k = Rpo256::merge(&[i, j]);
|
||||||
|
|
||||||
|
// subtree:
|
||||||
|
// g
|
||||||
|
// / \
|
||||||
|
// c 0
|
||||||
|
let subtree = {
|
||||||
|
let depth = 1;
|
||||||
|
let entries = vec![(0, c)];
|
||||||
|
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
// insert subtree
|
||||||
|
let tree = {
|
||||||
|
let depth = 3;
|
||||||
|
let entries = vec![(0, a), (1, b), (7, d)];
|
||||||
|
let mut tree = SimpleSmt::with_leaves(depth, entries).unwrap();
|
||||||
|
|
||||||
|
tree.set_subtree(2, subtree).unwrap();
|
||||||
|
|
||||||
|
tree
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(tree.root(), k);
|
||||||
|
assert_eq!(tree.get_leaf(4).unwrap(), c);
|
||||||
|
assert_eq!(tree.get_branch_node(&NodeIndex::new_unchecked(2, 2)).parent(), g);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ensures that an invalid input node index into `set_subtree()` incurs no mutation of the tree
|
||||||
|
#[test]
|
||||||
|
fn test_simplesmt_set_subtree_unchanged_for_wrong_index() {
|
||||||
|
// Final Tree:
|
||||||
|
//
|
||||||
|
// ____k____
|
||||||
|
// / \
|
||||||
|
// _i_ _j_
|
||||||
|
// / \ / \
|
||||||
|
// e f g h
|
||||||
|
// / \ / \ / \ / \
|
||||||
|
// a b 0 0 c 0 0 d
|
||||||
|
|
||||||
|
let z = EMPTY_WORD;
|
||||||
|
|
||||||
|
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||||
|
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||||
|
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||||
|
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||||
|
|
||||||
|
// subtree:
|
||||||
|
// g
|
||||||
|
// / \
|
||||||
|
// c 0
|
||||||
|
let subtree = {
|
||||||
|
let depth = 1;
|
||||||
|
let entries = vec![(0, c)];
|
||||||
|
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut tree = {
|
||||||
|
let depth = 3;
|
||||||
|
let entries = vec![(0, a), (1, b), (7, d)];
|
||||||
|
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||||
|
};
|
||||||
|
let tree_root_before_insertion = tree.root();
|
||||||
|
|
||||||
|
// insert subtree
|
||||||
|
assert!(tree.set_subtree(500, subtree).is_err());
|
||||||
|
|
||||||
|
assert_eq!(tree.root(), tree_root_before_insertion);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// We insert an empty subtree that has the same depth as the original tree
|
||||||
|
#[test]
|
||||||
|
fn test_simplesmt_set_subtree_entire_tree() {
|
||||||
|
// Initial Tree:
|
||||||
|
//
|
||||||
|
// ____k____
|
||||||
|
// / \
|
||||||
|
// _i_ _j_
|
||||||
|
// / \ / \
|
||||||
|
// e f g h
|
||||||
|
// / \ / \ / \ / \
|
||||||
|
// a b 0 0 c 0 0 d
|
||||||
|
|
||||||
|
let z = EMPTY_WORD;
|
||||||
|
|
||||||
|
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||||
|
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||||
|
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||||
|
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||||
|
|
||||||
|
let depth = 3;
|
||||||
|
|
||||||
|
// subtree: E3
|
||||||
|
let subtree = { SimpleSmt::with_leaves(depth, Vec::new()).unwrap() };
|
||||||
|
assert_eq!(subtree.root(), *EmptySubtreeRoots::entry(depth, 0));
|
||||||
|
|
||||||
|
// insert subtree
|
||||||
|
let mut tree = {
|
||||||
|
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
|
||||||
|
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
tree.set_subtree(0, subtree).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(tree.root(), *EmptySubtreeRoots::entry(depth, 0));
|
||||||
|
}
|
||||||
|
|
||||||
// HELPER FUNCTIONS
|
// HELPER FUNCTIONS
|
||||||
// --------------------------------------------------------------------------------------------
|
// --------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
use super::{
|
use super::{
|
||||||
DefaultMerkleStore as MerkleStore, EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex,
|
DefaultMerkleStore as MerkleStore, EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex,
|
||||||
PartialMerkleTree, RecordingMerkleStore, RpoDigest,
|
PartialMerkleTree, RecordingMerkleStore, Rpo256, RpoDigest,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
hash::rpo::Rpo256,
|
|
||||||
merkle::{digests_to_words, int_to_leaf, int_to_node, MerkleTree, SimpleSmt},
|
merkle::{digests_to_words, int_to_leaf, int_to_node, MerkleTree, SimpleSmt},
|
||||||
Felt, Word, ONE, WORD_SIZE, ZERO,
|
Felt, Word, ONE, WORD_SIZE, ZERO,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -85,18 +85,26 @@ impl TieredSmtProof {
|
|||||||
/// Note: this method cannot be used to assert non-membership. That is, if false is returned,
|
/// Note: this method cannot be used to assert non-membership. That is, if false is returned,
|
||||||
/// it does not mean that the provided key-value pair is not in the tree.
|
/// it does not mean that the provided key-value pair is not in the tree.
|
||||||
pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool {
|
pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool {
|
||||||
if self.is_value_empty() {
|
// Handles the following scenarios:
|
||||||
if value != &EMPTY_VALUE {
|
// - the value is set
|
||||||
return false;
|
// - empty leaf, there is an explicit entry for the key with the empty value
|
||||||
}
|
// - shared 64-bit prefix, the target key is not included in the entries list, the value is implicitly the empty word
|
||||||
// if the proof is for an empty value, we can verify it against any key which has a
|
let v = match self.entries.iter().find(|(k, _)| k == key) {
|
||||||
// common prefix with the key storied in entries, but the prefix must be greater than
|
Some((_, v)) => v,
|
||||||
// the path length
|
None => &EMPTY_VALUE,
|
||||||
let common_prefix_tier = get_common_prefix_tier_depth(key, &self.entries[0].0);
|
};
|
||||||
if common_prefix_tier < self.path.depth() {
|
|
||||||
return false;
|
// The value must match for the proof to be valid
|
||||||
}
|
if v != value {
|
||||||
} else if !self.entries.contains(&(*key, *value)) {
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the proof is for an empty value, we can verify it against any key which has a common
|
||||||
|
// prefix with the key storied in entries, but the prefix must be greater than the path
|
||||||
|
// length
|
||||||
|
if self.is_value_empty()
|
||||||
|
&& get_common_prefix_tier_depth(key, &self.entries[0].0) < self.path.depth()
|
||||||
|
{
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -715,6 +715,38 @@ fn tsmt_bottom_tier_two() {
|
|||||||
// GET PROOF TESTS
|
// GET PROOF TESTS
|
||||||
// ================================================================================================
|
// ================================================================================================
|
||||||
|
|
||||||
|
/// Tests the membership and non-membership proof for a single at depth 64
|
||||||
|
#[test]
|
||||||
|
fn tsmt_get_proof_single_element_64() {
|
||||||
|
let mut smt = TieredSmt::default();
|
||||||
|
|
||||||
|
let raw_a = 0b_00000000_00000001_00000000_00000001_00000000_00000001_00000000_00000001_u64;
|
||||||
|
let key_a = [ONE, ONE, ONE, raw_a.into()].into();
|
||||||
|
let value_a = [ONE, ONE, ONE, ONE];
|
||||||
|
smt.insert(key_a, value_a);
|
||||||
|
|
||||||
|
// push element `a` to depth 64, by inserting another value that shares the 48-bit prefix
|
||||||
|
let raw_b = 0b_00000000_00000001_00000000_00000001_00000000_00000001_00000000_00000000_u64;
|
||||||
|
let key_b = [ONE, ONE, ONE, raw_b.into()].into();
|
||||||
|
smt.insert(key_b, [ONE, ONE, ONE, ONE]);
|
||||||
|
|
||||||
|
// verify the proof for element `a`
|
||||||
|
let proof = smt.prove(key_a);
|
||||||
|
assert!(proof.verify_membership(&key_a, &value_a, &smt.root()));
|
||||||
|
|
||||||
|
// check that a value that is not inserted in the tree produces a valid membership proof for the
|
||||||
|
// empty word
|
||||||
|
let key = [ZERO, ZERO, ZERO, ZERO].into();
|
||||||
|
let proof = smt.prove(key);
|
||||||
|
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
|
||||||
|
|
||||||
|
// check that a key that shared the 64-bit prefix with `a`, but is not inserted, also has a
|
||||||
|
// valid membership proof for the empty word
|
||||||
|
let key = [ONE, ONE, ZERO, raw_a.into()].into();
|
||||||
|
let proof = smt.prove(key);
|
||||||
|
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn tsmt_get_proof() {
|
fn tsmt_get_proof() {
|
||||||
let mut smt = TieredSmt::default();
|
let mut smt = TieredSmt::default();
|
||||||
|
|||||||
@@ -1,3 +1,19 @@
|
|||||||
//! Pseudo-random element generation.
|
//! Pseudo-random element generation.
|
||||||
|
|
||||||
pub use winter_crypto::{RandomCoin, RandomCoinError};
|
pub use winter_crypto::{DefaultRandomCoin as WinterRandomCoin, RandomCoin, RandomCoinError};
|
||||||
|
|
||||||
|
use crate::{Felt, FieldElement, StarkField, Word, ZERO};
|
||||||
|
|
||||||
|
mod rpo;
|
||||||
|
pub use rpo::RpoRandomCoin;
|
||||||
|
|
||||||
|
/// Pseudo-random element generator.
|
||||||
|
///
|
||||||
|
/// An instance can be used to draw, uniformly at random, base field elements as well as [Word]s.
|
||||||
|
pub trait FeltRng {
|
||||||
|
/// Draw, uniformly at random, a base field element.
|
||||||
|
fn draw_element(&mut self) -> Felt;
|
||||||
|
|
||||||
|
/// Draw, uniformly at random, a [Word].
|
||||||
|
fn draw_word(&mut self) -> Word;
|
||||||
|
}
|
||||||
|
|||||||
267
src/rand/rpo.rs
Normal file
267
src/rand/rpo.rs
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
use super::{Felt, FeltRng, FieldElement, StarkField, Word, ZERO};
|
||||||
|
use crate::{
|
||||||
|
hash::rpo::{Rpo256, RpoDigest},
|
||||||
|
utils::{
|
||||||
|
collections::Vec, string::ToString, vec, ByteReader, ByteWriter, Deserializable,
|
||||||
|
DeserializationError, Serializable,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
pub use winter_crypto::{RandomCoin, RandomCoinError};
|
||||||
|
|
||||||
|
// CONSTANTS
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
const STATE_WIDTH: usize = Rpo256::STATE_WIDTH;
|
||||||
|
const RATE_START: usize = Rpo256::RATE_RANGE.start;
|
||||||
|
const RATE_END: usize = Rpo256::RATE_RANGE.end;
|
||||||
|
const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start) / 2;
|
||||||
|
|
||||||
|
// RPO RANDOM COIN
|
||||||
|
// ================================================================================================
|
||||||
|
/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
|
||||||
|
/// described in https://eprint.iacr.org/2011/499.pdf.
|
||||||
|
///
|
||||||
|
/// The simplification is related to the following facts:
|
||||||
|
/// 1. A call to the reseed method implies one and only one call to the permutation function.
|
||||||
|
/// This is possible because in our case we never reseed with more than 4 field elements.
|
||||||
|
/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
|
||||||
|
/// material.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub struct RpoRandomCoin {
|
||||||
|
state: [Felt; STATE_WIDTH],
|
||||||
|
current: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RpoRandomCoin {
|
||||||
|
/// Returns a new [RpoRandomCoin] initialize with the specified seed.
|
||||||
|
pub fn new(seed: Word) -> Self {
|
||||||
|
let mut state = [ZERO; STATE_WIDTH];
|
||||||
|
|
||||||
|
for i in 0..HALF_RATE_WIDTH {
|
||||||
|
state[RATE_START + i] += seed[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Absorb
|
||||||
|
Rpo256::apply_permutation(&mut state);
|
||||||
|
|
||||||
|
RpoRandomCoin { state, current: RATE_START }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an [RpoRandomCoin] instantiated from the provided components.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
/// Panics if `current` is smaller than 4 or greater than or equal to 12.
|
||||||
|
pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
|
||||||
|
assert!(
|
||||||
|
(RATE_START..RATE_END).contains(¤t),
|
||||||
|
"current value outside of valid range"
|
||||||
|
);
|
||||||
|
Self { state, current }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns components of this random coin.
|
||||||
|
pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
|
||||||
|
(self.state, self.current)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn draw_basefield(&mut self) -> Felt {
|
||||||
|
if self.current == RATE_END {
|
||||||
|
Rpo256::apply_permutation(&mut self.state);
|
||||||
|
self.current = RATE_START;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.current += 1;
|
||||||
|
self.state[self.current - 1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RANDOM COIN IMPLEMENTATION
|
||||||
|
// ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
impl RandomCoin for RpoRandomCoin {
|
||||||
|
type BaseField = Felt;
|
||||||
|
type Hasher = Rpo256;
|
||||||
|
|
||||||
|
fn new(seed: &[Self::BaseField]) -> Self {
|
||||||
|
let digest: Word = Rpo256::hash_elements(seed).into();
|
||||||
|
Self::new(digest)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reseed(&mut self, data: RpoDigest) {
|
||||||
|
// Reset buffer
|
||||||
|
self.current = RATE_START;
|
||||||
|
|
||||||
|
// Add the new seed material to the first half of the rate portion of the RPO state
|
||||||
|
let data: Word = data.into();
|
||||||
|
|
||||||
|
self.state[RATE_START] += data[0];
|
||||||
|
self.state[RATE_START + 1] += data[1];
|
||||||
|
self.state[RATE_START + 2] += data[2];
|
||||||
|
self.state[RATE_START + 3] += data[3];
|
||||||
|
|
||||||
|
// Absorb
|
||||||
|
Rpo256::apply_permutation(&mut self.state);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_leading_zeros(&self, value: u64) -> u32 {
|
||||||
|
let value = Felt::new(value);
|
||||||
|
let mut state_tmp = self.state;
|
||||||
|
|
||||||
|
state_tmp[RATE_START] += value;
|
||||||
|
|
||||||
|
Rpo256::apply_permutation(&mut state_tmp);
|
||||||
|
|
||||||
|
let first_rate_element = state_tmp[RATE_START].as_int();
|
||||||
|
first_rate_element.trailing_zeros()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn draw<E: FieldElement<BaseField = Felt>>(&mut self) -> Result<E, RandomCoinError> {
|
||||||
|
let ext_degree = E::EXTENSION_DEGREE;
|
||||||
|
let mut result = vec![ZERO; ext_degree];
|
||||||
|
for r in result.iter_mut().take(ext_degree) {
|
||||||
|
*r = self.draw_basefield();
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = E::slice_from_base_elements(&result);
|
||||||
|
Ok(result[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn draw_integers(
|
||||||
|
&mut self,
|
||||||
|
num_values: usize,
|
||||||
|
domain_size: usize,
|
||||||
|
nonce: u64,
|
||||||
|
) -> Result<Vec<usize>, RandomCoinError> {
|
||||||
|
assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
|
||||||
|
assert!(num_values < domain_size, "number of values must be smaller than domain size");
|
||||||
|
|
||||||
|
// absorb the nonce
|
||||||
|
let nonce = Felt::new(nonce);
|
||||||
|
self.state[RATE_START] += nonce;
|
||||||
|
Rpo256::apply_permutation(&mut self.state);
|
||||||
|
|
||||||
|
// reset the buffer
|
||||||
|
self.current = RATE_START;
|
||||||
|
|
||||||
|
// determine how many bits are needed to represent valid values in the domain
|
||||||
|
let v_mask = (domain_size - 1) as u64;
|
||||||
|
|
||||||
|
// draw values from PRNG until we get as many unique values as specified by num_queries
|
||||||
|
let mut values = Vec::new();
|
||||||
|
for _ in 0..1000 {
|
||||||
|
// get the next pseudo-random field element
|
||||||
|
let value = self.draw_basefield().as_int();
|
||||||
|
|
||||||
|
// use the mask to get a value within the range
|
||||||
|
let value = (value & v_mask) as usize;
|
||||||
|
|
||||||
|
values.push(value);
|
||||||
|
if values.len() == num_values {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if values.len() < num_values {
|
||||||
|
return Err(RandomCoinError::FailedToDrawIntegers(num_values, values.len(), 1000));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(values)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FELT RNG IMPLEMENTATION
|
||||||
|
// ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
impl FeltRng for RpoRandomCoin {
|
||||||
|
fn draw_element(&mut self) -> Felt {
|
||||||
|
self.draw_basefield()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn draw_word(&mut self) -> Word {
|
||||||
|
let mut output = [ZERO; 4];
|
||||||
|
for o in output.iter_mut() {
|
||||||
|
*o = self.draw_basefield();
|
||||||
|
}
|
||||||
|
output
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SERIALIZATION
|
||||||
|
// ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
impl Serializable for RpoRandomCoin {
|
||||||
|
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||||
|
self.state.iter().for_each(|v| v.write_into(target));
|
||||||
|
// casting to u8 is OK because `current` is always between 4 and 12.
|
||||||
|
target.write_u8(self.current as u8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Deserializable for RpoRandomCoin {
|
||||||
|
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||||
|
let state = [
|
||||||
|
Felt::read_from(source)?,
|
||||||
|
Felt::read_from(source)?,
|
||||||
|
Felt::read_from(source)?,
|
||||||
|
Felt::read_from(source)?,
|
||||||
|
Felt::read_from(source)?,
|
||||||
|
Felt::read_from(source)?,
|
||||||
|
Felt::read_from(source)?,
|
||||||
|
Felt::read_from(source)?,
|
||||||
|
Felt::read_from(source)?,
|
||||||
|
Felt::read_from(source)?,
|
||||||
|
Felt::read_from(source)?,
|
||||||
|
Felt::read_from(source)?,
|
||||||
|
];
|
||||||
|
let current = source.read_u8()? as usize;
|
||||||
|
if !(RATE_START..RATE_END).contains(¤t) {
|
||||||
|
return Err(DeserializationError::InvalidValue(
|
||||||
|
"current value outside of valid range".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
Ok(Self { state, current })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TESTS
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{Deserializable, FeltRng, RpoRandomCoin, Serializable, ZERO};
|
||||||
|
use crate::ONE;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_feltrng_felt() {
|
||||||
|
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||||
|
let output = rpocoin.draw_element();
|
||||||
|
|
||||||
|
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||||
|
let expected = rpocoin.draw_basefield();
|
||||||
|
|
||||||
|
assert_eq!(output, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_feltrng_word() {
|
||||||
|
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||||
|
let output = rpocoin.draw_word();
|
||||||
|
|
||||||
|
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||||
|
let mut expected = [ZERO; 4];
|
||||||
|
for o in expected.iter_mut() {
|
||||||
|
*o = rpocoin.draw_basefield();
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(output, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_feltrng_serialization() {
|
||||||
|
let coin1 = RpoRandomCoin::from_parts([ONE; 12], 5);
|
||||||
|
|
||||||
|
let bytes = coin1.to_bytes();
|
||||||
|
let coin2 = RpoRandomCoin::read_from_bytes(&bytes).unwrap();
|
||||||
|
assert_eq!(coin1, coin2);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user