mirror of
https://github.com/arnaucube/miden-crypto.git
synced 2026-01-11 16:41:29 +01:00
Compare commits
142 Commits
v0.4.0
...
al-gkr-bas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f06fa30a9 | ||
|
|
0b074a795d | ||
|
|
862ccf54dd | ||
|
|
88bcdfd576 | ||
|
|
290894f497 | ||
|
|
4aac00884c | ||
|
|
2ef6f79656 | ||
|
|
5142e2fd31 | ||
|
|
9fb41337ec | ||
|
|
0296e05ccd | ||
|
|
499f97046d | ||
|
|
600feafe53 | ||
|
|
9d854f1fcb | ||
|
|
af76cb10d0 | ||
|
|
4758e0672f | ||
|
|
8bb080a91d | ||
|
|
e5f3b28645 | ||
|
|
29e0d07129 | ||
|
|
81a94ecbe7 | ||
|
|
223fbf887d | ||
|
|
9e77a7c9b7 | ||
|
|
894e20fe0c | ||
|
|
7ec7b06574 | ||
|
|
2499a8a2dd | ||
|
|
800994c69b | ||
|
|
26560605bf | ||
|
|
672340d0c2 | ||
|
|
8083b02aef | ||
|
|
ecb8719d45 | ||
|
|
4144f98560 | ||
|
|
c726050957 | ||
|
|
9239340888 | ||
|
|
97ee9298a4 | ||
|
|
bfae06e128 | ||
|
|
b4e2d63c10 | ||
|
|
9679329746 | ||
|
|
2bbea37dbe | ||
|
|
83000940da | ||
|
|
f44175e7a9 | ||
|
|
4cf8eebff5 | ||
|
|
c86bdc6d51 | ||
|
|
650508cbc9 | ||
|
|
012ad5ae93 | ||
|
|
bde20f9752 | ||
|
|
7f3d4b8966 | ||
|
|
1a00c7035f | ||
|
|
7ddcdc5e39 | ||
|
|
bfd05e3d38 | ||
|
|
9235a78afd | ||
|
|
78aa714b89 | ||
|
|
aeadc96b05 | ||
|
|
0fb1ef837d | ||
|
|
cf91c89845 | ||
|
|
025c25fdd9 | ||
|
|
8078021aff | ||
|
|
b1dbcee21d | ||
|
|
396418659d | ||
|
|
01be4d6b9d | ||
|
|
701a187e7f | ||
|
|
1fa2895724 | ||
|
|
90dd3acb12 | ||
|
|
2f09410e87 | ||
|
|
51d527b568 | ||
|
|
9f54c82d62 | ||
|
|
c7f1535974 | ||
|
|
c1d0612115 | ||
|
|
2214ff2425 | ||
|
|
85034af1df | ||
|
|
f7e6922bff | ||
|
|
7780a50dad | ||
|
|
6d0c7567f0 | ||
|
|
854ade1bfc | ||
|
|
fb649df1e7 | ||
|
|
d9e85230a6 | ||
|
|
8cf5e9fd2c | ||
|
|
03f89f0aff | ||
|
|
2fa1b9768a | ||
|
|
f71d98970b | ||
|
|
b3e7578ab2 | ||
|
|
5c6a20cb60 | ||
|
|
bc364b72c0 | ||
|
|
83b6946432 | ||
|
|
3dfcc0810f | ||
|
|
33ef78f8f5 | ||
|
|
b6eb1f9134 | ||
|
|
92bb3ac462 | ||
|
|
1ac30f8989 | ||
|
|
6810b5e3ab | ||
|
|
a03f2b5d5e | ||
|
|
1bb75e85dd | ||
|
|
1578a9ee1f | ||
|
|
e49bccd7b7 | ||
|
|
71b04d0734 | ||
|
|
8c749e473a | ||
|
|
809b572a40 | ||
|
|
da2d08714d | ||
|
|
aaf1788228 | ||
|
|
44e60e7228 | ||
|
|
08aec4443c | ||
|
|
813fe24b88 | ||
|
|
18302d68e0 | ||
|
|
858f95d4a1 | ||
|
|
b2d6866d41 | ||
|
|
f52ac29a02 | ||
|
|
f08644e4df | ||
|
|
679a30e02e | ||
|
|
cede2e57da | ||
|
|
4215e83ae5 | ||
|
|
fe5cac9edc | ||
|
|
53d52b8adc | ||
|
|
1be64fc43d | ||
|
|
049ae32cbf | ||
|
|
b9def61e28 | ||
|
|
0e0a3fda4f | ||
|
|
fe9aa8c28c | ||
|
|
766702e37a | ||
|
|
218a64b5c7 | ||
|
|
2708a23649 | ||
|
|
43f1a4cb64 | ||
|
|
55cc71dadf | ||
|
|
ebf71c2dc7 | ||
|
|
b4324475b6 | ||
|
|
23f448fb33 | ||
|
|
59f7723221 | ||
|
|
2ed880d976 | ||
|
|
daa27f49f2 | ||
|
|
dcda57f71a | ||
|
|
d9e3211418 | ||
|
|
21e7a5c07d | ||
|
|
02673ff87e | ||
|
|
b768eade4d | ||
|
|
51ce07cc34 | ||
|
|
550738bd94 | ||
|
|
629494b601 | ||
|
|
13aeda5a27 | ||
|
|
e5aba870a2 | ||
|
|
fcf03478ba | ||
|
|
0ddd0db89b | ||
|
|
2100d6c861 | ||
|
|
52409ac039 | ||
|
|
4555fc918f | ||
|
|
52db23cd42 |
20
.editorconfig
Normal file
20
.editorconfig
Normal file
@@ -0,0 +1,20 @@
|
||||
# Documentation available at editorconfig.org
|
||||
|
||||
root=true
|
||||
|
||||
[*]
|
||||
ident_style = space
|
||||
ident_size = 4
|
||||
end_of_line = lf
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
[*.rs]
|
||||
max_line_length = 100
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
||||
|
||||
[*.yml]
|
||||
ident_size = 2
|
||||
@@ -1,2 +0,0 @@
|
||||
# initial run of pre-commit
|
||||
956e4c6fad779ef15eaa27702b26f05f65d31494
|
||||
10
.github/workflows/ci.yml
vendored
10
.github/workflows/ci.yml
vendored
@@ -19,6 +19,8 @@ jobs:
|
||||
args: [--no-default-features --target wasm32-unknown-unknown]
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
@@ -39,9 +41,11 @@ jobs:
|
||||
matrix:
|
||||
toolchain: [stable, nightly]
|
||||
os: [ubuntu]
|
||||
features: [--all-features, --no-default-features]
|
||||
features: ["--features default,std,serde", --no-default-features]
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
@@ -59,9 +63,11 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
features: [--all-features, --no-default-features]
|
||||
features: ["--features default,std,serde", --no-default-features]
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install minimal nightly with clippy
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -8,3 +8,6 @@ Cargo.lock
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
# Generated by cmake
|
||||
cmake-build-*
|
||||
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "PQClean"]
|
||||
path = PQClean
|
||||
url = https://github.com/PQClean/PQClean.git
|
||||
@@ -35,8 +35,8 @@ repos:
|
||||
name: Cargo check --all-targets --no-default-features
|
||||
args: ["+stable", "check", "--all-targets", "--no-default-features"]
|
||||
- id: cargo
|
||||
name: Cargo check --all-targets --all-features
|
||||
args: ["+stable", "check", "--all-targets", "--all-features"]
|
||||
name: Cargo check --all-targets --features default,std,serde
|
||||
args: ["+stable", "check", "--all-targets", "--features", "default,std,serde"]
|
||||
# Unlike fmt, clippy will not be automatically applied
|
||||
- id: cargo
|
||||
name: Cargo clippy
|
||||
|
||||
37
CHANGELOG.md
37
CHANGELOG.md
@@ -1,3 +1,40 @@
|
||||
## 0.8.0 (TBD)
|
||||
|
||||
* Implemented the `PartialMmr` data structure (#195).
|
||||
* Updated Winterfell dependency to v0.7 (#200)
|
||||
* Implemented RPX hash function (#201).
|
||||
* Added `FeltRng` and `RpoRandomCoin` (#237).
|
||||
* Added `inner_nodes()` method to `PartialMmr` (#238).
|
||||
|
||||
## 0.7.1 (2023-10-10)
|
||||
|
||||
* Fixed RPO Falcon signature build on Windows.
|
||||
|
||||
## 0.7.0 (2023-10-05)
|
||||
|
||||
* Replaced `MerklePathSet` with `PartialMerkleTree` (#165).
|
||||
* Implemented clearing of nodes in `TieredSmt` (#173).
|
||||
* Added ability to generate inclusion proofs for `TieredSmt` (#174).
|
||||
* Implemented Falcon DSA (#179).
|
||||
* Added conditional `serde`` support for various structs (#180).
|
||||
* Implemented benchmarking for `TieredSmt` (#182).
|
||||
* Added more leaf traversal methods for `MerkleStore` (#185).
|
||||
* Added SVE acceleration for RPO hash function (#189).
|
||||
|
||||
## 0.6.0 (2023-06-25)
|
||||
|
||||
* [BREAKING] Added support for recording capabilities for `MerkleStore` (#162).
|
||||
* [BREAKING] Refactored Merkle struct APIs to use `RpoDigest` instead of `Word` (#157).
|
||||
* Added initial implementation of `PartialMerkleTree` (#156).
|
||||
|
||||
## 0.5.0 (2023-05-26)
|
||||
|
||||
* Implemented `TieredSmt` (#152, #153).
|
||||
* Implemented ability to extract a subset of a `MerkleStore` (#151).
|
||||
* Cleaned up `SimpleSmt` interface (#149).
|
||||
* Decoupled hashing and padding of peaks in `Mmr` (#148).
|
||||
* Added `inner_nodes()` to `MerkleStore` (#146).
|
||||
|
||||
## 0.4.0 (2023-04-21)
|
||||
|
||||
- Exported `MmrProof` from the crate (#137).
|
||||
|
||||
45
Cargo.toml
45
Cargo.toml
@@ -1,16 +1,23 @@
|
||||
[package]
|
||||
name = "miden-crypto"
|
||||
version = "0.4.0"
|
||||
version = "0.8.0"
|
||||
description = "Miden Cryptographic primitives"
|
||||
authors = ["miden contributors"]
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
repository = "https://github.com/0xPolygonMiden/crypto"
|
||||
documentation = "https://docs.rs/miden-crypto/0.4.0"
|
||||
documentation = "https://docs.rs/miden-crypto/0.8.0"
|
||||
categories = ["cryptography", "no-std"]
|
||||
keywords = ["miden", "crypto", "hash", "merkle"]
|
||||
edition = "2021"
|
||||
rust-version = "1.67"
|
||||
rust-version = "1.73"
|
||||
|
||||
[[bin]]
|
||||
name = "miden-crypto"
|
||||
path = "src/main.rs"
|
||||
bench = false
|
||||
doctest = false
|
||||
required-features = ["executable"]
|
||||
|
||||
[[bench]]
|
||||
name = "hash"
|
||||
@@ -25,16 +32,30 @@ name = "store"
|
||||
harness = false
|
||||
|
||||
[features]
|
||||
default = ["blake3/default", "std", "winter_crypto/default", "winter_math/default", "winter_utils/default"]
|
||||
std = ["blake3/std", "winter_crypto/std", "winter_math/std", "winter_utils/std"]
|
||||
default = ["std"]
|
||||
executable = ["dep:clap", "dep:rand_utils", "std"]
|
||||
serde = ["dep:serde", "serde?/alloc", "winter_math/serde"]
|
||||
std = ["blake3/std", "dep:cc", "dep:libc", "winter_crypto/std", "winter_math/std", "winter_utils/std"]
|
||||
sve = ["std"]
|
||||
|
||||
[dependencies]
|
||||
blake3 = { version = "1.3", default-features = false }
|
||||
winter_crypto = { version = "0.6", package = "winter-crypto", default-features = false }
|
||||
winter_math = { version = "0.6", package = "winter-math", default-features = false }
|
||||
winter_utils = { version = "0.6", package = "winter-utils", default-features = false }
|
||||
blake3 = { version = "1.5", default-features = false }
|
||||
clap = { version = "4.4", features = ["derive"], optional = true }
|
||||
libc = { version = "0.2", default-features = false, optional = true }
|
||||
rand_utils = { version = "0.7", package = "winter-rand-utils", optional = true }
|
||||
serde = { version = "1.0", features = [ "derive" ], default-features = false, optional = true }
|
||||
winter_crypto = { version = "0.7", package = "winter-crypto", default-features = false }
|
||||
winter_math = { version = "0.7", package = "winter-math", default-features = false }
|
||||
winter_utils = { version = "0.7", package = "winter-utils", default-features = false }
|
||||
rayon = "1.8.0"
|
||||
rand = "0.8.4"
|
||||
rand_core = { version = "0.5", default-features = false }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.4", features = ["html_reports"] }
|
||||
proptest = "1.1.0"
|
||||
rand_utils = { version = "0.6", package = "winter-rand-utils" }
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
proptest = "1.3"
|
||||
rand_utils = { version = "0.7", package = "winter-rand-utils" }
|
||||
|
||||
[build-dependencies]
|
||||
cc = { version = "1.0", features = ["parallel"], optional = true }
|
||||
glob = "0.3"
|
||||
|
||||
1
PQClean
Submodule
1
PQClean
Submodule
Submodule PQClean added at c3abebf4ab
29
README.md
29
README.md
@@ -6,23 +6,36 @@ This crate contains cryptographic primitives used in Polygon Miden.
|
||||
|
||||
* [BLAKE3](https://github.com/BLAKE3-team/BLAKE3) hash function with 256-bit, 192-bit, or 160-bit output. The 192-bit and 160-bit outputs are obtained by truncating the 256-bit output of the standard BLAKE3.
|
||||
* [RPO](https://eprint.iacr.org/2022/1577) hash function with 256-bit output. This hash function is an algebraic hash function suitable for recursive STARKs.
|
||||
* [RPX](https://eprint.iacr.org/2023/1045) hash function with 256-bit output. Similar to RPO, this hash function is suitable for recursive STARKs but it is about 2x faster as compared to RPO.
|
||||
|
||||
For performance benchmarks of these hash functions and their comparison to other popular hash functions please see [here](./benches/).
|
||||
|
||||
## Merkle
|
||||
[Merkle module](./src/merkle/) provides a set of data structures related to Merkle trees. All these data structures are implemented using the RPO hash function described above. The data structures are:
|
||||
|
||||
* `MerkleStore`: a collection of Merkle trees of different heights designed to efficiently store trees with common subtrees. When instantiated with `RecordingMap`, a Merkle store records all accesses to the original data.
|
||||
* `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64.
|
||||
* `SimpleSmt`: a Sparse Merkle Tree, mapping 64-bit keys to 4-element leaf values.
|
||||
* `MerklePathSet`: a collection of Merkle authentication paths all resolving to the same root. The length of the paths can be at most 64.
|
||||
* `MerkleStore`: a collection of Merkle trees of different heights designed to efficiently store trees with common subtrees.
|
||||
* `Mmr`: a Merkle mountain range structure designed to function as an append-only log.
|
||||
* `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64.
|
||||
* `PartialMmr`: a partial view of a Merkle mountain range structure.
|
||||
* `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values.
|
||||
* `TieredSmt`: a Sparse Merkle tree (with compaction), mapping 4-element keys to 4-element values.
|
||||
|
||||
The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state.
|
||||
|
||||
## Extra
|
||||
[Root module](./src/lib.rs) provides a set of constants, types, aliases, and utils required to use the primitives of this library.
|
||||
## Signatures
|
||||
[DSA module](./src/dsa) provides a set of digital signature schemes supported by default in the Miden VM. Currently, these schemes are:
|
||||
|
||||
* `RPO Falcon512`: a variant of the [Falcon](https://falcon-sign.info/) signature scheme. This variant differs from the standard in that instead of using SHAKE256 hash function in the *hash-to-point* algorithm we use RPO256. This makes the signature more efficient to verify in Miden VM.
|
||||
|
||||
For the above signatures, key generation and signing is available only in the `std` context (see [crate features](#crate-features) below), while signature verification is available in `no_std` context as well.
|
||||
|
||||
## Pseudo-Random Element Generator
|
||||
[Pseudo random element generator module](./src/rand/) provides a set of traits and data structures that facilitate generating pseudo-random elements in the context of Miden VM and Miden rollup. The module currently includes:
|
||||
|
||||
* `FeltRng`: a trait for generating random field elements and random 4 field elements.
|
||||
* `RpoRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait.
|
||||
|
||||
## Crate features
|
||||
This crate can be compiled with the following features:
|
||||
|
||||
@@ -33,6 +46,12 @@ Both of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/
|
||||
|
||||
To compile with `no_std`, disable default features via `--no-default-features` flag.
|
||||
|
||||
### SVE acceleration
|
||||
On platforms with [SVE](https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)) support, RPO hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` feature enabled. This feature has an effect only if the platform exposes `target-feature=sve` flag. On some platforms (e.g., Graviton 3), for this flag to be set, the compilation must be done in "native" mode. For example, to enable SVE acceleration on Graviton 3, we can execute the following:
|
||||
```shell
|
||||
RUSTFLAGS="-C target-cpu=native" cargo build --release --features sve
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
You can use cargo defaults to test the library:
|
||||
|
||||
78
arch/arm64-sve/rpo/library.c
Normal file
78
arch/arm64-sve/rpo/library.c
Normal file
@@ -0,0 +1,78 @@
|
||||
#include <stddef.h>
|
||||
#include <arm_sve.h>
|
||||
#include "library.h"
|
||||
#include "rpo_hash.h"
|
||||
|
||||
// The STATE_WIDTH of RPO hash is 12x u64 elements.
|
||||
// The current generation of SVE-enabled processors - Neoverse V1
|
||||
// (e.g. AWS Graviton3) have 256-bit vector registers (4x u64)
|
||||
// This allows us to split the state into 3 vectors of 4 elements
|
||||
// and process all 3 independent of each other.
|
||||
|
||||
// We see the biggest performance gains by leveraging both
|
||||
// vector and scalar operations on parts of the state array.
|
||||
// Due to high latency of vector operations, the processor is able
|
||||
// to reorder and pipeline scalar instructions while we wait for
|
||||
// vector results. This effectively gives us some 'free' scalar
|
||||
// operations and masks vector latency.
|
||||
//
|
||||
// This also means that we can fully saturate all four arithmetic
|
||||
// units of the processor (2x scalar, 2x SIMD)
|
||||
//
|
||||
// THIS ANALYSIS NEEDS TO BE PERFORMED AGAIN ONCE PROCESSORS
|
||||
// GAIN WIDER REGISTERS. It's quite possible that with 8x u64
|
||||
// vectors processing 2 partially filled vectors might
|
||||
// be easier and faster than dealing with scalar operations
|
||||
// on the remainder of the array.
|
||||
//
|
||||
// FOR NOW THIS IS ONLY ENABLED ON 4x u64 VECTORS! It falls back
|
||||
// to the regular, already highly-optimized scalar version
|
||||
// if the conditions are not met.
|
||||
|
||||
bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
|
||||
const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector
|
||||
|
||||
if (vl != 4) {
|
||||
return false;
|
||||
}
|
||||
|
||||
svbool_t ptrue = svptrue_b64();
|
||||
|
||||
svuint64_t state1 = svld1(ptrue, state + 0*vl);
|
||||
svuint64_t state2 = svld1(ptrue, state + 1*vl);
|
||||
|
||||
svuint64_t const1 = svld1(ptrue, constants + 0*vl);
|
||||
svuint64_t const2 = svld1(ptrue, constants + 1*vl);
|
||||
|
||||
add_constants(ptrue, &state1, &const1, &state2, &const2, state+8, constants+8);
|
||||
apply_sbox(ptrue, &state1, &state2, state+8);
|
||||
|
||||
svst1(ptrue, state + 0*vl, state1);
|
||||
svst1(ptrue, state + 1*vl, state2);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
|
||||
const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector
|
||||
|
||||
if (vl != 4) {
|
||||
return false;
|
||||
}
|
||||
|
||||
svbool_t ptrue = svptrue_b64();
|
||||
|
||||
svuint64_t state1 = svld1(ptrue, state + 0 * vl);
|
||||
svuint64_t state2 = svld1(ptrue, state + 1 * vl);
|
||||
|
||||
svuint64_t const1 = svld1(ptrue, constants + 0 * vl);
|
||||
svuint64_t const2 = svld1(ptrue, constants + 1 * vl);
|
||||
|
||||
add_constants(ptrue, &state1, &const1, &state2, &const2, state + 8, constants + 8);
|
||||
apply_inv_sbox(ptrue, &state1, &state2, state + 8);
|
||||
|
||||
svst1(ptrue, state + 0 * vl, state1);
|
||||
svst1(ptrue, state + 1 * vl, state2);
|
||||
|
||||
return true;
|
||||
}
|
||||
12
arch/arm64-sve/rpo/library.h
Normal file
12
arch/arm64-sve/rpo/library.h
Normal file
@@ -0,0 +1,12 @@
|
||||
#ifndef CRYPTO_LIBRARY_H
|
||||
#define CRYPTO_LIBRARY_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#define STATE_WIDTH 12
|
||||
|
||||
bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]);
|
||||
bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]);
|
||||
|
||||
#endif //CRYPTO_LIBRARY_H
|
||||
221
arch/arm64-sve/rpo/rpo_hash.h
Normal file
221
arch/arm64-sve/rpo/rpo_hash.h
Normal file
@@ -0,0 +1,221 @@
|
||||
#ifndef RPO_SVE_RPO_HASH_H
|
||||
#define RPO_SVE_RPO_HASH_H
|
||||
|
||||
#include <arm_sve.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
#define COPY(NAME, VIN1, VIN2, SIN3) \
|
||||
svuint64_t NAME ## _1 = VIN1; \
|
||||
svuint64_t NAME ## _2 = VIN2; \
|
||||
uint64_t NAME ## _3[4]; \
|
||||
memcpy(NAME ## _3, SIN3, 4 * sizeof(uint64_t))
|
||||
|
||||
#define MULTIPLY(PRED, DEST, OP) \
|
||||
mul(PRED, &DEST ## _1, &OP ## _1, &DEST ## _2, &OP ## _2, DEST ## _3, OP ## _3)
|
||||
|
||||
#define SQUARE(PRED, NAME) \
|
||||
sq(PRED, &NAME ## _1, &NAME ## _2, NAME ## _3)
|
||||
|
||||
#define SQUARE_DEST(PRED, DEST, SRC) \
|
||||
COPY(DEST, SRC ## _1, SRC ## _2, SRC ## _3); \
|
||||
SQUARE(PRED, DEST);
|
||||
|
||||
#define POW_ACC(PRED, NAME, CNT, TAIL) \
|
||||
for (size_t i = 0; i < CNT; i++) { \
|
||||
SQUARE(PRED, NAME); \
|
||||
} \
|
||||
MULTIPLY(PRED, NAME, TAIL);
|
||||
|
||||
#define POW_ACC_DEST(PRED, DEST, CNT, HEAD, TAIL) \
|
||||
COPY(DEST, HEAD ## _1, HEAD ## _2, HEAD ## _3); \
|
||||
POW_ACC(PRED, DEST, CNT, TAIL)
|
||||
|
||||
extern inline void add_constants(
|
||||
svbool_t pg,
|
||||
svuint64_t *state1,
|
||||
svuint64_t *const1,
|
||||
svuint64_t *state2,
|
||||
svuint64_t *const2,
|
||||
uint64_t *state3,
|
||||
uint64_t *const3
|
||||
) {
|
||||
uint64_t Ms = 0xFFFFFFFF00000001ull;
|
||||
svuint64_t Mv = svindex_u64(Ms, 0);
|
||||
|
||||
uint64_t p_1 = Ms - const3[0];
|
||||
uint64_t p_2 = Ms - const3[1];
|
||||
uint64_t p_3 = Ms - const3[2];
|
||||
uint64_t p_4 = Ms - const3[3];
|
||||
|
||||
uint64_t x_1, x_2, x_3, x_4;
|
||||
uint32_t adj_1 = -__builtin_sub_overflow(state3[0], p_1, &x_1);
|
||||
uint32_t adj_2 = -__builtin_sub_overflow(state3[1], p_2, &x_2);
|
||||
uint32_t adj_3 = -__builtin_sub_overflow(state3[2], p_3, &x_3);
|
||||
uint32_t adj_4 = -__builtin_sub_overflow(state3[3], p_4, &x_4);
|
||||
|
||||
state3[0] = x_1 - (uint64_t)adj_1;
|
||||
state3[1] = x_2 - (uint64_t)adj_2;
|
||||
state3[2] = x_3 - (uint64_t)adj_3;
|
||||
state3[3] = x_4 - (uint64_t)adj_4;
|
||||
|
||||
svuint64_t p1 = svsub_x(pg, Mv, *const1);
|
||||
svuint64_t p2 = svsub_x(pg, Mv, *const2);
|
||||
|
||||
svuint64_t x1 = svsub_x(pg, *state1, p1);
|
||||
svuint64_t x2 = svsub_x(pg, *state2, p2);
|
||||
|
||||
svbool_t pt1 = svcmplt_u64(pg, *state1, p1);
|
||||
svbool_t pt2 = svcmplt_u64(pg, *state2, p2);
|
||||
|
||||
*state1 = svsub_m(pt1, x1, (uint32_t)-1);
|
||||
*state2 = svsub_m(pt2, x2, (uint32_t)-1);
|
||||
}
|
||||
|
||||
extern inline void mul(
|
||||
svbool_t pg,
|
||||
svuint64_t *r1,
|
||||
const svuint64_t *op1,
|
||||
svuint64_t *r2,
|
||||
const svuint64_t *op2,
|
||||
uint64_t *r3,
|
||||
const uint64_t *op3
|
||||
) {
|
||||
__uint128_t x_1 = r3[0];
|
||||
__uint128_t x_2 = r3[1];
|
||||
__uint128_t x_3 = r3[2];
|
||||
__uint128_t x_4 = r3[3];
|
||||
|
||||
x_1 *= (__uint128_t) op3[0];
|
||||
x_2 *= (__uint128_t) op3[1];
|
||||
x_3 *= (__uint128_t) op3[2];
|
||||
x_4 *= (__uint128_t) op3[3];
|
||||
|
||||
uint64_t x0_1 = x_1;
|
||||
uint64_t x0_2 = x_2;
|
||||
uint64_t x0_3 = x_3;
|
||||
uint64_t x0_4 = x_4;
|
||||
|
||||
svuint64_t l1 = svmul_x(pg, *r1, *op1);
|
||||
svuint64_t l2 = svmul_x(pg, *r2, *op2);
|
||||
|
||||
uint64_t x1_1 = (x_1 >> 64);
|
||||
uint64_t x1_2 = (x_2 >> 64);
|
||||
uint64_t x1_3 = (x_3 >> 64);
|
||||
uint64_t x1_4 = (x_4 >> 64);
|
||||
|
||||
uint64_t a_1, a_2, a_3, a_4;
|
||||
uint64_t e_1 = __builtin_add_overflow(x0_1, (x0_1 << 32), &a_1);
|
||||
uint64_t e_2 = __builtin_add_overflow(x0_2, (x0_2 << 32), &a_2);
|
||||
uint64_t e_3 = __builtin_add_overflow(x0_3, (x0_3 << 32), &a_3);
|
||||
uint64_t e_4 = __builtin_add_overflow(x0_4, (x0_4 << 32), &a_4);
|
||||
|
||||
svuint64_t ls1 = svlsl_x(pg, l1, 32);
|
||||
svuint64_t ls2 = svlsl_x(pg, l2, 32);
|
||||
|
||||
svuint64_t a1 = svadd_x(pg, l1, ls1);
|
||||
svuint64_t a2 = svadd_x(pg, l2, ls2);
|
||||
|
||||
svbool_t e1 = svcmplt(pg, a1, l1);
|
||||
svbool_t e2 = svcmplt(pg, a2, l2);
|
||||
|
||||
svuint64_t as1 = svlsr_x(pg, a1, 32);
|
||||
svuint64_t as2 = svlsr_x(pg, a2, 32);
|
||||
|
||||
svuint64_t b1 = svsub_x(pg, a1, as1);
|
||||
svuint64_t b2 = svsub_x(pg, a2, as2);
|
||||
|
||||
b1 = svsub_m(e1, b1, 1);
|
||||
b2 = svsub_m(e2, b2, 1);
|
||||
|
||||
uint64_t b_1 = a_1 - (a_1 >> 32) - e_1;
|
||||
uint64_t b_2 = a_2 - (a_2 >> 32) - e_2;
|
||||
uint64_t b_3 = a_3 - (a_3 >> 32) - e_3;
|
||||
uint64_t b_4 = a_4 - (a_4 >> 32) - e_4;
|
||||
|
||||
uint64_t r_1, r_2, r_3, r_4;
|
||||
uint32_t c_1 = __builtin_sub_overflow(x1_1, b_1, &r_1);
|
||||
uint32_t c_2 = __builtin_sub_overflow(x1_2, b_2, &r_2);
|
||||
uint32_t c_3 = __builtin_sub_overflow(x1_3, b_3, &r_3);
|
||||
uint32_t c_4 = __builtin_sub_overflow(x1_4, b_4, &r_4);
|
||||
|
||||
svuint64_t h1 = svmulh_x(pg, *r1, *op1);
|
||||
svuint64_t h2 = svmulh_x(pg, *r2, *op2);
|
||||
|
||||
svuint64_t tr1 = svsub_x(pg, h1, b1);
|
||||
svuint64_t tr2 = svsub_x(pg, h2, b2);
|
||||
|
||||
svbool_t c1 = svcmplt_u64(pg, h1, b1);
|
||||
svbool_t c2 = svcmplt_u64(pg, h2, b2);
|
||||
|
||||
*r1 = svsub_m(c1, tr1, (uint32_t) -1);
|
||||
*r2 = svsub_m(c2, tr2, (uint32_t) -1);
|
||||
|
||||
uint32_t minus1_1 = 0 - c_1;
|
||||
uint32_t minus1_2 = 0 - c_2;
|
||||
uint32_t minus1_3 = 0 - c_3;
|
||||
uint32_t minus1_4 = 0 - c_4;
|
||||
|
||||
r3[0] = r_1 - (uint64_t)minus1_1;
|
||||
r3[1] = r_2 - (uint64_t)minus1_2;
|
||||
r3[2] = r_3 - (uint64_t)minus1_3;
|
||||
r3[3] = r_4 - (uint64_t)minus1_4;
|
||||
}
|
||||
|
||||
extern inline void sq(svbool_t pg, svuint64_t *a, svuint64_t *b, uint64_t *c) {
|
||||
mul(pg, a, a, b, b, c, c);
|
||||
}
|
||||
|
||||
extern inline void apply_sbox(
|
||||
svbool_t pg,
|
||||
svuint64_t *state1,
|
||||
svuint64_t *state2,
|
||||
uint64_t *state3
|
||||
) {
|
||||
COPY(x, *state1, *state2, state3); // copy input to x
|
||||
SQUARE(pg, x); // x contains input^2
|
||||
mul(pg, state1, &x_1, state2, &x_2, state3, x_3); // state contains input^3
|
||||
SQUARE(pg, x); // x contains input^4
|
||||
mul(pg, state1, &x_1, state2, &x_2, state3, x_3); // state contains input^7
|
||||
}
|
||||
|
||||
extern inline void apply_inv_sbox(
|
||||
svbool_t pg,
|
||||
svuint64_t *state_1,
|
||||
svuint64_t *state_2,
|
||||
uint64_t *state_3
|
||||
) {
|
||||
// base^10
|
||||
COPY(t1, *state_1, *state_2, state_3);
|
||||
SQUARE(pg, t1);
|
||||
|
||||
// base^100
|
||||
SQUARE_DEST(pg, t2, t1);
|
||||
|
||||
// base^100100
|
||||
POW_ACC_DEST(pg, t3, 3, t2, t2);
|
||||
|
||||
// base^100100100100
|
||||
POW_ACC_DEST(pg, t4, 6, t3, t3);
|
||||
|
||||
// compute base^100100100100100100100100
|
||||
POW_ACC_DEST(pg, t5, 12, t4, t4);
|
||||
|
||||
// compute base^100100100100100100100100100100
|
||||
POW_ACC_DEST(pg, t6, 6, t5, t3);
|
||||
|
||||
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||
POW_ACC_DEST(pg, t7, 31, t6, t6);
|
||||
|
||||
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||
SQUARE(pg, t7);
|
||||
MULTIPLY(pg, t7, t6);
|
||||
SQUARE(pg, t7);
|
||||
SQUARE(pg, t7);
|
||||
MULTIPLY(pg, t7, t1);
|
||||
MULTIPLY(pg, t7, t2);
|
||||
mul(pg, state_1, &t7_1, state_2, &t7_2, state_3, t7_3);
|
||||
}
|
||||
|
||||
#endif //RPO_SVE_RPO_HASH_H
|
||||
@@ -6,6 +6,7 @@ In the Miden VM, we make use of different hash functions. Some of these are "tra
|
||||
* **Poseidon** as specified [here](https://eprint.iacr.org/2019/458.pdf) and implemented [here](https://github.com/mir-protocol/plonky2/blob/806b88d7d6e69a30dc0b4775f7ba275c45e8b63b/plonky2/src/hash/poseidon_goldilocks.rs) (but in pure Rust, without vectorized instructions).
|
||||
* **Rescue Prime (RP)** as specified [here](https://eprint.iacr.org/2020/1143) and implemented [here](https://github.com/novifinancial/winterfell/blob/46dce1adf0/crypto/src/hash/rescue/rp64_256/mod.rs).
|
||||
* **Rescue Prime Optimized (RPO)** as specified [here](https://eprint.iacr.org/2022/1577) and implemented in this crate.
|
||||
* **Rescue Prime Extended (RPX)** a variant of the [xHash](https://eprint.iacr.org/2023/1045) hash function as implemented in this crate.
|
||||
|
||||
## Comparison and Instructions
|
||||
|
||||
@@ -15,25 +16,28 @@ The second scenario is that of sequential hashing where we take a sequence of le
|
||||
|
||||
#### Scenario 1: 2-to-1 hashing `h(a,b)`
|
||||
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 |
|
||||
| ------------------- | ------ | --------| --------- | --------- | ------- |
|
||||
| Apple M1 Pro | 80 ns | 245 ns | 1.5 us | 9.1 us | 5.4 us |
|
||||
| Apple M2 | 76 ns | 233 ns | 1.3 us | 7.9 us | 5.0 us |
|
||||
| Amazon Graviton 3 | 116 ns | | | | 8.8 us |
|
||||
| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 us | 9.1 us | 5.5 us |
|
||||
| Intel Core i5-8279U | 80 ns | | | | 8.7 us |
|
||||
| Intel Xeon 8375C | 67 ns | | | | 8.2 us |
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
|
||||
| ------------------- | ------ | ------- | --------- | --------- | ------- | ------- |
|
||||
| Apple M1 Pro | 76 ns | 245 ns | 1.5 µs | 9.1 µs | 5.2 µs | 2.7 µs |
|
||||
| Apple M2 Max | 71 ns | 233 ns | 1.3 µs | 7.9 µs | 4.6 µs | 2.4 µs |
|
||||
| Amazon Graviton 3 | 108 ns | | | | 5.3 µs | 3.1 µs |
|
||||
| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 µs | 9.1 µs | 5.5 µs | |
|
||||
| Intel Core i5-8279U | 68 ns | 536 ns | 2.0 µs | 13.6 µs | 8.5 µs | 4.4 µs |
|
||||
| Intel Xeon 8375C | 67 ns | | | | 8.2 µs | |
|
||||
|
||||
#### Scenario 2: Sequential hashing of 100 elements `h([a_0,...,a_99])`
|
||||
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 |
|
||||
| ------------------- | -------| ------- | --------- | --------- | ------- |
|
||||
| Apple M1 Pro | 1.0 us | 1.5 us | 19.4 us | 118 us | 70 us |
|
||||
| Apple M2 | 1.0 us | 1.5 us | 17.4 us | 103 us | 65 us |
|
||||
| Amazon Graviton 3 | 1.4 us | | | | 114 us |
|
||||
| AMD Ryzen 9 5950X | 0.8 us | 1.7 us | 15.7 us | 120 us | 72 us |
|
||||
| Intel Core i5-8279U | 1.0 us | | | | 116 us |
|
||||
| Intel Xeon 8375C | 0.8 ns | | | | 110 us |
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
|
||||
| ------------------- | -------| ------- | --------- | --------- | ------- | ------- |
|
||||
| Apple M1 Pro | 1.0 µs | 1.5 µs | 19.4 µs | 118 µs | 69 µs | 35 µs |
|
||||
| Apple M2 Max | 0.9 µs | 1.5 µs | 17.4 µs | 103 µs | 60 µs | 31 µs |
|
||||
| Amazon Graviton 3 | 1.4 µs | | | | 69 µs | 41 µs |
|
||||
| AMD Ryzen 9 5950X | 0.8 µs | 1.7 µs | 15.7 µs | 120 µs | 72 µs | |
|
||||
| Intel Core i5-8279U | 0.9 µs | | | | 107 µs | 56 µs |
|
||||
| Intel Xeon 8375C | 0.8 µs | | | | 110 µs | |
|
||||
|
||||
Notes:
|
||||
- On Graviton 3, RPO256 and RPX256 are run with SVE acceleration enabled.
|
||||
|
||||
### Instructions
|
||||
Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks for RPO and BLAKE3, clone the current repository, and from the root directory of the repo run the following:
|
||||
|
||||
@@ -3,6 +3,7 @@ use miden_crypto::{
|
||||
hash::{
|
||||
blake::Blake3_256,
|
||||
rpo::{Rpo256, RpoDigest},
|
||||
rpx::{Rpx256, RpxDigest},
|
||||
},
|
||||
Felt,
|
||||
};
|
||||
@@ -57,6 +58,54 @@ fn rpo256_sequential(c: &mut Criterion) {
|
||||
});
|
||||
}
|
||||
|
||||
fn rpx256_2to1(c: &mut Criterion) {
|
||||
let v: [RpxDigest; 2] = [Rpx256::hash(&[1_u8]), Rpx256::hash(&[2_u8])];
|
||||
c.bench_function("RPX256 2-to-1 hashing (cached)", |bench| {
|
||||
bench.iter(|| Rpx256::merge(black_box(&v)))
|
||||
});
|
||||
|
||||
c.bench_function("RPX256 2-to-1 hashing (random)", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
[
|
||||
Rpx256::hash(&rand_value::<u64>().to_le_bytes()),
|
||||
Rpx256::hash(&rand_value::<u64>().to_le_bytes()),
|
||||
]
|
||||
},
|
||||
|state| Rpx256::merge(&state),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn rpx256_sequential(c: &mut Criterion) {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.into_iter()
|
||||
.map(Felt::new)
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
.expect("should not fail");
|
||||
c.bench_function("RPX256 sequential hashing (cached)", |bench| {
|
||||
bench.iter(|| Rpx256::hash_elements(black_box(&v)))
|
||||
});
|
||||
|
||||
c.bench_function("RPX256 sequential hashing (random)", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.into_iter()
|
||||
.map(|_| Felt::new(rand_value()))
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
.expect("should not fail");
|
||||
v
|
||||
},
|
||||
|state| Rpx256::hash_elements(&state),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn blake3_2to1(c: &mut Criterion) {
|
||||
let v: [<Blake3_256 as Hasher>::Digest; 2] =
|
||||
[Blake3_256::hash(&[1_u8]), Blake3_256::hash(&[2_u8])];
|
||||
@@ -106,5 +155,13 @@ fn blake3_sequential(c: &mut Criterion) {
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(hash_group, rpo256_2to1, rpo256_sequential, blake3_2to1, blake3_sequential);
|
||||
criterion_group!(
|
||||
hash_group,
|
||||
rpx256_2to1,
|
||||
rpx256_sequential,
|
||||
rpo256_2to1,
|
||||
rpo256_sequential,
|
||||
blake3_2to1,
|
||||
blake3_sequential
|
||||
);
|
||||
criterion_main!(hash_group);
|
||||
|
||||
@@ -18,8 +18,8 @@ fn smt_rpo(c: &mut Criterion) {
|
||||
(i, word)
|
||||
})
|
||||
.collect();
|
||||
let tree = SimpleSmt::new(depth).unwrap().with_leaves(entries).unwrap();
|
||||
trees.push(tree);
|
||||
let tree = SimpleSmt::with_leaves(depth, entries).unwrap();
|
||||
trees.push((tree, count));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,10 +29,9 @@ fn smt_rpo(c: &mut Criterion) {
|
||||
|
||||
let mut insert = c.benchmark_group(format!("smt update_leaf"));
|
||||
|
||||
for tree in trees.iter_mut() {
|
||||
for (tree, count) in trees.iter_mut() {
|
||||
let depth = tree.depth();
|
||||
let count = tree.leaves_count() as u64;
|
||||
let key = count >> 2;
|
||||
let key = *count >> 2;
|
||||
insert.bench_with_input(
|
||||
format!("simple smt(depth:{depth},count:{count})"),
|
||||
&(key, leaf),
|
||||
@@ -48,10 +47,9 @@ fn smt_rpo(c: &mut Criterion) {
|
||||
|
||||
let mut path = c.benchmark_group(format!("smt get_leaf_path"));
|
||||
|
||||
for tree in trees.iter_mut() {
|
||||
for (tree, count) in trees.iter_mut() {
|
||||
let depth = tree.depth();
|
||||
let count = tree.leaves_count() as u64;
|
||||
let key = count >> 2;
|
||||
let key = *count >> 2;
|
||||
path.bench_with_input(
|
||||
format!("simple smt(depth:{depth},count:{count})"),
|
||||
&key,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
|
||||
use miden_crypto::merkle::{MerkleStore, MerkleTree, NodeIndex, SimpleSmt};
|
||||
use miden_crypto::merkle::{DefaultMerkleStore as MerkleStore, MerkleTree, NodeIndex, SimpleSmt};
|
||||
use miden_crypto::Word;
|
||||
use miden_crypto::{hash::rpo::RpoDigest, Felt};
|
||||
use rand_utils::{rand_array, rand_value};
|
||||
@@ -104,10 +104,7 @@ fn get_leaf_simplesmt(c: &mut Criterion) {
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let smt = SimpleSmt::new(SimpleSmt::MAX_DEPTH)
|
||||
.unwrap()
|
||||
.with_leaves(smt_leaves.clone())
|
||||
.unwrap();
|
||||
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let depth = smt.depth();
|
||||
let root = smt.root();
|
||||
@@ -215,10 +212,7 @@ fn get_node_simplesmt(c: &mut Criterion) {
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let smt = SimpleSmt::new(SimpleSmt::MAX_DEPTH)
|
||||
.unwrap()
|
||||
.with_leaves(smt_leaves.clone())
|
||||
.unwrap();
|
||||
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let root = smt.root();
|
||||
let half_depth = smt.depth() / 2;
|
||||
@@ -292,10 +286,7 @@ fn get_leaf_path_simplesmt(c: &mut Criterion) {
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let smt = SimpleSmt::new(SimpleSmt::MAX_DEPTH)
|
||||
.unwrap()
|
||||
.with_leaves(smt_leaves.clone())
|
||||
.unwrap();
|
||||
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let depth = smt.depth();
|
||||
let root = smt.root();
|
||||
@@ -361,7 +352,7 @@ fn new(c: &mut Criterion) {
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>()
|
||||
},
|
||||
|l| black_box(SimpleSmt::new(SimpleSmt::MAX_DEPTH).unwrap().with_leaves(l)),
|
||||
|l| black_box(SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, l)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
@@ -376,7 +367,7 @@ fn new(c: &mut Criterion) {
|
||||
.collect::<Vec<(u64, Word)>>()
|
||||
},
|
||||
|l| {
|
||||
let smt = SimpleSmt::new(SimpleSmt::MAX_DEPTH).unwrap().with_leaves(l).unwrap();
|
||||
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, l).unwrap();
|
||||
black_box(MerkleStore::from(&smt));
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
@@ -418,7 +409,7 @@ fn update_leaf_merkletree(c: &mut Criterion) {
|
||||
// The MerkleTree automatically updates its internal root, the Store maintains
|
||||
// the old root and adds the new one. Here we update the root to have a fair
|
||||
// comparison
|
||||
store_root = store.set_node(root, index, value).unwrap().root;
|
||||
store_root = store.set_node(root, index, value.into()).unwrap().root;
|
||||
black_box(store_root)
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
@@ -442,10 +433,7 @@ fn update_leaf_simplesmt(c: &mut Criterion) {
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let mut smt = SimpleSmt::new(SimpleSmt::MAX_DEPTH)
|
||||
.unwrap()
|
||||
.with_leaves(smt_leaves.clone())
|
||||
.unwrap();
|
||||
let mut smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
|
||||
let mut store = MerkleStore::from(&smt);
|
||||
let depth = smt.depth();
|
||||
let root = smt.root();
|
||||
@@ -467,7 +455,7 @@ fn update_leaf_simplesmt(c: &mut Criterion) {
|
||||
// The MerkleTree automatically updates its internal root, the Store maintains
|
||||
// the old root and adds the new one. Here we update the root to have a fair
|
||||
// comparison
|
||||
store_root = store.set_node(root, index, value).unwrap().root;
|
||||
store_root = store.set_node(root, index, value.into()).unwrap().root;
|
||||
black_box(store_root)
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
|
||||
50
build.rs
Normal file
50
build.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
fn main() {
|
||||
#[cfg(feature = "std")]
|
||||
compile_rpo_falcon();
|
||||
|
||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
compile_arch_arm64_sve();
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
fn compile_rpo_falcon() {
|
||||
use std::path::PathBuf;
|
||||
|
||||
const RPO_FALCON_PATH: &str = "src/dsa/rpo_falcon512/falcon_c";
|
||||
|
||||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/falcon.h");
|
||||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/falcon.c");
|
||||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/rpo.h");
|
||||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/rpo.c");
|
||||
|
||||
let target_dir: PathBuf = ["PQClean", "crypto_sign", "falcon-512", "clean"].iter().collect();
|
||||
let common_dir: PathBuf = ["PQClean", "common"].iter().collect();
|
||||
|
||||
let scheme_files = glob::glob(target_dir.join("*.c").to_str().unwrap()).unwrap();
|
||||
let common_files = glob::glob(common_dir.join("*.c").to_str().unwrap()).unwrap();
|
||||
|
||||
cc::Build::new()
|
||||
.include(&common_dir)
|
||||
.include(target_dir)
|
||||
.files(scheme_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned()))
|
||||
.files(common_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned()))
|
||||
.file(format!("{RPO_FALCON_PATH}/falcon.c"))
|
||||
.file(format!("{RPO_FALCON_PATH}/rpo.c"))
|
||||
.flag("-O3")
|
||||
.compile("rpo_falcon512");
|
||||
}
|
||||
|
||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
fn compile_arch_arm64_sve() {
|
||||
const RPO_SVE_PATH: &str = "arch/arm64-sve/rpo";
|
||||
|
||||
println!("cargo:rerun-if-changed={RPO_SVE_PATH}/library.c");
|
||||
println!("cargo:rerun-if-changed={RPO_SVE_PATH}/library.h");
|
||||
println!("cargo:rerun-if-changed={RPO_SVE_PATH}/rpo_hash.h");
|
||||
|
||||
cc::Build::new()
|
||||
.file(format!("{RPO_SVE_PATH}/library.c"))
|
||||
.flag("-march=armv8-a+sve")
|
||||
.flag("-O3")
|
||||
.compile("rpo_sve");
|
||||
}
|
||||
@@ -16,5 +16,6 @@ newline_style = "Unix"
|
||||
#normalize_doc_attributes = true
|
||||
#reorder_impl_items = true
|
||||
single_line_if_else_max_width = 60
|
||||
struct_lit_width = 40
|
||||
use_field_init_shorthand = true
|
||||
use_try_shorthand = true
|
||||
|
||||
3
src/dsa/mod.rs
Normal file
3
src/dsa/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
//! Digital signature schemes supported by default in the Miden VM.
|
||||
|
||||
pub mod rpo_falcon512;
|
||||
55
src/dsa/rpo_falcon512/error.rs
Normal file
55
src/dsa/rpo_falcon512/error.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use super::{LOG_N, MODULUS, PK_LEN};
|
||||
use core::fmt;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum FalconError {
|
||||
KeyGenerationFailed,
|
||||
PubKeyDecodingExtraData,
|
||||
PubKeyDecodingInvalidCoefficient(u32),
|
||||
PubKeyDecodingInvalidLength(usize),
|
||||
PubKeyDecodingInvalidTag(u8),
|
||||
SigDecodingTooBigHighBits(u32),
|
||||
SigDecodingInvalidRemainder,
|
||||
SigDecodingNonZeroUnusedBitsLastByte,
|
||||
SigDecodingMinusZero,
|
||||
SigDecodingIncorrectEncodingAlgorithm,
|
||||
SigDecodingNotSupportedDegree(u8),
|
||||
SigGenerationFailed,
|
||||
}
|
||||
|
||||
impl fmt::Display for FalconError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use FalconError::*;
|
||||
match self {
|
||||
KeyGenerationFailed => write!(f, "Failed to generate a private-public key pair"),
|
||||
PubKeyDecodingExtraData => {
|
||||
write!(f, "Failed to decode public key: input not fully consumed")
|
||||
}
|
||||
PubKeyDecodingInvalidCoefficient(val) => {
|
||||
write!(f, "Failed to decode public key: coefficient {val} is greater than or equal to the field modulus {MODULUS}")
|
||||
}
|
||||
PubKeyDecodingInvalidLength(len) => {
|
||||
write!(f, "Failed to decode public key: expected {PK_LEN} bytes but received {len}")
|
||||
}
|
||||
PubKeyDecodingInvalidTag(byte) => {
|
||||
write!(f, "Failed to decode public key: expected the first byte to be {LOG_N} but was {byte}")
|
||||
}
|
||||
SigDecodingTooBigHighBits(m) => {
|
||||
write!(f, "Failed to decode signature: high bits {m} exceed 2048")
|
||||
}
|
||||
SigDecodingInvalidRemainder => {
|
||||
write!(f, "Failed to decode signature: incorrect remaining data")
|
||||
}
|
||||
SigDecodingNonZeroUnusedBitsLastByte => {
|
||||
write!(f, "Failed to decode signature: Non-zero unused bits in the last byte")
|
||||
}
|
||||
SigDecodingMinusZero => write!(f, "Failed to decode signature: -0 is forbidden"),
|
||||
SigDecodingIncorrectEncodingAlgorithm => write!(f, "Failed to decode signature: not supported encoding algorithm"),
|
||||
SigDecodingNotSupportedDegree(log_n) => write!(f, "Failed to decode signature: only supported irreducible polynomial degree is 512, 2^{log_n} was provided"),
|
||||
SigGenerationFailed => write!(f, "Failed to generate a signature"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for FalconError {}
|
||||
402
src/dsa/rpo_falcon512/falcon_c/falcon.c
Normal file
402
src/dsa/rpo_falcon512/falcon_c/falcon.c
Normal file
@@ -0,0 +1,402 @@
|
||||
/*
|
||||
* Wrapper for implementing the PQClean API.
|
||||
*/
|
||||
|
||||
#include <string.h>
|
||||
#include "randombytes.h"
|
||||
#include "falcon.h"
|
||||
#include "inner.h"
|
||||
#include "rpo.h"
|
||||
|
||||
#define NONCELEN 40
|
||||
|
||||
/*
|
||||
* Encoding formats (nnnn = log of degree, 9 for Falcon-512, 10 for Falcon-1024)
|
||||
*
|
||||
* private key:
|
||||
* header byte: 0101nnnn
|
||||
* private f (6 or 5 bits by element, depending on degree)
|
||||
* private g (6 or 5 bits by element, depending on degree)
|
||||
* private F (8 bits by element)
|
||||
*
|
||||
* public key:
|
||||
* header byte: 0000nnnn
|
||||
* public h (14 bits by element)
|
||||
*
|
||||
* signature:
|
||||
* header byte: 0011nnnn
|
||||
* nonce 40 bytes
|
||||
* value (12 bits by element)
|
||||
*
|
||||
* message + signature:
|
||||
* signature length (2 bytes, big-endian)
|
||||
* nonce 40 bytes
|
||||
* message
|
||||
* header byte: 0010nnnn
|
||||
* value (12 bits by element)
|
||||
* (signature length is 1+len(value), not counting the nonce)
|
||||
*/
|
||||
|
||||
/* see falcon.h */
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
uint8_t *pk,
|
||||
uint8_t *sk,
|
||||
unsigned char *seed
|
||||
) {
|
||||
union
|
||||
{
|
||||
uint8_t b[FALCON_KEYGEN_TEMP_9];
|
||||
uint64_t dummy_u64;
|
||||
fpr dummy_fpr;
|
||||
} tmp;
|
||||
int8_t f[512], g[512], F[512];
|
||||
uint16_t h[512];
|
||||
inner_shake256_context rng;
|
||||
size_t u, v;
|
||||
|
||||
/*
|
||||
* Generate key pair.
|
||||
*/
|
||||
inner_shake256_init(&rng);
|
||||
inner_shake256_inject(&rng, seed, sizeof seed);
|
||||
inner_shake256_flip(&rng);
|
||||
PQCLEAN_FALCON512_CLEAN_keygen(&rng, f, g, F, NULL, h, 9, tmp.b);
|
||||
inner_shake256_ctx_release(&rng);
|
||||
|
||||
/*
|
||||
* Encode private key.
|
||||
*/
|
||||
sk[0] = 0x50 + 9;
|
||||
u = 1;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode(
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u,
|
||||
f, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9]);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode(
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u,
|
||||
g, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9]);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode(
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u,
|
||||
F, 9, PQCLEAN_FALCON512_CLEAN_max_FG_bits[9]);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
if (u != PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
/*
|
||||
* Encode public key.
|
||||
*/
|
||||
pk[0] = 0x00 + 9;
|
||||
v = PQCLEAN_FALCON512_CLEAN_modq_encode(
|
||||
pk + 1, PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1,
|
||||
h, 9);
|
||||
if (v != PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
||||
uint8_t *pk,
|
||||
uint8_t *sk
|
||||
) {
|
||||
unsigned char seed[48];
|
||||
|
||||
/*
|
||||
* Generate a random seed.
|
||||
*/
|
||||
randombytes(seed, sizeof seed);
|
||||
|
||||
return PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(pk, sk, seed);
|
||||
}
|
||||
|
||||
/*
|
||||
* Compute the signature. nonce[] receives the nonce and must have length
|
||||
* NONCELEN bytes. sigbuf[] receives the signature value (without nonce
|
||||
* or header byte), with *sigbuflen providing the maximum value length and
|
||||
* receiving the actual value length.
|
||||
*
|
||||
* If a signature could be computed but not encoded because it would
|
||||
* exceed the output buffer size, then a new signature is computed. If
|
||||
* the provided buffer size is too low, this could loop indefinitely, so
|
||||
* the caller must provide a size that can accommodate signatures with a
|
||||
* large enough probability.
|
||||
*
|
||||
* Return value: 0 on success, -1 on error.
|
||||
*/
|
||||
static int do_sign(
|
||||
uint8_t *nonce,
|
||||
uint8_t *sigbuf,
|
||||
size_t *sigbuflen,
|
||||
const uint8_t *m,
|
||||
size_t mlen,
|
||||
const uint8_t *sk
|
||||
) {
|
||||
union
|
||||
{
|
||||
uint8_t b[72 * 512];
|
||||
uint64_t dummy_u64;
|
||||
fpr dummy_fpr;
|
||||
} tmp;
|
||||
int8_t f[512], g[512], F[512], G[512];
|
||||
struct
|
||||
{
|
||||
int16_t sig[512];
|
||||
uint16_t hm[512];
|
||||
} r;
|
||||
unsigned char seed[48];
|
||||
inner_shake256_context sc;
|
||||
rpo128_context rc;
|
||||
size_t u, v;
|
||||
|
||||
/*
|
||||
* Decode the private key.
|
||||
*/
|
||||
if (sk[0] != 0x50 + 9)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u = 1;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode(
|
||||
f, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9],
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode(
|
||||
g, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9],
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode(
|
||||
F, 9, PQCLEAN_FALCON512_CLEAN_max_FG_bits[9],
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
if (u != PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
if (!PQCLEAN_FALCON512_CLEAN_complete_private(G, f, g, F, 9, tmp.b))
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
/*
|
||||
* Create a random nonce (40 bytes).
|
||||
*/
|
||||
randombytes(nonce, NONCELEN);
|
||||
|
||||
/* ==== Start: Deviation from the reference implementation ================================= */
|
||||
|
||||
// Transform the nonce into 8 chunks each of size 5 bytes. We do this in order to be sure that
|
||||
// the conversion to field elements succeeds
|
||||
uint8_t buffer[64];
|
||||
memset(buffer, 0, 64);
|
||||
for (size_t i = 0; i < 8; i++)
|
||||
{
|
||||
buffer[8 * i] = nonce[5 * i];
|
||||
buffer[8 * i + 1] = nonce[5 * i + 1];
|
||||
buffer[8 * i + 2] = nonce[5 * i + 2];
|
||||
buffer[8 * i + 3] = nonce[5 * i + 3];
|
||||
buffer[8 * i + 4] = nonce[5 * i + 4];
|
||||
}
|
||||
|
||||
/*
|
||||
* Hash message nonce + message into a vector.
|
||||
*/
|
||||
rpo128_init(&rc);
|
||||
rpo128_absorb(&rc, buffer, NONCELEN + 24);
|
||||
rpo128_absorb(&rc, m, mlen);
|
||||
rpo128_finalize(&rc);
|
||||
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(&rc, r.hm, 9);
|
||||
rpo128_release(&rc);
|
||||
|
||||
/* ==== End: Deviation from the reference implementation =================================== */
|
||||
|
||||
/*
|
||||
* Initialize a RNG.
|
||||
*/
|
||||
randombytes(seed, sizeof seed);
|
||||
inner_shake256_init(&sc);
|
||||
inner_shake256_inject(&sc, seed, sizeof seed);
|
||||
inner_shake256_flip(&sc);
|
||||
|
||||
/*
|
||||
* Compute and return the signature. This loops until a signature
|
||||
* value is found that fits in the provided buffer.
|
||||
*/
|
||||
for (;;)
|
||||
{
|
||||
PQCLEAN_FALCON512_CLEAN_sign_dyn(r.sig, &sc, f, g, F, G, r.hm, 9, tmp.b);
|
||||
v = PQCLEAN_FALCON512_CLEAN_comp_encode(sigbuf, *sigbuflen, r.sig, 9);
|
||||
if (v != 0)
|
||||
{
|
||||
inner_shake256_ctx_release(&sc);
|
||||
*sigbuflen = v;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Verify a signature. The nonce has size NONCELEN bytes. sigbuf[]
|
||||
* (of size sigbuflen) contains the signature value, not including the
|
||||
* header byte or nonce. Return value is 0 on success, -1 on error.
|
||||
*/
|
||||
static int do_verify(
|
||||
const uint8_t *nonce,
|
||||
const uint8_t *sigbuf,
|
||||
size_t sigbuflen,
|
||||
const uint8_t *m,
|
||||
size_t mlen,
|
||||
const uint8_t *pk
|
||||
) {
|
||||
union
|
||||
{
|
||||
uint8_t b[2 * 512];
|
||||
uint64_t dummy_u64;
|
||||
fpr dummy_fpr;
|
||||
} tmp;
|
||||
uint16_t h[512], hm[512];
|
||||
int16_t sig[512];
|
||||
rpo128_context rc;
|
||||
|
||||
/*
|
||||
* Decode public key.
|
||||
*/
|
||||
if (pk[0] != 0x00 + 9)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
if (PQCLEAN_FALCON512_CLEAN_modq_decode(h, 9,
|
||||
pk + 1, PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1)
|
||||
!= PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
PQCLEAN_FALCON512_CLEAN_to_ntt_monty(h, 9);
|
||||
|
||||
/*
|
||||
* Decode signature.
|
||||
*/
|
||||
if (sigbuflen == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
if (PQCLEAN_FALCON512_CLEAN_comp_decode(sig, 9, sigbuf, sigbuflen) != sigbuflen)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
/* ==== Start: Deviation from the reference implementation ================================= */
|
||||
|
||||
/*
|
||||
* Hash nonce + message into a vector.
|
||||
*/
|
||||
|
||||
// Transform the nonce into 8 chunks each of size 5 bytes. We do this in order to be sure that
|
||||
// the conversion to field elements succeeds
|
||||
uint8_t buffer[64];
|
||||
memset(buffer, 0, 64);
|
||||
for (size_t i = 0; i < 8; i++)
|
||||
{
|
||||
buffer[8 * i] = nonce[5 * i];
|
||||
buffer[8 * i + 1] = nonce[5 * i + 1];
|
||||
buffer[8 * i + 2] = nonce[5 * i + 2];
|
||||
buffer[8 * i + 3] = nonce[5 * i + 3];
|
||||
buffer[8 * i + 4] = nonce[5 * i + 4];
|
||||
}
|
||||
|
||||
rpo128_init(&rc);
|
||||
rpo128_absorb(&rc, buffer, NONCELEN + 24);
|
||||
rpo128_absorb(&rc, m, mlen);
|
||||
rpo128_finalize(&rc);
|
||||
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(&rc, hm, 9);
|
||||
rpo128_release(&rc);
|
||||
|
||||
/* === End: Deviation from the reference implementation ==================================== */
|
||||
|
||||
/*
|
||||
* Verify signature.
|
||||
*/
|
||||
if (!PQCLEAN_FALCON512_CLEAN_verify_raw(hm, sig, h, 9, tmp.b))
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* see falcon.h */
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
uint8_t *sig,
|
||||
size_t *siglen,
|
||||
const uint8_t *m,
|
||||
size_t mlen,
|
||||
const uint8_t *sk
|
||||
) {
|
||||
/*
|
||||
* The PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES constant is used for
|
||||
* the signed message object (as produced by crypto_sign())
|
||||
* and includes a two-byte length value, so we take care here
|
||||
* to only generate signatures that are two bytes shorter than
|
||||
* the maximum. This is done to ensure that crypto_sign()
|
||||
* and crypto_sign_signature() produce the exact same signature
|
||||
* value, if used on the same message, with the same private key,
|
||||
* and using the same output from randombytes() (this is for
|
||||
* reproducibility of tests).
|
||||
*/
|
||||
size_t vlen;
|
||||
|
||||
vlen = PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES - NONCELEN - 3;
|
||||
if (do_sign(sig + 1, sig + 1 + NONCELEN, &vlen, m, mlen, sk) < 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
sig[0] = 0x30 + 9;
|
||||
*siglen = 1 + NONCELEN + vlen;
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* see falcon.h */
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
const uint8_t *sig,
|
||||
size_t siglen,
|
||||
const uint8_t *m,
|
||||
size_t mlen,
|
||||
const uint8_t *pk
|
||||
) {
|
||||
if (siglen < 1 + NONCELEN)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
if (sig[0] != 0x30 + 9)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
return do_verify(sig + 1, sig + 1 + NONCELEN, siglen - 1 - NONCELEN, m, mlen, pk);
|
||||
}
|
||||
66
src/dsa/rpo_falcon512/falcon_c/falcon.h
Normal file
66
src/dsa/rpo_falcon512/falcon_c/falcon.h
Normal file
@@ -0,0 +1,66 @@
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES 1281
|
||||
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES 897
|
||||
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES 666
|
||||
|
||||
/*
|
||||
* Generate a new key pair. Public key goes into pk[], private key in sk[].
|
||||
* Key sizes are exact (in bytes):
|
||||
* public (pk): PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES
|
||||
* private (sk): PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES
|
||||
*
|
||||
* Return value: 0 on success, -1 on error.
|
||||
*
|
||||
* Note: This implementation follows the reference implementation in PQClean
|
||||
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512
|
||||
* verbatim except for the sections that are marked otherwise.
|
||||
*/
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
||||
uint8_t *pk, uint8_t *sk);
|
||||
|
||||
/*
|
||||
* Generate a new key pair from seed. Public key goes into pk[], private key in sk[].
|
||||
* Key sizes are exact (in bytes):
|
||||
* public (pk): PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES
|
||||
* private (sk): PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES
|
||||
*
|
||||
* Return value: 0 on success, -1 on error.
|
||||
*/
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
uint8_t *pk, uint8_t *sk, unsigned char *seed);
|
||||
|
||||
/*
|
||||
* Compute a signature on a provided message (m, mlen), with a given
|
||||
* private key (sk). Signature is written in sig[], with length written
|
||||
* into *siglen. Signature length is variable; maximum signature length
|
||||
* (in bytes) is PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES.
|
||||
*
|
||||
* sig[], m[] and sk[] may overlap each other arbitrarily.
|
||||
*
|
||||
* Return value: 0 on success, -1 on error.
|
||||
*
|
||||
* Note: This implementation follows the reference implementation in PQClean
|
||||
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512
|
||||
* verbatim except for the sections that are marked otherwise.
|
||||
*/
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
uint8_t *sig, size_t *siglen,
|
||||
const uint8_t *m, size_t mlen, const uint8_t *sk);
|
||||
|
||||
/*
|
||||
* Verify a signature (sig, siglen) on a message (m, mlen) with a given
|
||||
* public key (pk).
|
||||
*
|
||||
* sig[], m[] and pk[] may overlap each other arbitrarily.
|
||||
*
|
||||
* Return value: 0 on success, -1 on error.
|
||||
*
|
||||
* Note: This implementation follows the reference implementation in PQClean
|
||||
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512
|
||||
* verbatim except for the sections that are marked otherwise.
|
||||
*/
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
const uint8_t *sig, size_t siglen,
|
||||
const uint8_t *m, size_t mlen, const uint8_t *pk);
|
||||
582
src/dsa/rpo_falcon512/falcon_c/rpo.c
Normal file
582
src/dsa/rpo_falcon512/falcon_c/rpo.c
Normal file
@@ -0,0 +1,582 @@
|
||||
/*
|
||||
* RPO implementation.
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
/* ================================================================================================
|
||||
* Modular Arithmetic
|
||||
*/
|
||||
|
||||
#define P 0xFFFFFFFF00000001
|
||||
#define M 12289
|
||||
|
||||
// From https://github.com/ncw/iprime/blob/master/mod_math_noasm.go
|
||||
static uint64_t add_mod_p(uint64_t a, uint64_t b)
|
||||
{
|
||||
a = P - a;
|
||||
uint64_t res = b - a;
|
||||
if (b < a)
|
||||
res += P;
|
||||
return res;
|
||||
}
|
||||
|
||||
static uint64_t sub_mod_p(uint64_t a, uint64_t b)
|
||||
{
|
||||
uint64_t r = a - b;
|
||||
if (a < b)
|
||||
r += P;
|
||||
return r;
|
||||
}
|
||||
|
||||
static uint64_t reduce_mod_p(uint64_t b, uint64_t a)
|
||||
{
|
||||
uint32_t d = b >> 32,
|
||||
c = b;
|
||||
if (a >= P)
|
||||
a -= P;
|
||||
a = sub_mod_p(a, c);
|
||||
a = sub_mod_p(a, d);
|
||||
a = add_mod_p(a, ((uint64_t)c) << 32);
|
||||
return a;
|
||||
}
|
||||
|
||||
static uint64_t mult_mod_p(uint64_t x, uint64_t y)
|
||||
{
|
||||
uint32_t a = x,
|
||||
b = x >> 32,
|
||||
c = y,
|
||||
d = y >> 32;
|
||||
|
||||
/* first synthesize the product using 32*32 -> 64 bit multiplies */
|
||||
x = b * (uint64_t)c; /* b*c */
|
||||
y = a * (uint64_t)d; /* a*d */
|
||||
uint64_t e = a * (uint64_t)c, /* a*c */
|
||||
f = b * (uint64_t)d, /* b*d */
|
||||
t;
|
||||
|
||||
x += y; /* b*c + a*d */
|
||||
/* carry? */
|
||||
if (x < y)
|
||||
f += 1LL << 32; /* carry into upper 32 bits - can't overflow */
|
||||
|
||||
t = x << 32;
|
||||
e += t; /* a*c + LSW(b*c + a*d) */
|
||||
/* carry? */
|
||||
if (e < t)
|
||||
f += 1; /* carry into upper 64 bits - can't overflow*/
|
||||
t = x >> 32;
|
||||
f += t; /* b*d + MSW(b*c + a*d) */
|
||||
/* can't overflow */
|
||||
|
||||
/* now reduce: (b*d + MSW(b*c + a*d), a*c + LSW(b*c + a*d)) */
|
||||
return reduce_mod_p(f, e);
|
||||
}
|
||||
|
||||
/* ================================================================================================
|
||||
* RPO128 Permutation
|
||||
*/
|
||||
|
||||
#define STATE_WIDTH 12
|
||||
#define NUM_ROUNDS 7
|
||||
|
||||
/*
|
||||
* MDS matrix
|
||||
*/
|
||||
static const uint64_t MDS[12][12] = {
|
||||
{ 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8 },
|
||||
{ 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21 },
|
||||
{ 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22 },
|
||||
{ 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6 },
|
||||
{ 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7 },
|
||||
{ 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9 },
|
||||
{ 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10 },
|
||||
{ 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13 },
|
||||
{ 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26 },
|
||||
{ 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8 },
|
||||
{ 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23 },
|
||||
{ 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7 },
|
||||
};
|
||||
|
||||
/*
|
||||
* Round constants.
|
||||
*/
|
||||
static const uint64_t ARK1[7][12] = {
|
||||
{
|
||||
5789762306288267392ULL,
|
||||
6522564764413701783ULL,
|
||||
17809893479458208203ULL,
|
||||
107145243989736508ULL,
|
||||
6388978042437517382ULL,
|
||||
15844067734406016715ULL,
|
||||
9975000513555218239ULL,
|
||||
3344984123768313364ULL,
|
||||
9959189626657347191ULL,
|
||||
12960773468763563665ULL,
|
||||
9602914297752488475ULL,
|
||||
16657542370200465908ULL,
|
||||
},
|
||||
{
|
||||
12987190162843096997ULL,
|
||||
653957632802705281ULL,
|
||||
4441654670647621225ULL,
|
||||
4038207883745915761ULL,
|
||||
5613464648874830118ULL,
|
||||
13222989726778338773ULL,
|
||||
3037761201230264149ULL,
|
||||
16683759727265180203ULL,
|
||||
8337364536491240715ULL,
|
||||
3227397518293416448ULL,
|
||||
8110510111539674682ULL,
|
||||
2872078294163232137ULL,
|
||||
},
|
||||
{
|
||||
18072785500942327487ULL,
|
||||
6200974112677013481ULL,
|
||||
17682092219085884187ULL,
|
||||
10599526828986756440ULL,
|
||||
975003873302957338ULL,
|
||||
8264241093196931281ULL,
|
||||
10065763900435475170ULL,
|
||||
2181131744534710197ULL,
|
||||
6317303992309418647ULL,
|
||||
1401440938888741532ULL,
|
||||
8884468225181997494ULL,
|
||||
13066900325715521532ULL,
|
||||
},
|
||||
{
|
||||
5674685213610121970ULL,
|
||||
5759084860419474071ULL,
|
||||
13943282657648897737ULL,
|
||||
1352748651966375394ULL,
|
||||
17110913224029905221ULL,
|
||||
1003883795902368422ULL,
|
||||
4141870621881018291ULL,
|
||||
8121410972417424656ULL,
|
||||
14300518605864919529ULL,
|
||||
13712227150607670181ULL,
|
||||
17021852944633065291ULL,
|
||||
6252096473787587650ULL,
|
||||
},
|
||||
{
|
||||
4887609836208846458ULL,
|
||||
3027115137917284492ULL,
|
||||
9595098600469470675ULL,
|
||||
10528569829048484079ULL,
|
||||
7864689113198939815ULL,
|
||||
17533723827845969040ULL,
|
||||
5781638039037710951ULL,
|
||||
17024078752430719006ULL,
|
||||
109659393484013511ULL,
|
||||
7158933660534805869ULL,
|
||||
2955076958026921730ULL,
|
||||
7433723648458773977ULL,
|
||||
},
|
||||
{
|
||||
16308865189192447297ULL,
|
||||
11977192855656444890ULL,
|
||||
12532242556065780287ULL,
|
||||
14594890931430968898ULL,
|
||||
7291784239689209784ULL,
|
||||
5514718540551361949ULL,
|
||||
10025733853830934803ULL,
|
||||
7293794580341021693ULL,
|
||||
6728552937464861756ULL,
|
||||
6332385040983343262ULL,
|
||||
13277683694236792804ULL,
|
||||
2600778905124452676ULL,
|
||||
},
|
||||
{
|
||||
7123075680859040534ULL,
|
||||
1034205548717903090ULL,
|
||||
7717824418247931797ULL,
|
||||
3019070937878604058ULL,
|
||||
11403792746066867460ULL,
|
||||
10280580802233112374ULL,
|
||||
337153209462421218ULL,
|
||||
13333398568519923717ULL,
|
||||
3596153696935337464ULL,
|
||||
8104208463525993784ULL,
|
||||
14345062289456085693ULL,
|
||||
17036731477169661256ULL,
|
||||
}};
|
||||
|
||||
const uint64_t ARK2[7][12] = {
|
||||
{
|
||||
6077062762357204287ULL,
|
||||
15277620170502011191ULL,
|
||||
5358738125714196705ULL,
|
||||
14233283787297595718ULL,
|
||||
13792579614346651365ULL,
|
||||
11614812331536767105ULL,
|
||||
14871063686742261166ULL,
|
||||
10148237148793043499ULL,
|
||||
4457428952329675767ULL,
|
||||
15590786458219172475ULL,
|
||||
10063319113072092615ULL,
|
||||
14200078843431360086ULL,
|
||||
},
|
||||
{
|
||||
6202948458916099932ULL,
|
||||
17690140365333231091ULL,
|
||||
3595001575307484651ULL,
|
||||
373995945117666487ULL,
|
||||
1235734395091296013ULL,
|
||||
14172757457833931602ULL,
|
||||
707573103686350224ULL,
|
||||
15453217512188187135ULL,
|
||||
219777875004506018ULL,
|
||||
17876696346199469008ULL,
|
||||
17731621626449383378ULL,
|
||||
2897136237748376248ULL,
|
||||
},
|
||||
{
|
||||
8023374565629191455ULL,
|
||||
15013690343205953430ULL,
|
||||
4485500052507912973ULL,
|
||||
12489737547229155153ULL,
|
||||
9500452585969030576ULL,
|
||||
2054001340201038870ULL,
|
||||
12420704059284934186ULL,
|
||||
355990932618543755ULL,
|
||||
9071225051243523860ULL,
|
||||
12766199826003448536ULL,
|
||||
9045979173463556963ULL,
|
||||
12934431667190679898ULL,
|
||||
},
|
||||
{
|
||||
18389244934624494276ULL,
|
||||
16731736864863925227ULL,
|
||||
4440209734760478192ULL,
|
||||
17208448209698888938ULL,
|
||||
8739495587021565984ULL,
|
||||
17000774922218161967ULL,
|
||||
13533282547195532087ULL,
|
||||
525402848358706231ULL,
|
||||
16987541523062161972ULL,
|
||||
5466806524462797102ULL,
|
||||
14512769585918244983ULL,
|
||||
10973956031244051118ULL,
|
||||
},
|
||||
{
|
||||
6982293561042362913ULL,
|
||||
14065426295947720331ULL,
|
||||
16451845770444974180ULL,
|
||||
7139138592091306727ULL,
|
||||
9012006439959783127ULL,
|
||||
14619614108529063361ULL,
|
||||
1394813199588124371ULL,
|
||||
4635111139507788575ULL,
|
||||
16217473952264203365ULL,
|
||||
10782018226466330683ULL,
|
||||
6844229992533662050ULL,
|
||||
7446486531695178711ULL,
|
||||
},
|
||||
{
|
||||
3736792340494631448ULL,
|
||||
577852220195055341ULL,
|
||||
6689998335515779805ULL,
|
||||
13886063479078013492ULL,
|
||||
14358505101923202168ULL,
|
||||
7744142531772274164ULL,
|
||||
16135070735728404443ULL,
|
||||
12290902521256031137ULL,
|
||||
12059913662657709804ULL,
|
||||
16456018495793751911ULL,
|
||||
4571485474751953524ULL,
|
||||
17200392109565783176ULL,
|
||||
},
|
||||
{
|
||||
17130398059294018733ULL,
|
||||
519782857322261988ULL,
|
||||
9625384390925085478ULL,
|
||||
1664893052631119222ULL,
|
||||
7629576092524553570ULL,
|
||||
3485239601103661425ULL,
|
||||
9755891797164033838ULL,
|
||||
15218148195153269027ULL,
|
||||
16460604813734957368ULL,
|
||||
9643968136937729763ULL,
|
||||
3611348709641382851ULL,
|
||||
18256379591337759196ULL,
|
||||
},
|
||||
};
|
||||
|
||||
static void apply_sbox(uint64_t *const state)
|
||||
{
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
uint64_t t2 = mult_mod_p(*(state + i), *(state + i));
|
||||
uint64_t t4 = mult_mod_p(t2, t2);
|
||||
|
||||
*(state + i) = mult_mod_p(*(state + i), mult_mod_p(t2, t4));
|
||||
}
|
||||
}
|
||||
|
||||
static void apply_mds(uint64_t *state)
|
||||
{
|
||||
uint64_t res[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
res[i] = 0;
|
||||
}
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
for (uint64_t j = 0; j < STATE_WIDTH; j++)
|
||||
{
|
||||
res[i] = add_mod_p(res[i], mult_mod_p(MDS[i][j], *(state + j)));
|
||||
}
|
||||
}
|
||||
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
*(state + i) = res[i];
|
||||
}
|
||||
}
|
||||
|
||||
static void apply_constants(uint64_t *const state, const uint64_t *ark)
|
||||
{
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
*(state + i) = add_mod_p(*(state + i), *(ark + i));
|
||||
}
|
||||
}
|
||||
|
||||
static void exp_acc(const uint64_t m, const uint64_t *base, const uint64_t *tail, uint64_t *const res)
|
||||
{
|
||||
for (uint64_t i = 0; i < m; i++)
|
||||
{
|
||||
for (uint64_t j = 0; j < STATE_WIDTH; j++)
|
||||
{
|
||||
if (i == 0)
|
||||
{
|
||||
*(res + j) = mult_mod_p(*(base + j), *(base + j));
|
||||
}
|
||||
else
|
||||
{
|
||||
*(res + j) = mult_mod_p(*(res + j), *(res + j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
*(res + i) = mult_mod_p(*(res + i), *(tail + i));
|
||||
}
|
||||
}
|
||||
|
||||
static void apply_inv_sbox(uint64_t *const state)
|
||||
{
|
||||
uint64_t t1[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t1[i] = 0;
|
||||
}
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t1[i] = mult_mod_p(*(state + i), *(state + i));
|
||||
}
|
||||
|
||||
uint64_t t2[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t2[i] = 0;
|
||||
}
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t2[i] = mult_mod_p(t1[i], t1[i]);
|
||||
}
|
||||
|
||||
uint64_t t3[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t3[i] = 0;
|
||||
}
|
||||
exp_acc(3, t2, t2, t3);
|
||||
|
||||
uint64_t t4[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t4[i] = 0;
|
||||
}
|
||||
exp_acc(6, t3, t3, t4);
|
||||
|
||||
uint64_t tmp[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
tmp[i] = 0;
|
||||
}
|
||||
exp_acc(12, t4, t4, tmp);
|
||||
|
||||
uint64_t t5[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t5[i] = 0;
|
||||
}
|
||||
exp_acc(6, tmp, t3, t5);
|
||||
|
||||
uint64_t t6[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t6[i] = 0;
|
||||
}
|
||||
exp_acc(31, t5, t5, t6);
|
||||
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
uint64_t a = mult_mod_p(mult_mod_p(t6[i], t6[i]), t5[i]);
|
||||
a = mult_mod_p(a, a);
|
||||
a = mult_mod_p(a, a);
|
||||
uint64_t b = mult_mod_p(mult_mod_p(t1[i], t2[i]), *(state + i));
|
||||
|
||||
*(state + i) = mult_mod_p(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
static void apply_round(uint64_t *const state, const uint64_t round)
|
||||
{
|
||||
apply_mds(state);
|
||||
apply_constants(state, ARK1[round]);
|
||||
apply_sbox(state);
|
||||
|
||||
apply_mds(state);
|
||||
apply_constants(state, ARK2[round]);
|
||||
apply_inv_sbox(state);
|
||||
}
|
||||
|
||||
static void apply_permutation(uint64_t *state)
|
||||
{
|
||||
for (uint64_t i = 0; i < NUM_ROUNDS; i++)
|
||||
{
|
||||
apply_round(state, i);
|
||||
}
|
||||
}
|
||||
|
||||
/* ================================================================================================
|
||||
* RPO128 implementation. This is supposed to substitute SHAKE256 in the hash-to-point algorithm.
|
||||
*/
|
||||
|
||||
#include "rpo.h"
|
||||
|
||||
void rpo128_init(rpo128_context *rc)
|
||||
{
|
||||
rc->dptr = 32;
|
||||
|
||||
memset(rc->st.A, 0, sizeof rc->st.A);
|
||||
}
|
||||
|
||||
void rpo128_absorb(rpo128_context *rc, const uint8_t *in, size_t len)
|
||||
{
|
||||
size_t dptr;
|
||||
|
||||
dptr = (size_t)rc->dptr;
|
||||
while (len > 0)
|
||||
{
|
||||
size_t clen, u;
|
||||
|
||||
/* 136 * 8 = 1088 bit for the rate portion in the case of SHAKE256
|
||||
* For RPO, this is 64 * 8 = 512 bits
|
||||
* The capacity for SHAKE256 is at the end while for RPO128 it is at the beginning
|
||||
*/
|
||||
clen = 96 - dptr;
|
||||
if (clen > len)
|
||||
{
|
||||
clen = len;
|
||||
}
|
||||
|
||||
for (u = 0; u < clen; u++)
|
||||
{
|
||||
rc->st.dbuf[dptr + u] = in[u];
|
||||
}
|
||||
|
||||
dptr += clen;
|
||||
in += clen;
|
||||
len -= clen;
|
||||
if (dptr == 96)
|
||||
{
|
||||
apply_permutation(rc->st.A);
|
||||
dptr = 32;
|
||||
}
|
||||
}
|
||||
rc->dptr = dptr;
|
||||
}
|
||||
|
||||
void rpo128_finalize(rpo128_context *rc)
|
||||
{
|
||||
// Set dptr to the end of the buffer, so that first call to extract will call the permutation.
|
||||
rc->dptr = 96;
|
||||
}
|
||||
|
||||
void rpo128_squeeze(rpo128_context *rc, uint8_t *out, size_t len)
|
||||
{
|
||||
size_t dptr;
|
||||
|
||||
dptr = (size_t)rc->dptr;
|
||||
while (len > 0)
|
||||
{
|
||||
size_t clen;
|
||||
|
||||
if (dptr == 96)
|
||||
{
|
||||
apply_permutation(rc->st.A);
|
||||
dptr = 32;
|
||||
}
|
||||
clen = 96 - dptr;
|
||||
if (clen > len)
|
||||
{
|
||||
clen = len;
|
||||
}
|
||||
len -= clen;
|
||||
|
||||
memcpy(out, rc->st.dbuf + dptr, clen);
|
||||
dptr += clen;
|
||||
out += clen;
|
||||
}
|
||||
rc->dptr = dptr;
|
||||
}
|
||||
|
||||
void rpo128_release(rpo128_context *rc)
|
||||
{
|
||||
memset(rc->st.A, 0, sizeof rc->st.A);
|
||||
rc->dptr = 32;
|
||||
}
|
||||
|
||||
/* ================================================================================================
|
||||
* Hash-to-Point algorithm implementation based on RPO128
|
||||
*/
|
||||
|
||||
void PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(rpo128_context *rc, uint16_t *x, unsigned logn)
|
||||
{
|
||||
/*
|
||||
* This implementation avoids the rejection sampling step needed in the
|
||||
* per-the-spec implementation. It uses a remark in https://falcon-sign.info/falcon.pdf
|
||||
* page 31, which argues that the current variant is secure for the parameters set by NIST.
|
||||
* Avoiding the rejection-sampling step leads to an implementation that is constant-time.
|
||||
* TODO: Check that the current implementation is indeed constant-time.
|
||||
*/
|
||||
size_t n;
|
||||
|
||||
n = (size_t)1 << logn;
|
||||
while (n > 0)
|
||||
{
|
||||
uint8_t buf[8];
|
||||
uint64_t w;
|
||||
|
||||
rpo128_squeeze(rc, (void *)buf, sizeof buf);
|
||||
w = ((uint64_t)(buf[7]) << 56) |
|
||||
((uint64_t)(buf[6]) << 48) |
|
||||
((uint64_t)(buf[5]) << 40) |
|
||||
((uint64_t)(buf[4]) << 32) |
|
||||
((uint64_t)(buf[3]) << 24) |
|
||||
((uint64_t)(buf[2]) << 16) |
|
||||
((uint64_t)(buf[1]) << 8) |
|
||||
((uint64_t)(buf[0]));
|
||||
|
||||
w %= M;
|
||||
|
||||
*x++ = (uint16_t)w;
|
||||
n--;
|
||||
}
|
||||
}
|
||||
83
src/dsa/rpo_falcon512/falcon_c/rpo.h
Normal file
83
src/dsa/rpo_falcon512/falcon_c/rpo.h
Normal file
@@ -0,0 +1,83 @@
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
/* ================================================================================================
|
||||
* RPO hashing algorithm related structs and methods.
|
||||
*/
|
||||
|
||||
/*
|
||||
* RPO128 context.
|
||||
*
|
||||
* This structure is used by the hashing API. It is composed of an internal state that can be
|
||||
* viewed as either:
|
||||
* 1. 12 field elements in the Miden VM.
|
||||
* 2. 96 bytes.
|
||||
*
|
||||
* The first view is used for the internal state in the context of the RPO hashing algorithm. The
|
||||
* second view is used for the buffer used to absorb the data to be hashed.
|
||||
*
|
||||
* The pointer to the buffer is updated as the data is absorbed.
|
||||
*
|
||||
* 'rpo128_context' must be initialized with rpo128_init() before first use.
|
||||
*/
|
||||
typedef struct
|
||||
{
|
||||
union
|
||||
{
|
||||
uint64_t A[12];
|
||||
uint8_t dbuf[96];
|
||||
} st;
|
||||
uint64_t dptr;
|
||||
} rpo128_context;
|
||||
|
||||
/*
|
||||
* Initializes an RPO state
|
||||
*/
|
||||
void rpo128_init(rpo128_context *rc);
|
||||
|
||||
/*
|
||||
* Absorbs an array of bytes of length 'len' into the state.
|
||||
*/
|
||||
void rpo128_absorb(rpo128_context *rc, const uint8_t *in, size_t len);
|
||||
|
||||
/*
|
||||
* Squeezes an array of bytes of length 'len' from the state.
|
||||
*/
|
||||
void rpo128_squeeze(rpo128_context *rc, uint8_t *out, size_t len);
|
||||
|
||||
/*
|
||||
* Finalizes the state in preparation for squeezing.
|
||||
*
|
||||
* This function should be called after all the data has been absorbed.
|
||||
*
|
||||
* Note that the current implementation does not perform any sort of padding for domain separation
|
||||
* purposes. The reason being that, for our purposes, we always perform the following sequence:
|
||||
* 1. Absorb a Nonce (which is always 40 bytes packed as 8 field elements).
|
||||
* 2. Absorb the message (which is always 4 field elements).
|
||||
* 3. Call finalize.
|
||||
* 4. Squeeze the output.
|
||||
* 5. Call release.
|
||||
*/
|
||||
void rpo128_finalize(rpo128_context *rc);
|
||||
|
||||
/*
|
||||
* Releases the state.
|
||||
*
|
||||
* This function should be called after the squeeze operation is finished.
|
||||
*/
|
||||
void rpo128_release(rpo128_context *rc);
|
||||
|
||||
/* ================================================================================================
|
||||
* Hash-to-Point algorithm for signature generation and signature verification.
|
||||
*/
|
||||
|
||||
/*
|
||||
* Hash-to-Point algorithm.
|
||||
*
|
||||
* This function generates a point in Z_q[x]/(phi) from a given message.
|
||||
*
|
||||
* It takes a finalized rpo128_context as input and it generates the coefficients of the polynomial
|
||||
* representing the point. The coefficients are stored in the array 'x'. The number of coefficients
|
||||
* is given by 'logn', which must in our case is 512.
|
||||
*/
|
||||
void PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(rpo128_context *rc, uint16_t *x, unsigned logn);
|
||||
189
src/dsa/rpo_falcon512/ffi.rs
Normal file
189
src/dsa/rpo_falcon512/ffi.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
use libc::c_int;
|
||||
|
||||
// C IMPLEMENTATION INTERFACE
|
||||
// ================================================================================================
|
||||
|
||||
#[link(name = "rpo_falcon512", kind = "static")]
|
||||
extern "C" {
|
||||
/// Generate a new key pair. Public key goes into pk[], private key in sk[].
|
||||
/// Key sizes are exact (in bytes):
|
||||
/// - public (pk): 897
|
||||
/// - private (sk): 1281
|
||||
///
|
||||
/// Return value: 0 on success, -1 on error.
|
||||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(pk: *mut u8, sk: *mut u8) -> c_int;
|
||||
|
||||
/// Generate a new key pair from seed. Public key goes into pk[], private key in sk[].
|
||||
/// Key sizes are exact (in bytes):
|
||||
/// - public (pk): 897
|
||||
/// - private (sk): 1281
|
||||
///
|
||||
/// Return value: 0 on success, -1 on error.
|
||||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
pk: *mut u8,
|
||||
sk: *mut u8,
|
||||
seed: *const u8,
|
||||
) -> c_int;
|
||||
|
||||
/// Compute a signature on a provided message (m, mlen), with a given private key (sk).
|
||||
/// Signature is written in sig[], with length written into *siglen. Signature length is
|
||||
/// variable; maximum signature length (in bytes) is 666.
|
||||
///
|
||||
/// sig[], m[] and sk[] may overlap each other arbitrarily.
|
||||
///
|
||||
/// Return value: 0 on success, -1 on error.
|
||||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
sig: *mut u8,
|
||||
siglen: *mut usize,
|
||||
m: *const u8,
|
||||
mlen: usize,
|
||||
sk: *const u8,
|
||||
) -> c_int;
|
||||
|
||||
// TEST HELPERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Verify a signature (sig, siglen) on a message (m, mlen) with a given public key (pk).
|
||||
///
|
||||
/// sig[], m[] and pk[] may overlap each other arbitrarily.
|
||||
///
|
||||
/// Return value: 0 on success, -1 on error.
|
||||
#[cfg(test)]
|
||||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
sig: *const u8,
|
||||
siglen: usize,
|
||||
m: *const u8,
|
||||
mlen: usize,
|
||||
pk: *const u8,
|
||||
) -> c_int;
|
||||
|
||||
/// Hash-to-Point algorithm.
|
||||
///
|
||||
/// This function generates a point in Z_q[x]/(phi) from a given message.
|
||||
///
|
||||
/// It takes a finalized rpo128_context as input and it generates the coefficients of the polynomial
|
||||
/// representing the point. The coefficients are stored in the array 'x'. The number of coefficients
|
||||
/// is given by 'logn', which must in our case is 512.
|
||||
#[cfg(test)]
|
||||
pub fn PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(
|
||||
rc: *mut Rpo128Context,
|
||||
x: *mut u16,
|
||||
logn: usize,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn rpo128_init(sc: *mut Rpo128Context);
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn rpo128_absorb(
|
||||
sc: *mut Rpo128Context,
|
||||
data: *const ::std::os::raw::c_void,
|
||||
len: libc::size_t,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn rpo128_finalize(sc: *mut Rpo128Context);
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[cfg(test)]
|
||||
pub struct Rpo128Context {
|
||||
pub content: [u64; 13usize],
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::dsa::rpo_falcon512::{NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
|
||||
use rand_utils::{rand_array, rand_value, rand_vector};
|
||||
|
||||
#[test]
|
||||
fn falcon_ffi() {
|
||||
unsafe {
|
||||
//let mut rng = rand::thread_rng();
|
||||
|
||||
// --- generate a key pair from a seed ----------------------------
|
||||
|
||||
let mut pk = [0u8; PK_LEN];
|
||||
let mut sk = [0u8; SK_LEN];
|
||||
let seed: [u8; NONCE_LEN] = rand_array();
|
||||
|
||||
assert_eq!(
|
||||
0,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
pk.as_mut_ptr(),
|
||||
sk.as_mut_ptr(),
|
||||
seed.as_ptr()
|
||||
)
|
||||
);
|
||||
|
||||
// --- sign a message and make sure it verifies -------------------
|
||||
|
||||
let mlen: usize = rand_value::<u16>() as usize;
|
||||
let msg: Vec<u8> = rand_vector(mlen);
|
||||
let mut detached_sig = [0u8; NONCE_LEN + SIG_LEN];
|
||||
let mut siglen = 0;
|
||||
|
||||
assert_eq!(
|
||||
0,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
detached_sig.as_mut_ptr(),
|
||||
&mut siglen as *mut usize,
|
||||
msg.as_ptr(),
|
||||
msg.len(),
|
||||
sk.as_ptr()
|
||||
)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
0,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
detached_sig.as_ptr(),
|
||||
siglen,
|
||||
msg.as_ptr(),
|
||||
msg.len(),
|
||||
pk.as_ptr()
|
||||
)
|
||||
);
|
||||
|
||||
// --- check verification of different signature ------------------
|
||||
|
||||
assert_eq!(
|
||||
-1,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
detached_sig.as_ptr(),
|
||||
siglen,
|
||||
msg.as_ptr(),
|
||||
msg.len() - 1,
|
||||
pk.as_ptr()
|
||||
)
|
||||
);
|
||||
|
||||
// --- check verification against a different pub key -------------
|
||||
|
||||
let mut pk_alt = [0u8; PK_LEN];
|
||||
let mut sk_alt = [0u8; SK_LEN];
|
||||
assert_eq!(
|
||||
0,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
||||
pk_alt.as_mut_ptr(),
|
||||
sk_alt.as_mut_ptr()
|
||||
)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
-1,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
detached_sig.as_ptr(),
|
||||
siglen,
|
||||
msg.as_ptr(),
|
||||
msg.len(),
|
||||
pk_alt.as_ptr()
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
232
src/dsa/rpo_falcon512/keys.rs
Normal file
232
src/dsa/rpo_falcon512/keys.rs
Normal file
@@ -0,0 +1,232 @@
|
||||
use super::{
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, FalconError, Polynomial,
|
||||
PublicKeyBytes, Rpo256, SecretKeyBytes, Serializable, Signature, Word,
|
||||
};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use super::{ffi, NonceBytes, StarkField, NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
|
||||
|
||||
// PUBLIC KEY
|
||||
// ================================================================================================
|
||||
|
||||
/// A public key for verifying signatures.
|
||||
///
|
||||
/// The public key is a [Word] (i.e., 4 field elements) that is the hash of the coefficients of
|
||||
/// the polynomial representing the raw bytes of the expanded public key.
|
||||
///
|
||||
/// For Falcon-512, the first byte of the expanded public key is always equal to log2(512) i.e., 9.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct PublicKey(Word);
|
||||
|
||||
impl PublicKey {
|
||||
/// Returns a new [PublicKey] which is a commitment to the provided expanded public key.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the decoding of the public key fails.
|
||||
pub fn new(pk: PublicKeyBytes) -> Result<Self, FalconError> {
|
||||
let h = Polynomial::from_pub_key(&pk)?;
|
||||
let pk_felts = h.to_elements();
|
||||
let pk_digest = Rpo256::hash_elements(&pk_felts).into();
|
||||
Ok(Self(pk_digest))
|
||||
}
|
||||
|
||||
/// Verifies the provided signature against provided message and this public key.
|
||||
pub fn verify(&self, message: Word, signature: &Signature) -> bool {
|
||||
signature.verify(message, self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PublicKey> for Word {
|
||||
fn from(key: PublicKey) -> Self {
|
||||
key.0
|
||||
}
|
||||
}
|
||||
|
||||
// KEY PAIR
|
||||
// ================================================================================================
|
||||
|
||||
/// A key pair (public and secret keys) for signing messages.
|
||||
///
|
||||
/// The secret key is a byte array of length [PK_LEN].
|
||||
/// The public key is a byte array of length [SK_LEN].
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct KeyPair {
|
||||
public_key: PublicKeyBytes,
|
||||
secret_key: SecretKeyBytes,
|
||||
}
|
||||
|
||||
#[allow(clippy::new_without_default)]
|
||||
impl KeyPair {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Generates a (public_key, secret_key) key pair from OS-provided randomness.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if key generation fails.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn new() -> Result<Self, FalconError> {
|
||||
let mut public_key = [0u8; PK_LEN];
|
||||
let mut secret_key = [0u8; SK_LEN];
|
||||
|
||||
let res = unsafe {
|
||||
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
||||
public_key.as_mut_ptr(),
|
||||
secret_key.as_mut_ptr(),
|
||||
)
|
||||
};
|
||||
|
||||
if res == 0 {
|
||||
Ok(Self { public_key, secret_key })
|
||||
} else {
|
||||
Err(FalconError::KeyGenerationFailed)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a (public_key, secret_key) key pair from the provided seed.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if key generation fails.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn from_seed(seed: &NonceBytes) -> Result<Self, FalconError> {
|
||||
let mut public_key = [0u8; PK_LEN];
|
||||
let mut secret_key = [0u8; SK_LEN];
|
||||
|
||||
let res = unsafe {
|
||||
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
public_key.as_mut_ptr(),
|
||||
secret_key.as_mut_ptr(),
|
||||
seed.as_ptr(),
|
||||
)
|
||||
};
|
||||
|
||||
if res == 0 {
|
||||
Ok(Self { public_key, secret_key })
|
||||
} else {
|
||||
Err(FalconError::KeyGenerationFailed)
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the public key corresponding to this key pair.
|
||||
pub fn public_key(&self) -> PublicKey {
|
||||
// TODO: memoize public key commitment as computing it requires quite a bit of hashing.
|
||||
// expect() is fine here because we assume that the key pair was constructed correctly.
|
||||
PublicKey::new(self.public_key).expect("invalid key pair")
|
||||
}
|
||||
|
||||
/// Returns the expanded public key corresponding to this key pair.
|
||||
pub fn expanded_public_key(&self) -> PublicKeyBytes {
|
||||
self.public_key
|
||||
}
|
||||
|
||||
// SIGNATURE GENERATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Signs a message with a secret key and a seed.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error of signature generation fails.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn sign(&self, message: Word) -> Result<Signature, FalconError> {
|
||||
let msg = message.iter().flat_map(|e| e.as_int().to_le_bytes()).collect::<Vec<_>>();
|
||||
let msg_len = msg.len();
|
||||
let mut sig = [0_u8; SIG_LEN + NONCE_LEN];
|
||||
let mut sig_len: usize = 0;
|
||||
|
||||
let res = unsafe {
|
||||
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
sig.as_mut_ptr(),
|
||||
&mut sig_len as *mut usize,
|
||||
msg.as_ptr(),
|
||||
msg_len,
|
||||
self.secret_key.as_ptr(),
|
||||
)
|
||||
};
|
||||
|
||||
if res == 0 {
|
||||
Ok(Signature {
|
||||
sig,
|
||||
pk: self.public_key,
|
||||
pk_polynomial: Default::default(),
|
||||
sig_polynomial: Default::default(),
|
||||
})
|
||||
} else {
|
||||
Err(FalconError::SigGenerationFailed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for KeyPair {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.public_key);
|
||||
target.write_bytes(&self.secret_key);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for KeyPair {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let public_key: PublicKeyBytes = source.read_array()?;
|
||||
let secret_key: SecretKeyBytes = source.read_array()?;
|
||||
Ok(Self { public_key, secret_key })
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use super::{super::Felt, KeyPair, NonceBytes, Word};
|
||||
use rand_utils::{rand_array, rand_vector};
|
||||
|
||||
#[test]
|
||||
fn test_falcon_verification() {
|
||||
// generate random keys
|
||||
let keys = KeyPair::new().unwrap();
|
||||
let pk = keys.public_key();
|
||||
|
||||
// sign a random message
|
||||
let message: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
let signature = keys.sign(message);
|
||||
|
||||
// make sure the signature verifies correctly
|
||||
assert!(pk.verify(message, signature.as_ref().unwrap()));
|
||||
|
||||
// a signature should not verify against a wrong message
|
||||
let message2: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
assert!(!pk.verify(message2, signature.as_ref().unwrap()));
|
||||
|
||||
// a signature should not verify against a wrong public key
|
||||
let keys2 = KeyPair::new().unwrap();
|
||||
assert!(!keys2.public_key().verify(message, signature.as_ref().unwrap()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_falcon_verification_from_seed() {
|
||||
// generate keys from a random seed
|
||||
let seed: NonceBytes = rand_array();
|
||||
let keys = KeyPair::from_seed(&seed).unwrap();
|
||||
let pk = keys.public_key();
|
||||
|
||||
// sign a random message
|
||||
let message: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
let signature = keys.sign(message);
|
||||
|
||||
// make sure the signature verifies correctly
|
||||
assert!(pk.verify(message, signature.as_ref().unwrap()));
|
||||
|
||||
// a signature should not verify against a wrong message
|
||||
let message2: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
assert!(!pk.verify(message2, signature.as_ref().unwrap()));
|
||||
|
||||
// a signature should not verify against a wrong public key
|
||||
let keys2 = KeyPair::new().unwrap();
|
||||
assert!(!keys2.public_key().verify(message, signature.as_ref().unwrap()))
|
||||
}
|
||||
}
|
||||
60
src/dsa/rpo_falcon512/mod.rs
Normal file
60
src/dsa/rpo_falcon512/mod.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
use crate::{
|
||||
hash::rpo::Rpo256,
|
||||
utils::{
|
||||
collections::Vec, ByteReader, ByteWriter, Deserializable, DeserializationError,
|
||||
Serializable,
|
||||
},
|
||||
Felt, StarkField, Word, ZERO,
|
||||
};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
mod ffi;
|
||||
|
||||
mod error;
|
||||
mod keys;
|
||||
mod polynomial;
|
||||
mod signature;
|
||||
|
||||
pub use error::FalconError;
|
||||
pub use keys::{KeyPair, PublicKey};
|
||||
pub use polynomial::Polynomial;
|
||||
pub use signature::Signature;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
// The Falcon modulus.
|
||||
const MODULUS: u16 = 12289;
|
||||
const MODULUS_MINUS_1_OVER_TWO: u16 = 6144;
|
||||
|
||||
// The Falcon parameters for Falcon-512. This is the degree of the polynomial `phi := x^N + 1`
|
||||
// defining the ring Z_p[x]/(phi).
|
||||
const N: usize = 512;
|
||||
const LOG_N: usize = 9;
|
||||
|
||||
/// Length of nonce used for key-pair generation.
|
||||
const NONCE_LEN: usize = 40;
|
||||
|
||||
/// Number of filed elements used to encode a nonce.
|
||||
const NONCE_ELEMENTS: usize = 8;
|
||||
|
||||
/// Public key length as a u8 vector.
|
||||
const PK_LEN: usize = 897;
|
||||
|
||||
/// Secret key length as a u8 vector.
|
||||
const SK_LEN: usize = 1281;
|
||||
|
||||
/// Signature length as a u8 vector.
|
||||
const SIG_LEN: usize = 626;
|
||||
|
||||
/// Bound on the squared-norm of the signature.
|
||||
const SIG_L2_BOUND: u64 = 34034726;
|
||||
|
||||
// TYPE ALIASES
|
||||
// ================================================================================================
|
||||
|
||||
type SignatureBytes = [u8; NONCE_LEN + SIG_LEN];
|
||||
type PublicKeyBytes = [u8; PK_LEN];
|
||||
type SecretKeyBytes = [u8; SK_LEN];
|
||||
type NonceBytes = [u8; NONCE_LEN];
|
||||
type NonceElements = [Felt; NONCE_ELEMENTS];
|
||||
277
src/dsa/rpo_falcon512/polynomial.rs
Normal file
277
src/dsa/rpo_falcon512/polynomial.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
use super::{FalconError, Felt, Vec, LOG_N, MODULUS, MODULUS_MINUS_1_OVER_TWO, N, PK_LEN};
|
||||
use core::ops::{Add, Mul, Sub};
|
||||
|
||||
// FALCON POLYNOMIAL
|
||||
// ================================================================================================
|
||||
|
||||
/// A polynomial over Z_p[x]/(phi) where phi := x^512 + 1
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub struct Polynomial([u16; N]);
|
||||
|
||||
impl Polynomial {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Constructs a new polynomial from a list of coefficients.
|
||||
///
|
||||
/// # Safety
|
||||
/// This constructor validates that the coefficients are in the valid range only in debug mode.
|
||||
pub unsafe fn new(data: [u16; N]) -> Self {
|
||||
for value in data {
|
||||
debug_assert!(value < MODULUS);
|
||||
}
|
||||
|
||||
Self(data)
|
||||
}
|
||||
|
||||
/// Decodes raw bytes representing a public key into a polynomial in Z_p[x]/(phi).
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The provided input is not exactly 897 bytes long.
|
||||
/// - The first byte of the input is not equal to log2(512) i.e., 9.
|
||||
/// - Any of the coefficients encoded in the provided input is greater than or equal to the
|
||||
/// Falcon field modulus.
|
||||
pub fn from_pub_key(input: &[u8]) -> Result<Self, FalconError> {
|
||||
if input.len() != PK_LEN {
|
||||
return Err(FalconError::PubKeyDecodingInvalidLength(input.len()));
|
||||
}
|
||||
|
||||
if input[0] != LOG_N as u8 {
|
||||
return Err(FalconError::PubKeyDecodingInvalidTag(input[0]));
|
||||
}
|
||||
|
||||
let mut acc = 0_u32;
|
||||
let mut acc_len = 0;
|
||||
|
||||
let mut output = [0_u16; N];
|
||||
let mut output_idx = 0;
|
||||
|
||||
for &byte in input.iter().skip(1) {
|
||||
acc = (acc << 8) | (byte as u32);
|
||||
acc_len += 8;
|
||||
|
||||
if acc_len >= 14 {
|
||||
acc_len -= 14;
|
||||
let w = (acc >> acc_len) & 0x3FFF;
|
||||
if w >= MODULUS as u32 {
|
||||
return Err(FalconError::PubKeyDecodingInvalidCoefficient(w));
|
||||
}
|
||||
output[output_idx] = w as u16;
|
||||
output_idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (acc & ((1u32 << acc_len) - 1)) == 0 {
|
||||
Ok(Self(output))
|
||||
} else {
|
||||
Err(FalconError::PubKeyDecodingExtraData)
|
||||
}
|
||||
}
|
||||
|
||||
/// Decodes the signature into the coefficients of a polynomial in Z_p[x]/(phi). It assumes
|
||||
/// that the signature has been encoded using the uncompressed format.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The signature has been encoded using a different algorithm than the reference compressed
|
||||
/// encoding algorithm.
|
||||
/// - The encoded signature polynomial is in Z_p[x]/(phi') where phi' = x^N' + 1 and N' != 512.
|
||||
/// - While decoding the high bits of a coefficient, the current accumulated value of its
|
||||
/// high bits is larger than 2048.
|
||||
/// - The decoded coefficient is -0.
|
||||
/// - The remaining unused bits in the last byte of `input` are non-zero.
|
||||
pub fn from_signature(input: &[u8]) -> Result<Self, FalconError> {
|
||||
let (encoding, log_n) = (input[0] >> 4, input[0] & 0b00001111);
|
||||
|
||||
if encoding != 0b0011 {
|
||||
return Err(FalconError::SigDecodingIncorrectEncodingAlgorithm);
|
||||
}
|
||||
if log_n != 0b1001 {
|
||||
return Err(FalconError::SigDecodingNotSupportedDegree(log_n));
|
||||
}
|
||||
|
||||
let input = &input[41..];
|
||||
let mut input_idx = 0;
|
||||
let mut acc = 0u32;
|
||||
let mut acc_len = 0;
|
||||
let mut output = [0_u16; N];
|
||||
|
||||
for e in output.iter_mut() {
|
||||
acc = (acc << 8) | (input[input_idx] as u32);
|
||||
input_idx += 1;
|
||||
let b = acc >> acc_len;
|
||||
let s = b & 128;
|
||||
let mut m = b & 127;
|
||||
|
||||
loop {
|
||||
if acc_len == 0 {
|
||||
acc = (acc << 8) | (input[input_idx] as u32);
|
||||
input_idx += 1;
|
||||
acc_len = 8;
|
||||
}
|
||||
acc_len -= 1;
|
||||
if ((acc >> acc_len) & 1) != 0 {
|
||||
break;
|
||||
}
|
||||
m += 128;
|
||||
if m >= 2048 {
|
||||
return Err(FalconError::SigDecodingTooBigHighBits(m));
|
||||
}
|
||||
}
|
||||
if s != 0 && m == 0 {
|
||||
return Err(FalconError::SigDecodingMinusZero);
|
||||
}
|
||||
|
||||
*e = if s != 0 { (MODULUS as u32 - m) as u16 } else { m as u16 };
|
||||
}
|
||||
|
||||
if (acc & ((1 << acc_len) - 1)) != 0 {
|
||||
return Err(FalconError::SigDecodingNonZeroUnusedBitsLastByte);
|
||||
}
|
||||
|
||||
Ok(Self(output))
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the coefficients of this polynomial as integers.
|
||||
pub fn inner(&self) -> [u16; N] {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// Returns the coefficients of this polynomial as field elements.
|
||||
pub fn to_elements(&self) -> Vec<Felt> {
|
||||
self.0.iter().map(|&a| Felt::from(a)).collect()
|
||||
}
|
||||
|
||||
// POLYNOMIAL OPERATIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Multiplies two polynomials over Z_p[x] without reducing modulo p. Given that the degrees
|
||||
/// of the input polynomials are less than 512 and their coefficients are less than the modulus
|
||||
/// q equal to 12289, the resulting product polynomial is guaranteed to have coefficients less
|
||||
/// than the Miden prime.
|
||||
///
|
||||
/// Note that this multiplication is not over Z_p[x]/(phi).
|
||||
pub fn mul_modulo_p(a: &Self, b: &Self) -> [u64; 1024] {
|
||||
let mut c = [0; 2 * N];
|
||||
for i in 0..N {
|
||||
for j in 0..N {
|
||||
c[i + j] += a.0[i] as u64 * b.0[j] as u64;
|
||||
}
|
||||
}
|
||||
|
||||
c
|
||||
}
|
||||
|
||||
/// Reduces a polynomial, that is the product of two polynomials over Z_p[x], modulo
|
||||
/// the irreducible polynomial phi. This results in an element in Z_p[x]/(phi).
|
||||
pub fn reduce_negacyclic(a: &[u64; 1024]) -> Self {
|
||||
let mut c = [0; N];
|
||||
for i in 0..N {
|
||||
let ai = a[N + i] % MODULUS as u64;
|
||||
let neg_ai = (MODULUS - ai as u16) % MODULUS;
|
||||
|
||||
let bi = (a[i] % MODULUS as u64) as u16;
|
||||
c[i] = (neg_ai + bi) % MODULUS;
|
||||
}
|
||||
|
||||
Self(c)
|
||||
}
|
||||
|
||||
/// Computes the norm squared of a polynomial in Z_p[x]/(phi) after normalizing its
|
||||
/// coefficients to be in the interval (-p/2, p/2].
|
||||
pub fn sq_norm(&self) -> u64 {
|
||||
let mut res = 0;
|
||||
for e in self.0 {
|
||||
if e > MODULUS_MINUS_1_OVER_TWO {
|
||||
res += (MODULUS - e) as u64 * (MODULUS - e) as u64
|
||||
} else {
|
||||
res += e as u64 * e as u64
|
||||
}
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a polynomial representing the zero polynomial i.e. default element.
|
||||
impl Default for Polynomial {
|
||||
fn default() -> Self {
|
||||
Self([0_u16; N])
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiplication over Z_p[x]/(phi)
|
||||
impl Mul for Polynomial {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, other: Self) -> <Self as Mul<Self>>::Output {
|
||||
let mut result = [0_u16; N];
|
||||
for j in 0..N {
|
||||
for k in 0..N {
|
||||
let i = (j + k) % N;
|
||||
let a = self.0[j] as usize;
|
||||
let b = other.0[k] as usize;
|
||||
let q = MODULUS as usize;
|
||||
let mut prod = a * b % q;
|
||||
if (N - 1) < (j + k) {
|
||||
prod = (q - prod) % q;
|
||||
}
|
||||
result[i] = ((result[i] as usize + prod) % q) as u16;
|
||||
}
|
||||
}
|
||||
|
||||
Polynomial(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Addition over Z_p[x]/(phi)
|
||||
impl Add for Polynomial {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, other: Self) -> <Self as Add<Self>>::Output {
|
||||
let mut res = self;
|
||||
res.0.iter_mut().zip(other.0.iter()).for_each(|(x, y)| *x = (*x + *y) % MODULUS);
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
/// Subtraction over Z_p[x]/(phi)
|
||||
impl Sub for Polynomial {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, other: Self) -> <Self as Add<Self>>::Output {
|
||||
let mut res = self;
|
||||
res.0
|
||||
.iter_mut()
|
||||
.zip(other.0.iter())
|
||||
.for_each(|(x, y)| *x = (*x + MODULUS - *y) % MODULUS);
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Polynomial, N};
|
||||
|
||||
#[test]
|
||||
fn test_negacyclic_reduction() {
|
||||
let coef1: [u16; N] = rand_utils::rand_array();
|
||||
let coef2: [u16; N] = rand_utils::rand_array();
|
||||
|
||||
let poly1 = Polynomial(coef1);
|
||||
let poly2 = Polynomial(coef2);
|
||||
|
||||
assert_eq!(
|
||||
poly1 * poly2,
|
||||
Polynomial::reduce_negacyclic(&Polynomial::mul_modulo_p(&poly1, &poly2))
|
||||
);
|
||||
}
|
||||
}
|
||||
271
src/dsa/rpo_falcon512/signature.rs
Normal file
271
src/dsa/rpo_falcon512/signature.rs
Normal file
@@ -0,0 +1,271 @@
|
||||
use super::{
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, NonceBytes, NonceElements,
|
||||
Polynomial, PublicKeyBytes, Rpo256, Serializable, SignatureBytes, StarkField, Word, MODULUS, N,
|
||||
SIG_L2_BOUND, ZERO,
|
||||
};
|
||||
use crate::utils::string::ToString;
|
||||
use core::cell::OnceCell;
|
||||
|
||||
// FALCON SIGNATURE
|
||||
// ================================================================================================
|
||||
|
||||
/// An RPO Falcon512 signature over a message.
|
||||
///
|
||||
/// The signature is a pair of polynomials (s1, s2) in (Z_p[x]/(phi))^2, where:
|
||||
/// - p := 12289
|
||||
/// - phi := x^512 + 1
|
||||
/// - s1 = c - s2 * h
|
||||
/// - h is a polynomial representing the public key and c is a polynomial that is the hash-to-point
|
||||
/// of the message being signed.
|
||||
///
|
||||
/// The signature verifies if and only if:
|
||||
/// 1. s1 = c - s2 * h
|
||||
/// 2. |s1|^2 + |s2|^2 <= SIG_L2_BOUND
|
||||
///
|
||||
/// where |.| is the norm.
|
||||
///
|
||||
/// [Signature] also includes the extended public key which is serialized as:
|
||||
/// 1. 1 byte representing the log2(512) i.e., 9.
|
||||
/// 2. 896 bytes for the public key. This is decoded into the `h` polynomial above.
|
||||
///
|
||||
/// The actual signature is serialized as:
|
||||
/// 1. A header byte specifying the algorithm used to encode the coefficients of the `s2` polynomial
|
||||
/// together with the degree of the irreducible polynomial phi.
|
||||
/// The general format of this byte is 0b0cc1nnnn where:
|
||||
/// a. cc is either 01 when the compressed encoding algorithm is used and 10 when the
|
||||
/// uncompressed algorithm is used.
|
||||
/// b. nnnn is log2(N) where N is the degree of the irreducible polynomial phi.
|
||||
/// The current implementation works always with cc equal to 0b01 and nnnn equal to 0b1001 and
|
||||
/// thus the header byte is always equal to 0b00111001.
|
||||
/// 2. 40 bytes for the nonce.
|
||||
/// 3. 625 bytes encoding the `s2` polynomial above.
|
||||
///
|
||||
/// The total size of the signature (including the extended public key) is 1563 bytes.
|
||||
pub struct Signature {
|
||||
pub(super) pk: PublicKeyBytes,
|
||||
pub(super) sig: SignatureBytes,
|
||||
|
||||
// Cached polynomial decoding for public key and signatures
|
||||
pub(super) pk_polynomial: OnceCell<Polynomial>,
|
||||
pub(super) sig_polynomial: OnceCell<Polynomial>,
|
||||
}
|
||||
|
||||
impl Signature {
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the public key polynomial h.
|
||||
pub fn pub_key_poly(&self) -> Polynomial {
|
||||
*self.pk_polynomial.get_or_init(|| {
|
||||
// we assume that the signature was constructed with a valid public key, and thus
|
||||
// expect() is OK here.
|
||||
Polynomial::from_pub_key(&self.pk).expect("invalid public key")
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the nonce component of the signature represented as field elements.
|
||||
///
|
||||
/// Nonce bytes are converted to field elements by taking consecutive 5 byte chunks
|
||||
/// of the nonce and interpreting them as field elements.
|
||||
pub fn nonce(&self) -> NonceElements {
|
||||
// we assume that the signature was constructed with a valid signature, and thus
|
||||
// expect() is OK here.
|
||||
let nonce = self.sig[1..41].try_into().expect("invalid signature");
|
||||
decode_nonce(nonce)
|
||||
}
|
||||
|
||||
// Returns the polynomial representation of the signature in Z_p[x]/(phi).
|
||||
pub fn sig_poly(&self) -> Polynomial {
|
||||
*self.sig_polynomial.get_or_init(|| {
|
||||
// we assume that the signature was constructed with a valid signature, and thus
|
||||
// expect() is OK here.
|
||||
Polynomial::from_signature(&self.sig).expect("invalid signature")
|
||||
})
|
||||
}
|
||||
|
||||
// HASH-TO-POINT
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message.
|
||||
pub fn hash_to_point(&self, message: Word) -> Polynomial {
|
||||
hash_to_point(message, &self.nonce())
|
||||
}
|
||||
|
||||
// SIGNATURE VERIFICATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns true if this signature is a valid signature for the specified message generated
|
||||
/// against key pair matching the specified public key commitment.
|
||||
pub fn verify(&self, message: Word, pubkey_com: Word) -> bool {
|
||||
// Make sure the expanded public key matches the provided public key commitment
|
||||
let h = self.pub_key_poly();
|
||||
let h_digest: Word = Rpo256::hash_elements(&h.to_elements()).into();
|
||||
if h_digest != pubkey_com {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Make sure the signature is valid
|
||||
let s2 = self.sig_poly();
|
||||
let c = self.hash_to_point(message);
|
||||
|
||||
let s1 = c - s2 * h;
|
||||
|
||||
let sq_norm = s1.sq_norm() + s2.sq_norm();
|
||||
sq_norm <= SIG_L2_BOUND
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for Signature {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.pk);
|
||||
target.write_bytes(&self.sig);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for Signature {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let pk: PublicKeyBytes = source.read_array()?;
|
||||
let sig: SignatureBytes = source.read_array()?;
|
||||
|
||||
// make sure public key and signature can be decoded correctly
|
||||
let pk_polynomial = Polynomial::from_pub_key(&pk)
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?
|
||||
.into();
|
||||
let sig_polynomial = Polynomial::from_signature(&sig[41..])
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?
|
||||
.into();
|
||||
|
||||
Ok(Self { pk, sig, pk_polynomial, sig_polynomial })
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
|
||||
/// nonce.
|
||||
fn hash_to_point(message: Word, nonce: &NonceElements) -> Polynomial {
|
||||
let mut state = [ZERO; Rpo256::STATE_WIDTH];
|
||||
|
||||
// absorb the nonce into the state
|
||||
for (&n, s) in nonce.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
||||
*s = n;
|
||||
}
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
|
||||
// absorb message into the state
|
||||
for (&m, s) in message.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
||||
*s = m;
|
||||
}
|
||||
|
||||
// squeeze the coefficients of the polynomial
|
||||
let mut i = 0;
|
||||
let mut res = [0_u16; N];
|
||||
for _ in 0..64 {
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
for a in &state[Rpo256::RATE_RANGE] {
|
||||
res[i] = (a.as_int() % MODULUS as u64) as u16;
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// using the raw constructor is OK here because we reduce all coefficients by the modulus above
|
||||
unsafe { Polynomial::new(res) }
|
||||
}
|
||||
|
||||
/// Converts byte representation of the nonce into field element representation.
|
||||
fn decode_nonce(nonce: &NonceBytes) -> NonceElements {
|
||||
let mut buffer = [0_u8; 8];
|
||||
let mut result = [ZERO; 8];
|
||||
for (i, bytes) in nonce.chunks(5).enumerate() {
|
||||
buffer[..5].copy_from_slice(bytes);
|
||||
result[i] = u64::from_le_bytes(buffer).into();
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use super::{
|
||||
super::{ffi::*, Felt},
|
||||
*,
|
||||
};
|
||||
use libc::c_void;
|
||||
use rand_utils::rand_vector;
|
||||
|
||||
// Wrappers for unsafe functions
|
||||
impl Rpo128Context {
|
||||
/// Initializes the RPO state.
|
||||
pub fn init() -> Self {
|
||||
let mut ctx = Rpo128Context { content: [0u64; 13] };
|
||||
unsafe {
|
||||
rpo128_init(&mut ctx as *mut Rpo128Context);
|
||||
}
|
||||
ctx
|
||||
}
|
||||
|
||||
/// Absorbs data into the RPO state.
|
||||
pub fn absorb(&mut self, data: &[u8]) {
|
||||
unsafe {
|
||||
rpo128_absorb(
|
||||
self as *mut Rpo128Context,
|
||||
data.as_ptr() as *const c_void,
|
||||
data.len(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Finalizes the RPO state to prepare for squeezing.
|
||||
pub fn finalize(&mut self) {
|
||||
unsafe { rpo128_finalize(self as *mut Rpo128Context) }
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hash_to_point() {
|
||||
// Create a random message and transform it into a u8 vector
|
||||
let msg_felts: Word = rand_vector::<Felt>(4).try_into().unwrap();
|
||||
let msg_bytes = msg_felts.iter().flat_map(|e| e.as_int().to_le_bytes()).collect::<Vec<_>>();
|
||||
|
||||
// Create a nonce i.e. a [u8; 40] array and pack into a [Felt; 8] array.
|
||||
let nonce: [u8; 40] = rand_vector::<u8>(40).try_into().unwrap();
|
||||
|
||||
let mut buffer = [0_u8; 64];
|
||||
for i in 0..8 {
|
||||
buffer[8 * i] = nonce[5 * i];
|
||||
buffer[8 * i + 1] = nonce[5 * i + 1];
|
||||
buffer[8 * i + 2] = nonce[5 * i + 2];
|
||||
buffer[8 * i + 3] = nonce[5 * i + 3];
|
||||
buffer[8 * i + 4] = nonce[5 * i + 4];
|
||||
}
|
||||
|
||||
// Initialize the RPO state
|
||||
let mut rng = Rpo128Context::init();
|
||||
|
||||
// Absorb the nonce and message into the RPO state
|
||||
rng.absorb(&buffer);
|
||||
rng.absorb(&msg_bytes);
|
||||
rng.finalize();
|
||||
|
||||
// Generate the coefficients of the hash-to-point polynomial.
|
||||
let mut res: [u16; N] = [0; N];
|
||||
|
||||
unsafe {
|
||||
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(
|
||||
&mut rng as *mut Rpo128Context,
|
||||
res.as_mut_ptr(),
|
||||
9,
|
||||
);
|
||||
}
|
||||
|
||||
// Check that the coefficients are correct
|
||||
let nonce = decode_nonce(&nonce);
|
||||
assert_eq!(res, hash_to_point(msg_felts, &nonce).inner());
|
||||
}
|
||||
}
|
||||
977
src/gkr/circuit/mod.rs
Normal file
977
src/gkr/circuit/mod.rs
Normal file
@@ -0,0 +1,977 @@
|
||||
use alloc::sync::Arc;
|
||||
use winter_crypto::{ElementHasher, RandomCoin};
|
||||
use winter_math::fields::f64::BaseElement;
|
||||
use winter_math::FieldElement;
|
||||
|
||||
use crate::gkr::multivariate::{
|
||||
ComposedMultiLinearsOracle, EqPolynomial, GkrCompositionVanilla, MultiLinearOracle,
|
||||
};
|
||||
use crate::gkr::sumcheck::{sum_check_verify, Claim};
|
||||
|
||||
use super::multivariate::{
|
||||
gen_plain_gkr_oracle, gkr_composition_from_composition_polys, ComposedMultiLinears,
|
||||
CompositionPolynomial, MultiLinear,
|
||||
};
|
||||
use super::sumcheck::{
|
||||
sum_check_prove, sum_check_verify_and_reduce, FinalEvaluationClaim,
|
||||
PartialProof as SumcheckInstanceProof, RoundProof as SumCheckRoundProof, Witness,
|
||||
};
|
||||
|
||||
/// Layered circuit for computing a sum of fractions.
|
||||
///
|
||||
/// The circuit computes a sum of fractions based on the formula a / c + b / d = (a * d + b * c) / (c * d)
|
||||
/// which defines a "gate" ((a, b), (c, d)) --> (a * d + b * c, c * d) upon which the `FractionalSumCircuit`
|
||||
/// is built.
|
||||
/// TODO: Swap 1 and 0
|
||||
#[derive(Debug)]
|
||||
pub struct FractionalSumCircuit<E: FieldElement> {
|
||||
p_1_vec: Vec<MultiLinear<E>>,
|
||||
p_0_vec: Vec<MultiLinear<E>>,
|
||||
q_1_vec: Vec<MultiLinear<E>>,
|
||||
q_0_vec: Vec<MultiLinear<E>>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> FractionalSumCircuit<E> {
|
||||
/// Computes The values of the gates outputs for each of the layers of the fractional sum circuit.
|
||||
pub fn new_(num_den: &Vec<MultiLinear<E>>) -> Self {
|
||||
let mut p_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut p_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut q_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut q_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
|
||||
let num_layers = num_den[0].len().ilog2() as usize;
|
||||
|
||||
p_1_vec.push(num_den[0].to_owned());
|
||||
p_0_vec.push(num_den[1].to_owned());
|
||||
q_1_vec.push(num_den[2].to_owned());
|
||||
q_0_vec.push(num_den[3].to_owned());
|
||||
|
||||
for i in 0..num_layers {
|
||||
let (output_p_1, output_p_0, output_q_1, output_q_0) =
|
||||
FractionalSumCircuit::compute_layer(
|
||||
&p_1_vec[i],
|
||||
&p_0_vec[i],
|
||||
&q_1_vec[i],
|
||||
&q_0_vec[i],
|
||||
);
|
||||
p_1_vec.push(output_p_1);
|
||||
p_0_vec.push(output_p_0);
|
||||
q_1_vec.push(output_q_1);
|
||||
q_0_vec.push(output_q_0);
|
||||
}
|
||||
|
||||
FractionalSumCircuit { p_1_vec, p_0_vec, q_1_vec, q_0_vec }
|
||||
}
|
||||
|
||||
/// Compute the output values of the layer given a set of input values
|
||||
fn compute_layer(
|
||||
inp_p_1: &MultiLinear<E>,
|
||||
inp_p_0: &MultiLinear<E>,
|
||||
inp_q_1: &MultiLinear<E>,
|
||||
inp_q_0: &MultiLinear<E>,
|
||||
) -> (MultiLinear<E>, MultiLinear<E>, MultiLinear<E>, MultiLinear<E>) {
|
||||
let len = inp_q_1.len();
|
||||
let outp_p_1 = (0..len / 2)
|
||||
.map(|i| inp_p_1[i] * inp_q_0[i] + inp_p_0[i] * inp_q_1[i])
|
||||
.collect::<Vec<E>>();
|
||||
let outp_p_0 = (len / 2..len)
|
||||
.map(|i| inp_p_1[i] * inp_q_0[i] + inp_p_0[i] * inp_q_1[i])
|
||||
.collect::<Vec<E>>();
|
||||
let outp_q_1 = (0..len / 2).map(|i| inp_q_1[i] * inp_q_0[i]).collect::<Vec<E>>();
|
||||
let outp_q_0 = (len / 2..len).map(|i| inp_q_1[i] * inp_q_0[i]).collect::<Vec<E>>();
|
||||
|
||||
(
|
||||
MultiLinear::new(outp_p_1),
|
||||
MultiLinear::new(outp_p_0),
|
||||
MultiLinear::new(outp_q_1),
|
||||
MultiLinear::new(outp_q_0),
|
||||
)
|
||||
}
|
||||
|
||||
/// Computes The values of the gates outputs for each of the layers of the fractional sum circuit.
|
||||
pub fn new(poly: &MultiLinear<E>) -> Self {
|
||||
let mut p_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut p_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut q_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut q_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
|
||||
let num_layers = poly.len().ilog2() as usize - 1;
|
||||
let (output_p, output_q) = poly.split(poly.len() / 2);
|
||||
let (output_p_1, output_p_0) = output_p.split(output_p.len() / 2);
|
||||
let (output_q_1, output_q_0) = output_q.split(output_q.len() / 2);
|
||||
|
||||
p_1_vec.push(output_p_1);
|
||||
p_0_vec.push(output_p_0);
|
||||
q_1_vec.push(output_q_1);
|
||||
q_0_vec.push(output_q_0);
|
||||
|
||||
for i in 0..num_layers - 1 {
|
||||
let (output_p_1, output_p_0, output_q_1, output_q_0) =
|
||||
FractionalSumCircuit::compute_layer(
|
||||
&p_1_vec[i],
|
||||
&p_0_vec[i],
|
||||
&q_1_vec[i],
|
||||
&q_0_vec[i],
|
||||
);
|
||||
p_1_vec.push(output_p_1);
|
||||
p_0_vec.push(output_p_0);
|
||||
q_1_vec.push(output_q_1);
|
||||
q_0_vec.push(output_q_0);
|
||||
}
|
||||
|
||||
FractionalSumCircuit { p_1_vec, p_0_vec, q_1_vec, q_0_vec }
|
||||
}
|
||||
|
||||
/// Given a value r, computes the evaluation of the last layer at r when interpreted as (two)
|
||||
/// multilinear polynomials.
|
||||
pub fn evaluate(&self, r: E) -> (E, E) {
|
||||
let len = self.p_1_vec.len();
|
||||
assert_eq!(self.p_1_vec[len - 1].num_variables(), 0);
|
||||
assert_eq!(self.p_0_vec[len - 1].num_variables(), 0);
|
||||
assert_eq!(self.q_1_vec[len - 1].num_variables(), 0);
|
||||
assert_eq!(self.q_0_vec[len - 1].num_variables(), 0);
|
||||
|
||||
let mut p_1 = self.p_1_vec[len - 1].clone();
|
||||
p_1.extend(&self.p_0_vec[len - 1]);
|
||||
let mut q_1 = self.q_1_vec[len - 1].clone();
|
||||
q_1.extend(&self.q_0_vec[len - 1]);
|
||||
|
||||
(p_1.evaluate(&[r]), q_1.evaluate(&[r]))
|
||||
}
|
||||
}
|
||||
|
||||
/// A proof for reducing a claim on the correctness of the output of a layer to that of:
|
||||
///
|
||||
/// 1. Correctness of a sumcheck proof on the claimed output.
|
||||
/// 2. Correctness of the evaluation of the input (to the said layer) at a random point when
|
||||
/// interpreted as multilinear polynomial.
|
||||
///
|
||||
/// The verifier will then have to work backward and:
|
||||
///
|
||||
/// 1. Verify that the sumcheck proof is valid.
|
||||
/// 2. Recurse on the (claimed evaluations) using the same approach as above.
|
||||
///
|
||||
/// Note that the following struct batches proofs for many circuits of the same type that
|
||||
/// are independent i.e., parallel.
|
||||
#[derive(Debug)]
|
||||
pub struct LayerProof<E: FieldElement> {
|
||||
pub proof: SumcheckInstanceProof<E>,
|
||||
pub claims_sum_p1: E,
|
||||
pub claims_sum_p0: E,
|
||||
pub claims_sum_q1: E,
|
||||
pub claims_sum_q0: E,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl<E: FieldElement<BaseField = BaseElement> + 'static> LayerProof<E> {
|
||||
/// Checks the validity of a `LayerProof`.
|
||||
///
|
||||
/// It first reduces the 2 claims to 1 claim using randomness and then checks that the sumcheck
|
||||
/// protocol was correctly executed.
|
||||
///
|
||||
/// The method outputs:
|
||||
///
|
||||
/// 1. A vector containing the randomness sent by the verifier throughout the course of the
|
||||
/// sum-check protocol.
|
||||
/// 2. The (claimed) evaluation of the inner polynomial (i.e., the one being summed) at the this random vector.
|
||||
/// 3. The random value used in the 2-to-1 reduction of the 2 sumchecks.
|
||||
pub fn verify_sum_check_before_last<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
&self,
|
||||
claim: (E, E),
|
||||
num_rounds: usize,
|
||||
transcript: &mut C,
|
||||
) -> ((E, Vec<E>), E) {
|
||||
// Absorb the claims
|
||||
let data = vec![claim.0, claim.1];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Squeeze challenge to reduce two sumchecks to one
|
||||
let r_sum_check: E = transcript.draw().unwrap();
|
||||
|
||||
// Run the sumcheck protocol
|
||||
|
||||
// Given r_sum_check and claim, we create a Claim with the GKR composer and then call the generic sum-check verifier
|
||||
let reduced_claim = claim.0 + claim.1 * r_sum_check;
|
||||
|
||||
// Create vanilla oracle
|
||||
let oracle = gen_plain_gkr_oracle(num_rounds, r_sum_check);
|
||||
|
||||
// Create sum-check claim
|
||||
let transformed_claim = Claim {
|
||||
sum_value: reduced_claim,
|
||||
polynomial: oracle,
|
||||
};
|
||||
|
||||
let reduced_gkr_claim =
|
||||
sum_check_verify_and_reduce(&transformed_claim, self.proof.clone(), transcript);
|
||||
|
||||
(reduced_gkr_claim, r_sum_check)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GkrClaim<E: FieldElement + 'static> {
|
||||
evaluation_point: Vec<E>,
|
||||
claimed_evaluation: (E, E),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CircuitProof<E: FieldElement + 'static> {
|
||||
pub proof: Vec<LayerProof<E>>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement<BaseField = BaseElement> + 'static> CircuitProof<E> {
|
||||
pub fn prove<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
circuit: &mut FractionalSumCircuit<E>,
|
||||
transcript: &mut C,
|
||||
) -> (Self, Vec<E>, Vec<Vec<E>>) {
|
||||
let mut proof_layers: Vec<LayerProof<E>> = Vec::new();
|
||||
let num_layers = circuit.p_0_vec.len();
|
||||
|
||||
let data = vec![
|
||||
circuit.p_1_vec[num_layers - 1][0],
|
||||
circuit.p_0_vec[num_layers - 1][0],
|
||||
circuit.q_1_vec[num_layers - 1][0],
|
||||
circuit.q_0_vec[num_layers - 1][0],
|
||||
];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Challenge to reduce p1, p0, q1, q0 to pr, qr
|
||||
let r_cord = transcript.draw().unwrap();
|
||||
|
||||
// Compute the (2-to-1 folded) claim
|
||||
let mut claim = circuit.evaluate(r_cord);
|
||||
let mut all_rand = Vec::new();
|
||||
|
||||
let mut rand = Vec::new();
|
||||
rand.push(r_cord);
|
||||
for layer_id in (0..num_layers - 1).rev() {
|
||||
let len = circuit.p_0_vec[layer_id].len();
|
||||
|
||||
// Construct the Lagrange kernel evaluated at previous GKR round randomness.
|
||||
// TODO: Treat the direction of doing sum-check more robustly.
|
||||
let mut rand_reversed = rand.clone();
|
||||
rand_reversed.reverse();
|
||||
let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations();
|
||||
let mut poly_x = MultiLinear::from_values(&eq_evals);
|
||||
assert_eq!(poly_x.len(), len);
|
||||
|
||||
let num_rounds = poly_x.len().ilog2() as usize;
|
||||
|
||||
// 1. A is a polynomial containing the evaluations `p_1`.
|
||||
// 2. B is a polynomial containing the evaluations `p_0`.
|
||||
// 3. C is a polynomial containing the evaluations `q_1`.
|
||||
// 4. D is a polynomial containing the evaluations `q_0`.
|
||||
let poly_a: &mut MultiLinear<E>;
|
||||
let poly_b: &mut MultiLinear<E>;
|
||||
let poly_c: &mut MultiLinear<E>;
|
||||
let poly_d: &mut MultiLinear<E>;
|
||||
poly_a = &mut circuit.p_1_vec[layer_id];
|
||||
poly_b = &mut circuit.p_0_vec[layer_id];
|
||||
poly_c = &mut circuit.q_1_vec[layer_id];
|
||||
poly_d = &mut circuit.q_0_vec[layer_id];
|
||||
|
||||
let poly_vec_par = (poly_a, poly_b, poly_c, poly_d, &mut poly_x);
|
||||
|
||||
// The (non-linear) polynomial combining the multilinear polynomials
|
||||
let comb_func = |a: &E, b: &E, c: &E, d: &E, x: &E, rho: &E| -> E {
|
||||
(*a * *d + *b * *c + *rho * *c * *d) * *x
|
||||
};
|
||||
|
||||
// Run the sumcheck protocol
|
||||
let (proof, rand_sumcheck, claims_sum) = sum_check_prover_gkr_before_last::<E, _, _>(
|
||||
claim,
|
||||
num_rounds,
|
||||
poly_vec_par,
|
||||
comb_func,
|
||||
transcript,
|
||||
);
|
||||
|
||||
let (claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0, _claims_eq) =
|
||||
claims_sum;
|
||||
|
||||
let data = vec![claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Produce a random challenge to condense claims into a single claim
|
||||
let r_layer = transcript.draw().unwrap();
|
||||
|
||||
claim = (
|
||||
claims_sum_p1 + r_layer * (claims_sum_p0 - claims_sum_p1),
|
||||
claims_sum_q1 + r_layer * (claims_sum_q0 - claims_sum_q1),
|
||||
);
|
||||
|
||||
// Collect the randomness used for the current layer in order to construct the random
|
||||
// point where the input multilinear polynomials were evaluated.
|
||||
let mut ext = rand_sumcheck;
|
||||
ext.push(r_layer);
|
||||
all_rand.push(rand);
|
||||
rand = ext;
|
||||
|
||||
proof_layers.push(LayerProof {
|
||||
proof,
|
||||
claims_sum_p1,
|
||||
claims_sum_p0,
|
||||
claims_sum_q1,
|
||||
claims_sum_q0,
|
||||
});
|
||||
}
|
||||
|
||||
(CircuitProof { proof: proof_layers }, rand, all_rand)
|
||||
}
|
||||
|
||||
pub fn prove_virtual_bus<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
composition_polys: Vec<Vec<Arc<dyn CompositionPolynomial<E>>>>,
|
||||
mls: &mut Vec<MultiLinear<E>>,
|
||||
transcript: &mut C,
|
||||
) -> (Vec<E>, Self, super::sumcheck::FullProof<E>) {
|
||||
let num_evaluations = 1 << mls[0].num_variables();
|
||||
|
||||
// I) Evaluate the numerators and denominators over the boolean hyper-cube
|
||||
let mut num_den: Vec<Vec<E>> = vec![vec![]; 4];
|
||||
for i in 0..num_evaluations {
|
||||
for j in 0..4 {
|
||||
let query: Vec<E> = mls.iter().map(|ml| ml[i]).collect();
|
||||
|
||||
composition_polys[j].iter().for_each(|c| {
|
||||
let evaluation = c.as_ref().evaluate(&query);
|
||||
num_den[j].push(evaluation);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// II) Evaluate the GKR fractional sum circuit
|
||||
let input: Vec<MultiLinear<E>> =
|
||||
(0..4).map(|i| MultiLinear::from_values(&num_den[i])).collect();
|
||||
let mut circuit = FractionalSumCircuit::new_(&input);
|
||||
|
||||
// III) Run the GKR prover for all layers except the last one
|
||||
let (gkr_proofs, GkrClaim { evaluation_point, claimed_evaluation }) =
|
||||
CircuitProof::prove_before_final(&mut circuit, transcript);
|
||||
|
||||
// IV) Run the sum-check prover for the last GKR layer counting backwards i.e., first layer
|
||||
// in the circuit.
|
||||
|
||||
// 1) Build the EQ polynomial (Lagrange kernel) at the randomness sampled during the previous
|
||||
// sum-check protocol run
|
||||
let mut rand_reversed = evaluation_point.clone();
|
||||
rand_reversed.reverse();
|
||||
let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations();
|
||||
let poly_x = MultiLinear::from_values(&eq_evals);
|
||||
|
||||
// 2) Add the Lagrange kernel to the list of MLs
|
||||
mls.push(poly_x);
|
||||
|
||||
// 3) Absorb the final sum-check claims and generate randomness for 2-to-1 sum-check reduction
|
||||
let data = vec![claimed_evaluation.0, claimed_evaluation.1];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
// Squeeze challenge to reduce two sumchecks to one
|
||||
let r_sum_check = transcript.draw().unwrap();
|
||||
let reduced_claim = claimed_evaluation.0 + claimed_evaluation.1 * r_sum_check;
|
||||
|
||||
// 4) Create the composed ML representing the numerators and denominators of the topmost GKR layer
|
||||
let gkr_final_composed_ml = gkr_composition_from_composition_polys(
|
||||
&composition_polys,
|
||||
r_sum_check,
|
||||
1 << mls[0].num_variables,
|
||||
);
|
||||
let composed_ml =
|
||||
ComposedMultiLinears::new(Arc::new(gkr_final_composed_ml.clone()), mls.to_vec());
|
||||
|
||||
// 5) Create the composed ML oracle. This will be used for verifying the FinalEvaluationClaim downstream
|
||||
// TODO: This should be an input to the current function.
|
||||
// TODO: Make MultiLinearOracle a variant in an enum so that it is possible to capture other types of oracles.
|
||||
// For example, shifts of polynomials, Lagrange kernels at a random point or periodic (transparent) polynomials.
|
||||
let left_num_oracle = MultiLinearOracle { id: 0 };
|
||||
let right_num_oracle = MultiLinearOracle { id: 1 };
|
||||
let left_denom_oracle = MultiLinearOracle { id: 2 };
|
||||
let right_denom_oracle = MultiLinearOracle { id: 3 };
|
||||
let eq_oracle = MultiLinearOracle { id: 4 };
|
||||
let composed_ml_oracle = ComposedMultiLinearsOracle {
|
||||
composer: (Arc::new(gkr_final_composed_ml.clone())),
|
||||
multi_linears: vec![
|
||||
eq_oracle,
|
||||
left_num_oracle,
|
||||
right_num_oracle,
|
||||
left_denom_oracle,
|
||||
right_denom_oracle,
|
||||
],
|
||||
};
|
||||
|
||||
// 6) Create the claim for the final sum-check protocol.
|
||||
let claim = Claim {
|
||||
sum_value: reduced_claim,
|
||||
polynomial: composed_ml_oracle.clone(),
|
||||
};
|
||||
|
||||
// 7) Create the witness for the sum-check claim.
|
||||
let witness = Witness { polynomial: composed_ml };
|
||||
let output = sum_check_prove(&claim, composed_ml_oracle, witness, transcript);
|
||||
|
||||
// 8) Create the claimed output of the circuit.
|
||||
let circuit_outputs = vec![
|
||||
circuit.p_1_vec.last().unwrap()[0],
|
||||
circuit.p_0_vec.last().unwrap()[0],
|
||||
circuit.q_1_vec.last().unwrap()[0],
|
||||
circuit.q_0_vec.last().unwrap()[0],
|
||||
];
|
||||
|
||||
// 9) Return:
|
||||
// 1. The claimed circuit outputs.
|
||||
// 2. GKR proofs of all circuit layers except the initial layer.
|
||||
// 3. Output of the final sum-check protocol.
|
||||
(circuit_outputs, gkr_proofs, output)
|
||||
}
|
||||
|
||||
pub fn prove_before_final<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
sum_circuits: &mut FractionalSumCircuit<E>,
|
||||
transcript: &mut C,
|
||||
) -> (Self, GkrClaim<E>) {
|
||||
let mut proof_layers: Vec<LayerProof<E>> = Vec::new();
|
||||
let num_layers = sum_circuits.p_0_vec.len();
|
||||
|
||||
let data = vec![
|
||||
sum_circuits.p_1_vec[num_layers - 1][0],
|
||||
sum_circuits.p_0_vec[num_layers - 1][0],
|
||||
sum_circuits.q_1_vec[num_layers - 1][0],
|
||||
sum_circuits.q_0_vec[num_layers - 1][0],
|
||||
];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Challenge to reduce p1, p0, q1, q0 to pr, qr
|
||||
let r_cord = transcript.draw().unwrap();
|
||||
|
||||
// Compute the (2-to-1 folded) claim
|
||||
let mut claims_to_verify = sum_circuits.evaluate(r_cord);
|
||||
let mut all_rand = Vec::new();
|
||||
|
||||
let mut rand = Vec::new();
|
||||
rand.push(r_cord);
|
||||
for layer_id in (1..num_layers - 1).rev() {
|
||||
let len = sum_circuits.p_0_vec[layer_id].len();
|
||||
|
||||
// Construct the Lagrange kernel evaluated at previous GKR round randomness.
|
||||
// TODO: Treat the direction of doing sum-check more robustly.
|
||||
let mut rand_reversed = rand.clone();
|
||||
rand_reversed.reverse();
|
||||
let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations();
|
||||
let mut poly_x = MultiLinear::from_values(&eq_evals);
|
||||
assert_eq!(poly_x.len(), len);
|
||||
|
||||
let num_rounds = poly_x.len().ilog2() as usize;
|
||||
|
||||
// 1. A is a polynomial containing the evaluations `p_1`.
|
||||
// 2. B is a polynomial containing the evaluations `p_0`.
|
||||
// 3. C is a polynomial containing the evaluations `q_1`.
|
||||
// 4. D is a polynomial containing the evaluations `q_0`.
|
||||
let poly_a: &mut MultiLinear<E>;
|
||||
let poly_b: &mut MultiLinear<E>;
|
||||
let poly_c: &mut MultiLinear<E>;
|
||||
let poly_d: &mut MultiLinear<E>;
|
||||
poly_a = &mut sum_circuits.p_1_vec[layer_id];
|
||||
poly_b = &mut sum_circuits.p_0_vec[layer_id];
|
||||
poly_c = &mut sum_circuits.q_1_vec[layer_id];
|
||||
poly_d = &mut sum_circuits.q_0_vec[layer_id];
|
||||
|
||||
let poly_vec = (poly_a, poly_b, poly_c, poly_d, &mut poly_x);
|
||||
|
||||
let claim = claims_to_verify;
|
||||
|
||||
// The (non-linear) polynomial combining the multilinear polynomials
|
||||
let comb_func = |a: &E, b: &E, c: &E, d: &E, x: &E, rho: &E| -> E {
|
||||
(*a * *d + *b * *c + *rho * *c * *d) * *x
|
||||
};
|
||||
|
||||
// Run the sumcheck protocol
|
||||
let (proof, rand_sumcheck, claims_sum) = sum_check_prover_gkr_before_last::<E, _, _>(
|
||||
claim, num_rounds, poly_vec, comb_func, transcript,
|
||||
);
|
||||
|
||||
let (claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0, _claims_eq) =
|
||||
claims_sum;
|
||||
|
||||
let data = vec![claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Produce a random challenge to condense claims into a single claim
|
||||
let r_layer = transcript.draw().unwrap();
|
||||
|
||||
claims_to_verify = (
|
||||
claims_sum_p1 + r_layer * (claims_sum_p0 - claims_sum_p1),
|
||||
claims_sum_q1 + r_layer * (claims_sum_q0 - claims_sum_q1),
|
||||
);
|
||||
|
||||
// Collect the randomness used for the current layer in order to construct the random
|
||||
// point where the input multilinear polynomials were evaluated.
|
||||
let mut ext = rand_sumcheck;
|
||||
ext.push(r_layer);
|
||||
all_rand.push(rand);
|
||||
rand = ext;
|
||||
|
||||
proof_layers.push(LayerProof {
|
||||
proof,
|
||||
claims_sum_p1,
|
||||
claims_sum_p0,
|
||||
claims_sum_q1,
|
||||
claims_sum_q0,
|
||||
});
|
||||
}
|
||||
let gkr_claim = GkrClaim {
|
||||
evaluation_point: rand.clone(),
|
||||
claimed_evaluation: claims_to_verify,
|
||||
};
|
||||
|
||||
(CircuitProof { proof: proof_layers }, gkr_claim)
|
||||
}
|
||||
|
||||
pub fn verify<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
&self,
|
||||
claims_sum_vec: &[E],
|
||||
transcript: &mut C,
|
||||
) -> ((E, E), Vec<E>) {
|
||||
let num_layers = self.proof.len() as usize - 1;
|
||||
let mut rand: Vec<E> = Vec::new();
|
||||
|
||||
let data = claims_sum_vec;
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
let r_cord = transcript.draw().unwrap();
|
||||
|
||||
let p_poly_coef = vec![claims_sum_vec[0], claims_sum_vec[1]];
|
||||
let q_poly_coef = vec![claims_sum_vec[2], claims_sum_vec[3]];
|
||||
|
||||
let p_poly = MultiLinear::new(p_poly_coef);
|
||||
let q_poly = MultiLinear::new(q_poly_coef);
|
||||
let p_eval = p_poly.evaluate(&[r_cord]);
|
||||
let q_eval = q_poly.evaluate(&[r_cord]);
|
||||
|
||||
let mut reduced_claim = (p_eval, q_eval);
|
||||
|
||||
rand.push(r_cord);
|
||||
for (num_rounds, i) in (0..num_layers).enumerate() {
|
||||
let ((claim_last, rand_sumcheck), r_two_sumchecks) = self.proof[i]
|
||||
.verify_sum_check_before_last::<_, _>(reduced_claim, num_rounds + 1, transcript);
|
||||
|
||||
let claims_sum_p1 = &self.proof[i].claims_sum_p1;
|
||||
let claims_sum_p0 = &self.proof[i].claims_sum_p0;
|
||||
let claims_sum_q1 = &self.proof[i].claims_sum_q1;
|
||||
let claims_sum_q0 = &self.proof[i].claims_sum_q0;
|
||||
|
||||
let data = vec![
|
||||
claims_sum_p1.clone(),
|
||||
claims_sum_p0.clone(),
|
||||
claims_sum_q1.clone(),
|
||||
claims_sum_q0.clone(),
|
||||
];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
assert_eq!(rand.len(), rand_sumcheck.len());
|
||||
|
||||
let eq: E = (0..rand.len())
|
||||
.map(|i| {
|
||||
rand[i] * rand_sumcheck[i] + (E::ONE - rand[i]) * (E::ONE - rand_sumcheck[i])
|
||||
})
|
||||
.fold(E::ONE, |acc, term| acc * term);
|
||||
|
||||
let claim_expected: E = (*claims_sum_p1 * *claims_sum_q0
|
||||
+ *claims_sum_p0 * *claims_sum_q1
|
||||
+ r_two_sumchecks * *claims_sum_q1 * *claims_sum_q0)
|
||||
* eq;
|
||||
|
||||
assert_eq!(claim_expected, claim_last);
|
||||
|
||||
// Produce a random challenge to condense claims into a single claim
|
||||
let r_layer = transcript.draw().unwrap();
|
||||
|
||||
reduced_claim = (
|
||||
*claims_sum_p1 + r_layer * (*claims_sum_p0 - *claims_sum_p1),
|
||||
*claims_sum_q1 + r_layer * (*claims_sum_q0 - *claims_sum_q1),
|
||||
);
|
||||
|
||||
// Collect the randomness' used for the current layer in order to construct the random
|
||||
// point where the input multilinear polynomials were evaluated.
|
||||
let mut ext = rand_sumcheck;
|
||||
ext.push(r_layer);
|
||||
rand = ext;
|
||||
}
|
||||
(reduced_claim, rand)
|
||||
}
|
||||
|
||||
pub fn verify_virtual_bus<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
&self,
|
||||
composition_polys: Vec<Vec<Arc<dyn CompositionPolynomial<E>>>>,
|
||||
final_layer_proof: super::sumcheck::FullProof<E>,
|
||||
claims_sum_vec: &[E],
|
||||
transcript: &mut C,
|
||||
) -> (FinalEvaluationClaim<E>, Vec<E>) {
|
||||
let num_layers = self.proof.len() as usize;
|
||||
let mut rand: Vec<E> = Vec::new();
|
||||
|
||||
// Check that a/b + d/e is equal to 0
|
||||
assert_ne!(claims_sum_vec[2], E::ZERO);
|
||||
assert_ne!(claims_sum_vec[3], E::ZERO);
|
||||
assert_eq!(
|
||||
claims_sum_vec[0] * claims_sum_vec[3] + claims_sum_vec[1] * claims_sum_vec[2],
|
||||
E::ZERO
|
||||
);
|
||||
|
||||
let data = claims_sum_vec;
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
let r_cord = transcript.draw().unwrap();
|
||||
|
||||
let p_poly_coef = vec![claims_sum_vec[0], claims_sum_vec[1]];
|
||||
let q_poly_coef = vec![claims_sum_vec[2], claims_sum_vec[3]];
|
||||
|
||||
let p_poly = MultiLinear::new(p_poly_coef);
|
||||
let q_poly = MultiLinear::new(q_poly_coef);
|
||||
let p_eval = p_poly.evaluate(&[r_cord]);
|
||||
let q_eval = q_poly.evaluate(&[r_cord]);
|
||||
|
||||
let mut reduced_claim = (p_eval, q_eval);
|
||||
|
||||
// I) Verify all GKR layers but for the last one counting backwards.
|
||||
rand.push(r_cord);
|
||||
for (num_rounds, i) in (0..num_layers).enumerate() {
|
||||
let ((claim_last, rand_sumcheck), r_two_sumchecks) = self.proof[i]
|
||||
.verify_sum_check_before_last::<_, _>(reduced_claim, num_rounds + 1, transcript);
|
||||
|
||||
let claims_sum_p1 = &self.proof[i].claims_sum_p1;
|
||||
let claims_sum_p0 = &self.proof[i].claims_sum_p0;
|
||||
let claims_sum_q1 = &self.proof[i].claims_sum_q1;
|
||||
let claims_sum_q0 = &self.proof[i].claims_sum_q0;
|
||||
|
||||
let data = vec![
|
||||
claims_sum_p1.clone(),
|
||||
claims_sum_p0.clone(),
|
||||
claims_sum_q1.clone(),
|
||||
claims_sum_q0.clone(),
|
||||
];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
assert_eq!(rand.len(), rand_sumcheck.len());
|
||||
|
||||
let eq: E = (0..rand.len())
|
||||
.map(|i| {
|
||||
rand[i] * rand_sumcheck[i] + (E::ONE - rand[i]) * (E::ONE - rand_sumcheck[i])
|
||||
})
|
||||
.fold(E::ONE, |acc, term| acc * term);
|
||||
|
||||
let claim_expected: E = (*claims_sum_p1 * *claims_sum_q0
|
||||
+ *claims_sum_p0 * *claims_sum_q1
|
||||
+ r_two_sumchecks * *claims_sum_q1 * *claims_sum_q0)
|
||||
* eq;
|
||||
|
||||
assert_eq!(claim_expected, claim_last);
|
||||
|
||||
// Produce a random challenge to condense claims into a single claim
|
||||
let r_layer = transcript.draw().unwrap();
|
||||
|
||||
reduced_claim = (
|
||||
*claims_sum_p1 + r_layer * (*claims_sum_p0 - *claims_sum_p1),
|
||||
*claims_sum_q1 + r_layer * (*claims_sum_q0 - *claims_sum_q1),
|
||||
);
|
||||
|
||||
let mut ext = rand_sumcheck;
|
||||
ext.push(r_layer);
|
||||
rand = ext;
|
||||
}
|
||||
|
||||
// II) Verify the final GKR layer counting backwards.
|
||||
|
||||
// Absorb the claims
|
||||
let data = vec![reduced_claim.0, reduced_claim.1];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Squeeze challenge to reduce two sumchecks to one
|
||||
let r_sum_check = transcript.draw().unwrap();
|
||||
let reduced_claim = reduced_claim.0 + reduced_claim.1 * r_sum_check;
|
||||
|
||||
let gkr_final_composed_ml = gkr_composition_from_composition_polys(
|
||||
&composition_polys,
|
||||
r_sum_check,
|
||||
1 << (num_layers + 1),
|
||||
);
|
||||
|
||||
// TODO: refactor
|
||||
let composed_ml_oracle = {
|
||||
let left_num_oracle = MultiLinearOracle { id: 0 };
|
||||
let right_num_oracle = MultiLinearOracle { id: 1 };
|
||||
let left_denom_oracle = MultiLinearOracle { id: 2 };
|
||||
let right_denom_oracle = MultiLinearOracle { id: 3 };
|
||||
let eq_oracle = MultiLinearOracle { id: 4 };
|
||||
ComposedMultiLinearsOracle {
|
||||
composer: (Arc::new(gkr_final_composed_ml.clone())),
|
||||
multi_linears: vec![
|
||||
eq_oracle,
|
||||
left_num_oracle,
|
||||
right_num_oracle,
|
||||
left_denom_oracle,
|
||||
right_denom_oracle,
|
||||
],
|
||||
}
|
||||
};
|
||||
|
||||
let claim = Claim {
|
||||
sum_value: reduced_claim,
|
||||
polynomial: composed_ml_oracle.clone(),
|
||||
};
|
||||
|
||||
let final_eval_claim = sum_check_verify(&claim, final_layer_proof, transcript);
|
||||
|
||||
(final_eval_claim, rand)
|
||||
}
|
||||
}
|
||||
|
||||
fn sum_check_prover_gkr_before_last<
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
claim: (E, E),
|
||||
num_rounds: usize,
|
||||
ml_polys: (
|
||||
&mut MultiLinear<E>,
|
||||
&mut MultiLinear<E>,
|
||||
&mut MultiLinear<E>,
|
||||
&mut MultiLinear<E>,
|
||||
&mut MultiLinear<E>,
|
||||
),
|
||||
comb_func: impl Fn(&E, &E, &E, &E, &E, &E) -> E,
|
||||
transcript: &mut C,
|
||||
) -> (SumcheckInstanceProof<E>, Vec<E>, (E, E, E, E, E)) {
|
||||
// Absorb the claims
|
||||
let data = vec![claim.0, claim.1];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Squeeze challenge to reduce two sumchecks to one
|
||||
let r_sum_check = transcript.draw().unwrap();
|
||||
|
||||
let (poly_a, poly_b, poly_c, poly_d, poly_x) = ml_polys;
|
||||
|
||||
let mut e = claim.0 + claim.1 * r_sum_check;
|
||||
|
||||
let mut r: Vec<E> = Vec::new();
|
||||
let mut round_proofs: Vec<SumCheckRoundProof<E>> = Vec::new();
|
||||
|
||||
for _j in 0..num_rounds {
|
||||
let evals: (E, E, E) = {
|
||||
let mut eval_point_0 = E::ZERO;
|
||||
let mut eval_point_2 = E::ZERO;
|
||||
let mut eval_point_3 = E::ZERO;
|
||||
|
||||
let len = poly_a.len() / 2;
|
||||
for i in 0..len {
|
||||
// The interpolation formula for a linear function is:
|
||||
// z * A(x) + (1 - z) * A (y)
|
||||
// z * A(1) + (1 - z) * A(0)
|
||||
|
||||
// eval at z = 0: A(1)
|
||||
eval_point_0 += comb_func(
|
||||
&poly_a[i << 1],
|
||||
&poly_b[i << 1],
|
||||
&poly_c[i << 1],
|
||||
&poly_d[i << 1],
|
||||
&poly_x[i << 1],
|
||||
&r_sum_check,
|
||||
);
|
||||
|
||||
let poly_a_u = poly_a[(i << 1) + 1];
|
||||
let poly_a_v = poly_a[i << 1];
|
||||
let poly_b_u = poly_b[(i << 1) + 1];
|
||||
let poly_b_v = poly_b[i << 1];
|
||||
let poly_c_u = poly_c[(i << 1) + 1];
|
||||
let poly_c_v = poly_c[i << 1];
|
||||
let poly_d_u = poly_d[(i << 1) + 1];
|
||||
let poly_d_v = poly_d[i << 1];
|
||||
let poly_x_u = poly_x[(i << 1) + 1];
|
||||
let poly_x_v = poly_x[i << 1];
|
||||
|
||||
// eval at z = 2: 2 * A(1) - A(0)
|
||||
let poly_a_extrapolated_point = poly_a_u + poly_a_u - poly_a_v;
|
||||
let poly_b_extrapolated_point = poly_b_u + poly_b_u - poly_b_v;
|
||||
let poly_c_extrapolated_point = poly_c_u + poly_c_u - poly_c_v;
|
||||
let poly_d_extrapolated_point = poly_d_u + poly_d_u - poly_d_v;
|
||||
let poly_x_extrapolated_point = poly_x_u + poly_x_u - poly_x_v;
|
||||
eval_point_2 += comb_func(
|
||||
&poly_a_extrapolated_point,
|
||||
&poly_b_extrapolated_point,
|
||||
&poly_c_extrapolated_point,
|
||||
&poly_d_extrapolated_point,
|
||||
&poly_x_extrapolated_point,
|
||||
&r_sum_check,
|
||||
);
|
||||
|
||||
// eval at z = 3: 3 * A(1) - 2 * A(0) = 2 * A(1) - A(0) + A(1) - A(0)
|
||||
// hence we can compute the evaluation at z + 1 from that of z for z > 1
|
||||
let poly_a_extrapolated_point = poly_a_extrapolated_point + poly_a_u - poly_a_v;
|
||||
let poly_b_extrapolated_point = poly_b_extrapolated_point + poly_b_u - poly_b_v;
|
||||
let poly_c_extrapolated_point = poly_c_extrapolated_point + poly_c_u - poly_c_v;
|
||||
let poly_d_extrapolated_point = poly_d_extrapolated_point + poly_d_u - poly_d_v;
|
||||
let poly_x_extrapolated_point = poly_x_extrapolated_point + poly_x_u - poly_x_v;
|
||||
|
||||
eval_point_3 += comb_func(
|
||||
&poly_a_extrapolated_point,
|
||||
&poly_b_extrapolated_point,
|
||||
&poly_c_extrapolated_point,
|
||||
&poly_d_extrapolated_point,
|
||||
&poly_x_extrapolated_point,
|
||||
&r_sum_check,
|
||||
);
|
||||
}
|
||||
|
||||
(eval_point_0, eval_point_2, eval_point_3)
|
||||
};
|
||||
|
||||
let eval_0 = evals.0;
|
||||
let eval_2 = evals.1;
|
||||
let eval_3 = evals.2;
|
||||
|
||||
let evals = vec![e - eval_0, eval_2, eval_3];
|
||||
let compressed_poly = SumCheckRoundProof { poly_evals: evals };
|
||||
|
||||
// append the prover's message to the transcript
|
||||
transcript.reseed(H::hash_elements(&compressed_poly.poly_evals));
|
||||
|
||||
// derive the verifier's challenge for the next round
|
||||
let r_j = transcript.draw().unwrap();
|
||||
r.push(r_j);
|
||||
|
||||
poly_a.bind_assign(r_j);
|
||||
poly_b.bind_assign(r_j);
|
||||
poly_c.bind_assign(r_j);
|
||||
poly_d.bind_assign(r_j);
|
||||
|
||||
poly_x.bind_assign(r_j);
|
||||
|
||||
e = compressed_poly.evaluate(e, r_j);
|
||||
|
||||
round_proofs.push(compressed_poly);
|
||||
}
|
||||
let claims_sum = (poly_a[0], poly_b[0], poly_c[0], poly_d[0], poly_x[0]);
|
||||
|
||||
(SumcheckInstanceProof { round_proofs }, r, claims_sum)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod sum_circuit_tests {
|
||||
use crate::rand::RpoRandomCoin;
|
||||
|
||||
use super::*;
|
||||
use rand::Rng;
|
||||
use rand_utils::rand_value;
|
||||
use BaseElement as Felt;
|
||||
|
||||
/// The following tests the fractional sum circuit to check that \sum_{i = 0}^{log(m)-1} m / 2^{i} = 2 * (m - 1)
|
||||
#[test]
|
||||
fn sum_circuit_example() {
|
||||
let n = 4; // n := log(m)
|
||||
let mut inp: Vec<Felt> = (0..n).map(|_| Felt::from(1_u64 << n)).collect();
|
||||
let inp_: Vec<Felt> = (0..n).map(|i| Felt::from(1_u64 << i)).collect();
|
||||
inp.extend(inp_.iter());
|
||||
|
||||
let summation = MultiLinear::new(inp);
|
||||
|
||||
let expected_output = Felt::from(2 * ((1_u64 << n) - 1));
|
||||
|
||||
let mut circuit = FractionalSumCircuit::new(&summation);
|
||||
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
|
||||
let (proof, _evals, _) = CircuitProof::prove(&mut circuit, &mut transcript);
|
||||
|
||||
let (p1, q1) = circuit.evaluate(Felt::from(1_u8));
|
||||
let (p0, q0) = circuit.evaluate(Felt::from(0_u8));
|
||||
assert_eq!(expected_output, (p1 * q0 + q1 * p0) / (q1 * q0));
|
||||
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
let claims = vec![p0, p1, q0, q1];
|
||||
proof.verify(&claims, &mut transcript);
|
||||
}
|
||||
|
||||
// Test the fractional sum GKR in the context of LogUp.
|
||||
#[test]
|
||||
fn log_up() {
|
||||
use rand::distributions::Slice;
|
||||
|
||||
let n: usize = 16;
|
||||
let num_w: usize = 31; // This should be of the form 2^k - 1
|
||||
let rng = rand::thread_rng();
|
||||
|
||||
let t_table: Vec<u32> = (0..(1 << n)).collect();
|
||||
let mut m_table: Vec<u32> = (0..(1 << n)).map(|_| 0).collect();
|
||||
|
||||
let t_table_slice = Slice::new(&t_table).unwrap();
|
||||
|
||||
// Construct the witness columns. Uses sampling with replacement in order to have multiplicities
|
||||
// different from 1.
|
||||
let mut w_tables = Vec::new();
|
||||
for _ in 0..num_w {
|
||||
let wi_table: Vec<u32> =
|
||||
rng.clone().sample_iter(&t_table_slice).cloned().take(1 << n).collect();
|
||||
|
||||
// Construct the multiplicities
|
||||
wi_table.iter().for_each(|w| {
|
||||
m_table[*w as usize] += 1;
|
||||
});
|
||||
w_tables.push(wi_table)
|
||||
}
|
||||
|
||||
// The numerators
|
||||
let mut p: Vec<Felt> = m_table.iter().map(|m| Felt::from(*m as u32)).collect();
|
||||
p.extend((0..(num_w * (1 << n))).map(|_| Felt::from(1_u32)).collect::<Vec<Felt>>());
|
||||
|
||||
// Sample the challenge alpha to construct the denominators.
|
||||
let alpha = rand_value();
|
||||
|
||||
// Construct the denominators
|
||||
let mut q: Vec<Felt> = t_table.iter().map(|t| Felt::from(*t) - alpha).collect();
|
||||
for w_table in w_tables {
|
||||
q.extend(w_table.iter().map(|w| alpha - Felt::from(*w)).collect::<Vec<Felt>>());
|
||||
}
|
||||
|
||||
// Build the input to the fractional sum GKR circuit
|
||||
p.extend(q);
|
||||
let input = p;
|
||||
|
||||
let summation = MultiLinear::new(input);
|
||||
|
||||
let expected_output = Felt::from(0_u8);
|
||||
|
||||
let mut circuit = FractionalSumCircuit::new(&summation);
|
||||
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
|
||||
let (proof, _evals, _) = CircuitProof::prove(&mut circuit, &mut transcript);
|
||||
|
||||
let (p1, q1) = circuit.evaluate(Felt::from(1_u8));
|
||||
let (p0, q0) = circuit.evaluate(Felt::from(0_u8));
|
||||
assert_eq!(expected_output, (p1 * q0 + q1 * p0) / (q1 * q0)); // This check should be part of verification
|
||||
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
let claims = vec![p0, p1, q0, q1];
|
||||
proof.verify(&claims, &mut transcript);
|
||||
}
|
||||
}
|
||||
7
src/gkr/mod.rs
Normal file
7
src/gkr/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
#![allow(unused_imports)]
|
||||
#![allow(dead_code)]
|
||||
|
||||
mod sumcheck;
|
||||
mod multivariate;
|
||||
mod utils;
|
||||
mod circuit;
|
||||
34
src/gkr/multivariate/eq_poly.rs
Normal file
34
src/gkr/multivariate/eq_poly.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use super::FieldElement;
|
||||
|
||||
pub struct EqPolynomial<E> {
|
||||
r: Vec<E>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> EqPolynomial<E> {
|
||||
pub fn new(r: Vec<E>) -> Self {
|
||||
EqPolynomial { r }
|
||||
}
|
||||
|
||||
pub fn evaluate(&self, rho: &[E]) -> E {
|
||||
assert_eq!(self.r.len(), rho.len());
|
||||
(0..rho.len())
|
||||
.map(|i| self.r[i] * rho[i] + (E::ONE - self.r[i]) * (E::ONE - rho[i]))
|
||||
.fold(E::ONE, |acc, term| acc * term)
|
||||
}
|
||||
|
||||
pub fn evaluations(&self) -> Vec<E> {
|
||||
let nu = self.r.len();
|
||||
|
||||
let mut evals: Vec<E> = vec![E::ONE; 1 << nu];
|
||||
let mut size = 1;
|
||||
for j in 0..nu {
|
||||
size *= 2;
|
||||
for i in (0..size).rev().step_by(2) {
|
||||
let scalar = evals[i / 2];
|
||||
evals[i] = scalar * self.r[j];
|
||||
evals[i - 1] = scalar - evals[i];
|
||||
}
|
||||
}
|
||||
evals
|
||||
}
|
||||
}
|
||||
543
src/gkr/multivariate/mod.rs
Normal file
543
src/gkr/multivariate/mod.rs
Normal file
@@ -0,0 +1,543 @@
|
||||
use core::ops::Index;
|
||||
|
||||
use alloc::sync::Arc;
|
||||
use winter_math::{fields::f64::BaseElement, log2, FieldElement, StarkField};
|
||||
|
||||
mod eq_poly;
|
||||
pub use eq_poly::EqPolynomial;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MultiLinear<E: FieldElement> {
|
||||
pub num_variables: usize,
|
||||
pub evaluations: Vec<E>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> MultiLinear<E> {
|
||||
pub fn new(values: Vec<E>) -> Self {
|
||||
Self {
|
||||
num_variables: log2(values.len()) as usize,
|
||||
evaluations: values,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_values(values: &[E]) -> Self {
|
||||
Self {
|
||||
num_variables: log2(values.len()) as usize,
|
||||
evaluations: values.to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_variables(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
pub fn evaluations(&self) -> &[E] {
|
||||
&self.evaluations
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.evaluations.len()
|
||||
}
|
||||
|
||||
pub fn evaluate(&self, query: &[E]) -> E {
|
||||
let tensored_query = tensorize(query);
|
||||
inner_product(&self.evaluations, &tensored_query)
|
||||
}
|
||||
|
||||
pub fn bind(&self, round_challenge: E) -> Self {
|
||||
let mut result = vec![E::ZERO; 1 << (self.num_variables() - 1)];
|
||||
for i in 0..(1 << (self.num_variables() - 1)) {
|
||||
result[i] = self.evaluations[i << 1]
|
||||
+ round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]);
|
||||
}
|
||||
Self::from_values(&result)
|
||||
}
|
||||
|
||||
pub fn bind_assign(&mut self, round_challenge: E) {
|
||||
let mut result = vec![E::ZERO; 1 << (self.num_variables() - 1)];
|
||||
for i in 0..(1 << (self.num_variables() - 1)) {
|
||||
result[i] = self.evaluations[i << 1]
|
||||
+ round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]);
|
||||
}
|
||||
*self = Self::from_values(&result);
|
||||
}
|
||||
|
||||
pub fn split(&self, at: usize) -> (Self, Self) {
|
||||
assert!(at < self.len());
|
||||
(
|
||||
Self::new(self.evaluations[..at].to_vec()),
|
||||
Self::new(self.evaluations[at..2 * at].to_vec()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn extend(&mut self, other: &MultiLinear<E>) {
|
||||
let other_vec = other.evaluations.to_vec();
|
||||
assert_eq!(other_vec.len(), self.len());
|
||||
self.evaluations.extend(other_vec);
|
||||
self.num_variables += 1;
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: FieldElement> Index<usize> for MultiLinear<E> {
|
||||
type Output = E;
|
||||
|
||||
fn index(&self, index: usize) -> &E {
|
||||
&(self.evaluations[index])
|
||||
}
|
||||
}
|
||||
|
||||
/// A multi-variate polynomial for composing individual multi-linear polynomials
|
||||
pub trait CompositionPolynomial<E: FieldElement>: Sync + Send {
|
||||
/// The number of variables when interpreted as a multi-variate polynomial.
|
||||
fn num_variables(&self) -> usize;
|
||||
|
||||
/// Maximum degree in all variables.
|
||||
fn max_degree(&self) -> usize;
|
||||
|
||||
/// Given a query, of length equal the number of variables, evaluate [Self] at this query.
|
||||
fn evaluate(&self, query: &[E]) -> E;
|
||||
}
|
||||
|
||||
pub struct ComposedMultiLinears<E: FieldElement> {
|
||||
pub composer: Arc<dyn CompositionPolynomial<E>>,
|
||||
pub multi_linears: Vec<MultiLinear<E>>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> ComposedMultiLinears<E> {
|
||||
pub fn new(
|
||||
composer: Arc<dyn CompositionPolynomial<E>>,
|
||||
multi_linears: Vec<MultiLinear<E>>,
|
||||
) -> Self {
|
||||
Self { composer, multi_linears }
|
||||
}
|
||||
|
||||
pub fn num_ml(&self) -> usize {
|
||||
self.multi_linears.len()
|
||||
}
|
||||
|
||||
pub fn num_variables(&self) -> usize {
|
||||
self.composer.num_variables()
|
||||
}
|
||||
|
||||
pub fn num_variables_ml(&self) -> usize {
|
||||
self.multi_linears[0].num_variables
|
||||
}
|
||||
|
||||
pub fn degree(&self) -> usize {
|
||||
self.composer.max_degree()
|
||||
}
|
||||
|
||||
pub fn bind(&self, round_challenge: E) -> ComposedMultiLinears<E> {
|
||||
let result: Vec<MultiLinear<E>> =
|
||||
self.multi_linears.iter().map(|f| f.bind(round_challenge)).collect();
|
||||
|
||||
Self {
|
||||
composer: self.composer.clone(),
|
||||
multi_linears: result,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ComposedMultiLinearsOracle<E: FieldElement> {
|
||||
pub composer: Arc<dyn CompositionPolynomial<E>>,
|
||||
pub multi_linears: Vec<MultiLinearOracle>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultiLinearOracle {
|
||||
pub id: usize,
|
||||
}
|
||||
|
||||
// Composition polynomials
|
||||
|
||||
pub struct IdentityComposition {
|
||||
num_variables: usize,
|
||||
}
|
||||
|
||||
impl IdentityComposition {
|
||||
pub fn new() -> Self {
|
||||
Self { num_variables: 1 }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for IdentityComposition
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
assert_eq!(query.len(), 1);
|
||||
query[0]
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ProjectionComposition {
|
||||
coordinate: usize,
|
||||
}
|
||||
|
||||
impl ProjectionComposition {
|
||||
pub fn new(coordinate: usize) -> Self {
|
||||
Self { coordinate }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for ProjectionComposition
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
query[self.coordinate]
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LogUpDenominatorTableComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
projection_coordinate: usize,
|
||||
alpha: E,
|
||||
}
|
||||
|
||||
impl<E> LogUpDenominatorTableComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
pub fn new(projection_coordinate: usize, alpha: E) -> Self {
|
||||
Self { projection_coordinate, alpha }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for LogUpDenominatorTableComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
query[self.projection_coordinate] + self.alpha
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LogUpDenominatorWitnessComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
projection_coordinate: usize,
|
||||
alpha: E,
|
||||
}
|
||||
|
||||
impl<E> LogUpDenominatorWitnessComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
pub fn new(projection_coordinate: usize, alpha: E) -> Self {
|
||||
Self { projection_coordinate, alpha }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for LogUpDenominatorWitnessComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
-(query[self.projection_coordinate] + self.alpha)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ProductComposition {
|
||||
num_variables: usize,
|
||||
}
|
||||
|
||||
impl ProductComposition {
|
||||
pub fn new(num_variables: usize) -> Self {
|
||||
Self { num_variables }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for ProductComposition
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
query.iter().fold(E::ONE, |acc, x| acc * *x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SumComposition {
|
||||
num_variables: usize,
|
||||
}
|
||||
|
||||
impl SumComposition {
|
||||
pub fn new(num_variables: usize) -> Self {
|
||||
Self { num_variables }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for SumComposition
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
query.iter().fold(E::ZERO, |acc, x| acc + *x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GkrCompositionVanilla<E: 'static>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
num_variables_ml: usize,
|
||||
num_variables_merge: usize,
|
||||
combining_randomness: E,
|
||||
gkr_randomness: Vec<E>,
|
||||
}
|
||||
|
||||
impl<E> GkrCompositionVanilla<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
pub fn new(
|
||||
num_variables_ml: usize,
|
||||
num_variables_merge: usize,
|
||||
combining_randomness: E,
|
||||
gkr_randomness: Vec<E>,
|
||||
) -> Self {
|
||||
Self {
|
||||
num_variables_ml,
|
||||
num_variables_merge,
|
||||
combining_randomness,
|
||||
gkr_randomness,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for GkrCompositionVanilla<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
self.num_variables_ml // + TODO
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
self.num_variables_ml //TODO
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
let eval_left_numerator = query[0];
|
||||
let eval_right_numerator = query[1];
|
||||
let eval_left_denominator = query[2];
|
||||
let eval_right_denominator = query[3];
|
||||
let eq_eval = query[4];
|
||||
|
||||
eq_eval
|
||||
* ((eval_left_numerator * eval_right_denominator
|
||||
+ eval_right_numerator * eval_left_denominator)
|
||||
+ eval_left_denominator * eval_right_denominator * self.combining_randomness)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct GkrComposition<E>
|
||||
where
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
{
|
||||
pub num_variables_ml: usize,
|
||||
pub combining_randomness: E,
|
||||
|
||||
eq_composer: Arc<dyn CompositionPolynomial<E>>,
|
||||
right_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
left_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
right_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
left_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
}
|
||||
|
||||
impl<E> GkrComposition<E>
|
||||
where
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
{
|
||||
pub fn new(
|
||||
num_variables_ml: usize,
|
||||
combining_randomness: E,
|
||||
eq_composer: Arc<dyn CompositionPolynomial<E>>,
|
||||
right_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
left_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
right_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
left_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
num_variables_ml,
|
||||
combining_randomness,
|
||||
eq_composer,
|
||||
right_numerator_composer,
|
||||
left_numerator_composer,
|
||||
right_denominator_composer,
|
||||
left_denominator_composer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for GkrComposition<E>
|
||||
where
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
self.num_variables_ml // + TODO
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
3 // TODO
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
let eval_right_numerator = self.right_numerator_composer[0].evaluate(query);
|
||||
let eval_left_numerator = self.left_numerator_composer[0].evaluate(query);
|
||||
let eval_right_denominator = self.right_denominator_composer[0].evaluate(query);
|
||||
let eval_left_denominator = self.left_denominator_composer[0].evaluate(query);
|
||||
let eq_eval = self.eq_composer.evaluate(query);
|
||||
|
||||
let res = eq_eval
|
||||
* ((eval_left_numerator * eval_right_denominator
|
||||
+ eval_right_numerator * eval_left_denominator)
|
||||
+ eval_left_denominator * eval_right_denominator * self.combining_randomness);
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a composed ML polynomial for the initial GKR layer from a vector of composition
|
||||
/// polynomials.
|
||||
/// The composition polynomials are divided into LeftNumerator, RightNumerator, LeftDenominator
|
||||
/// and RightDenominator.
|
||||
/// TODO: Generalize this to the case where each numerator/denominator contains more than one
|
||||
/// composition polynomial i.e., a merged composed ML polynomial.
|
||||
pub fn gkr_composition_from_composition_polys<
|
||||
E: FieldElement<BaseField = BaseElement> + 'static,
|
||||
>(
|
||||
composition_polys: &Vec<Vec<Arc<dyn CompositionPolynomial<E>>>>,
|
||||
combining_randomness: E,
|
||||
num_variables: usize,
|
||||
) -> GkrComposition<E> {
|
||||
let eq_composer = Arc::new(ProjectionComposition::new(4));
|
||||
let left_numerator = composition_polys[0].to_owned();
|
||||
let right_numerator = composition_polys[1].to_owned();
|
||||
let left_denominator = composition_polys[2].to_owned();
|
||||
let right_denominator = composition_polys[3].to_owned();
|
||||
GkrComposition::new(
|
||||
num_variables,
|
||||
combining_randomness,
|
||||
eq_composer,
|
||||
right_numerator,
|
||||
left_numerator,
|
||||
right_denominator,
|
||||
left_denominator,
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates a plain oracle for the sum-check protocol except the final one.
|
||||
pub fn gen_plain_gkr_oracle<E: FieldElement<BaseField = BaseElement> + 'static>(
|
||||
num_rounds: usize,
|
||||
r_sum_check: E,
|
||||
) -> ComposedMultiLinearsOracle<E> {
|
||||
let gkr_composer = Arc::new(GkrCompositionVanilla::new(num_rounds, 0, r_sum_check, vec![]));
|
||||
|
||||
let ml_oracles = vec![
|
||||
MultiLinearOracle { id: 0 },
|
||||
MultiLinearOracle { id: 1 },
|
||||
MultiLinearOracle { id: 2 },
|
||||
MultiLinearOracle { id: 3 },
|
||||
MultiLinearOracle { id: 4 },
|
||||
];
|
||||
|
||||
let oracle = ComposedMultiLinearsOracle {
|
||||
composer: gkr_composer,
|
||||
multi_linears: ml_oracles,
|
||||
};
|
||||
oracle
|
||||
}
|
||||
|
||||
fn to_index<E: FieldElement<BaseField = BaseElement>>(index: &[E]) -> usize {
|
||||
let res = index.iter().fold(E::ZERO, |acc, term| acc * E::ONE.double() + (*term));
|
||||
let res = res.base_element(0);
|
||||
res.as_int() as usize
|
||||
}
|
||||
|
||||
fn inner_product<E: FieldElement>(evaluations: &[E], tensored_query: &[E]) -> E {
|
||||
assert_eq!(evaluations.len(), tensored_query.len());
|
||||
evaluations
|
||||
.iter()
|
||||
.zip(tensored_query.iter())
|
||||
.fold(E::ZERO, |acc, (x_i, y_i)| acc + *x_i * *y_i)
|
||||
}
|
||||
|
||||
pub fn tensorize<E: FieldElement>(query: &[E]) -> Vec<E> {
|
||||
let nu = query.len();
|
||||
let n = 1 << nu;
|
||||
|
||||
(0..n).map(|i| lagrange_basis_eval(query, i)).collect()
|
||||
}
|
||||
|
||||
fn lagrange_basis_eval<E: FieldElement>(query: &[E], i: usize) -> E {
|
||||
query
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(j, x_j)| if i & (1 << j) == 0 { E::ONE - *x_j } else { *x_j })
|
||||
.fold(E::ONE, |acc, v| acc * v)
|
||||
}
|
||||
|
||||
pub fn compute_claim<E: FieldElement>(poly: &ComposedMultiLinears<E>) -> E {
|
||||
let cube_size = 1 << poly.num_variables_ml();
|
||||
let mut res = E::ZERO;
|
||||
|
||||
for i in 0..cube_size {
|
||||
let eval_point: Vec<E> =
|
||||
poly.multi_linears.iter().map(|poly| poly.evaluations[i]).collect();
|
||||
res += poly.composer.evaluate(&eval_point);
|
||||
}
|
||||
res
|
||||
}
|
||||
108
src/gkr/sumcheck/mod.rs
Normal file
108
src/gkr/sumcheck/mod.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use super::{
|
||||
multivariate::{ComposedMultiLinears, ComposedMultiLinearsOracle},
|
||||
utils::{barycentric_weights, evaluate_barycentric},
|
||||
};
|
||||
use winter_math::FieldElement;
|
||||
|
||||
mod prover;
|
||||
pub use prover::sum_check_prove;
|
||||
mod verifier;
|
||||
pub use verifier::{sum_check_verify, sum_check_verify_and_reduce};
|
||||
mod tests;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RoundProof<E> {
|
||||
pub poly_evals: Vec<E>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> RoundProof<E> {
|
||||
pub fn to_evals(&self, claim: E) -> Vec<E> {
|
||||
let mut result = vec![];
|
||||
|
||||
// s(0) + s(1) = claim
|
||||
let c0 = claim - self.poly_evals[0];
|
||||
|
||||
result.push(c0);
|
||||
result.extend_from_slice(&self.poly_evals);
|
||||
result
|
||||
}
|
||||
|
||||
// TODO: refactor once we move to coefficient form
|
||||
pub(crate) fn evaluate(&self, claim: E, r: E) -> E {
|
||||
let poly_evals = self.to_evals(claim);
|
||||
|
||||
let points: Vec<E> = (0..poly_evals.len()).map(|i| E::from(i as u8)).collect();
|
||||
let evalss: Vec<(E, E)> =
|
||||
points.iter().zip(poly_evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||
let weights = barycentric_weights(&evalss);
|
||||
let new_claim = evaluate_barycentric(&evalss, r, &weights);
|
||||
new_claim
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PartialProof<E> {
|
||||
pub round_proofs: Vec<RoundProof<E>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FinalEvaluationClaim<E: FieldElement> {
|
||||
pub evaluation_point: Vec<E>,
|
||||
pub claimed_evaluation: E,
|
||||
pub polynomial: ComposedMultiLinearsOracle<E>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FullProof<E: FieldElement> {
|
||||
pub sum_check_proof: PartialProof<E>,
|
||||
pub final_evaluation_claim: FinalEvaluationClaim<E>,
|
||||
}
|
||||
|
||||
pub struct Claim<E: FieldElement> {
|
||||
pub sum_value: E,
|
||||
pub polynomial: ComposedMultiLinearsOracle<E>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RoundClaim<E: FieldElement> {
|
||||
pub partial_eval_point: Vec<E>,
|
||||
pub current_claim: E,
|
||||
}
|
||||
|
||||
pub struct RoundOutput<E: FieldElement> {
|
||||
proof: PartialProof<E>,
|
||||
witness: Witness<E>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> From<Claim<E>> for RoundClaim<E> {
|
||||
fn from(value: Claim<E>) -> Self {
|
||||
Self {
|
||||
partial_eval_point: vec![],
|
||||
current_claim: value.sum_value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Witness<E: FieldElement> {
|
||||
pub(crate) polynomial: ComposedMultiLinears<E>,
|
||||
}
|
||||
|
||||
pub fn reduce_claim<E: FieldElement>(
|
||||
current_poly: RoundProof<E>,
|
||||
current_round_claim: RoundClaim<E>,
|
||||
round_challenge: E,
|
||||
) -> RoundClaim<E> {
|
||||
let poly_evals = current_poly.to_evals(current_round_claim.current_claim);
|
||||
let points: Vec<E> = (0..poly_evals.len()).map(|i| E::from(i as u8)).collect();
|
||||
let evalss: Vec<(E, E)> = points.iter().zip(poly_evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||
let weights = barycentric_weights(&evalss);
|
||||
let new_claim = evaluate_barycentric(&evalss, round_challenge, &weights);
|
||||
|
||||
let mut new_partial_eval_point = current_round_claim.partial_eval_point;
|
||||
new_partial_eval_point.push(round_challenge);
|
||||
|
||||
RoundClaim {
|
||||
partial_eval_point: new_partial_eval_point,
|
||||
current_claim: new_claim,
|
||||
}
|
||||
}
|
||||
109
src/gkr/sumcheck/prover.rs
Normal file
109
src/gkr/sumcheck/prover.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
use super::{Claim, FullProof, RoundProof, Witness};
|
||||
use crate::gkr::{
|
||||
multivariate::{ComposedMultiLinears, ComposedMultiLinearsOracle},
|
||||
sumcheck::{reduce_claim, FinalEvaluationClaim, PartialProof, RoundClaim, RoundOutput},
|
||||
};
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
||||
use winter_crypto::{ElementHasher, RandomCoin};
|
||||
use winter_math::{fields::f64::BaseElement, FieldElement};
|
||||
|
||||
pub fn sum_check_prove<
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
claim: &Claim<E>,
|
||||
oracle: ComposedMultiLinearsOracle<E>,
|
||||
witness: Witness<E>,
|
||||
coin: &mut C,
|
||||
) -> FullProof<E> {
|
||||
// Setup first round
|
||||
let mut prev_claim = RoundClaim {
|
||||
partial_eval_point: vec![],
|
||||
current_claim: claim.sum_value.clone(),
|
||||
};
|
||||
let prev_proof = PartialProof { round_proofs: vec![] };
|
||||
let num_vars = witness.polynomial.num_variables_ml();
|
||||
let prev_output = RoundOutput { proof: prev_proof, witness };
|
||||
|
||||
let mut output = sumcheck_round(prev_output);
|
||||
let poly_evals = &output.proof.round_proofs[0].poly_evals;
|
||||
coin.reseed(H::hash_elements(&poly_evals));
|
||||
|
||||
for i in 1..num_vars {
|
||||
let round_challenge = coin.draw().unwrap();
|
||||
let new_claim = reduce_claim(
|
||||
output.proof.round_proofs.last().unwrap().clone(),
|
||||
prev_claim,
|
||||
round_challenge,
|
||||
);
|
||||
output.witness.polynomial = output.witness.polynomial.bind(round_challenge);
|
||||
|
||||
output = sumcheck_round(output);
|
||||
prev_claim = new_claim;
|
||||
|
||||
let poly_evals = &output.proof.round_proofs[i].poly_evals;
|
||||
coin.reseed(H::hash_elements(&poly_evals));
|
||||
}
|
||||
|
||||
let round_challenge = coin.draw().unwrap();
|
||||
let RoundClaim { partial_eval_point, current_claim } = reduce_claim(
|
||||
output.proof.round_proofs.last().unwrap().clone(),
|
||||
prev_claim,
|
||||
round_challenge,
|
||||
);
|
||||
let final_eval_claim = FinalEvaluationClaim {
|
||||
evaluation_point: partial_eval_point,
|
||||
claimed_evaluation: current_claim,
|
||||
polynomial: oracle,
|
||||
};
|
||||
|
||||
FullProof {
|
||||
sum_check_proof: output.proof,
|
||||
final_evaluation_claim: final_eval_claim,
|
||||
}
|
||||
}
|
||||
|
||||
fn sumcheck_round<E: FieldElement>(prev_proof: RoundOutput<E>) -> RoundOutput<E> {
|
||||
let RoundOutput { mut proof, witness } = prev_proof;
|
||||
|
||||
let polynomial = witness.polynomial;
|
||||
let num_ml = polynomial.num_ml();
|
||||
let num_vars = polynomial.num_variables_ml();
|
||||
let num_rounds = num_vars - 1;
|
||||
|
||||
let mut evals_zero = vec![E::ZERO; num_ml];
|
||||
let mut evals_one = vec![E::ZERO; num_ml];
|
||||
let mut deltas = vec![E::ZERO; num_ml];
|
||||
let mut evals_x = vec![E::ZERO; num_ml];
|
||||
|
||||
let total_evals = (0..1 << num_rounds).into_iter().map(|i| {
|
||||
for (j, ml) in polynomial.multi_linears.iter().enumerate() {
|
||||
evals_zero[j] = ml.evaluations[(i << 1) as usize];
|
||||
evals_one[j] = ml.evaluations[(i << 1) + 1];
|
||||
}
|
||||
let mut total_evals = vec![E::ZERO; polynomial.degree()];
|
||||
total_evals[0] = polynomial.composer.evaluate(&evals_one);
|
||||
evals_zero
|
||||
.iter()
|
||||
.zip(evals_one.iter().zip(deltas.iter_mut().zip(evals_x.iter_mut())))
|
||||
.for_each(|(a0, (a1, (delta, evx)))| {
|
||||
*delta = *a1 - *a0;
|
||||
*evx = *a1;
|
||||
});
|
||||
total_evals.iter_mut().skip(1).for_each(|e| {
|
||||
evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| {
|
||||
*evx += *delta;
|
||||
});
|
||||
*e = polynomial.composer.evaluate(&evals_x);
|
||||
});
|
||||
total_evals
|
||||
});
|
||||
let evaluations = total_evals.fold(vec![E::ZERO; polynomial.degree()], |mut acc, evals| {
|
||||
acc.iter_mut().zip(evals.iter()).for_each(|(a, ev)| *a += *ev);
|
||||
acc
|
||||
});
|
||||
let proof_update = RoundProof { poly_evals: evaluations };
|
||||
proof.round_proofs.push(proof_update);
|
||||
RoundOutput { proof, witness: Witness { polynomial } }
|
||||
}
|
||||
199
src/gkr/sumcheck/tests.rs
Normal file
199
src/gkr/sumcheck/tests.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
use alloc::sync::Arc;
|
||||
use rand::{distributions::Uniform, SeedableRng};
|
||||
use winter_crypto::RandomCoin;
|
||||
use winter_math::{fields::f64::BaseElement, FieldElement};
|
||||
|
||||
use crate::{
|
||||
gkr::{
|
||||
circuit::{CircuitProof, FractionalSumCircuit},
|
||||
multivariate::{
|
||||
compute_claim, gkr_composition_from_composition_polys, ComposedMultiLinears,
|
||||
ComposedMultiLinearsOracle, CompositionPolynomial, EqPolynomial, GkrComposition,
|
||||
GkrCompositionVanilla, LogUpDenominatorTableComposition,
|
||||
LogUpDenominatorWitnessComposition, MultiLinear, MultiLinearOracle,
|
||||
ProjectionComposition, SumComposition,
|
||||
},
|
||||
sumcheck::{
|
||||
prover::sum_check_prove, verifier::sum_check_verify, Claim, FinalEvaluationClaim,
|
||||
FullProof, Witness,
|
||||
},
|
||||
},
|
||||
hash::rpo::Rpo256,
|
||||
rand::RpoRandomCoin,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn gkr_workflow() {
|
||||
// generate the data witness for the LogUp argument
|
||||
let mut mls = generate_logup_witness::<BaseElement>(3);
|
||||
|
||||
// the is sampled after receiving the main trace commitment
|
||||
let alpha = rand_utils::rand_value();
|
||||
|
||||
// the composition polynomials defining the numerators/denominators
|
||||
let composition_polys: Vec<Vec<Arc<dyn CompositionPolynomial<BaseElement>>>> = vec![
|
||||
// left num
|
||||
vec![Arc::new(ProjectionComposition::new(0))],
|
||||
// right num
|
||||
vec![Arc::new(ProjectionComposition::new(1))],
|
||||
// left den
|
||||
vec![Arc::new(LogUpDenominatorTableComposition::new(2, alpha))],
|
||||
// right den
|
||||
vec![Arc::new(LogUpDenominatorWitnessComposition::new(3, alpha))],
|
||||
];
|
||||
|
||||
// run the GKR prover to obtain:
|
||||
// 1. The fractional sum circuit output.
|
||||
// 2. GKR proofs up to the last circuit layer counting backwards.
|
||||
// 3. GKR proof (i.e., a sum-check proof) for the last circuit layer counting backwards.
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
let (circuit_outputs, gkr_before_last_proof, final_layer_proof) =
|
||||
CircuitProof::prove_virtual_bus(composition_polys.clone(), &mut mls, &mut transcript);
|
||||
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
|
||||
// run the GKR verifier to obtain:
|
||||
// 1. A final evaluation claim.
|
||||
// 2. Randomness defining the Lagrange kernel in the final sum-check protocol. Note that this
|
||||
// Lagrange kernel is different from the one used by the STARK (outer) prover to open the MLs
|
||||
// at the evaluation point.
|
||||
let (final_eval_claim, gkr_lagrange_kernel_rand) = gkr_before_last_proof.verify_virtual_bus(
|
||||
composition_polys.clone(),
|
||||
final_layer_proof,
|
||||
&circuit_outputs,
|
||||
&mut transcript,
|
||||
);
|
||||
|
||||
// the final verification step is composed of:
|
||||
// 1. Querying the oracles for the openings at the evaluation point. This will be done by the
|
||||
// (outer) STARK prover using:
|
||||
// a. The Lagrange kernel (auxiliary) column at the evaluation point.
|
||||
// b. An extra (auxiliary) column to compute an inner product between two vectors. The first
|
||||
// being the Lagrange kernel and the second being (\sum_{j=0}^3 mls[j][i] * \lambda_i)_{i\in\{0,..,n\}}
|
||||
// 2. Evaluating the composition polynomial at the previous openings and checking equality with
|
||||
// the claimed evaluation.
|
||||
|
||||
// 1. Querying the oracles
|
||||
|
||||
let FinalEvaluationClaim {
|
||||
evaluation_point,
|
||||
claimed_evaluation,
|
||||
polynomial,
|
||||
} = final_eval_claim;
|
||||
|
||||
// The evaluation of the EQ polynomial can be done by the verifier directly
|
||||
let eq = (0..gkr_lagrange_kernel_rand.len())
|
||||
.map(|i| {
|
||||
gkr_lagrange_kernel_rand[i] * evaluation_point[i]
|
||||
+ (BaseElement::ONE - gkr_lagrange_kernel_rand[i])
|
||||
* (BaseElement::ONE - evaluation_point[i])
|
||||
})
|
||||
.fold(BaseElement::ONE, |acc, term| acc * term);
|
||||
|
||||
// These are the queries to the oracles.
|
||||
// They should be provided by the prover non-deterministically
|
||||
let left_num_eval = mls[0].evaluate(&evaluation_point);
|
||||
let right_num_eval = mls[1].evaluate(&evaluation_point);
|
||||
let left_den_eval = mls[2].evaluate(&evaluation_point);
|
||||
let right_den_eval = mls[3].evaluate(&evaluation_point);
|
||||
|
||||
// The verifier absorbs the claimed openings and generates batching randomness lambda
|
||||
let mut query = vec![left_num_eval, right_num_eval, left_den_eval, right_den_eval];
|
||||
transcript.reseed(Rpo256::hash_elements(&query));
|
||||
let lambdas: Vec<BaseElement> = vec![
|
||||
transcript.draw().unwrap(),
|
||||
transcript.draw().unwrap(),
|
||||
transcript.draw().unwrap(),
|
||||
];
|
||||
let batched_query =
|
||||
query[0] + query[1] * lambdas[0] + query[2] * lambdas[1] + query[3] * lambdas[2];
|
||||
|
||||
// The prover generates the Lagrange kernel as an auxiliary column
|
||||
let mut rev_evaluation_point = evaluation_point;
|
||||
rev_evaluation_point.reverse();
|
||||
let lagrange_kernel = EqPolynomial::new(rev_evaluation_point).evaluations();
|
||||
|
||||
// The prover generates the additional auxiliary column for the inner product
|
||||
let tmp_col: Vec<BaseElement> = (0..mls[0].len())
|
||||
.map(|i| {
|
||||
mls[0][i] + mls[1][i] * lambdas[0] + mls[2][i] * lambdas[1] + mls[3][i] * lambdas[2]
|
||||
})
|
||||
.collect();
|
||||
let mut running_sum_col = vec![BaseElement::ZERO; tmp_col.len() + 1];
|
||||
running_sum_col[0] = BaseElement::ZERO;
|
||||
for i in 1..(tmp_col.len() + 1) {
|
||||
running_sum_col[i] = running_sum_col[i - 1] + tmp_col[i - 1] * lagrange_kernel[i - 1];
|
||||
}
|
||||
|
||||
// Boundary constraint to check correctness of openings
|
||||
assert_eq!(batched_query, *running_sum_col.last().unwrap());
|
||||
|
||||
// 2) Final evaluation and check
|
||||
query.push(eq);
|
||||
let verifier_computed = polynomial.composer.evaluate(&query);
|
||||
|
||||
assert_eq!(verifier_computed, claimed_evaluation);
|
||||
}
|
||||
|
||||
pub fn generate_logup_witness<E: FieldElement>(trace_len: usize) -> Vec<MultiLinear<E>> {
|
||||
let num_variables_ml = trace_len;
|
||||
let num_evaluations = 1 << num_variables_ml;
|
||||
let num_witnesses = 1;
|
||||
let (p, q) = generate_logup_data::<E>(num_variables_ml, num_witnesses);
|
||||
let numerators: Vec<Vec<E>> = p.chunks(num_evaluations).map(|x| x.into()).collect();
|
||||
let denominators: Vec<Vec<E>> = q.chunks(num_evaluations).map(|x| x.into()).collect();
|
||||
|
||||
let mut mls = vec![];
|
||||
for i in 0..2 {
|
||||
let ml = MultiLinear::from_values(&numerators[i]);
|
||||
mls.push(ml);
|
||||
}
|
||||
for i in 0..2 {
|
||||
let ml = MultiLinear::from_values(&denominators[i]);
|
||||
mls.push(ml);
|
||||
}
|
||||
mls
|
||||
}
|
||||
|
||||
pub fn generate_logup_data<E: FieldElement>(
|
||||
trace_len: usize,
|
||||
num_witnesses: usize,
|
||||
) -> (Vec<E>, Vec<E>) {
|
||||
use rand::distributions::Slice;
|
||||
use rand::Rng;
|
||||
let n: usize = trace_len;
|
||||
let num_w: usize = num_witnesses; // This should be of the form 2^k - 1
|
||||
let rng = rand::rngs::StdRng::seed_from_u64(0);
|
||||
|
||||
let t_table: Vec<u32> = (0..(1 << n)).collect();
|
||||
let mut m_table: Vec<u32> = (0..(1 << n)).map(|_| 0).collect();
|
||||
|
||||
let t_table_slice = Slice::new(&t_table).unwrap();
|
||||
|
||||
// Construct the witness columns. Uses sampling with replacement in order to have multiplicities
|
||||
// different from 1.
|
||||
let mut w_tables = Vec::new();
|
||||
for _ in 0..num_w {
|
||||
let wi_table: Vec<u32> =
|
||||
rng.clone().sample_iter(&t_table_slice).cloned().take(1 << n).collect();
|
||||
|
||||
// Construct the multiplicities
|
||||
wi_table.iter().for_each(|w| {
|
||||
m_table[*w as usize] += 1;
|
||||
});
|
||||
w_tables.push(wi_table)
|
||||
}
|
||||
|
||||
// The numerators
|
||||
let mut p: Vec<E> = m_table.iter().map(|m| E::from(*m as u32)).collect();
|
||||
p.extend((0..(num_w * (1 << n))).map(|_| E::from(1_u32)).collect::<Vec<E>>());
|
||||
|
||||
// Construct the denominators
|
||||
let mut q: Vec<E> = t_table.iter().map(|t| E::from(*t)).collect();
|
||||
for w_table in w_tables {
|
||||
q.extend(w_table.iter().map(|w| E::from(*w)).collect::<Vec<E>>());
|
||||
}
|
||||
(p, q)
|
||||
}
|
||||
71
src/gkr/sumcheck/verifier.rs
Normal file
71
src/gkr/sumcheck/verifier.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use winter_crypto::{ElementHasher, RandomCoin};
|
||||
use winter_math::{fields::f64::BaseElement, FieldElement};
|
||||
|
||||
use crate::gkr::utils::{barycentric_weights, evaluate_barycentric};
|
||||
|
||||
use super::{Claim, FinalEvaluationClaim, FullProof, PartialProof};
|
||||
|
||||
pub fn sum_check_verify_and_reduce<
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
claim: &Claim<E>,
|
||||
proofs: PartialProof<E>,
|
||||
coin: &mut C,
|
||||
) -> (E, Vec<E>) {
|
||||
let degree = 3;
|
||||
let points: Vec<E> = (0..degree + 1).map(|x| E::from(x as u8)).collect();
|
||||
let mut sum_value = claim.sum_value.clone();
|
||||
let mut randomness = vec![];
|
||||
|
||||
for proof in proofs.round_proofs {
|
||||
let partial_evals = proof.poly_evals.clone();
|
||||
coin.reseed(H::hash_elements(&partial_evals));
|
||||
|
||||
// get r
|
||||
let r: E = coin.draw().unwrap();
|
||||
randomness.push(r);
|
||||
let evals = proof.to_evals(sum_value);
|
||||
|
||||
let point_evals: Vec<_> = points.iter().zip(evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||
let weights = barycentric_weights(&point_evals);
|
||||
sum_value = evaluate_barycentric(&point_evals, r, &weights);
|
||||
}
|
||||
(sum_value, randomness)
|
||||
}
|
||||
|
||||
pub fn sum_check_verify<
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
claim: &Claim<E>,
|
||||
proofs: FullProof<E>,
|
||||
coin: &mut C,
|
||||
) -> FinalEvaluationClaim<E> {
|
||||
let FullProof {
|
||||
sum_check_proof: proofs,
|
||||
final_evaluation_claim,
|
||||
} = proofs;
|
||||
let Claim { mut sum_value, polynomial } = claim;
|
||||
let degree = polynomial.composer.max_degree();
|
||||
let points: Vec<E> = (0..degree + 1).map(|x| E::from(x as u8)).collect();
|
||||
|
||||
for proof in proofs.round_proofs {
|
||||
let partial_evals = proof.poly_evals.clone();
|
||||
coin.reseed(H::hash_elements(&partial_evals));
|
||||
|
||||
// get r
|
||||
let r: E = coin.draw().unwrap();
|
||||
let evals = proof.to_evals(sum_value);
|
||||
|
||||
let point_evals: Vec<_> = points.iter().zip(evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||
let weights = barycentric_weights(&point_evals);
|
||||
sum_value = evaluate_barycentric(&point_evals, r, &weights);
|
||||
}
|
||||
|
||||
assert_eq!(final_evaluation_claim.claimed_evaluation, sum_value);
|
||||
|
||||
final_evaluation_claim
|
||||
}
|
||||
33
src/gkr/utils/mod.rs
Normal file
33
src/gkr/utils/mod.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use winter_math::{FieldElement, batch_inversion};
|
||||
|
||||
|
||||
pub fn barycentric_weights<E: FieldElement>(points: &[(E, E)]) -> Vec<E> {
|
||||
let n = points.len();
|
||||
let tmp = (0..n)
|
||||
.map(|i| (0..n).filter(|&j| j != i).fold(E::ONE, |acc, j| acc * (points[i].0 - points[j].0)))
|
||||
.collect::<Vec<_>>();
|
||||
batch_inversion(&tmp)
|
||||
}
|
||||
|
||||
pub fn evaluate_barycentric<E: FieldElement>(
|
||||
points: &[(E, E)],
|
||||
x: E,
|
||||
barycentric_weights: &[E],
|
||||
) -> E {
|
||||
for &(x_i, y_i) in points {
|
||||
if x_i == x {
|
||||
return y_i;
|
||||
}
|
||||
}
|
||||
|
||||
let l_x: E = points.iter().fold(E::ONE, |acc, &(x_i, _y_i)| acc * (x - x_i));
|
||||
|
||||
let sum = (0..points.len()).fold(E::ZERO, |acc, i| {
|
||||
let x_i = points[i].0;
|
||||
let y_i = points[i].1;
|
||||
let w_i = barycentric_weights[i];
|
||||
acc + (w_i / (x - x_i) * y_i)
|
||||
});
|
||||
|
||||
l_x * sum
|
||||
}
|
||||
@@ -1,5 +1,8 @@
|
||||
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
use crate::utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
};
|
||||
use core::{
|
||||
mem::{size_of, transmute, transmute_copy},
|
||||
ops::Deref,
|
||||
@@ -23,7 +26,9 @@ const DIGEST20_BYTES: usize = 20;
|
||||
///
|
||||
/// Note: `N` can't be greater than `32` because [`Digest::as_bytes`] currently supports only 32
|
||||
/// bytes.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
|
||||
pub struct Blake3Digest<const N: usize>([u8; N]);
|
||||
|
||||
impl<const N: usize> Default for Blake3Digest<N> {
|
||||
@@ -52,6 +57,20 @@ impl<const N: usize> From<[u8; N]> for Blake3Digest<N> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> From<Blake3Digest<N>> for String {
|
||||
fn from(value: Blake3Digest<N>) -> Self {
|
||||
bytes_to_hex_string(value.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> TryFrom<&str> for Blake3Digest<N> {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
hex_to_bytes(value).map(|v| v.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Serializable for Blake3Digest<N> {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.0);
|
||||
|
||||
@@ -1,7 +1,17 @@
|
||||
use super::{Felt, FieldElement, StarkField, ONE, ZERO};
|
||||
//! Cryptographic hash functions used by the Miden VM and the Miden rollup.
|
||||
|
||||
use super::{CubeExtension, Felt, FieldElement, StarkField, ONE, ZERO};
|
||||
|
||||
pub mod blake;
|
||||
pub mod rpo;
|
||||
|
||||
mod rescue;
|
||||
pub mod rpo {
|
||||
pub use super::rescue::{Rpo256, RpoDigest};
|
||||
}
|
||||
|
||||
pub mod rpx {
|
||||
pub use super::rescue::{Rpx256, RpxDigest};
|
||||
}
|
||||
|
||||
// RE-EXPORTS
|
||||
// ================================================================================================
|
||||
|
||||
101
src/hash/rescue/arch/mod.rs
Normal file
101
src/hash/rescue/arch/mod.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
pub mod optimized {
|
||||
use crate::hash::rescue::STATE_WIDTH;
|
||||
use crate::Felt;
|
||||
|
||||
mod ffi {
|
||||
#[link(name = "rpo_sve", kind = "static")]
|
||||
extern "C" {
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
state: *mut std::ffi::c_ulong,
|
||||
constants: *const std::ffi::c_ulong,
|
||||
) -> bool;
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
state: *mut std::ffi::c_ulong,
|
||||
constants: *const std::ffi::c_ulong,
|
||||
) -> bool;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
unsafe {
|
||||
ffi::add_constants_and_apply_sbox(
|
||||
state.as_mut_ptr() as *mut u64,
|
||||
ark.as_ptr() as *const u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
unsafe {
|
||||
ffi::add_constants_and_apply_inv_sbox(
|
||||
state.as_mut_ptr() as *mut u64,
|
||||
ark.as_ptr() as *const u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "avx2")]
|
||||
mod x86_64_avx2;
|
||||
|
||||
#[cfg(target_feature = "avx2")]
|
||||
pub mod optimized {
|
||||
use super::x86_64_avx2::{apply_inv_sbox, apply_sbox};
|
||||
use crate::hash::rescue::{add_constants, STATE_WIDTH};
|
||||
use crate::Felt;
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
add_constants(state, ark);
|
||||
unsafe {
|
||||
apply_sbox(std::mem::transmute(state));
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
add_constants(state, ark);
|
||||
unsafe {
|
||||
apply_inv_sbox(std::mem::transmute(state));
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(target_feature = "avx2", all(target_feature = "sve", feature = "sve"))))]
|
||||
pub mod optimized {
|
||||
use crate::hash::rescue::STATE_WIDTH;
|
||||
use crate::Felt;
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
_state: &mut [Felt; STATE_WIDTH],
|
||||
_ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
_state: &mut [Felt; STATE_WIDTH],
|
||||
_ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
325
src/hash/rescue/arch/x86_64_avx2.rs
Normal file
325
src/hash/rescue/arch/x86_64_avx2.rs
Normal file
@@ -0,0 +1,325 @@
|
||||
use core::arch::x86_64::*;
|
||||
|
||||
// The following AVX2 implementation has been copied from plonky2:
|
||||
// https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs
|
||||
|
||||
// Preliminary notes:
|
||||
// 1. AVX does not support addition with carry but 128-bit (2-word) addition can be easily
|
||||
// emulated. The method recognizes that for a + b overflowed iff (a + b) < a:
|
||||
// i. res_lo = a_lo + b_lo
|
||||
// ii. carry_mask = res_lo < a_lo
|
||||
// iii. res_hi = a_hi + b_hi - carry_mask
|
||||
// Notice that carry_mask is subtracted, not added. This is because AVX comparison instructions
|
||||
// return -1 (all bits 1) for true and 0 for false.
|
||||
//
|
||||
// 2. AVX does not have unsigned 64-bit comparisons. Those can be emulated with signed comparisons
|
||||
// by recognizing that a <u b iff a + (1 << 63) <s b + (1 << 63), where the addition wraps around
|
||||
// and the comparisons are unsigned and signed respectively. The shift function adds/subtracts
|
||||
// 1 << 63 to enable this trick.
|
||||
// Example: addition with carry.
|
||||
// i. a_lo_s = shift(a_lo)
|
||||
// ii. res_lo_s = a_lo_s + b_lo
|
||||
// iii. carry_mask = res_lo_s <s a_lo_s
|
||||
// iv. res_lo = shift(res_lo_s)
|
||||
// v. res_hi = a_hi + b_hi - carry_mask
|
||||
// The suffix _s denotes a value that has been shifted by 1 << 63. The result of addition is
|
||||
// shifted if exactly one of the operands is shifted, as is the case on line ii. Line iii.
|
||||
// performs a signed comparison res_lo_s <s a_lo_s on shifted values to emulate unsigned
|
||||
// comparison res_lo <u a_lo on unshifted values. Finally, line iv. reverses the shift so the
|
||||
// result can be returned.
|
||||
// When performing a chain of calculations, we can often save instructions by letting the shift
|
||||
// propagate through and only undoing it when necessary. For example, to compute the addition of
|
||||
// three two-word (128-bit) numbers we can do:
|
||||
// i. a_lo_s = shift(a_lo)
|
||||
// ii. tmp_lo_s = a_lo_s + b_lo
|
||||
// iii. tmp_carry_mask = tmp_lo_s <s a_lo_s
|
||||
// iv. tmp_hi = a_hi + b_hi - tmp_carry_mask
|
||||
// v. res_lo_s = tmp_lo_s + c_lo
|
||||
// vi. res_carry_mask = res_lo_s <s tmp_lo_s
|
||||
// vii. res_lo = shift(res_lo_s)
|
||||
// viii. res_hi = tmp_hi + c_hi - res_carry_mask
|
||||
// Notice that the above 3-value addition still only requires two calls to shift, just like our
|
||||
// 2-value addition.
|
||||
|
||||
#[inline(always)]
|
||||
pub fn branch_hint() {
|
||||
// NOTE: These are the currently supported assembly architectures. See the
|
||||
// [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
|
||||
// the most up-to-date list.
|
||||
#[cfg(any(
|
||||
target_arch = "aarch64",
|
||||
target_arch = "arm",
|
||||
target_arch = "riscv32",
|
||||
target_arch = "riscv64",
|
||||
target_arch = "x86",
|
||||
target_arch = "x86_64",
|
||||
))]
|
||||
unsafe {
|
||||
core::arch::asm!("", options(nomem, nostack, preserves_flags));
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! map3 {
|
||||
($f:ident::<$l:literal>, $v:ident) => {
|
||||
($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2))
|
||||
};
|
||||
($f:ident::<$l:literal>, $v1:ident, $v2:ident) => {
|
||||
($f::<$l>($v1.0, $v2.0), $f::<$l>($v1.1, $v2.1), $f::<$l>($v1.2, $v2.2))
|
||||
};
|
||||
($f:ident, $v:ident) => {
|
||||
($f($v.0), $f($v.1), $f($v.2))
|
||||
};
|
||||
($f:ident, $v0:ident, $v1:ident) => {
|
||||
($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2))
|
||||
};
|
||||
($f:ident, rep $v0:ident, $v1:ident) => {
|
||||
($f($v0, $v1.0), $f($v0, $v1.1), $f($v0, $v1.2))
|
||||
};
|
||||
|
||||
($f:ident, $v0:ident, rep $v1:ident) => {
|
||||
($f($v0.0, $v1), $f($v0.1, $v1), $f($v0.2, $v1))
|
||||
};
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn square3(
|
||||
x: (__m256i, __m256i, __m256i),
|
||||
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
|
||||
let x_hi = {
|
||||
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
|
||||
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
|
||||
// This is safe and free.
|
||||
let x_ps = map3!(_mm256_castsi256_ps, x);
|
||||
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
|
||||
map3!(_mm256_castps_si256, x_hi_ps)
|
||||
};
|
||||
|
||||
// All pairwise multiplications.
|
||||
let mul_ll = map3!(_mm256_mul_epu32, x, x);
|
||||
let mul_lh = map3!(_mm256_mul_epu32, x, x_hi);
|
||||
let mul_hh = map3!(_mm256_mul_epu32, x_hi, x_hi);
|
||||
|
||||
// Bignum addition, but mul_lh is shifted by 33 bits (not 32).
|
||||
let mul_ll_hi = map3!(_mm256_srli_epi64::<33>, mul_ll);
|
||||
let t0 = map3!(_mm256_add_epi64, mul_lh, mul_ll_hi);
|
||||
let t0_hi = map3!(_mm256_srli_epi64::<31>, t0);
|
||||
let res_hi = map3!(_mm256_add_epi64, mul_hh, t0_hi);
|
||||
|
||||
// Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high
|
||||
// position).
|
||||
let mul_lh_lo = map3!(_mm256_slli_epi64::<33>, mul_lh);
|
||||
let res_lo = map3!(_mm256_add_epi64, mul_ll, mul_lh_lo);
|
||||
|
||||
(res_lo, res_hi)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mul3(
|
||||
x: (__m256i, __m256i, __m256i),
|
||||
y: (__m256i, __m256i, __m256i),
|
||||
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
|
||||
let epsilon = _mm256_set1_epi64x(0xffffffff);
|
||||
let x_hi = {
|
||||
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
|
||||
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
|
||||
// This is safe and free.
|
||||
let x_ps = map3!(_mm256_castsi256_ps, x);
|
||||
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
|
||||
map3!(_mm256_castps_si256, x_hi_ps)
|
||||
};
|
||||
let y_hi = {
|
||||
let y_ps = map3!(_mm256_castsi256_ps, y);
|
||||
let y_hi_ps = map3!(_mm256_movehdup_ps, y_ps);
|
||||
map3!(_mm256_castps_si256, y_hi_ps)
|
||||
};
|
||||
|
||||
// All four pairwise multiplications
|
||||
let mul_ll = map3!(_mm256_mul_epu32, x, y);
|
||||
let mul_lh = map3!(_mm256_mul_epu32, x, y_hi);
|
||||
let mul_hl = map3!(_mm256_mul_epu32, x_hi, y);
|
||||
let mul_hh = map3!(_mm256_mul_epu32, x_hi, y_hi);
|
||||
|
||||
// Bignum addition
|
||||
// Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow.
|
||||
let mul_ll_hi = map3!(_mm256_srli_epi64::<32>, mul_ll);
|
||||
let t0 = map3!(_mm256_add_epi64, mul_hl, mul_ll_hi);
|
||||
// Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow.
|
||||
// Also, extract high 32 bits of t0 and add to mul_hh.
|
||||
let t0_lo = map3!(_mm256_and_si256, t0, rep epsilon);
|
||||
let t0_hi = map3!(_mm256_srli_epi64::<32>, t0);
|
||||
let t1 = map3!(_mm256_add_epi64, mul_lh, t0_lo);
|
||||
let t2 = map3!(_mm256_add_epi64, mul_hh, t0_hi);
|
||||
// Lastly, extract the high 32 bits of t1 and add to t2.
|
||||
let t1_hi = map3!(_mm256_srli_epi64::<32>, t1);
|
||||
let res_hi = map3!(_mm256_add_epi64, t2, t1_hi);
|
||||
|
||||
// Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high
|
||||
// position).
|
||||
let t1_lo = {
|
||||
let t1_ps = map3!(_mm256_castsi256_ps, t1);
|
||||
let t1_lo_ps = map3!(_mm256_moveldup_ps, t1_ps);
|
||||
map3!(_mm256_castps_si256, t1_lo_ps)
|
||||
};
|
||||
let res_lo = map3!(_mm256_blend_epi32::<0xaa>, mul_ll, t1_lo);
|
||||
|
||||
(res_lo, res_hi)
|
||||
}
|
||||
|
||||
/// Addition, where the second operand is `0 <= y < 0xffffffff00000001`.
|
||||
#[inline(always)]
|
||||
unsafe fn add_small(
|
||||
x_s: (__m256i, __m256i, __m256i),
|
||||
y: (__m256i, __m256i, __m256i),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let res_wrapped_s = map3!(_mm256_add_epi64, x_s, y);
|
||||
let mask = map3!(_mm256_cmpgt_epi32, x_s, res_wrapped_s);
|
||||
let wrapback_amt = map3!(_mm256_srli_epi64::<32>, mask); // EPSILON if overflowed else 0.
|
||||
let res_s = map3!(_mm256_add_epi64, res_wrapped_s, wrapback_amt);
|
||||
res_s
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn maybe_adj_sub(res_wrapped_s: __m256i, mask: __m256i) -> __m256i {
|
||||
// The subtraction is very unlikely to overflow so we're best off branching.
|
||||
// The even u32s in `mask` are meaningless, so we want to ignore them. `_mm256_testz_pd`
|
||||
// branches depending on the sign bit of double-precision (64-bit) floats. Bit cast `mask` to
|
||||
// floating-point (this is free).
|
||||
let mask_pd = _mm256_castsi256_pd(mask);
|
||||
// `_mm256_testz_pd(mask_pd, mask_pd) == 1` iff all sign bits are 0, meaning that underflow
|
||||
// did not occur for any of the vector elements.
|
||||
if _mm256_testz_pd(mask_pd, mask_pd) == 1 {
|
||||
res_wrapped_s
|
||||
} else {
|
||||
branch_hint();
|
||||
// Highly unlikely: underflow did occur. Find adjustment per element and apply it.
|
||||
let adj_amount = _mm256_srli_epi64::<32>(mask); // EPSILON if underflow.
|
||||
_mm256_sub_epi64(res_wrapped_s, adj_amount)
|
||||
}
|
||||
}
|
||||
|
||||
/// Addition, where the second operand is much smaller than `0xffffffff00000001`.
|
||||
#[inline(always)]
|
||||
unsafe fn sub_tiny(
|
||||
x_s: (__m256i, __m256i, __m256i),
|
||||
y: (__m256i, __m256i, __m256i),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let res_wrapped_s = map3!(_mm256_sub_epi64, x_s, y);
|
||||
let mask = map3!(_mm256_cmpgt_epi32, res_wrapped_s, x_s);
|
||||
let res_s = map3!(maybe_adj_sub, res_wrapped_s, mask);
|
||||
res_s
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn reduce3(
|
||||
(lo0, hi0): ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let sign_bit = _mm256_set1_epi64x(i64::MIN);
|
||||
let epsilon = _mm256_set1_epi64x(0xffffffff);
|
||||
let lo0_s = map3!(_mm256_xor_si256, lo0, rep sign_bit);
|
||||
let hi_hi0 = map3!(_mm256_srli_epi64::<32>, hi0);
|
||||
let lo1_s = sub_tiny(lo0_s, hi_hi0);
|
||||
let t1 = map3!(_mm256_mul_epu32, hi0, rep epsilon);
|
||||
let lo2_s = add_small(lo1_s, t1);
|
||||
let lo2 = map3!(_mm256_xor_si256, lo2_s, rep sign_bit);
|
||||
lo2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mul_reduce(
|
||||
a: (__m256i, __m256i, __m256i),
|
||||
b: (__m256i, __m256i, __m256i),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
reduce3(mul3(a, b))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn square_reduce(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||
reduce3(square3(state))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn exp_acc(
|
||||
high: (__m256i, __m256i, __m256i),
|
||||
low: (__m256i, __m256i, __m256i),
|
||||
exp: usize,
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let mut result = high;
|
||||
for _ in 0..exp {
|
||||
result = square_reduce(result);
|
||||
}
|
||||
mul_reduce(result, low)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn do_apply_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||
let state2 = square_reduce(state);
|
||||
let state4_unreduced = square3(state2);
|
||||
let state3_unreduced = mul3(state2, state);
|
||||
let state4 = reduce3(state4_unreduced);
|
||||
let state3 = reduce3(state3_unreduced);
|
||||
let state7_unreduced = mul3(state3, state4);
|
||||
let state7 = reduce3(state7_unreduced);
|
||||
state7
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn do_apply_inv_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||
// compute base^10540996611094048183 using 72 multiplications per array element
|
||||
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
|
||||
|
||||
// compute base^10
|
||||
let t1 = square_reduce(state);
|
||||
|
||||
// compute base^100
|
||||
let t2 = square_reduce(t1);
|
||||
|
||||
// compute base^100100
|
||||
let t3 = exp_acc(t2, t2, 3);
|
||||
|
||||
// compute base^100100100100
|
||||
let t4 = exp_acc(t3, t3, 6);
|
||||
|
||||
// compute base^100100100100100100100100
|
||||
let t5 = exp_acc(t4, t4, 12);
|
||||
|
||||
// compute base^100100100100100100100100100100
|
||||
let t6 = exp_acc(t5, t3, 6);
|
||||
|
||||
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||
let t7 = exp_acc(t6, t6, 31);
|
||||
|
||||
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||
let a = square_reduce(square_reduce(mul_reduce(square_reduce(t7), t6)));
|
||||
let b = mul_reduce(t1, mul_reduce(t2, state));
|
||||
mul_reduce(a, b)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn avx2_load(state: &[u64; 12]) -> (__m256i, __m256i, __m256i) {
|
||||
(
|
||||
_mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()),
|
||||
_mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()),
|
||||
_mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn avx2_store(buf: &mut [u64; 12], state: (__m256i, __m256i, __m256i)) {
|
||||
_mm256_storeu_si256((&mut buf[0..4]).as_mut_ptr().cast::<__m256i>(), state.0);
|
||||
_mm256_storeu_si256((&mut buf[4..8]).as_mut_ptr().cast::<__m256i>(), state.1);
|
||||
_mm256_storeu_si256((&mut buf[8..12]).as_mut_ptr().cast::<__m256i>(), state.2);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub unsafe fn apply_sbox(buffer: &mut [u64; 12]) {
|
||||
let mut state = avx2_load(&buffer);
|
||||
state = do_apply_sbox(state);
|
||||
avx2_store(buffer, state);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub unsafe fn apply_inv_sbox(buffer: &mut [u64; 12]) {
|
||||
let mut state = avx2_load(&buffer);
|
||||
state = do_apply_inv_sbox(state);
|
||||
avx2_store(buffer, state);
|
||||
}
|
||||
@@ -11,7 +11,8 @@
|
||||
/// divisions by 2 and repeated modular reductions. This is because of our explicit choice of
|
||||
/// an MDS matrix that has small powers of 2 entries in frequency domain.
|
||||
/// The following implementation has benefited greatly from the discussions and insights of
|
||||
/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero.
|
||||
/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is base on Nabaglo's Plonky2
|
||||
/// implementation.
|
||||
|
||||
// Rescue MDS matrix in frequency domain.
|
||||
// More precisely, this is the output of the three 4-point (real) FFTs of the first column of
|
||||
@@ -26,7 +27,7 @@ const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1];
|
||||
|
||||
// We use split 3 x 4 FFT transform in order to transform our vectors into the frequency domain.
|
||||
#[inline(always)]
|
||||
pub(crate) const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
|
||||
pub const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
|
||||
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = state;
|
||||
|
||||
let (u0, u1, u2) = fft4_real([s0, s3, s6, s9]);
|
||||
@@ -156,14 +157,14 @@ const fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::{Felt, FieldElement, Rpo256, MDS};
|
||||
use super::super::{apply_mds, Felt, MDS, ZERO};
|
||||
use proptest::prelude::*;
|
||||
|
||||
const STATE_WIDTH: usize = 12;
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_mds_naive(state: &mut [Felt; STATE_WIDTH]) {
|
||||
let mut result = [Felt::ZERO; STATE_WIDTH];
|
||||
let mut result = [ZERO; STATE_WIDTH];
|
||||
result.iter_mut().zip(MDS).for_each(|(r, mds_row)| {
|
||||
state.iter().zip(mds_row).for_each(|(&s, m)| {
|
||||
*r += m * s;
|
||||
@@ -174,9 +175,9 @@ mod tests {
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn mds_freq_proptest(a in any::<[u64;STATE_WIDTH]>()) {
|
||||
fn mds_freq_proptest(a in any::<[u64; STATE_WIDTH]>()) {
|
||||
|
||||
let mut v1 = [Felt::ZERO;STATE_WIDTH];
|
||||
let mut v1 = [ZERO; STATE_WIDTH];
|
||||
let mut v2;
|
||||
|
||||
for i in 0..STATE_WIDTH {
|
||||
@@ -185,7 +186,7 @@ mod tests {
|
||||
v2 = v1;
|
||||
|
||||
apply_mds_naive(&mut v1);
|
||||
Rpo256::apply_mds(&mut v2);
|
||||
apply_mds(&mut v2);
|
||||
|
||||
prop_assert_eq!(v1, v2);
|
||||
}
|
||||
214
src/hash/rescue/mds/mod.rs
Normal file
214
src/hash/rescue/mds/mod.rs
Normal file
@@ -0,0 +1,214 @@
|
||||
use super::{Felt, STATE_WIDTH, ZERO};
|
||||
|
||||
mod freq;
|
||||
pub use freq::mds_multiply_freq;
|
||||
|
||||
// MDS MULTIPLICATION
|
||||
// ================================================================================================
|
||||
|
||||
#[inline(always)]
|
||||
pub fn apply_mds(state: &mut [Felt; STATE_WIDTH]) {
|
||||
let mut result = [ZERO; STATE_WIDTH];
|
||||
|
||||
// Using the linearity of the operations we can split the state into a low||high decomposition
|
||||
// and operate on each with no overflow and then combine/reduce the result to a field element.
|
||||
// The no overflow is guaranteed by the fact that the MDS matrix is a small powers of two in
|
||||
// frequency domain.
|
||||
let mut state_l = [0u64; STATE_WIDTH];
|
||||
let mut state_h = [0u64; STATE_WIDTH];
|
||||
|
||||
for r in 0..STATE_WIDTH {
|
||||
let s = state[r].inner();
|
||||
state_h[r] = s >> 32;
|
||||
state_l[r] = (s as u32) as u64;
|
||||
}
|
||||
|
||||
let state_h = mds_multiply_freq(state_h);
|
||||
let state_l = mds_multiply_freq(state_l);
|
||||
|
||||
for r in 0..STATE_WIDTH {
|
||||
let s = state_l[r] as u128 + ((state_h[r] as u128) << 32);
|
||||
let s_hi = (s >> 64) as u64;
|
||||
let s_lo = s as u64;
|
||||
let z = (s_hi << 32) - s_hi;
|
||||
let (res, over) = s_lo.overflowing_add(z);
|
||||
|
||||
result[r] = Felt::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64));
|
||||
}
|
||||
*state = result;
|
||||
}
|
||||
|
||||
// MDS MATRIX
|
||||
// ================================================================================================
|
||||
|
||||
/// RPO MDS matrix
|
||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = [
|
||||
[
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
],
|
||||
[
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
],
|
||||
[
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
],
|
||||
[
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
],
|
||||
[
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
],
|
||||
[
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
],
|
||||
[
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
],
|
||||
[
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
],
|
||||
[
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
],
|
||||
[
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
],
|
||||
[
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
],
|
||||
[
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
],
|
||||
];
|
||||
348
src/hash/rescue/mod.rs
Normal file
348
src/hash/rescue/mod.rs
Normal file
@@ -0,0 +1,348 @@
|
||||
use super::{
|
||||
CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO,
|
||||
};
|
||||
use core::ops::Range;
|
||||
|
||||
mod arch;
|
||||
pub use arch::optimized::{add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox};
|
||||
|
||||
mod mds;
|
||||
use mds::{apply_mds, MDS};
|
||||
|
||||
mod rpo;
|
||||
pub use rpo::{Rpo256, RpoDigest};
|
||||
|
||||
mod rpx;
|
||||
pub use rpx::{Rpx256, RpxDigest};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// The number of rounds is set to 7. For the RPO hash functions all rounds are uniform. For the
|
||||
/// RPX hash function, there are 3 different types of rounds.
|
||||
const NUM_ROUNDS: usize = 7;
|
||||
|
||||
/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
const STATE_WIDTH: usize = 12;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11.
|
||||
const RATE_RANGE: Range<usize> = 4..12;
|
||||
const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
|
||||
|
||||
const INPUT1_RANGE: Range<usize> = 4..8;
|
||||
const INPUT2_RANGE: Range<usize> = 8..12;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
const CAPACITY_RANGE: Range<usize> = 0..4;
|
||||
|
||||
/// The output of the hash function is a digest which consists of 4 field elements or 32 bytes.
|
||||
///
|
||||
/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the
|
||||
/// rate portion).
|
||||
const DIGEST_RANGE: Range<usize> = 4..8;
|
||||
const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start;
|
||||
|
||||
/// The number of bytes needed to encoded a digest
|
||||
const DIGEST_BYTES: usize = 32;
|
||||
|
||||
/// The number of byte chunks defining a field element when hashing a sequence of bytes
|
||||
const BINARY_CHUNK_SIZE: usize = 7;
|
||||
|
||||
/// S-Box and Inverse S-Box powers;
|
||||
///
|
||||
/// The constants are defined for tests only because the exponentiations in the code are unrolled
|
||||
/// for efficiency reasons.
|
||||
#[cfg(test)]
|
||||
const ALPHA: u64 = 7;
|
||||
#[cfg(test)]
|
||||
const INV_ALPHA: u64 = 10540996611094048183;
|
||||
|
||||
// SBOX FUNCTION
|
||||
// ================================================================================================
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||
state[0] = state[0].exp7();
|
||||
state[1] = state[1].exp7();
|
||||
state[2] = state[2].exp7();
|
||||
state[3] = state[3].exp7();
|
||||
state[4] = state[4].exp7();
|
||||
state[5] = state[5].exp7();
|
||||
state[6] = state[6].exp7();
|
||||
state[7] = state[7].exp7();
|
||||
state[8] = state[8].exp7();
|
||||
state[9] = state[9].exp7();
|
||||
state[10] = state[10].exp7();
|
||||
state[11] = state[11].exp7();
|
||||
}
|
||||
|
||||
// INVERSE SBOX FUNCTION
|
||||
// ================================================================================================
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||
// compute base^10540996611094048183 using 72 multiplications per array element
|
||||
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
|
||||
|
||||
// compute base^10
|
||||
let mut t1 = *state;
|
||||
t1.iter_mut().for_each(|t| *t = t.square());
|
||||
|
||||
// compute base^100
|
||||
let mut t2 = t1;
|
||||
t2.iter_mut().for_each(|t| *t = t.square());
|
||||
|
||||
// compute base^100100
|
||||
let t3 = exp_acc::<Felt, STATE_WIDTH, 3>(t2, t2);
|
||||
|
||||
// compute base^100100100100
|
||||
let t4 = exp_acc::<Felt, STATE_WIDTH, 6>(t3, t3);
|
||||
|
||||
// compute base^100100100100100100100100
|
||||
let t5 = exp_acc::<Felt, STATE_WIDTH, 12>(t4, t4);
|
||||
|
||||
// compute base^100100100100100100100100100100
|
||||
let t6 = exp_acc::<Felt, STATE_WIDTH, 6>(t5, t3);
|
||||
|
||||
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||
let t7 = exp_acc::<Felt, STATE_WIDTH, 31>(t6, t6);
|
||||
|
||||
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||
for (i, s) in state.iter_mut().enumerate() {
|
||||
let a = (t7[i].square() * t6[i]).square().square();
|
||||
let b = t1[i] * t2[i] * *s;
|
||||
*s = a * b;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn exp_acc<B: StarkField, const N: usize, const M: usize>(
|
||||
base: [B; N],
|
||||
tail: [B; N],
|
||||
) -> [B; N] {
|
||||
let mut result = base;
|
||||
for _ in 0..M {
|
||||
result.iter_mut().for_each(|r| *r = r.square());
|
||||
}
|
||||
result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t);
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) {
|
||||
state.iter_mut().zip(ark).for_each(|(s, &k)| *s += k);
|
||||
}
|
||||
|
||||
// ROUND CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Rescue round constants;
|
||||
/// computed as in [specifications](https://github.com/ASDiscreteMathematics/rpo)
|
||||
///
|
||||
/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the
|
||||
/// first half of RPO round, and ARK2 contains constants for the second half of RPO round.
|
||||
const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(5789762306288267392),
|
||||
Felt::new(6522564764413701783),
|
||||
Felt::new(17809893479458208203),
|
||||
Felt::new(107145243989736508),
|
||||
Felt::new(6388978042437517382),
|
||||
Felt::new(15844067734406016715),
|
||||
Felt::new(9975000513555218239),
|
||||
Felt::new(3344984123768313364),
|
||||
Felt::new(9959189626657347191),
|
||||
Felt::new(12960773468763563665),
|
||||
Felt::new(9602914297752488475),
|
||||
Felt::new(16657542370200465908),
|
||||
],
|
||||
[
|
||||
Felt::new(12987190162843096997),
|
||||
Felt::new(653957632802705281),
|
||||
Felt::new(4441654670647621225),
|
||||
Felt::new(4038207883745915761),
|
||||
Felt::new(5613464648874830118),
|
||||
Felt::new(13222989726778338773),
|
||||
Felt::new(3037761201230264149),
|
||||
Felt::new(16683759727265180203),
|
||||
Felt::new(8337364536491240715),
|
||||
Felt::new(3227397518293416448),
|
||||
Felt::new(8110510111539674682),
|
||||
Felt::new(2872078294163232137),
|
||||
],
|
||||
[
|
||||
Felt::new(18072785500942327487),
|
||||
Felt::new(6200974112677013481),
|
||||
Felt::new(17682092219085884187),
|
||||
Felt::new(10599526828986756440),
|
||||
Felt::new(975003873302957338),
|
||||
Felt::new(8264241093196931281),
|
||||
Felt::new(10065763900435475170),
|
||||
Felt::new(2181131744534710197),
|
||||
Felt::new(6317303992309418647),
|
||||
Felt::new(1401440938888741532),
|
||||
Felt::new(8884468225181997494),
|
||||
Felt::new(13066900325715521532),
|
||||
],
|
||||
[
|
||||
Felt::new(5674685213610121970),
|
||||
Felt::new(5759084860419474071),
|
||||
Felt::new(13943282657648897737),
|
||||
Felt::new(1352748651966375394),
|
||||
Felt::new(17110913224029905221),
|
||||
Felt::new(1003883795902368422),
|
||||
Felt::new(4141870621881018291),
|
||||
Felt::new(8121410972417424656),
|
||||
Felt::new(14300518605864919529),
|
||||
Felt::new(13712227150607670181),
|
||||
Felt::new(17021852944633065291),
|
||||
Felt::new(6252096473787587650),
|
||||
],
|
||||
[
|
||||
Felt::new(4887609836208846458),
|
||||
Felt::new(3027115137917284492),
|
||||
Felt::new(9595098600469470675),
|
||||
Felt::new(10528569829048484079),
|
||||
Felt::new(7864689113198939815),
|
||||
Felt::new(17533723827845969040),
|
||||
Felt::new(5781638039037710951),
|
||||
Felt::new(17024078752430719006),
|
||||
Felt::new(109659393484013511),
|
||||
Felt::new(7158933660534805869),
|
||||
Felt::new(2955076958026921730),
|
||||
Felt::new(7433723648458773977),
|
||||
],
|
||||
[
|
||||
Felt::new(16308865189192447297),
|
||||
Felt::new(11977192855656444890),
|
||||
Felt::new(12532242556065780287),
|
||||
Felt::new(14594890931430968898),
|
||||
Felt::new(7291784239689209784),
|
||||
Felt::new(5514718540551361949),
|
||||
Felt::new(10025733853830934803),
|
||||
Felt::new(7293794580341021693),
|
||||
Felt::new(6728552937464861756),
|
||||
Felt::new(6332385040983343262),
|
||||
Felt::new(13277683694236792804),
|
||||
Felt::new(2600778905124452676),
|
||||
],
|
||||
[
|
||||
Felt::new(7123075680859040534),
|
||||
Felt::new(1034205548717903090),
|
||||
Felt::new(7717824418247931797),
|
||||
Felt::new(3019070937878604058),
|
||||
Felt::new(11403792746066867460),
|
||||
Felt::new(10280580802233112374),
|
||||
Felt::new(337153209462421218),
|
||||
Felt::new(13333398568519923717),
|
||||
Felt::new(3596153696935337464),
|
||||
Felt::new(8104208463525993784),
|
||||
Felt::new(14345062289456085693),
|
||||
Felt::new(17036731477169661256),
|
||||
],
|
||||
];
|
||||
|
||||
const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(6077062762357204287),
|
||||
Felt::new(15277620170502011191),
|
||||
Felt::new(5358738125714196705),
|
||||
Felt::new(14233283787297595718),
|
||||
Felt::new(13792579614346651365),
|
||||
Felt::new(11614812331536767105),
|
||||
Felt::new(14871063686742261166),
|
||||
Felt::new(10148237148793043499),
|
||||
Felt::new(4457428952329675767),
|
||||
Felt::new(15590786458219172475),
|
||||
Felt::new(10063319113072092615),
|
||||
Felt::new(14200078843431360086),
|
||||
],
|
||||
[
|
||||
Felt::new(6202948458916099932),
|
||||
Felt::new(17690140365333231091),
|
||||
Felt::new(3595001575307484651),
|
||||
Felt::new(373995945117666487),
|
||||
Felt::new(1235734395091296013),
|
||||
Felt::new(14172757457833931602),
|
||||
Felt::new(707573103686350224),
|
||||
Felt::new(15453217512188187135),
|
||||
Felt::new(219777875004506018),
|
||||
Felt::new(17876696346199469008),
|
||||
Felt::new(17731621626449383378),
|
||||
Felt::new(2897136237748376248),
|
||||
],
|
||||
[
|
||||
Felt::new(8023374565629191455),
|
||||
Felt::new(15013690343205953430),
|
||||
Felt::new(4485500052507912973),
|
||||
Felt::new(12489737547229155153),
|
||||
Felt::new(9500452585969030576),
|
||||
Felt::new(2054001340201038870),
|
||||
Felt::new(12420704059284934186),
|
||||
Felt::new(355990932618543755),
|
||||
Felt::new(9071225051243523860),
|
||||
Felt::new(12766199826003448536),
|
||||
Felt::new(9045979173463556963),
|
||||
Felt::new(12934431667190679898),
|
||||
],
|
||||
[
|
||||
Felt::new(18389244934624494276),
|
||||
Felt::new(16731736864863925227),
|
||||
Felt::new(4440209734760478192),
|
||||
Felt::new(17208448209698888938),
|
||||
Felt::new(8739495587021565984),
|
||||
Felt::new(17000774922218161967),
|
||||
Felt::new(13533282547195532087),
|
||||
Felt::new(525402848358706231),
|
||||
Felt::new(16987541523062161972),
|
||||
Felt::new(5466806524462797102),
|
||||
Felt::new(14512769585918244983),
|
||||
Felt::new(10973956031244051118),
|
||||
],
|
||||
[
|
||||
Felt::new(6982293561042362913),
|
||||
Felt::new(14065426295947720331),
|
||||
Felt::new(16451845770444974180),
|
||||
Felt::new(7139138592091306727),
|
||||
Felt::new(9012006439959783127),
|
||||
Felt::new(14619614108529063361),
|
||||
Felt::new(1394813199588124371),
|
||||
Felt::new(4635111139507788575),
|
||||
Felt::new(16217473952264203365),
|
||||
Felt::new(10782018226466330683),
|
||||
Felt::new(6844229992533662050),
|
||||
Felt::new(7446486531695178711),
|
||||
],
|
||||
[
|
||||
Felt::new(3736792340494631448),
|
||||
Felt::new(577852220195055341),
|
||||
Felt::new(6689998335515779805),
|
||||
Felt::new(13886063479078013492),
|
||||
Felt::new(14358505101923202168),
|
||||
Felt::new(7744142531772274164),
|
||||
Felt::new(16135070735728404443),
|
||||
Felt::new(12290902521256031137),
|
||||
Felt::new(12059913662657709804),
|
||||
Felt::new(16456018495793751911),
|
||||
Felt::new(4571485474751953524),
|
||||
Felt::new(17200392109565783176),
|
||||
],
|
||||
[
|
||||
Felt::new(17130398059294018733),
|
||||
Felt::new(519782857322261988),
|
||||
Felt::new(9625384390925085478),
|
||||
Felt::new(1664893052631119222),
|
||||
Felt::new(7629576092524553570),
|
||||
Felt::new(3485239601103661425),
|
||||
Felt::new(9755891797164033838),
|
||||
Felt::new(15218148195153269027),
|
||||
Felt::new(16460604813734957368),
|
||||
Felt::new(9643968136937729763),
|
||||
Felt::new(3611348709641382851),
|
||||
Felt::new(18256379591337759196),
|
||||
],
|
||||
];
|
||||
408
src/hash/rescue/rpo/digest.rs
Normal file
408
src/hash/rescue/rpo/digest.rs
Normal file
@@ -0,0 +1,408 @@
|
||||
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
|
||||
use crate::utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
};
|
||||
use core::{cmp::Ordering, fmt::Display, ops::Deref};
|
||||
use winter_utils::Randomizable;
|
||||
|
||||
// DIGEST TRAIT IMPLEMENTATIONS
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
|
||||
pub struct RpoDigest([Felt; DIGEST_SIZE]);
|
||||
|
||||
impl RpoDigest {
|
||||
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
|
||||
pub fn as_elements(&self) -> &[Felt] {
|
||||
self.as_ref()
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
|
||||
<Self as Digest>::as_bytes(self)
|
||||
}
|
||||
|
||||
pub fn digests_as_elements<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
|
||||
where
|
||||
I: Iterator<Item = &'a Self>,
|
||||
{
|
||||
digests.flat_map(|d| d.0.iter())
|
||||
}
|
||||
}
|
||||
|
||||
impl Digest for RpoDigest {
|
||||
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
|
||||
let mut result = [0; DIGEST_BYTES];
|
||||
|
||||
result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes());
|
||||
result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes());
|
||||
result[16..24].copy_from_slice(&self.0[2].as_int().to_le_bytes());
|
||||
result[24..].copy_from_slice(&self.0[3].as_int().to_le_bytes());
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for RpoDigest {
|
||||
type Target = [Felt; DIGEST_SIZE];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for RpoDigest {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
// compare the inner u64 of both elements.
|
||||
//
|
||||
// it will iterate the elements and will return the first computation different than
|
||||
// `Equal`. Otherwise, the ordering is equal.
|
||||
//
|
||||
// the endianness is irrelevant here because since, this being a cryptographically secure
|
||||
// hash computation, the digest shouldn't have any ordered property of its input.
|
||||
//
|
||||
// finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a
|
||||
// montgomery reduction for every limb. that is safe because every inner element of the
|
||||
// digest is guaranteed to be in its canonical form (that is, `x in [0,p)`).
|
||||
self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold(
|
||||
Ordering::Equal,
|
||||
|ord, (a, b)| match ord {
|
||||
Ordering::Equal => a.cmp(&b),
|
||||
_ => ord,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for RpoDigest {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for RpoDigest {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
let encoded: String = self.into();
|
||||
write!(f, "{}", encoded)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Randomizable for RpoDigest {
|
||||
const VALUE_SIZE: usize = DIGEST_BYTES;
|
||||
|
||||
fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
|
||||
let bytes_array: Option<[u8; 32]> = bytes.try_into().ok();
|
||||
if let Some(bytes_array) = bytes_array {
|
||||
Self::try_from(bytes_array).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: FROM RPO DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
impl From<&RpoDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for [u64; DIGEST_SIZE] {
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
[
|
||||
value.0[0].as_int(),
|
||||
value.0[1].as_int(),
|
||||
value.0[2].as_int(),
|
||||
value.0[3].as_int(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for [u64; DIGEST_SIZE] {
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
[
|
||||
value.0[0].as_int(),
|
||||
value.0[1].as_int(),
|
||||
value.0[2].as_int(),
|
||||
value.0[3].as_int(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for [u8; DIGEST_BYTES] {
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
value.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for [u8; DIGEST_BYTES] {
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
value.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
bytes_to_hex_string(value.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: TO RPO DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum RpoDigestError {
|
||||
/// The provided u64 integer does not fit in the field's moduli.
|
||||
InvalidInteger,
|
||||
}
|
||||
|
||||
impl From<&[Felt; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: &[Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(*value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[Felt; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: [u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
// Note: the input length is known, the conversion from slice to array must succeed so the
|
||||
// `unwrap`s below are safe
|
||||
let a = u64::from_le_bytes(value[0..8].try_into().unwrap());
|
||||
let b = u64::from_le_bytes(value[8..16].try_into().unwrap());
|
||||
let c = u64::from_le_bytes(value[16..24].try_into().unwrap());
|
||||
let d = u64::from_le_bytes(value[24..32].try_into().unwrap());
|
||||
|
||||
if [a, b, c, d].iter().any(|v| *v >= Felt::MODULUS) {
|
||||
return Err(HexParseError::OutOfRange);
|
||||
}
|
||||
|
||||
Ok(RpoDigest([Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8; DIGEST_BYTES]> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u64; DIGEST_SIZE]> for RpoDigest {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
|
||||
if value[0] >= Felt::MODULUS
|
||||
|| value[1] >= Felt::MODULUS
|
||||
|| value[2] >= Felt::MODULUS
|
||||
|| value[3] >= Felt::MODULUS
|
||||
{
|
||||
return Err(RpoDigestError::InvalidInteger);
|
||||
}
|
||||
|
||||
Ok(Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()]))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u64; DIGEST_SIZE]> for RpoDigest {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
hex_to_bytes(value).and_then(|v| v.try_into())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<String> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: String) -> Result<Self, Self::Error> {
|
||||
value.as_str().try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&String> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: &String) -> Result<Self, Self::Error> {
|
||||
value.as_str().try_into()
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for RpoDigest {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.as_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RpoDigest {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE];
|
||||
for inner in inner.iter_mut() {
|
||||
let e = source.read_u64()?;
|
||||
if e >= Felt::MODULUS {
|
||||
return Err(DeserializationError::InvalidValue(String::from(
|
||||
"Value not in the appropriate range",
|
||||
)));
|
||||
}
|
||||
*inner = Felt::new(e);
|
||||
}
|
||||
|
||||
Ok(Self(inner))
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
impl IntoIterator for RpoDigest {
|
||||
type Item = Felt;
|
||||
type IntoIter = <[Felt; 4] as IntoIterator>::IntoIter;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.0.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
|
||||
use crate::utils::{string::String, SliceReader};
|
||||
use rand_utils::rand_value;
|
||||
|
||||
#[test]
|
||||
fn digest_serialization() {
|
||||
let e1 = Felt::new(rand_value());
|
||||
let e2 = Felt::new(rand_value());
|
||||
let e3 = Felt::new(rand_value());
|
||||
let e4 = Felt::new(rand_value());
|
||||
|
||||
let d1 = RpoDigest([e1, e2, e3, e4]);
|
||||
|
||||
let mut bytes = vec![];
|
||||
d1.write_into(&mut bytes);
|
||||
assert_eq!(DIGEST_BYTES, bytes.len());
|
||||
|
||||
let mut reader = SliceReader::new(&bytes);
|
||||
let d2 = RpoDigest::read_from(&mut reader).unwrap();
|
||||
|
||||
assert_eq!(d1, d2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn digest_encoding() {
|
||||
let digest = RpoDigest([
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
let string: String = digest.into();
|
||||
let round_trip: RpoDigest = string.try_into().expect("decoding failed");
|
||||
|
||||
assert_eq!(digest, round_trip);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversions() {
|
||||
let digest = RpoDigest([
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = digest.into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = (&digest).into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpoDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpoDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
}
|
||||
}
|
||||
323
src/hash/rescue/rpo/mod.rs
Normal file
323
src/hash/rescue/rpo/mod.rs
Normal file
@@ -0,0 +1,323 @@
|
||||
use super::{
|
||||
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
|
||||
apply_mds, apply_sbox, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ARK1,
|
||||
ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, DIGEST_SIZE, INPUT1_RANGE,
|
||||
INPUT2_RANGE, MDS, NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO,
|
||||
};
|
||||
use core::{convert::TryInto, ops::Range};
|
||||
|
||||
mod digest;
|
||||
pub use digest::RpoDigest;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// HASHER IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
/// Implementation of the Rescue Prime Optimized hash function with 256-bit output.
|
||||
///
|
||||
/// The hash function is implemented according to the Rescue Prime Optimized
|
||||
/// [specifications](https://eprint.iacr.org/2022/1577)
|
||||
///
|
||||
/// The parameters used to instantiate the function are:
|
||||
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
|
||||
/// * State width: 12 field elements.
|
||||
/// * Capacity size: 4 field elements.
|
||||
/// * Number of founds: 7.
|
||||
/// * S-Box degree: 7.
|
||||
///
|
||||
/// The above parameters target 128-bit security level. The digest consists of four field elements
|
||||
/// and it can be serialized into 32 bytes (256 bits).
|
||||
///
|
||||
/// ## Hash output consistency
|
||||
/// Functions [hash_elements()](Rpo256::hash_elements), [merge()](Rpo256::merge), and
|
||||
/// [merge_with_int()](Rpo256::merge_with_int) are internally consistent. That is, computing
|
||||
/// a hash for the same set of elements using these functions will always produce the same
|
||||
/// result. For example, merging two digests using [merge()](Rpo256::merge) will produce the
|
||||
/// same result as hashing 8 elements which make up these digests using
|
||||
/// [hash_elements()](Rpo256::hash_elements) function.
|
||||
///
|
||||
/// However, [hash()](Rpo256::hash) function is not consistent with functions mentioned above.
|
||||
/// For example, if we take two field elements, serialize them to bytes and hash them using
|
||||
/// [hash()](Rpo256::hash), the result will differ from the result obtained by hashing these
|
||||
/// elements directly using [hash_elements()](Rpo256::hash_elements) function. The reason for
|
||||
/// this difference is that [hash()](Rpo256::hash) function needs to be able to handle
|
||||
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
|
||||
/// deserialization procedure used by this function is different from the procedure used to
|
||||
/// deserialize valid field elements.
|
||||
///
|
||||
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
||||
/// to deserialize them into field elements and then hash them using
|
||||
/// [hash_elements()](Rpo256::hash_elements) function rather then hashing the serialized bytes
|
||||
/// using [hash()](Rpo256::hash) function.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Rpo256();
|
||||
|
||||
impl Hasher for Rpo256 {
|
||||
/// Rpo256 collision resistance is the same as the security level, that is 128-bits.
|
||||
///
|
||||
/// #### Collision resistance
|
||||
///
|
||||
/// However, our setup of the capacity registers might drop it to 126.
|
||||
///
|
||||
/// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69)
|
||||
const COLLISION_RESISTANCE: u32 = 128;
|
||||
|
||||
type Digest = RpoDigest;
|
||||
|
||||
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||
// initialize the state with zeroes
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
// set the capacity (first element) to a flag on whether or not the input length is evenly
|
||||
// divided by the rate. this will prevent collisions between padded and non-padded inputs,
|
||||
// and will rule out the need to perform an extra permutation in case of evenly divided
|
||||
// inputs.
|
||||
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0;
|
||||
if !is_rate_multiple {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// initialize a buffer to receive the little-endian elements.
|
||||
let mut buf = [0_u8; 8];
|
||||
|
||||
// iterate the chunks of bytes, creating a field element from each chunk and copying it
|
||||
// into the state.
|
||||
//
|
||||
// every time the rate range is filled, a permutation is performed. if the final value of
|
||||
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
|
||||
// additional permutation must be performed.
|
||||
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
|
||||
// the last element of the iteration may or may not be a full chunk. if it's not, then
|
||||
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`.
|
||||
// this will avoid collisions.
|
||||
if chunk.len() == BINARY_CHUNK_SIZE {
|
||||
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||
} else {
|
||||
buf.fill(0);
|
||||
buf[..chunk.len()].copy_from_slice(chunk);
|
||||
buf[chunk.len()] = 1;
|
||||
}
|
||||
|
||||
// set the current rate element to the input. since we take at most 7 bytes, we are
|
||||
// guaranteed that the inputs data will fit into a single field element.
|
||||
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf));
|
||||
|
||||
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
||||
// counter to the beginning of the range.
|
||||
if i == RATE_WIDTH - 1 {
|
||||
Self::apply_permutation(&mut state);
|
||||
0
|
||||
} else {
|
||||
i + 1
|
||||
}
|
||||
});
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we
|
||||
// don't need to apply any extra padding because the first capacity element contains a
|
||||
// flag indicating whether the input is evenly divisible by the rate.
|
||||
if i != 0 {
|
||||
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the rate as hash result.
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = Self::Digest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
// initialize the state as follows:
|
||||
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element
|
||||
// and set the sixth rate element to 1.
|
||||
// - if the value doesn't fit into a single field element, split it into two field
|
||||
// elements, copy them into rate elements 5 and 6, and set the seventh rate element
|
||||
// to 1.
|
||||
// - set the first capacity element to 1
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||
if value < Felt::MODULUS {
|
||||
state[INPUT2_RANGE.start + 1] = ONE;
|
||||
} else {
|
||||
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
||||
state[INPUT2_RANGE.start + 2] = ONE;
|
||||
}
|
||||
|
||||
// common padding for both cases
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementHasher for Rpo256 {
|
||||
type BaseField = Felt;
|
||||
|
||||
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
|
||||
// convert the elements into a list of base field elements
|
||||
let elements = E::slice_as_base_elements(elements);
|
||||
|
||||
// initialize state to all zeros, except for the first element of the capacity part, which
|
||||
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
if elements.len() % RATE_WIDTH != 0 {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// absorb elements into the state one by one until the rate portion of the state is filled
|
||||
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
||||
// elements have been absorbed
|
||||
let mut i = 0;
|
||||
for &element in elements.iter() {
|
||||
state[RATE_RANGE.start + i] = element;
|
||||
i += 1;
|
||||
if i % RATE_WIDTH == 0 {
|
||||
Self::apply_permutation(&mut state);
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after
|
||||
// padding by appending a 1 followed by as many 0 as necessary to make the input length a
|
||||
// multiple of the RATE_WIDTH.
|
||||
if i > 0 {
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
i += 1;
|
||||
while i != RATE_WIDTH {
|
||||
state[RATE_RANGE.start + i] = ZERO;
|
||||
i += 1;
|
||||
}
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the state as hash result
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
// HASH FUNCTION IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Rpo256 {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// The number of rounds is set to 7 to target 128-bit security level.
|
||||
pub const NUM_ROUNDS: usize = NUM_ROUNDS;
|
||||
|
||||
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
pub const STATE_WIDTH: usize = STATE_WIDTH;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
|
||||
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
|
||||
|
||||
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
|
||||
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
|
||||
|
||||
/// MDS matrix used for computing the linear layer in a RPO round.
|
||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
|
||||
|
||||
/// Round constants added to the hasher state in the first half of the RPO round.
|
||||
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
|
||||
|
||||
/// Round constants added to the hasher state in the second half of the RPO round.
|
||||
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
|
||||
|
||||
// TRAIT PASS-THROUGH FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of the provided sequence of bytes.
|
||||
#[inline(always)]
|
||||
pub fn hash(bytes: &[u8]) -> RpoDigest {
|
||||
<Self as Hasher>::hash(bytes)
|
||||
}
|
||||
|
||||
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||
/// Merkle trees and verification of Merkle paths.
|
||||
#[inline(always)]
|
||||
pub fn merge(values: &[RpoDigest; 2]) -> RpoDigest {
|
||||
<Self as Hasher>::merge(values)
|
||||
}
|
||||
|
||||
/// Returns a hash of the provided field elements.
|
||||
#[inline(always)]
|
||||
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpoDigest {
|
||||
<Self as ElementHasher>::hash_elements(elements)
|
||||
}
|
||||
|
||||
// DOMAIN IDENTIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of two digests and a domain identifier.
|
||||
pub fn merge_in_domain(values: &[RpoDigest; 2], domain: Felt) -> RpoDigest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = RpoDigest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// set the second capacity element to the domain value. The first capacity element is used
|
||||
// for padding purposes.
|
||||
state[CAPACITY_RANGE.start + 1] = domain;
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
// RESCUE PERMUTATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Applies RPO permutation to the provided state.
|
||||
#[inline(always)]
|
||||
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
|
||||
for i in 0..NUM_ROUNDS {
|
||||
Self::apply_round(state, i);
|
||||
}
|
||||
}
|
||||
|
||||
/// RPO round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
// apply first half of RPO round
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_sbox(state, &ARK1[round]) {
|
||||
add_constants(state, &ARK1[round]);
|
||||
apply_sbox(state);
|
||||
}
|
||||
|
||||
// apply second half of RPO round
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
|
||||
add_constants(state, &ARK2[round]);
|
||||
apply_inv_sbox(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,19 +1,15 @@
|
||||
use super::{
|
||||
Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, ALPHA, INV_ALPHA, ONE, STATE_WIDTH,
|
||||
ZERO,
|
||||
super::{apply_inv_sbox, apply_sbox, ALPHA, INV_ALPHA},
|
||||
Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, ONE, STATE_WIDTH, ZERO,
|
||||
};
|
||||
use crate::{
|
||||
utils::collections::{BTreeSet, Vec},
|
||||
Word,
|
||||
};
|
||||
use crate::utils::collections::{BTreeSet, Vec};
|
||||
use core::convert::TryInto;
|
||||
use proptest::prelude::*;
|
||||
use rand_utils::rand_value;
|
||||
|
||||
#[test]
|
||||
fn test_alphas() {
|
||||
let e: Felt = Felt::new(rand_value());
|
||||
let e_exp = e.exp(ALPHA);
|
||||
assert_eq!(e, e_exp.exp(INV_ALPHA));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sbox() {
|
||||
let state = [Felt::new(rand_value()); STATE_WIDTH];
|
||||
@@ -22,7 +18,7 @@ fn test_sbox() {
|
||||
expected.iter_mut().for_each(|v| *v = v.exp(ALPHA));
|
||||
|
||||
let mut actual = state;
|
||||
Rpo256::apply_sbox(&mut actual);
|
||||
apply_sbox(&mut actual);
|
||||
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
@@ -35,7 +31,7 @@ fn test_inv_sbox() {
|
||||
expected.iter_mut().for_each(|v| *v = v.exp(INV_ALPHA));
|
||||
|
||||
let mut actual = state;
|
||||
Rpo256::apply_inv_sbox(&mut actual);
|
||||
apply_inv_sbox(&mut actual);
|
||||
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
@@ -102,7 +98,7 @@ fn hash_elements_vs_merge_with_int() {
|
||||
|
||||
let mut elements = seed.as_elements().to_vec();
|
||||
elements.push(Felt::new(val));
|
||||
elements.push(Felt::new(1));
|
||||
elements.push(ONE);
|
||||
let h_result = Rpo256::hash_elements(&elements);
|
||||
|
||||
assert_eq!(m_result, h_result);
|
||||
@@ -144,8 +140,8 @@ fn hash_elements_padding() {
|
||||
#[test]
|
||||
fn hash_elements() {
|
||||
let elements = [
|
||||
Felt::new(0),
|
||||
Felt::new(1),
|
||||
ZERO,
|
||||
ONE,
|
||||
Felt::new(2),
|
||||
Felt::new(3),
|
||||
Felt::new(4),
|
||||
@@ -167,8 +163,8 @@ fn hash_elements() {
|
||||
#[test]
|
||||
fn hash_test_vectors() {
|
||||
let elements = [
|
||||
Felt::new(0),
|
||||
Felt::new(1),
|
||||
ZERO,
|
||||
ONE,
|
||||
Felt::new(2),
|
||||
Felt::new(3),
|
||||
Felt::new(4),
|
||||
@@ -203,7 +199,7 @@ fn sponge_bytes_with_remainder_length_wont_panic() {
|
||||
// size.
|
||||
//
|
||||
// this is a preliminary test to the fuzzy-stress of proptest.
|
||||
Rpo256::hash(&vec![0; 113]);
|
||||
Rpo256::hash(&[0; 113]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -227,12 +223,12 @@ fn sponge_zeroes_collision() {
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn rpo256_wont_panic_with_arbitrary_input(ref vec in any::<Vec<u8>>()) {
|
||||
Rpo256::hash(&vec);
|
||||
fn rpo256_wont_panic_with_arbitrary_input(ref bytes in any::<Vec<u8>>()) {
|
||||
Rpo256::hash(bytes);
|
||||
}
|
||||
}
|
||||
|
||||
const EXPECTED: [[Felt; 4]; 19] = [
|
||||
const EXPECTED: [Word; 19] = [
|
||||
[
|
||||
Felt::new(1502364727743950833),
|
||||
Felt::new(5880949717274681448),
|
||||
398
src/hash/rescue/rpx/digest.rs
Normal file
398
src/hash/rescue/rpx/digest.rs
Normal file
@@ -0,0 +1,398 @@
|
||||
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
|
||||
use crate::utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
};
|
||||
use core::{cmp::Ordering, fmt::Display, ops::Deref};
|
||||
use winter_utils::Randomizable;
|
||||
|
||||
// DIGEST TRAIT IMPLEMENTATIONS
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
|
||||
pub struct RpxDigest([Felt; DIGEST_SIZE]);
|
||||
|
||||
impl RpxDigest {
|
||||
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
|
||||
pub fn as_elements(&self) -> &[Felt] {
|
||||
self.as_ref()
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
|
||||
<Self as Digest>::as_bytes(self)
|
||||
}
|
||||
|
||||
pub fn digests_as_elements<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
|
||||
where
|
||||
I: Iterator<Item = &'a Self>,
|
||||
{
|
||||
digests.flat_map(|d| d.0.iter())
|
||||
}
|
||||
}
|
||||
|
||||
impl Digest for RpxDigest {
|
||||
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
|
||||
let mut result = [0; DIGEST_BYTES];
|
||||
|
||||
result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes());
|
||||
result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes());
|
||||
result[16..24].copy_from_slice(&self.0[2].as_int().to_le_bytes());
|
||||
result[24..].copy_from_slice(&self.0[3].as_int().to_le_bytes());
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for RpxDigest {
|
||||
type Target = [Felt; DIGEST_SIZE];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for RpxDigest {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
// compare the inner u64 of both elements.
|
||||
//
|
||||
// it will iterate the elements and will return the first computation different than
|
||||
// `Equal`. Otherwise, the ordering is equal.
|
||||
//
|
||||
// the endianness is irrelevant here because since, this being a cryptographically secure
|
||||
// hash computation, the digest shouldn't have any ordered property of its input.
|
||||
//
|
||||
// finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a
|
||||
// montgomery reduction for every limb. that is safe because every inner element of the
|
||||
// digest is guaranteed to be in its canonical form (that is, `x in [0,p)`).
|
||||
self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold(
|
||||
Ordering::Equal,
|
||||
|ord, (a, b)| match ord {
|
||||
Ordering::Equal => a.cmp(&b),
|
||||
_ => ord,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for RpxDigest {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for RpxDigest {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
let encoded: String = self.into();
|
||||
write!(f, "{}", encoded)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Randomizable for RpxDigest {
|
||||
const VALUE_SIZE: usize = DIGEST_BYTES;
|
||||
|
||||
fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
|
||||
let bytes_array: Option<[u8; 32]> = bytes.try_into().ok();
|
||||
if let Some(bytes_array) = bytes_array {
|
||||
Self::try_from(bytes_array).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: FROM RPX DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
impl From<&RpxDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for [u64; DIGEST_SIZE] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
[
|
||||
value.0[0].as_int(),
|
||||
value.0[1].as_int(),
|
||||
value.0[2].as_int(),
|
||||
value.0[3].as_int(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for [u64; DIGEST_SIZE] {
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
[
|
||||
value.0[0].as_int(),
|
||||
value.0[1].as_int(),
|
||||
value.0[2].as_int(),
|
||||
value.0[3].as_int(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for [u8; DIGEST_BYTES] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
value.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for [u8; DIGEST_BYTES] {
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
value.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
bytes_to_hex_string(value.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: TO RPX DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum RpxDigestError {
|
||||
/// The provided u64 integer does not fit in the field's moduli.
|
||||
InvalidInteger,
|
||||
}
|
||||
|
||||
impl From<&[Felt; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: &[Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(*value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[Felt; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u8; DIGEST_BYTES]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: [u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
// Note: the input length is known, the conversion from slice to array must succeed so the
|
||||
// `unwrap`s below are safe
|
||||
let a = u64::from_le_bytes(value[0..8].try_into().unwrap());
|
||||
let b = u64::from_le_bytes(value[8..16].try_into().unwrap());
|
||||
let c = u64::from_le_bytes(value[16..24].try_into().unwrap());
|
||||
let d = u64::from_le_bytes(value[24..32].try_into().unwrap());
|
||||
|
||||
if [a, b, c, d].iter().any(|v| *v >= Felt::MODULUS) {
|
||||
return Err(HexParseError::OutOfRange);
|
||||
}
|
||||
|
||||
Ok(RpxDigest([Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8; DIGEST_BYTES]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u64; DIGEST_SIZE]> for RpxDigest {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
|
||||
if value[0] >= Felt::MODULUS
|
||||
|| value[1] >= Felt::MODULUS
|
||||
|| value[2] >= Felt::MODULUS
|
||||
|| value[3] >= Felt::MODULUS
|
||||
{
|
||||
return Err(RpxDigestError::InvalidInteger);
|
||||
}
|
||||
|
||||
Ok(Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()]))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u64; DIGEST_SIZE]> for RpxDigest {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
hex_to_bytes(value).and_then(|v| v.try_into())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<String> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: String) -> Result<Self, Self::Error> {
|
||||
value.as_str().try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&String> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: &String) -> Result<Self, Self::Error> {
|
||||
value.as_str().try_into()
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for RpxDigest {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.as_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RpxDigest {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE];
|
||||
for inner in inner.iter_mut() {
|
||||
let e = source.read_u64()?;
|
||||
if e >= Felt::MODULUS {
|
||||
return Err(DeserializationError::InvalidValue(String::from(
|
||||
"Value not in the appropriate range",
|
||||
)));
|
||||
}
|
||||
*inner = Felt::new(e);
|
||||
}
|
||||
|
||||
Ok(Self(inner))
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Deserializable, Felt, RpxDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
|
||||
use crate::utils::{string::String, SliceReader};
|
||||
use rand_utils::rand_value;
|
||||
|
||||
#[test]
|
||||
fn digest_serialization() {
|
||||
let e1 = Felt::new(rand_value());
|
||||
let e2 = Felt::new(rand_value());
|
||||
let e3 = Felt::new(rand_value());
|
||||
let e4 = Felt::new(rand_value());
|
||||
|
||||
let d1 = RpxDigest([e1, e2, e3, e4]);
|
||||
|
||||
let mut bytes = vec![];
|
||||
d1.write_into(&mut bytes);
|
||||
assert_eq!(DIGEST_BYTES, bytes.len());
|
||||
|
||||
let mut reader = SliceReader::new(&bytes);
|
||||
let d2 = RpxDigest::read_from(&mut reader).unwrap();
|
||||
|
||||
assert_eq!(d1, d2);
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn digest_encoding() {
|
||||
let digest = RpxDigest([
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
let string: String = digest.into();
|
||||
let round_trip: RpxDigest = string.try_into().expect("decoding failed");
|
||||
|
||||
assert_eq!(digest, round_trip);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversions() {
|
||||
let digest = RpxDigest([
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = digest.into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = (&digest).into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpxDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpxDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
}
|
||||
}
|
||||
366
src/hash/rescue/rpx/mod.rs
Normal file
366
src/hash/rescue/rpx/mod.rs
Normal file
@@ -0,0 +1,366 @@
|
||||
use super::{
|
||||
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
|
||||
apply_mds, apply_sbox, CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher,
|
||||
StarkField, ARK1, ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE,
|
||||
DIGEST_SIZE, INPUT1_RANGE, INPUT2_RANGE, MDS, NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH,
|
||||
STATE_WIDTH, ZERO,
|
||||
};
|
||||
use core::{convert::TryInto, ops::Range};
|
||||
|
||||
mod digest;
|
||||
pub use digest::RpxDigest;
|
||||
|
||||
pub type CubicExtElement = CubeExtension<Felt>;
|
||||
|
||||
// HASHER IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
/// Implementation of the Rescue Prime eXtension hash function with 256-bit output.
|
||||
///
|
||||
/// The hash function is based on the XHash12 construction in [specifications](https://eprint.iacr.org/2023/1045)
|
||||
///
|
||||
/// The parameters used to instantiate the function are:
|
||||
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
|
||||
/// * State width: 12 field elements.
|
||||
/// * Capacity size: 4 field elements.
|
||||
/// * S-Box degree: 7.
|
||||
/// * Rounds: There are 3 different types of rounds:
|
||||
/// - (FB): `apply_mds` → `add_constants` → `apply_sbox` → `apply_mds` → `add_constants` → `apply_inv_sbox`.
|
||||
/// - (E): `add_constants` → `ext_sbox` (which is raising to power 7 in the degree 3 extension field).
|
||||
/// - (M): `apply_mds` → `add_constants`.
|
||||
/// * Permutation: (FB) (E) (FB) (E) (FB) (E) (M).
|
||||
///
|
||||
/// The above parameters target 128-bit security level. The digest consists of four field elements
|
||||
/// and it can be serialized into 32 bytes (256 bits).
|
||||
///
|
||||
/// ## Hash output consistency
|
||||
/// Functions [hash_elements()](Rpx256::hash_elements), [merge()](Rpx256::merge), and
|
||||
/// [merge_with_int()](Rpx256::merge_with_int) are internally consistent. That is, computing
|
||||
/// a hash for the same set of elements using these functions will always produce the same
|
||||
/// result. For example, merging two digests using [merge()](Rpx256::merge) will produce the
|
||||
/// same result as hashing 8 elements which make up these digests using
|
||||
/// [hash_elements()](Rpx256::hash_elements) function.
|
||||
///
|
||||
/// However, [hash()](Rpx256::hash) function is not consistent with functions mentioned above.
|
||||
/// For example, if we take two field elements, serialize them to bytes and hash them using
|
||||
/// [hash()](Rpx256::hash), the result will differ from the result obtained by hashing these
|
||||
/// elements directly using [hash_elements()](Rpx256::hash_elements) function. The reason for
|
||||
/// this difference is that [hash()](Rpx256::hash) function needs to be able to handle
|
||||
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
|
||||
/// deserialization procedure used by this function is different from the procedure used to
|
||||
/// deserialize valid field elements.
|
||||
///
|
||||
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
||||
/// to deserialize them into field elements and then hash them using
|
||||
/// [hash_elements()](Rpx256::hash_elements) function rather then hashing the serialized bytes
|
||||
/// using [hash()](Rpx256::hash) function.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Rpx256();
|
||||
|
||||
impl Hasher for Rpx256 {
|
||||
/// Rpx256 collision resistance is the same as the security level, that is 128-bits.
|
||||
///
|
||||
/// #### Collision resistance
|
||||
///
|
||||
/// However, our setup of the capacity registers might drop it to 126.
|
||||
///
|
||||
/// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69)
|
||||
const COLLISION_RESISTANCE: u32 = 128;
|
||||
|
||||
type Digest = RpxDigest;
|
||||
|
||||
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||
// initialize the state with zeroes
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
// set the capacity (first element) to a flag on whether or not the input length is evenly
|
||||
// divided by the rate. this will prevent collisions between padded and non-padded inputs,
|
||||
// and will rule out the need to perform an extra permutation in case of evenly divided
|
||||
// inputs.
|
||||
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0;
|
||||
if !is_rate_multiple {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// initialize a buffer to receive the little-endian elements.
|
||||
let mut buf = [0_u8; 8];
|
||||
|
||||
// iterate the chunks of bytes, creating a field element from each chunk and copying it
|
||||
// into the state.
|
||||
//
|
||||
// every time the rate range is filled, a permutation is performed. if the final value of
|
||||
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
|
||||
// additional permutation must be performed.
|
||||
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
|
||||
// the last element of the iteration may or may not be a full chunk. if it's not, then
|
||||
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`.
|
||||
// this will avoid collisions.
|
||||
if chunk.len() == BINARY_CHUNK_SIZE {
|
||||
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||
} else {
|
||||
buf.fill(0);
|
||||
buf[..chunk.len()].copy_from_slice(chunk);
|
||||
buf[chunk.len()] = 1;
|
||||
}
|
||||
|
||||
// set the current rate element to the input. since we take at most 7 bytes, we are
|
||||
// guaranteed that the inputs data will fit into a single field element.
|
||||
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf));
|
||||
|
||||
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
||||
// counter to the beginning of the range.
|
||||
if i == RATE_WIDTH - 1 {
|
||||
Self::apply_permutation(&mut state);
|
||||
0
|
||||
} else {
|
||||
i + 1
|
||||
}
|
||||
});
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation. we
|
||||
// don't need to apply any extra padding because the first capacity element contains a
|
||||
// flag indicating whether the input is evenly divisible by the rate.
|
||||
if i != 0 {
|
||||
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the rate as hash result.
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = Self::Digest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// apply the RPX permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
// initialize the state as follows:
|
||||
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element
|
||||
// and set the sixth rate element to 1.
|
||||
// - if the value doesn't fit into a single field element, split it into two field
|
||||
// elements, copy them into rate elements 5 and 6, and set the seventh rate element
|
||||
// to 1.
|
||||
// - set the first capacity element to 1
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||
if value < Felt::MODULUS {
|
||||
state[INPUT2_RANGE.start + 1] = ONE;
|
||||
} else {
|
||||
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
||||
state[INPUT2_RANGE.start + 2] = ONE;
|
||||
}
|
||||
|
||||
// common padding for both cases
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
|
||||
// apply the RPX permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementHasher for Rpx256 {
|
||||
type BaseField = Felt;
|
||||
|
||||
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
|
||||
// convert the elements into a list of base field elements
|
||||
let elements = E::slice_as_base_elements(elements);
|
||||
|
||||
// initialize state to all zeros, except for the first element of the capacity part, which
|
||||
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
if elements.len() % RATE_WIDTH != 0 {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// absorb elements into the state one by one until the rate portion of the state is filled
|
||||
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
||||
// elements have been absorbed
|
||||
let mut i = 0;
|
||||
for &element in elements.iter() {
|
||||
state[RATE_RANGE.start + i] = element;
|
||||
i += 1;
|
||||
if i % RATE_WIDTH == 0 {
|
||||
Self::apply_permutation(&mut state);
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation after
|
||||
// padding by appending a 1 followed by as many 0 as necessary to make the input length a
|
||||
// multiple of the RATE_WIDTH.
|
||||
if i > 0 {
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
i += 1;
|
||||
while i != RATE_WIDTH {
|
||||
state[RATE_RANGE.start + i] = ZERO;
|
||||
i += 1;
|
||||
}
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the state as hash result
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
// HASH FUNCTION IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Rpx256 {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
pub const STATE_WIDTH: usize = STATE_WIDTH;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
|
||||
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
|
||||
|
||||
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
|
||||
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
|
||||
|
||||
/// MDS matrix used for computing the linear layer in the (FB) and (E) rounds.
|
||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
|
||||
|
||||
/// Round constants added to the hasher state in the first half of the round.
|
||||
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
|
||||
|
||||
/// Round constants added to the hasher state in the second half of the round.
|
||||
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
|
||||
|
||||
// TRAIT PASS-THROUGH FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of the provided sequence of bytes.
|
||||
#[inline(always)]
|
||||
pub fn hash(bytes: &[u8]) -> RpxDigest {
|
||||
<Self as Hasher>::hash(bytes)
|
||||
}
|
||||
|
||||
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||
/// Merkle trees and verification of Merkle paths.
|
||||
#[inline(always)]
|
||||
pub fn merge(values: &[RpxDigest; 2]) -> RpxDigest {
|
||||
<Self as Hasher>::merge(values)
|
||||
}
|
||||
|
||||
/// Returns a hash of the provided field elements.
|
||||
#[inline(always)]
|
||||
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpxDigest {
|
||||
<Self as ElementHasher>::hash_elements(elements)
|
||||
}
|
||||
|
||||
// DOMAIN IDENTIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of two digests and a domain identifier.
|
||||
pub fn merge_in_domain(values: &[RpxDigest; 2], domain: Felt) -> RpxDigest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = RpxDigest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// set the second capacity element to the domain value. The first capacity element is used
|
||||
// for padding purposes.
|
||||
state[CAPACITY_RANGE.start + 1] = domain;
|
||||
|
||||
// apply the RPX permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
// RPX PERMUTATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Applies RPX permutation to the provided state.
|
||||
#[inline(always)]
|
||||
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
|
||||
Self::apply_fb_round(state, 0);
|
||||
Self::apply_ext_round(state, 1);
|
||||
Self::apply_fb_round(state, 2);
|
||||
Self::apply_ext_round(state, 3);
|
||||
Self::apply_fb_round(state, 4);
|
||||
Self::apply_ext_round(state, 5);
|
||||
Self::apply_final_round(state, 6);
|
||||
}
|
||||
|
||||
// RPX PERMUTATION ROUND FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// (FB) round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_fb_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_sbox(state, &ARK1[round]) {
|
||||
add_constants(state, &ARK1[round]);
|
||||
apply_sbox(state);
|
||||
}
|
||||
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
|
||||
add_constants(state, &ARK2[round]);
|
||||
apply_inv_sbox(state);
|
||||
}
|
||||
}
|
||||
|
||||
/// (E) round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_ext_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
// add constants
|
||||
add_constants(state, &ARK1[round]);
|
||||
|
||||
// decompose the state into 4 elements in the cubic extension field and apply the power 7
|
||||
// map to each of the elements
|
||||
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = *state;
|
||||
let ext0 = Self::exp7(CubicExtElement::new(s0, s1, s2));
|
||||
let ext1 = Self::exp7(CubicExtElement::new(s3, s4, s5));
|
||||
let ext2 = Self::exp7(CubicExtElement::new(s6, s7, s8));
|
||||
let ext3 = Self::exp7(CubicExtElement::new(s9, s10, s11));
|
||||
|
||||
// decompose the state back into 12 base field elements
|
||||
let arr_ext = [ext0, ext1, ext2, ext3];
|
||||
*state = CubicExtElement::slice_as_base_elements(&arr_ext)
|
||||
.try_into()
|
||||
.expect("shouldn't fail");
|
||||
}
|
||||
|
||||
/// (M) round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_final_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
apply_mds(state);
|
||||
add_constants(state, &ARK1[round]);
|
||||
}
|
||||
|
||||
/// Computes an exponentiation to the power 7 in cubic extension field
|
||||
#[inline(always)]
|
||||
pub fn exp7(x: CubeExtension<Felt>) -> CubeExtension<Felt> {
|
||||
let x2 = x.square();
|
||||
let x4 = x2.square();
|
||||
|
||||
let x3 = x2 * x;
|
||||
x3 * x4
|
||||
}
|
||||
}
|
||||
9
src/hash/rescue/tests.rs
Normal file
9
src/hash/rescue/tests.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
use super::{Felt, FieldElement, ALPHA, INV_ALPHA};
|
||||
use rand_utils::rand_value;
|
||||
|
||||
#[test]
|
||||
fn test_alphas() {
|
||||
let e: Felt = Felt::new(rand_value());
|
||||
let e_exp = e.exp(ALPHA);
|
||||
assert_eq!(e, e_exp.exp(INV_ALPHA));
|
||||
}
|
||||
@@ -1,165 +0,0 @@
|
||||
use super::{Digest, Felt, StarkField, DIGEST_SIZE, ZERO};
|
||||
use crate::utils::{
|
||||
string::String, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
|
||||
};
|
||||
use core::{cmp::Ordering, ops::Deref};
|
||||
|
||||
// DIGEST TRAIT IMPLEMENTATIONS
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct RpoDigest([Felt; DIGEST_SIZE]);
|
||||
|
||||
impl RpoDigest {
|
||||
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
|
||||
pub fn as_elements(&self) -> &[Felt] {
|
||||
self.as_ref()
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> [u8; 32] {
|
||||
<Self as Digest>::as_bytes(self)
|
||||
}
|
||||
|
||||
pub fn digests_as_elements<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
|
||||
where
|
||||
I: Iterator<Item = &'a Self>,
|
||||
{
|
||||
digests.flat_map(|d| d.0.iter())
|
||||
}
|
||||
}
|
||||
|
||||
impl Digest for RpoDigest {
|
||||
fn as_bytes(&self) -> [u8; 32] {
|
||||
let mut result = [0; 32];
|
||||
|
||||
result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes());
|
||||
result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes());
|
||||
result[16..24].copy_from_slice(&self.0[2].as_int().to_le_bytes());
|
||||
result[24..].copy_from_slice(&self.0[3].as_int().to_le_bytes());
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for RpoDigest {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.as_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RpoDigest {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE];
|
||||
for inner in inner.iter_mut() {
|
||||
let e = source.read_u64()?;
|
||||
if e >= Felt::MODULUS {
|
||||
return Err(DeserializationError::InvalidValue(String::from(
|
||||
"Value not in the appropriate range",
|
||||
)));
|
||||
}
|
||||
*inner = Felt::new(e);
|
||||
}
|
||||
|
||||
Ok(Self(inner))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[Felt; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for [u8; 32] {
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
value.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for [u8; 32] {
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
value.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for RpoDigest {
|
||||
type Target = [Felt; DIGEST_SIZE];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for RpoDigest {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
// compare the inner u64 of both elements.
|
||||
//
|
||||
// it will iterate the elements and will return the first computation different than
|
||||
// `Equal`. Otherwise, the ordering is equal.
|
||||
//
|
||||
// the endianness is irrelevant here because since, this being a cryptographically secure
|
||||
// hash computation, the digest shouldn't have any ordered property of its input.
|
||||
//
|
||||
// finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a
|
||||
// montgomery reduction for every limb. that is safe because every inner element of the
|
||||
// digest is guaranteed to be in its canonical form (that is, `x in [0,p)`).
|
||||
self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold(
|
||||
Ordering::Equal,
|
||||
|ord, (a, b)| match ord {
|
||||
Ordering::Equal => a.cmp(&b),
|
||||
_ => ord,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for RpoDigest {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::{Deserializable, Felt, RpoDigest, Serializable};
|
||||
use crate::utils::SliceReader;
|
||||
use rand_utils::rand_value;
|
||||
|
||||
#[test]
|
||||
fn digest_serialization() {
|
||||
let e1 = Felt::new(rand_value());
|
||||
let e2 = Felt::new(rand_value());
|
||||
let e3 = Felt::new(rand_value());
|
||||
let e4 = Felt::new(rand_value());
|
||||
|
||||
let d1 = RpoDigest([e1, e2, e3, e4]);
|
||||
|
||||
let mut bytes = vec![];
|
||||
d1.write_into(&mut bytes);
|
||||
assert_eq!(32, bytes.len());
|
||||
|
||||
let mut reader = SliceReader::new(&bytes);
|
||||
let d2 = RpoDigest::read_from(&mut reader).unwrap();
|
||||
|
||||
assert_eq!(d1, d2);
|
||||
}
|
||||
}
|
||||
@@ -1,845 +0,0 @@
|
||||
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO};
|
||||
use core::{convert::TryInto, ops::Range};
|
||||
|
||||
mod digest;
|
||||
pub use digest::RpoDigest;
|
||||
|
||||
mod mds_freq;
|
||||
use mds_freq::mds_multiply_freq;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
const STATE_WIDTH: usize = 12;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11.
|
||||
const RATE_RANGE: Range<usize> = 4..12;
|
||||
const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
|
||||
|
||||
const INPUT1_RANGE: Range<usize> = 4..8;
|
||||
const INPUT2_RANGE: Range<usize> = 8..12;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
const CAPACITY_RANGE: Range<usize> = 0..4;
|
||||
|
||||
/// The output of the hash function is a digest which consists of 4 field elements or 32 bytes.
|
||||
///
|
||||
/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the
|
||||
/// rate portion).
|
||||
const DIGEST_RANGE: Range<usize> = 4..8;
|
||||
const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start;
|
||||
|
||||
/// The number of rounds is set to 7 to target 128-bit security level
|
||||
const NUM_ROUNDS: usize = 7;
|
||||
|
||||
/// The number of byte chunks defining a field element when hashing a sequence of bytes
|
||||
const BINARY_CHUNK_SIZE: usize = 7;
|
||||
|
||||
/// S-Box and Inverse S-Box powers;
|
||||
///
|
||||
/// The constants are defined for tests only because the exponentiations in the code are unrolled
|
||||
/// for efficiency reasons.
|
||||
#[cfg(test)]
|
||||
const ALPHA: u64 = 7;
|
||||
#[cfg(test)]
|
||||
const INV_ALPHA: u64 = 10540996611094048183;
|
||||
|
||||
// HASHER IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
/// Implementation of the Rescue Prime Optimized hash function with 256-bit output.
|
||||
///
|
||||
/// The hash function is implemented according to the Rescue Prime Optimized
|
||||
/// [specifications](https://eprint.iacr.org/2022/1577)
|
||||
///
|
||||
/// The parameters used to instantiate the function are:
|
||||
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
|
||||
/// * State width: 12 field elements.
|
||||
/// * Capacity size: 4 field elements.
|
||||
/// * Number of founds: 7.
|
||||
/// * S-Box degree: 7.
|
||||
///
|
||||
/// The above parameters target 128-bit security level. The digest consists of four field elements
|
||||
/// and it can be serialized into 32 bytes (256 bits).
|
||||
///
|
||||
/// ## Hash output consistency
|
||||
/// Functions [hash_elements()](Rpo256::hash_elements), [merge()](Rpo256::merge), and
|
||||
/// [merge_with_int()](Rpo256::merge_with_int) are internally consistent. That is, computing
|
||||
/// a hash for the same set of elements using these functions will always produce the same
|
||||
/// result. For example, merging two digests using [merge()](Rpo256::merge) will produce the
|
||||
/// same result as hashing 8 elements which make up these digests using
|
||||
/// [hash_elements()](Rpo256::hash_elements) function.
|
||||
///
|
||||
/// However, [hash()](Rpo256::hash) function is not consistent with functions mentioned above.
|
||||
/// For example, if we take two field elements, serialize them to bytes and hash them using
|
||||
/// [hash()](Rpo256::hash), the result will differ from the result obtained by hashing these
|
||||
/// elements directly using [hash_elements()](Rpo256::hash_elements) function. The reason for
|
||||
/// this difference is that [hash()](Rpo256::hash) function needs to be able to handle
|
||||
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
|
||||
/// deserialization procedure used by this function is different from the procedure used to
|
||||
/// deserialize valid field elements.
|
||||
///
|
||||
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
||||
/// to deserialize them into field elements and then hash them using
|
||||
/// [hash_elements()](Rpo256::hash_elements) function rather then hashing the serialized bytes
|
||||
/// using [hash()](Rpo256::hash) function.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Rpo256();
|
||||
|
||||
impl Hasher for Rpo256 {
|
||||
/// Rpo256 collision resistance is the same as the security level, that is 128-bits.
|
||||
///
|
||||
/// #### Collision resistance
|
||||
///
|
||||
/// However, our setup of the capacity registers might drop it to 126.
|
||||
///
|
||||
/// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69)
|
||||
const COLLISION_RESISTANCE: u32 = 128;
|
||||
|
||||
type Digest = RpoDigest;
|
||||
|
||||
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||
// initialize the state with zeroes
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
// set the capacity (first element) to a flag on whether or not the input length is evenly
|
||||
// divided by the rate. this will prevent collisions between padded and non-padded inputs,
|
||||
// and will rule out the need to perform an extra permutation in case of evenly divided
|
||||
// inputs.
|
||||
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0;
|
||||
if !is_rate_multiple {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// initialize a buffer to receive the little-endian elements.
|
||||
let mut buf = [0_u8; 8];
|
||||
|
||||
// iterate the chunks of bytes, creating a field element from each chunk and copying it
|
||||
// into the state.
|
||||
//
|
||||
// every time the rate range is filled, a permutation is performed. if the final value of
|
||||
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
|
||||
// additional permutation must be performed.
|
||||
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
|
||||
// the last element of the iteration may or may not be a full chunk. if it's not, then
|
||||
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`.
|
||||
// this will avoid collisions.
|
||||
if chunk.len() == BINARY_CHUNK_SIZE {
|
||||
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||
} else {
|
||||
buf.fill(0);
|
||||
buf[..chunk.len()].copy_from_slice(chunk);
|
||||
buf[chunk.len()] = 1;
|
||||
}
|
||||
|
||||
// set the current rate element to the input. since we take at most 7 bytes, we are
|
||||
// guaranteed that the inputs data will fit into a single field element.
|
||||
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf));
|
||||
|
||||
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
||||
// counter to the beginning of the range.
|
||||
if i == RATE_WIDTH - 1 {
|
||||
Self::apply_permutation(&mut state);
|
||||
0
|
||||
} else {
|
||||
i + 1
|
||||
}
|
||||
});
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we
|
||||
// don't need to apply any extra padding because the first capacity element containts a
|
||||
// flag indicating whether the input is evenly divisible by the rate.
|
||||
if i != 0 {
|
||||
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the rate as hash result.
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = Self::Digest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
// initialize the state as follows:
|
||||
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element
|
||||
// and set the sixth rate element to 1.
|
||||
// - if the value doesn't fit into a single field element, split it into two field
|
||||
// elements, copy them into rate elements 5 and 6, and set the seventh rate element
|
||||
// to 1.
|
||||
// - set the first capacity element to 1
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||
if value < Felt::MODULUS {
|
||||
state[INPUT2_RANGE.start + 1] = ONE;
|
||||
} else {
|
||||
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
||||
state[INPUT2_RANGE.start + 2] = ONE;
|
||||
}
|
||||
|
||||
// common padding for both cases
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementHasher for Rpo256 {
|
||||
type BaseField = Felt;
|
||||
|
||||
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
|
||||
// convert the elements into a list of base field elements
|
||||
let elements = E::slice_as_base_elements(elements);
|
||||
|
||||
// initialize state to all zeros, except for the first element of the capacity part, which
|
||||
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
if elements.len() % RATE_WIDTH != 0 {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// absorb elements into the state one by one until the rate portion of the state is filled
|
||||
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
||||
// elements have been absorbed
|
||||
let mut i = 0;
|
||||
for &element in elements.iter() {
|
||||
state[RATE_RANGE.start + i] = element;
|
||||
i += 1;
|
||||
if i % RATE_WIDTH == 0 {
|
||||
Self::apply_permutation(&mut state);
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after
|
||||
// padding by appending a 1 followed by as many 0 as necessary to make the input length a
|
||||
// multiple of the RATE_WIDTH.
|
||||
if i > 0 {
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
i += 1;
|
||||
while i != RATE_WIDTH {
|
||||
state[RATE_RANGE.start + i] = ZERO;
|
||||
i += 1;
|
||||
}
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the state as hash result
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
// HASH FUNCTION IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Rpo256 {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// The number of rounds is set to 7 to target 128-bit security level.
|
||||
pub const NUM_ROUNDS: usize = NUM_ROUNDS;
|
||||
|
||||
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
pub const STATE_WIDTH: usize = STATE_WIDTH;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
|
||||
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
|
||||
|
||||
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
|
||||
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
|
||||
|
||||
/// MDS matrix used for computing the linear layer in a RPO round.
|
||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
|
||||
|
||||
/// Round constants added to the hasher state in the first half of the RPO round.
|
||||
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
|
||||
|
||||
/// Round constants added to the hasher state in the second half of the RPO round.
|
||||
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
|
||||
|
||||
// TRAIT PASS-THROUGH FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of the provided sequence of bytes.
|
||||
#[inline(always)]
|
||||
pub fn hash(bytes: &[u8]) -> RpoDigest {
|
||||
<Self as Hasher>::hash(bytes)
|
||||
}
|
||||
|
||||
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||
/// Merkle trees and verification of Merkle paths.
|
||||
#[inline(always)]
|
||||
pub fn merge(values: &[RpoDigest; 2]) -> RpoDigest {
|
||||
<Self as Hasher>::merge(values)
|
||||
}
|
||||
|
||||
/// Returns a hash of the provided field elements.
|
||||
#[inline(always)]
|
||||
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpoDigest {
|
||||
<Self as ElementHasher>::hash_elements(elements)
|
||||
}
|
||||
|
||||
// DOMAIN IDENTIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of two digests and a domain identifier.
|
||||
pub fn merge_in_domain(values: &[RpoDigest; 2], domain: Felt) -> RpoDigest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = RpoDigest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// set the second capacity element to the domain value. The first capacity element is used
|
||||
// for padding purposes.
|
||||
state[CAPACITY_RANGE.start + 1] = domain;
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
// RESCUE PERMUTATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Applies RPO permutation to the provided state.
|
||||
#[inline(always)]
|
||||
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
|
||||
for i in 0..NUM_ROUNDS {
|
||||
Self::apply_round(state, i);
|
||||
}
|
||||
}
|
||||
|
||||
/// RPO round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
// apply first half of RPO round
|
||||
Self::apply_mds(state);
|
||||
Self::add_constants(state, &ARK1[round]);
|
||||
Self::apply_sbox(state);
|
||||
|
||||
// apply second half of RPO round
|
||||
Self::apply_mds(state);
|
||||
Self::add_constants(state, &ARK2[round]);
|
||||
Self::apply_inv_sbox(state);
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_mds(state: &mut [Felt; STATE_WIDTH]) {
|
||||
let mut result = [ZERO; STATE_WIDTH];
|
||||
|
||||
// Using the linearity of the operations we can split the state into a low||high decomposition
|
||||
// and operate on each with no overflow and then combine/reduce the result to a field element.
|
||||
// The no overflow is guaranteed by the fact that the MDS matrix is a small powers of two in
|
||||
// frequency domain.
|
||||
let mut state_l = [0u64; STATE_WIDTH];
|
||||
let mut state_h = [0u64; STATE_WIDTH];
|
||||
|
||||
for r in 0..STATE_WIDTH {
|
||||
let s = state[r].inner();
|
||||
state_h[r] = s >> 32;
|
||||
state_l[r] = (s as u32) as u64;
|
||||
}
|
||||
|
||||
let state_h = mds_multiply_freq(state_h);
|
||||
let state_l = mds_multiply_freq(state_l);
|
||||
|
||||
for r in 0..STATE_WIDTH {
|
||||
let s = state_l[r] as u128 + ((state_h[r] as u128) << 32);
|
||||
let s_hi = (s >> 64) as u64;
|
||||
let s_lo = s as u64;
|
||||
let z = (s_hi << 32) - s_hi;
|
||||
let (res, over) = s_lo.overflowing_add(z);
|
||||
|
||||
result[r] = Felt::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64));
|
||||
}
|
||||
*state = result;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) {
|
||||
state.iter_mut().zip(ark).for_each(|(s, &k)| *s += k);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||
state[0] = state[0].exp7();
|
||||
state[1] = state[1].exp7();
|
||||
state[2] = state[2].exp7();
|
||||
state[3] = state[3].exp7();
|
||||
state[4] = state[4].exp7();
|
||||
state[5] = state[5].exp7();
|
||||
state[6] = state[6].exp7();
|
||||
state[7] = state[7].exp7();
|
||||
state[8] = state[8].exp7();
|
||||
state[9] = state[9].exp7();
|
||||
state[10] = state[10].exp7();
|
||||
state[11] = state[11].exp7();
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||
// compute base^10540996611094048183 using 72 multiplications per array element
|
||||
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
|
||||
|
||||
// compute base^10
|
||||
let mut t1 = *state;
|
||||
t1.iter_mut().for_each(|t| *t = t.square());
|
||||
|
||||
// compute base^100
|
||||
let mut t2 = t1;
|
||||
t2.iter_mut().for_each(|t| *t = t.square());
|
||||
|
||||
// compute base^100100
|
||||
let t3 = Self::exp_acc::<Felt, STATE_WIDTH, 3>(t2, t2);
|
||||
|
||||
// compute base^100100100100
|
||||
let t4 = Self::exp_acc::<Felt, STATE_WIDTH, 6>(t3, t3);
|
||||
|
||||
// compute base^100100100100100100100100
|
||||
let t5 = Self::exp_acc::<Felt, STATE_WIDTH, 12>(t4, t4);
|
||||
|
||||
// compute base^100100100100100100100100100100
|
||||
let t6 = Self::exp_acc::<Felt, STATE_WIDTH, 6>(t5, t3);
|
||||
|
||||
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||
let t7 = Self::exp_acc::<Felt, STATE_WIDTH, 31>(t6, t6);
|
||||
|
||||
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||
for (i, s) in state.iter_mut().enumerate() {
|
||||
let a = (t7[i].square() * t6[i]).square().square();
|
||||
let b = t1[i] * t2[i] * *s;
|
||||
*s = a * b;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn exp_acc<B: StarkField, const N: usize, const M: usize>(
|
||||
base: [B; N],
|
||||
tail: [B; N],
|
||||
) -> [B; N] {
|
||||
let mut result = base;
|
||||
for _ in 0..M {
|
||||
result.iter_mut().for_each(|r| *r = r.square());
|
||||
}
|
||||
result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t);
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
// MDS
|
||||
// ================================================================================================
|
||||
/// RPO MDS matrix
|
||||
const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = [
|
||||
[
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
],
|
||||
[
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
],
|
||||
[
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
],
|
||||
[
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
],
|
||||
[
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
],
|
||||
[
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
],
|
||||
[
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
],
|
||||
[
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
],
|
||||
[
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
],
|
||||
[
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
],
|
||||
[
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
],
|
||||
[
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
],
|
||||
];
|
||||
|
||||
// ROUND CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Rescue round constants;
|
||||
/// computed as in [specifications](https://github.com/ASDiscreteMathematics/rpo)
|
||||
///
|
||||
/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the
|
||||
/// first half of RPO round, and ARK2 contains constants for the second half of RPO round.
|
||||
const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(5789762306288267392),
|
||||
Felt::new(6522564764413701783),
|
||||
Felt::new(17809893479458208203),
|
||||
Felt::new(107145243989736508),
|
||||
Felt::new(6388978042437517382),
|
||||
Felt::new(15844067734406016715),
|
||||
Felt::new(9975000513555218239),
|
||||
Felt::new(3344984123768313364),
|
||||
Felt::new(9959189626657347191),
|
||||
Felt::new(12960773468763563665),
|
||||
Felt::new(9602914297752488475),
|
||||
Felt::new(16657542370200465908),
|
||||
],
|
||||
[
|
||||
Felt::new(12987190162843096997),
|
||||
Felt::new(653957632802705281),
|
||||
Felt::new(4441654670647621225),
|
||||
Felt::new(4038207883745915761),
|
||||
Felt::new(5613464648874830118),
|
||||
Felt::new(13222989726778338773),
|
||||
Felt::new(3037761201230264149),
|
||||
Felt::new(16683759727265180203),
|
||||
Felt::new(8337364536491240715),
|
||||
Felt::new(3227397518293416448),
|
||||
Felt::new(8110510111539674682),
|
||||
Felt::new(2872078294163232137),
|
||||
],
|
||||
[
|
||||
Felt::new(18072785500942327487),
|
||||
Felt::new(6200974112677013481),
|
||||
Felt::new(17682092219085884187),
|
||||
Felt::new(10599526828986756440),
|
||||
Felt::new(975003873302957338),
|
||||
Felt::new(8264241093196931281),
|
||||
Felt::new(10065763900435475170),
|
||||
Felt::new(2181131744534710197),
|
||||
Felt::new(6317303992309418647),
|
||||
Felt::new(1401440938888741532),
|
||||
Felt::new(8884468225181997494),
|
||||
Felt::new(13066900325715521532),
|
||||
],
|
||||
[
|
||||
Felt::new(5674685213610121970),
|
||||
Felt::new(5759084860419474071),
|
||||
Felt::new(13943282657648897737),
|
||||
Felt::new(1352748651966375394),
|
||||
Felt::new(17110913224029905221),
|
||||
Felt::new(1003883795902368422),
|
||||
Felt::new(4141870621881018291),
|
||||
Felt::new(8121410972417424656),
|
||||
Felt::new(14300518605864919529),
|
||||
Felt::new(13712227150607670181),
|
||||
Felt::new(17021852944633065291),
|
||||
Felt::new(6252096473787587650),
|
||||
],
|
||||
[
|
||||
Felt::new(4887609836208846458),
|
||||
Felt::new(3027115137917284492),
|
||||
Felt::new(9595098600469470675),
|
||||
Felt::new(10528569829048484079),
|
||||
Felt::new(7864689113198939815),
|
||||
Felt::new(17533723827845969040),
|
||||
Felt::new(5781638039037710951),
|
||||
Felt::new(17024078752430719006),
|
||||
Felt::new(109659393484013511),
|
||||
Felt::new(7158933660534805869),
|
||||
Felt::new(2955076958026921730),
|
||||
Felt::new(7433723648458773977),
|
||||
],
|
||||
[
|
||||
Felt::new(16308865189192447297),
|
||||
Felt::new(11977192855656444890),
|
||||
Felt::new(12532242556065780287),
|
||||
Felt::new(14594890931430968898),
|
||||
Felt::new(7291784239689209784),
|
||||
Felt::new(5514718540551361949),
|
||||
Felt::new(10025733853830934803),
|
||||
Felt::new(7293794580341021693),
|
||||
Felt::new(6728552937464861756),
|
||||
Felt::new(6332385040983343262),
|
||||
Felt::new(13277683694236792804),
|
||||
Felt::new(2600778905124452676),
|
||||
],
|
||||
[
|
||||
Felt::new(7123075680859040534),
|
||||
Felt::new(1034205548717903090),
|
||||
Felt::new(7717824418247931797),
|
||||
Felt::new(3019070937878604058),
|
||||
Felt::new(11403792746066867460),
|
||||
Felt::new(10280580802233112374),
|
||||
Felt::new(337153209462421218),
|
||||
Felt::new(13333398568519923717),
|
||||
Felt::new(3596153696935337464),
|
||||
Felt::new(8104208463525993784),
|
||||
Felt::new(14345062289456085693),
|
||||
Felt::new(17036731477169661256),
|
||||
],
|
||||
];
|
||||
|
||||
const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(6077062762357204287),
|
||||
Felt::new(15277620170502011191),
|
||||
Felt::new(5358738125714196705),
|
||||
Felt::new(14233283787297595718),
|
||||
Felt::new(13792579614346651365),
|
||||
Felt::new(11614812331536767105),
|
||||
Felt::new(14871063686742261166),
|
||||
Felt::new(10148237148793043499),
|
||||
Felt::new(4457428952329675767),
|
||||
Felt::new(15590786458219172475),
|
||||
Felt::new(10063319113072092615),
|
||||
Felt::new(14200078843431360086),
|
||||
],
|
||||
[
|
||||
Felt::new(6202948458916099932),
|
||||
Felt::new(17690140365333231091),
|
||||
Felt::new(3595001575307484651),
|
||||
Felt::new(373995945117666487),
|
||||
Felt::new(1235734395091296013),
|
||||
Felt::new(14172757457833931602),
|
||||
Felt::new(707573103686350224),
|
||||
Felt::new(15453217512188187135),
|
||||
Felt::new(219777875004506018),
|
||||
Felt::new(17876696346199469008),
|
||||
Felt::new(17731621626449383378),
|
||||
Felt::new(2897136237748376248),
|
||||
],
|
||||
[
|
||||
Felt::new(8023374565629191455),
|
||||
Felt::new(15013690343205953430),
|
||||
Felt::new(4485500052507912973),
|
||||
Felt::new(12489737547229155153),
|
||||
Felt::new(9500452585969030576),
|
||||
Felt::new(2054001340201038870),
|
||||
Felt::new(12420704059284934186),
|
||||
Felt::new(355990932618543755),
|
||||
Felt::new(9071225051243523860),
|
||||
Felt::new(12766199826003448536),
|
||||
Felt::new(9045979173463556963),
|
||||
Felt::new(12934431667190679898),
|
||||
],
|
||||
[
|
||||
Felt::new(18389244934624494276),
|
||||
Felt::new(16731736864863925227),
|
||||
Felt::new(4440209734760478192),
|
||||
Felt::new(17208448209698888938),
|
||||
Felt::new(8739495587021565984),
|
||||
Felt::new(17000774922218161967),
|
||||
Felt::new(13533282547195532087),
|
||||
Felt::new(525402848358706231),
|
||||
Felt::new(16987541523062161972),
|
||||
Felt::new(5466806524462797102),
|
||||
Felt::new(14512769585918244983),
|
||||
Felt::new(10973956031244051118),
|
||||
],
|
||||
[
|
||||
Felt::new(6982293561042362913),
|
||||
Felt::new(14065426295947720331),
|
||||
Felt::new(16451845770444974180),
|
||||
Felt::new(7139138592091306727),
|
||||
Felt::new(9012006439959783127),
|
||||
Felt::new(14619614108529063361),
|
||||
Felt::new(1394813199588124371),
|
||||
Felt::new(4635111139507788575),
|
||||
Felt::new(16217473952264203365),
|
||||
Felt::new(10782018226466330683),
|
||||
Felt::new(6844229992533662050),
|
||||
Felt::new(7446486531695178711),
|
||||
],
|
||||
[
|
||||
Felt::new(3736792340494631448),
|
||||
Felt::new(577852220195055341),
|
||||
Felt::new(6689998335515779805),
|
||||
Felt::new(13886063479078013492),
|
||||
Felt::new(14358505101923202168),
|
||||
Felt::new(7744142531772274164),
|
||||
Felt::new(16135070735728404443),
|
||||
Felt::new(12290902521256031137),
|
||||
Felt::new(12059913662657709804),
|
||||
Felt::new(16456018495793751911),
|
||||
Felt::new(4571485474751953524),
|
||||
Felt::new(17200392109565783176),
|
||||
],
|
||||
[
|
||||
Felt::new(17130398059294018733),
|
||||
Felt::new(519782857322261988),
|
||||
Felt::new(9625384390925085478),
|
||||
Felt::new(1664893052631119222),
|
||||
Felt::new(7629576092524553570),
|
||||
Felt::new(3485239601103661425),
|
||||
Felt::new(9755891797164033838),
|
||||
Felt::new(15218148195153269027),
|
||||
Felt::new(16460604813734957368),
|
||||
Felt::new(9643968136937729763),
|
||||
Felt::new(3611348709641382851),
|
||||
Felt::new(18256379591337759196),
|
||||
],
|
||||
];
|
||||
16
src/lib.rs
16
src/lib.rs
@@ -1,18 +1,23 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[cfg_attr(test, macro_use)]
|
||||
//#[cfg(not(feature = "std"))]
|
||||
//#[cfg_attr(test, macro_use)]
|
||||
extern crate alloc;
|
||||
|
||||
pub mod dsa;
|
||||
pub mod hash;
|
||||
pub mod merkle;
|
||||
pub mod rand;
|
||||
pub mod utils;
|
||||
pub mod gkr;
|
||||
|
||||
// RE-EXPORTS
|
||||
// ================================================================================================
|
||||
|
||||
pub use winter_crypto::{RandomCoin, RandomCoinError};
|
||||
pub use winter_math::{fields::f64::BaseElement as Felt, FieldElement, StarkField};
|
||||
pub use winter_math::{
|
||||
fields::{f64::BaseElement as Felt, CubeExtension, QuadExtension},
|
||||
FieldElement, StarkField,
|
||||
};
|
||||
|
||||
// TYPE ALIASES
|
||||
// ================================================================================================
|
||||
@@ -32,6 +37,9 @@ pub const ZERO: Felt = Felt::ZERO;
|
||||
/// Field element representing ONE in the Miden base filed.
|
||||
pub const ONE: Felt = Felt::ONE;
|
||||
|
||||
/// Array of field elements representing word of ZEROs in the Miden base field.
|
||||
pub const EMPTY_WORD: [Felt; 4] = [ZERO; WORD_SIZE];
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
|
||||
128
src/main.rs
Normal file
128
src/main.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
use clap::Parser;
|
||||
use miden_crypto::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
merkle::{MerkleError, TieredSmt},
|
||||
Felt, Word, ONE,
|
||||
};
|
||||
use rand_utils::rand_value;
|
||||
use std::time::Instant;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(
|
||||
name = "Benchmark",
|
||||
about = "Tiered SMT benchmark",
|
||||
version,
|
||||
rename_all = "kebab-case"
|
||||
)]
|
||||
pub struct BenchmarkCmd {
|
||||
/// Size of the tree
|
||||
#[clap(short = 's', long = "size")]
|
||||
size: u64,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
benchmark_tsmt();
|
||||
}
|
||||
|
||||
/// Run a benchmark for the Tiered SMT.
|
||||
pub fn benchmark_tsmt() {
|
||||
let args = BenchmarkCmd::parse();
|
||||
let tree_size = args.size;
|
||||
|
||||
// prepare the `leaves` vector for tree creation
|
||||
let mut entries = Vec::new();
|
||||
for i in 0..tree_size {
|
||||
let key = rand_value::<RpoDigest>();
|
||||
let value = [ONE, ONE, ONE, Felt::new(i)];
|
||||
entries.push((key, value));
|
||||
}
|
||||
|
||||
let mut tree = construction(entries, tree_size).unwrap();
|
||||
insertion(&mut tree, tree_size).unwrap();
|
||||
proof_generation(&mut tree, tree_size).unwrap();
|
||||
}
|
||||
|
||||
/// Runs the construction benchmark for the Tiered SMT, returning the constructed tree.
|
||||
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result<TieredSmt, MerkleError> {
|
||||
println!("Running a construction benchmark:");
|
||||
let now = Instant::now();
|
||||
let tree = TieredSmt::with_entries(entries)?;
|
||||
let elapsed = now.elapsed();
|
||||
println!(
|
||||
"Constructed a TSMT with {} key-value pairs in {:.3} seconds",
|
||||
size,
|
||||
elapsed.as_secs_f32(),
|
||||
);
|
||||
|
||||
// Count how many nodes end up at each tier
|
||||
let mut nodes_num_16_32_48 = (0, 0, 0);
|
||||
|
||||
tree.upper_leaf_nodes().for_each(|(index, _)| match index.depth() {
|
||||
16 => nodes_num_16_32_48.0 += 1,
|
||||
32 => nodes_num_16_32_48.1 += 1,
|
||||
48 => nodes_num_16_32_48.2 += 1,
|
||||
_ => unreachable!(),
|
||||
});
|
||||
|
||||
println!("Number of nodes on depth 16: {}", nodes_num_16_32_48.0);
|
||||
println!("Number of nodes on depth 32: {}", nodes_num_16_32_48.1);
|
||||
println!("Number of nodes on depth 48: {}", nodes_num_16_32_48.2);
|
||||
println!("Number of nodes on depth 64: {}\n", tree.bottom_leaves().count());
|
||||
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Runs the insertion benchmark for the Tiered SMT.
|
||||
pub fn insertion(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
|
||||
println!("Running an insertion benchmark:");
|
||||
|
||||
let mut insertion_times = Vec::new();
|
||||
|
||||
for i in 0..20 {
|
||||
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
|
||||
let test_value = [ONE, ONE, ONE, Felt::new(size + i)];
|
||||
|
||||
let now = Instant::now();
|
||||
tree.insert(test_key, test_value);
|
||||
let elapsed = now.elapsed();
|
||||
insertion_times.push(elapsed.as_secs_f32());
|
||||
}
|
||||
|
||||
println!(
|
||||
"An average insertion time measured by 20 inserts into a TSMT with {} key-value pairs is {:.3} milliseconds\n",
|
||||
size,
|
||||
// calculate the average by dividing by 20 and convert to milliseconds by multiplying by
|
||||
// 1000. As a result, we can only multiply by 50
|
||||
insertion_times.iter().sum::<f32>() * 50f32,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Runs the proof generation benchmark for the Tiered SMT.
|
||||
pub fn proof_generation(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
|
||||
println!("Running a proof generation benchmark:");
|
||||
|
||||
let mut insertion_times = Vec::new();
|
||||
|
||||
for i in 0..20 {
|
||||
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
|
||||
let test_value = [ONE, ONE, ONE, Felt::new(size + i)];
|
||||
tree.insert(test_key, test_value);
|
||||
|
||||
let now = Instant::now();
|
||||
let _proof = tree.prove(test_key);
|
||||
let elapsed = now.elapsed();
|
||||
insertion_times.push(elapsed.as_secs_f32());
|
||||
}
|
||||
|
||||
println!(
|
||||
"An average proving time measured by 20 value proofs in a TSMT with {} key-value pairs in {:.3} microseconds",
|
||||
size,
|
||||
// calculate the average by dividing by 20 and convert to microseconds by multiplying by
|
||||
// 1000000. As a result, we can only multiply by 50000
|
||||
insertion_times.iter().sum::<f32>() * 50000f32,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
156
src/merkle/delta.rs
Normal file
156
src/merkle/delta.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use super::{
|
||||
BTreeMap, KvMap, MerkleError, MerkleStore, NodeIndex, RpoDigest, StoreNode, Vec, Word,
|
||||
};
|
||||
use crate::utils::collections::Diff;
|
||||
|
||||
#[cfg(test)]
|
||||
use super::{super::ONE, Felt, SimpleSmt, EMPTY_WORD, ZERO};
|
||||
|
||||
// MERKLE STORE DELTA
|
||||
// ================================================================================================
|
||||
|
||||
/// [MerkleStoreDelta] stores a vector of ([RpoDigest], [MerkleTreeDelta]) tuples where the
|
||||
/// [RpoDigest] represents the root of the Merkle tree and [MerkleTreeDelta] represents the
|
||||
/// differences between the initial and final Merkle tree states.
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>);
|
||||
|
||||
// MERKLE TREE DELTA
|
||||
// ================================================================================================
|
||||
|
||||
/// [MerkleDelta] stores the differences between the initial and final Merkle tree states.
|
||||
///
|
||||
/// The differences are represented as follows:
|
||||
/// - depth: the depth of the merkle tree.
|
||||
/// - cleared_slots: indexes of slots where values were set to [ZERO; 4].
|
||||
/// - updated_slots: index-value pairs of slots where values were set to non [ZERO; 4] values.
|
||||
#[cfg(not(test))]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MerkleTreeDelta {
|
||||
depth: u8,
|
||||
cleared_slots: Vec<u64>,
|
||||
updated_slots: Vec<(u64, Word)>,
|
||||
}
|
||||
|
||||
impl MerkleTreeDelta {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
pub fn new(depth: u8) -> Self {
|
||||
Self {
|
||||
depth,
|
||||
cleared_slots: Vec::new(),
|
||||
updated_slots: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns the depth of the Merkle tree the [MerkleDelta] is associated with.
|
||||
pub fn depth(&self) -> u8 {
|
||||
self.depth
|
||||
}
|
||||
|
||||
/// Returns the indexes of slots where values were set to [ZERO; 4].
|
||||
pub fn cleared_slots(&self) -> &[u64] {
|
||||
&self.cleared_slots
|
||||
}
|
||||
|
||||
/// Returns the index-value pairs of slots where values were set to non [ZERO; 4] values.
|
||||
pub fn updated_slots(&self) -> &[(u64, Word)] {
|
||||
&self.updated_slots
|
||||
}
|
||||
|
||||
// MODIFIERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Adds a slot index to the list of cleared slots.
|
||||
pub fn add_cleared_slot(&mut self, index: u64) {
|
||||
self.cleared_slots.push(index);
|
||||
}
|
||||
|
||||
/// Adds a slot index and a value to the list of updated slots.
|
||||
pub fn add_updated_slot(&mut self, index: u64, value: Word) {
|
||||
self.updated_slots.push((index, value));
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts a [MerkleDelta] object by comparing the leaves of two Merkle trees specifies by
|
||||
/// their roots and depth.
|
||||
pub fn merkle_tree_delta<T: KvMap<RpoDigest, StoreNode>>(
|
||||
tree_root_1: RpoDigest,
|
||||
tree_root_2: RpoDigest,
|
||||
depth: u8,
|
||||
merkle_store: &MerkleStore<T>,
|
||||
) -> Result<MerkleTreeDelta, MerkleError> {
|
||||
if tree_root_1 == tree_root_2 {
|
||||
return Ok(MerkleTreeDelta::new(depth));
|
||||
}
|
||||
|
||||
let tree_1_leaves: BTreeMap<NodeIndex, RpoDigest> =
|
||||
merkle_store.non_empty_leaves(tree_root_1, depth).collect();
|
||||
let tree_2_leaves: BTreeMap<NodeIndex, RpoDigest> =
|
||||
merkle_store.non_empty_leaves(tree_root_2, depth).collect();
|
||||
let diff = tree_1_leaves.diff(&tree_2_leaves);
|
||||
|
||||
// TODO: Refactor this diff implementation to prevent allocation of both BTree and Vec.
|
||||
Ok(MerkleTreeDelta {
|
||||
depth,
|
||||
cleared_slots: diff.removed.into_iter().map(|index| index.value()).collect(),
|
||||
updated_slots: diff
|
||||
.updated
|
||||
.into_iter()
|
||||
.map(|(index, leaf)| (index.value(), *leaf))
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
// INTERNALS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
#[cfg(test)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MerkleTreeDelta {
|
||||
pub depth: u8,
|
||||
pub cleared_slots: Vec<u64>,
|
||||
pub updated_slots: Vec<(u64, Word)>,
|
||||
}
|
||||
|
||||
// MERKLE DELTA
|
||||
// ================================================================================================
|
||||
#[test]
|
||||
fn test_compute_merkle_delta() {
|
||||
let entries = vec![
|
||||
(10, [ZERO, ONE, Felt::new(2), Felt::new(3)]),
|
||||
(15, [Felt::new(4), Felt::new(5), Felt::new(6), Felt::new(7)]),
|
||||
(20, [Felt::new(8), Felt::new(9), Felt::new(10), Felt::new(11)]),
|
||||
(31, [Felt::new(12), Felt::new(13), Felt::new(14), Felt::new(15)]),
|
||||
];
|
||||
let simple_smt = SimpleSmt::with_leaves(30, entries.clone()).unwrap();
|
||||
let mut store: MerkleStore = (&simple_smt).into();
|
||||
let root = simple_smt.root();
|
||||
|
||||
// add a new node
|
||||
let new_value = [Felt::new(16), Felt::new(17), Felt::new(18), Felt::new(19)];
|
||||
let new_index = NodeIndex::new(simple_smt.depth(), 32).unwrap();
|
||||
let root = store.set_node(root, new_index, new_value.into()).unwrap().root;
|
||||
|
||||
// update an existing node
|
||||
let update_value = [Felt::new(20), Felt::new(21), Felt::new(22), Felt::new(23)];
|
||||
let update_idx = NodeIndex::new(simple_smt.depth(), entries[0].0).unwrap();
|
||||
let root = store.set_node(root, update_idx, update_value.into()).unwrap().root;
|
||||
|
||||
// remove a node
|
||||
let remove_idx = NodeIndex::new(simple_smt.depth(), entries[1].0).unwrap();
|
||||
let root = store.set_node(root, remove_idx, EMPTY_WORD.into()).unwrap().root;
|
||||
|
||||
let merkle_delta =
|
||||
merkle_tree_delta(simple_smt.root(), root, simple_smt.depth(), &store).unwrap();
|
||||
let expected_merkle_delta = MerkleTreeDelta {
|
||||
depth: simple_smt.depth(),
|
||||
cleared_slots: vec![remove_idx.value()],
|
||||
updated_slots: vec![(update_idx.value(), update_value), (new_index.value(), new_value)],
|
||||
};
|
||||
|
||||
assert_eq!(merkle_delta, expected_merkle_delta);
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::{Felt, RpoDigest, WORD_SIZE, ZERO};
|
||||
use super::{Felt, RpoDigest, EMPTY_WORD};
|
||||
use core::slice;
|
||||
|
||||
// EMPTY NODES SUBTREES
|
||||
@@ -10,12 +10,19 @@ pub struct EmptySubtreeRoots;
|
||||
impl EmptySubtreeRoots {
|
||||
/// Returns a static slice with roots of empty subtrees of a Merkle tree starting at the
|
||||
/// specified depth.
|
||||
pub const fn empty_hashes(depth: u8) -> &'static [RpoDigest] {
|
||||
let ptr = &EMPTY_SUBTREES[255 - depth as usize] as *const RpoDigest;
|
||||
pub const fn empty_hashes(tree_depth: u8) -> &'static [RpoDigest] {
|
||||
let ptr = &EMPTY_SUBTREES[255 - tree_depth as usize] as *const RpoDigest;
|
||||
// Safety: this is a static/constant array, so it will never be outlived. If we attempt to
|
||||
// use regular slices, this wouldn't be a `const` function, meaning we won't be able to use
|
||||
// the returned value for static/constant definitions.
|
||||
unsafe { slice::from_raw_parts(ptr, depth as usize + 1) }
|
||||
unsafe { slice::from_raw_parts(ptr, tree_depth as usize + 1) }
|
||||
}
|
||||
|
||||
/// Returns the node's digest for a sub-tree with all its leaves set to the empty word.
|
||||
pub const fn entry(tree_depth: u8, node_depth: u8) -> &'static RpoDigest {
|
||||
assert!(node_depth <= tree_depth);
|
||||
let pos = 255 - tree_depth + node_depth;
|
||||
&EMPTY_SUBTREES[pos as usize]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1550,7 +1557,7 @@ const EMPTY_SUBTREES: [RpoDigest; 256] = [
|
||||
Felt::new(0xd3ad9fb0cea61624),
|
||||
Felt::new(0x66ab5c684fbb8597),
|
||||
]),
|
||||
RpoDigest::new([ZERO; WORD_SIZE]),
|
||||
RpoDigest::new(EMPTY_WORD),
|
||||
];
|
||||
|
||||
#[test]
|
||||
@@ -1570,7 +1577,7 @@ fn all_depths_opens_to_zero() {
|
||||
assert_eq!(depth as usize + 1, subtree.len());
|
||||
|
||||
// assert the opening is zero
|
||||
let initial = RpoDigest::new([ZERO; WORD_SIZE]);
|
||||
let initial = RpoDigest::new(EMPTY_WORD);
|
||||
assert_eq!(initial, subtree.remove(0));
|
||||
|
||||
// compute every node of the path manually and compare with the output
|
||||
@@ -1583,3 +1590,16 @@ fn all_depths_opens_to_zero() {
|
||||
.for_each(|(x, computed)| assert_eq!(x, computed));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_entry() {
|
||||
// check the leaf is always the empty work
|
||||
for depth in 0..255 {
|
||||
assert_eq!(EmptySubtreeRoots::entry(depth, depth), &RpoDigest::new(EMPTY_WORD));
|
||||
}
|
||||
|
||||
// check the root matches the first element of empty_hashes
|
||||
for depth in 0..255 {
|
||||
assert_eq!(EmptySubtreeRoots::entry(depth, 0), &EmptySubtreeRoots::empty_hashes(depth)[0]);
|
||||
}
|
||||
}
|
||||
|
||||
58
src/merkle/error.rs
Normal file
58
src/merkle/error.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
use crate::{
|
||||
merkle::{MerklePath, NodeIndex, RpoDigest},
|
||||
utils::collections::Vec,
|
||||
};
|
||||
use core::fmt;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum MerkleError {
|
||||
ConflictingRoots(Vec<RpoDigest>),
|
||||
DepthTooSmall(u8),
|
||||
DepthTooBig(u64),
|
||||
DuplicateValuesForIndex(u64),
|
||||
DuplicateValuesForKey(RpoDigest),
|
||||
InvalidIndex { depth: u8, value: u64 },
|
||||
InvalidDepth { expected: u8, provided: u8 },
|
||||
InvalidSubtreeDepth { subtree_depth: u8, tree_depth: u8 },
|
||||
InvalidPath(MerklePath),
|
||||
InvalidNumEntries(usize),
|
||||
NodeNotInSet(NodeIndex),
|
||||
NodeNotInStore(RpoDigest, NodeIndex),
|
||||
NumLeavesNotPowerOfTwo(usize),
|
||||
RootNotInStore(RpoDigest),
|
||||
}
|
||||
|
||||
impl fmt::Display for MerkleError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use MerkleError::*;
|
||||
match self {
|
||||
ConflictingRoots(roots) => write!(f, "the merkle paths roots do not match {roots:?}"),
|
||||
DepthTooSmall(depth) => write!(f, "the provided depth {depth} is too small"),
|
||||
DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"),
|
||||
DuplicateValuesForIndex(key) => write!(f, "multiple values provided for key {key}"),
|
||||
DuplicateValuesForKey(key) => write!(f, "multiple values provided for key {key}"),
|
||||
InvalidIndex { depth, value } => {
|
||||
write!(f, "the index value {value} is not valid for the depth {depth}")
|
||||
}
|
||||
InvalidDepth { expected, provided } => {
|
||||
write!(f, "the provided depth {provided} is not valid for {expected}")
|
||||
}
|
||||
InvalidSubtreeDepth { subtree_depth, tree_depth } => {
|
||||
write!(f, "tried inserting a subtree of depth {subtree_depth} into a tree of depth {tree_depth}")
|
||||
}
|
||||
InvalidPath(_path) => write!(f, "the provided path is not valid"),
|
||||
InvalidNumEntries(max) => write!(f, "number of entries exceeded the maximum: {max}"),
|
||||
NodeNotInSet(index) => write!(f, "the node with index ({index}) is not in the set"),
|
||||
NodeNotInStore(hash, index) => {
|
||||
write!(f, "the node {hash:?} with index ({index}) is not in the store")
|
||||
}
|
||||
NumLeavesNotPowerOfTwo(leaves) => {
|
||||
write!(f, "the leaves count {leaves} is not a power of 2")
|
||||
}
|
||||
RootNotInStore(root) => write!(f, "the root {:?} is not in the store", root),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for MerkleError {}
|
||||
@@ -1,4 +1,6 @@
|
||||
use super::{Felt, MerkleError, RpoDigest, StarkField};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
use core::fmt::Display;
|
||||
|
||||
// NODE INDEX
|
||||
// ================================================================================================
|
||||
@@ -19,6 +21,7 @@ use super::{Felt, MerkleError, RpoDigest, StarkField};
|
||||
/// The root is represented by the pair $(0, 0)$, its left child is $(1, 0)$ and its right child
|
||||
/// $(1, 1)$.
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct NodeIndex {
|
||||
depth: u8,
|
||||
value: u64,
|
||||
@@ -40,6 +43,12 @@ impl NodeIndex {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new node index without checking its validity.
|
||||
pub const fn new_unchecked(depth: u8, value: u64) -> Self {
|
||||
debug_assert!((64 - value.leading_zeros()) <= depth as u32);
|
||||
Self { depth, value }
|
||||
}
|
||||
|
||||
/// Creates a new node index for testing purposes.
|
||||
///
|
||||
/// # Panics
|
||||
@@ -67,12 +76,26 @@ impl NodeIndex {
|
||||
Self { depth: 0, value: 0 }
|
||||
}
|
||||
|
||||
/// Computes the value of the sibling of the current node.
|
||||
pub fn sibling(mut self) -> Self {
|
||||
/// Computes sibling index of the current node.
|
||||
pub const fn sibling(mut self) -> Self {
|
||||
self.value ^= 1;
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns left child index of the current node.
|
||||
pub const fn left_child(mut self) -> Self {
|
||||
self.depth += 1;
|
||||
self.value <<= 1;
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns right child index of the current node.
|
||||
pub const fn right_child(mut self) -> Self {
|
||||
self.depth += 1;
|
||||
self.value = (self.value << 1) + 1;
|
||||
self
|
||||
}
|
||||
|
||||
// PROVIDERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
@@ -117,11 +140,42 @@ impl NodeIndex {
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Traverse one level towards the root, decrementing the depth by `1`.
|
||||
pub fn move_up(&mut self) -> &mut Self {
|
||||
/// Traverses one level towards the root, decrementing the depth by `1`.
|
||||
pub fn move_up(&mut self) {
|
||||
self.depth = self.depth.saturating_sub(1);
|
||||
self.value >>= 1;
|
||||
self
|
||||
}
|
||||
|
||||
/// Traverses towards the root until the specified depth is reached.
|
||||
///
|
||||
/// Assumes that the specified depth is smaller than the current depth.
|
||||
pub fn move_up_to(&mut self, depth: u8) {
|
||||
debug_assert!(depth < self.depth);
|
||||
let delta = self.depth.saturating_sub(depth);
|
||||
self.depth = self.depth.saturating_sub(delta);
|
||||
self.value >>= delta as u32;
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for NodeIndex {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
write!(f, "depth={}, value={}", self.depth, self.value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for NodeIndex {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_u8(self.depth);
|
||||
target.write_u64(self.value);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for NodeIndex {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let depth = source.read_u8()?;
|
||||
let value = source.read_u64()?;
|
||||
NodeIndex::new(depth, value)
|
||||
.map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,13 +187,20 @@ mod tests {
|
||||
#[test]
|
||||
fn test_node_index_value_too_high() {
|
||||
assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
|
||||
match NodeIndex::new(0, 1) {
|
||||
Err(MerkleError::InvalidIndex { depth, value }) => {
|
||||
assert_eq!(depth, 0);
|
||||
assert_eq!(value, 1);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let err = NodeIndex::new(0, 1).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 0, value: 1 });
|
||||
|
||||
assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
|
||||
let err = NodeIndex::new(1, 2).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 1, value: 2 });
|
||||
|
||||
assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
|
||||
let err = NodeIndex::new(2, 4).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 2, value: 4 });
|
||||
|
||||
assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
|
||||
let err = NodeIndex::new(3, 8).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 3, value: 8 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -154,7 +215,7 @@ mod tests {
|
||||
if value > (1 << depth) { // round up
|
||||
depth += 1;
|
||||
}
|
||||
NodeIndex::new(depth, value.into()).unwrap()
|
||||
NodeIndex::new(depth, value).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
use super::{
|
||||
Felt, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word,
|
||||
};
|
||||
use crate::{
|
||||
utils::{string::String, uninit_vector, word_to_hex},
|
||||
FieldElement,
|
||||
};
|
||||
use core::{fmt, slice};
|
||||
use super::{InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word};
|
||||
use crate::utils::{string::String, uninit_vector, word_to_hex};
|
||||
use core::{fmt, ops::Deref, slice};
|
||||
use winter_math::log2;
|
||||
|
||||
// MERKLE TREE
|
||||
@@ -13,8 +8,9 @@ use winter_math::log2;
|
||||
|
||||
/// A fully-balanced binary Merkle tree (i.e., a tree where the number of leaves is a power of two).
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MerkleTree {
|
||||
nodes: Vec<Word>,
|
||||
nodes: Vec<RpoDigest>,
|
||||
}
|
||||
|
||||
impl MerkleTree {
|
||||
@@ -24,7 +20,11 @@ impl MerkleTree {
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the number of leaves is smaller than two or is not a power of two.
|
||||
pub fn new(leaves: Vec<Word>) -> Result<Self, MerkleError> {
|
||||
pub fn new<T>(leaves: T) -> Result<Self, MerkleError>
|
||||
where
|
||||
T: AsRef<[Word]>,
|
||||
{
|
||||
let leaves = leaves.as_ref();
|
||||
let n = leaves.len();
|
||||
if n <= 1 {
|
||||
return Err(MerkleError::DepthTooSmall(n as u8));
|
||||
@@ -34,10 +34,12 @@ impl MerkleTree {
|
||||
|
||||
// create un-initialized vector to hold all tree nodes
|
||||
let mut nodes = unsafe { uninit_vector(2 * n) };
|
||||
nodes[0] = [Felt::ZERO; 4];
|
||||
nodes[0] = RpoDigest::default();
|
||||
|
||||
// copy leaves into the second part of the nodes vector
|
||||
nodes[n..].copy_from_slice(&leaves);
|
||||
nodes[n..].iter_mut().zip(leaves).for_each(|(node, leaf)| {
|
||||
*node = RpoDigest::from(*leaf);
|
||||
});
|
||||
|
||||
// re-interpret nodes as an array of two nodes fused together
|
||||
// Safety: `nodes` will never move here as it is not bound to an external lifetime (i.e.
|
||||
@@ -47,7 +49,7 @@ impl MerkleTree {
|
||||
|
||||
// calculate all internal tree nodes
|
||||
for i in (1..n).rev() {
|
||||
nodes[i] = Rpo256::merge(&pairs[i]).into();
|
||||
nodes[i] = Rpo256::merge(&pairs[i]);
|
||||
}
|
||||
|
||||
Ok(Self { nodes })
|
||||
@@ -57,7 +59,7 @@ impl MerkleTree {
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the root of this Merkle tree.
|
||||
pub fn root(&self) -> Word {
|
||||
pub fn root(&self) -> RpoDigest {
|
||||
self.nodes[1]
|
||||
}
|
||||
|
||||
@@ -74,7 +76,7 @@ impl MerkleTree {
|
||||
/// Returns an error if:
|
||||
/// * The specified depth is greater than the depth of the tree.
|
||||
/// * The specified index is not valid for the specified depth.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
} else if index.depth() > self.depth() {
|
||||
@@ -114,6 +116,32 @@ impl MerkleTree {
|
||||
Ok(path.into())
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over the leaves of this [MerkleTree].
|
||||
pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
|
||||
let leaves_start = self.nodes.len() / 2;
|
||||
self.nodes
|
||||
.iter()
|
||||
.skip(leaves_start)
|
||||
.enumerate()
|
||||
.map(|(i, v)| (i as u64, v.deref()))
|
||||
}
|
||||
|
||||
/// Returns n iterator over every inner node of this [MerkleTree].
|
||||
///
|
||||
/// The iterator order is unspecified.
|
||||
pub fn inner_nodes(&self) -> InnerNodeIterator {
|
||||
InnerNodeIterator {
|
||||
nodes: &self.nodes,
|
||||
index: 1, // index 0 is just padding, start at 1
|
||||
}
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Replaces the leaf at the specified index with the provided value.
|
||||
///
|
||||
/// # Errors
|
||||
@@ -137,27 +165,37 @@ impl MerkleTree {
|
||||
|
||||
// update the current node
|
||||
let pos = index.to_scalar_index() as usize;
|
||||
self.nodes[pos] = value;
|
||||
self.nodes[pos] = value.into();
|
||||
|
||||
// traverse to the root, updating each node with the merged values of its parents
|
||||
for _ in 0..index.depth() {
|
||||
index.move_up();
|
||||
let pos = index.to_scalar_index() as usize;
|
||||
let value = Rpo256::merge(&pairs[pos]).into();
|
||||
let value = Rpo256::merge(&pairs[pos]);
|
||||
self.nodes[pos] = value;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns n iterator over every inner node of this [MerkleTree].
|
||||
///
|
||||
/// The iterator order is unspecified.
|
||||
pub fn inner_nodes(&self) -> InnerNodeIterator<'_> {
|
||||
InnerNodeIterator {
|
||||
nodes: &self.nodes,
|
||||
index: 1, // index 0 is just padding, start at 1
|
||||
}
|
||||
// CONVERSIONS
|
||||
// ================================================================================================
|
||||
|
||||
impl TryFrom<&[Word]> for MerkleTree {
|
||||
type Error = MerkleError;
|
||||
|
||||
fn try_from(value: &[Word]) -> Result<Self, Self::Error> {
|
||||
MerkleTree::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[RpoDigest]> for MerkleTree {
|
||||
type Error = MerkleError;
|
||||
|
||||
fn try_from(value: &[RpoDigest]) -> Result<Self, Self::Error> {
|
||||
let value: Vec<Word> = value.iter().map(|v| *v.deref()).collect();
|
||||
MerkleTree::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,7 +206,7 @@ impl MerkleTree {
|
||||
///
|
||||
/// Use this to extract the data of the tree, there is no guarantee on the order of the elements.
|
||||
pub struct InnerNodeIterator<'a> {
|
||||
nodes: &'a Vec<Word>,
|
||||
nodes: &'a Vec<RpoDigest>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
@@ -246,13 +284,17 @@ pub fn path_to_text(path: &MerklePath) -> Result<String, fmt::Error> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::merkle::{int_to_node, InnerNodeInfo};
|
||||
use crate::{
|
||||
merkle::{digests_to_words, int_to_leaf, int_to_node, InnerNodeInfo},
|
||||
Felt, Word, WORD_SIZE,
|
||||
};
|
||||
use core::mem::size_of;
|
||||
use proptest::prelude::*;
|
||||
|
||||
const LEAVES4: [Word; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
|
||||
const LEAVES4: [RpoDigest; WORD_SIZE] =
|
||||
[int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
|
||||
|
||||
const LEAVES8: [Word; 8] = [
|
||||
const LEAVES8: [RpoDigest; 8] = [
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
int_to_node(3),
|
||||
@@ -265,7 +307,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn build_merkle_tree() {
|
||||
let tree = super::MerkleTree::new(LEAVES4.to_vec()).unwrap();
|
||||
let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
|
||||
assert_eq!(8, tree.nodes.len());
|
||||
|
||||
// leaves were copied correctly
|
||||
@@ -284,7 +326,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn get_leaf() {
|
||||
let tree = super::MerkleTree::new(LEAVES4.to_vec()).unwrap();
|
||||
let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
|
||||
|
||||
// check depth 2
|
||||
assert_eq!(LEAVES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
|
||||
@@ -301,7 +343,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn get_path() {
|
||||
let tree = super::MerkleTree::new(LEAVES4.to_vec()).unwrap();
|
||||
let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
|
||||
|
||||
let (_, node2, node3) = compute_internal_nodes();
|
||||
|
||||
@@ -318,12 +360,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn update_leaf() {
|
||||
let mut tree = super::MerkleTree::new(LEAVES8.to_vec()).unwrap();
|
||||
let mut tree = super::MerkleTree::new(digests_to_words(&LEAVES8)).unwrap();
|
||||
|
||||
// update one leaf
|
||||
let value = 3;
|
||||
let new_node = int_to_node(9);
|
||||
let mut expected_leaves = LEAVES8.to_vec();
|
||||
let new_node = int_to_leaf(9);
|
||||
let mut expected_leaves = digests_to_words(&LEAVES8);
|
||||
expected_leaves[value as usize] = new_node;
|
||||
let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
|
||||
|
||||
@@ -332,7 +374,7 @@ mod tests {
|
||||
|
||||
// update another leaf
|
||||
let value = 6;
|
||||
let new_node = int_to_node(10);
|
||||
let new_node = int_to_leaf(10);
|
||||
expected_leaves[value as usize] = new_node;
|
||||
let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
|
||||
|
||||
@@ -342,7 +384,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn nodes() -> Result<(), MerkleError> {
|
||||
let tree = super::MerkleTree::new(LEAVES4.to_vec()).unwrap();
|
||||
let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
|
||||
let root = tree.root();
|
||||
let l1n0 = tree.get_node(NodeIndex::make(1, 0))?;
|
||||
let l1n1 = tree.get_node(NodeIndex::make(1, 1))?;
|
||||
@@ -353,21 +395,9 @@ mod tests {
|
||||
|
||||
let nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
|
||||
let expected = vec![
|
||||
InnerNodeInfo {
|
||||
value: root,
|
||||
left: l1n0,
|
||||
right: l1n1,
|
||||
},
|
||||
InnerNodeInfo {
|
||||
value: l1n0,
|
||||
left: l2n0,
|
||||
right: l2n1,
|
||||
},
|
||||
InnerNodeInfo {
|
||||
value: l1n1,
|
||||
left: l2n2,
|
||||
right: l2n3,
|
||||
},
|
||||
InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
|
||||
InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
|
||||
InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
|
||||
];
|
||||
assert_eq!(nodes, expected);
|
||||
|
||||
@@ -391,8 +421,8 @@ mod tests {
|
||||
let digest = RpoDigest::from(word);
|
||||
|
||||
// assert the addresses are different
|
||||
let word_ptr = (&word).as_ptr() as *const u8;
|
||||
let digest_ptr = (&digest).as_ptr() as *const u8;
|
||||
let word_ptr = word.as_ptr() as *const u8;
|
||||
let digest_ptr = digest.as_ptr() as *const u8;
|
||||
assert_ne!(word_ptr, digest_ptr);
|
||||
|
||||
// compare the bytes representation
|
||||
@@ -405,11 +435,13 @@ mod tests {
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn compute_internal_nodes() -> (Word, Word, Word) {
|
||||
let node2 = Rpo256::hash_elements(&[LEAVES4[0], LEAVES4[1]].concat());
|
||||
let node3 = Rpo256::hash_elements(&[LEAVES4[2], LEAVES4[3]].concat());
|
||||
fn compute_internal_nodes() -> (RpoDigest, RpoDigest, RpoDigest) {
|
||||
let node2 =
|
||||
Rpo256::hash_elements(&[Word::from(LEAVES4[0]), Word::from(LEAVES4[1])].concat());
|
||||
let node3 =
|
||||
Rpo256::hash_elements(&[Word::from(LEAVES4[2]), Word::from(LEAVES4[3])].concat());
|
||||
let root = Rpo256::merge(&[node2, node3]);
|
||||
|
||||
(root.into(), node2.into(), node3.into())
|
||||
(root, node2, node3)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
use super::{
|
||||
super::Vec,
|
||||
super::{WORD_SIZE, ZERO},
|
||||
MmrProof, Rpo256, Word,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct MmrPeaks {
|
||||
/// The number of leaves is used to differentiate accumulators that have the same number of
|
||||
/// peaks. This happens because the number of peaks goes up-and-down as the structure is used
|
||||
/// causing existing trees to be merged and new ones to be created. As an example, every time
|
||||
/// the MMR has a power-of-two number of leaves there is a single peak.
|
||||
///
|
||||
/// Every tree in the MMR forest has a distinct power-of-two size, this means only the right
|
||||
/// most tree can have an odd number of elements (e.g. `1`). Additionally this means that the bits in
|
||||
/// `num_leaves` conveniently encode the size of each individual tree.
|
||||
///
|
||||
/// Examples:
|
||||
///
|
||||
/// - With 5 leaves, the binary `0b101`. The number of set bits is equal the number
|
||||
/// of peaks, in this case there are 2 peaks. The 0-indexed least-significant position of
|
||||
/// the bit determines the number of elements of a tree, so the rightmost tree has `2**0`
|
||||
/// elements and the left most has `2**2`.
|
||||
/// - With 12 leaves, the binary is `0b1100`, this case also has 2 peaks, the
|
||||
/// leftmost tree has `2**3=8` elements, and the right most has `2**2=4` elements.
|
||||
pub num_leaves: usize,
|
||||
|
||||
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
|
||||
/// leaves, starting from the peak with most children, to the one with least.
|
||||
///
|
||||
/// Invariant: The length of `peaks` must be equal to the number of true bits in `num_leaves`.
|
||||
pub peaks: Vec<Word>,
|
||||
}
|
||||
|
||||
impl MmrPeaks {
|
||||
/// Hashes the peaks.
|
||||
///
|
||||
/// The hashing is optimized to work with the Miden VM, the procedure will:
|
||||
///
|
||||
/// - Pad the peaks with ZERO to an even number of words, this removes the need to handle RPO padding.
|
||||
/// - Pad the peaks to a minimum length of 16 words, which reduces the constant cost of
|
||||
/// hashing.
|
||||
pub fn hash_peaks(&self) -> Word {
|
||||
let mut copy = self.peaks.clone();
|
||||
|
||||
if copy.len() < 16 {
|
||||
copy.resize(16, [ZERO; WORD_SIZE])
|
||||
} else if copy.len() % 2 == 1 {
|
||||
copy.push([ZERO; WORD_SIZE])
|
||||
}
|
||||
|
||||
Rpo256::hash_elements(©.as_slice().concat()).into()
|
||||
}
|
||||
|
||||
pub fn verify(&self, value: Word, opening: MmrProof) -> bool {
|
||||
let root = &self.peaks[opening.peak_index()];
|
||||
opening.merkle_path.verify(opening.relative_pos() as u64, value, root)
|
||||
}
|
||||
}
|
||||
16
src/merkle/mmr/delta.rs
Normal file
16
src/merkle/mmr/delta.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use super::super::{RpoDigest, Vec};
|
||||
|
||||
/// Container for the update data of a [PartialMmr]
|
||||
#[derive(Debug)]
|
||||
pub struct MmrDelta {
|
||||
/// The new version of the [Mmr]
|
||||
pub forest: usize,
|
||||
|
||||
/// Update data.
|
||||
///
|
||||
/// The data is packed as follows:
|
||||
/// 1. All the elements needed to perform authentication path updates. These are the right
|
||||
/// siblings required to perform tree merges on the [PartialMmr].
|
||||
/// 2. The new peaks.
|
||||
pub data: Vec<RpoDigest>,
|
||||
}
|
||||
35
src/merkle/mmr/error.rs
Normal file
35
src/merkle/mmr/error.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use crate::merkle::MerkleError;
|
||||
use core::fmt::{Display, Formatter};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use std::error::Error;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub enum MmrError {
|
||||
InvalidPosition(usize),
|
||||
InvalidPeaks,
|
||||
InvalidPeak,
|
||||
InvalidUpdate,
|
||||
UnknownPeak,
|
||||
MerkleError(MerkleError),
|
||||
}
|
||||
|
||||
impl Display for MmrError {
|
||||
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), core::fmt::Error> {
|
||||
match self {
|
||||
MmrError::InvalidPosition(pos) => write!(fmt, "Mmr does not contain position {pos}"),
|
||||
MmrError::InvalidPeaks => write!(fmt, "Invalid peaks count"),
|
||||
MmrError::InvalidPeak => {
|
||||
write!(fmt, "Peak values does not match merkle path computed root")
|
||||
}
|
||||
MmrError::InvalidUpdate => write!(fmt, "Invalid mmr update"),
|
||||
MmrError::UnknownPeak => {
|
||||
write!(fmt, "Peak not in Mmr")
|
||||
}
|
||||
MmrError::MerkleError(err) => write!(fmt, "{}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl Error for MmrError {}
|
||||
@@ -9,16 +9,13 @@
|
||||
//! least number of leaves. The structure preserves the invariant that each tree has different
|
||||
//! depths, i.e. as part of adding adding a new element to the forest the trees with same depth are
|
||||
//! merged, creating a new tree with depth d+1, this process is continued until the property is
|
||||
//! restabilished.
|
||||
use super::bit::TrueBitPositionIterator;
|
||||
//! reestablished.
|
||||
use super::{
|
||||
super::{InnerNodeInfo, MerklePath, Vec},
|
||||
MmrPeaks, MmrProof, Rpo256, Word,
|
||||
bit::TrueBitPositionIterator,
|
||||
leaf_to_corresponding_tree, nodes_in_forest, MmrDelta, MmrError, MmrPeaks, MmrProof, Rpo256,
|
||||
RpoDigest,
|
||||
};
|
||||
use core::fmt::{Display, Formatter};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use std::error::Error;
|
||||
|
||||
// MMR
|
||||
// ===============================================================================================
|
||||
@@ -28,6 +25,8 @@ use std::error::Error;
|
||||
///
|
||||
/// Since this is a full representation of the MMR, elements are never removed and the MMR will
|
||||
/// grow roughly `O(2n)` in number of leaf elements.
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct Mmr {
|
||||
/// Refer to the `forest` method documentation for details of the semantics of this value.
|
||||
pub(super) forest: usize,
|
||||
@@ -38,25 +37,9 @@ pub struct Mmr {
|
||||
/// the elements of every tree in the forest to be stored in the same sequential buffer. It
|
||||
/// also means new elements can be added to the forest, and merging of trees is very cheap with
|
||||
/// no need to copy elements.
|
||||
pub(super) nodes: Vec<Word>,
|
||||
pub(super) nodes: Vec<RpoDigest>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
|
||||
pub enum MmrError {
|
||||
InvalidPosition(usize),
|
||||
}
|
||||
|
||||
impl Display for MmrError {
|
||||
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), core::fmt::Error> {
|
||||
match self {
|
||||
MmrError::InvalidPosition(pos) => write!(fmt, "Mmr does not contain position {pos}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl Error for MmrError {}
|
||||
|
||||
impl Default for Mmr {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
@@ -69,10 +52,7 @@ impl Mmr {
|
||||
|
||||
/// Constructor for an empty `Mmr`.
|
||||
pub fn new() -> Mmr {
|
||||
Mmr {
|
||||
forest: 0,
|
||||
nodes: Vec::new(),
|
||||
}
|
||||
Mmr { forest: 0, nodes: Vec::new() }
|
||||
}
|
||||
|
||||
// ACCESSORS
|
||||
@@ -97,28 +77,23 @@ impl Mmr {
|
||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||
/// has position 0, the second position 1, and so on.
|
||||
pub fn open(&self, pos: usize) -> Result<MmrProof, MmrError> {
|
||||
pub fn open(&self, pos: usize, target_forest: usize) -> Result<MmrProof, MmrError> {
|
||||
// find the target tree responsible for the MMR position
|
||||
let tree_bit =
|
||||
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?;
|
||||
let forest_target = 1usize << tree_bit;
|
||||
leaf_to_corresponding_tree(pos, target_forest).ok_or(MmrError::InvalidPosition(pos))?;
|
||||
|
||||
// isolate the trees before the target
|
||||
let forest_before = self.forest & high_bitmask(tree_bit + 1);
|
||||
let forest_before = target_forest & high_bitmask(tree_bit + 1);
|
||||
let index_offset = nodes_in_forest(forest_before);
|
||||
|
||||
// find the root
|
||||
let index = nodes_in_forest(forest_target) - 1;
|
||||
|
||||
// update the value position from global to the target tree
|
||||
let relative_pos = pos - forest_before;
|
||||
|
||||
// collect the path and the final index of the target value
|
||||
let (_, path) =
|
||||
self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset, index);
|
||||
let (_, path) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);
|
||||
|
||||
Ok(MmrProof {
|
||||
forest: self.forest,
|
||||
forest: target_forest,
|
||||
position: pos,
|
||||
merkle_path: MerklePath::new(path),
|
||||
})
|
||||
@@ -129,31 +104,26 @@ impl Mmr {
|
||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||
/// has position 0, the second position 1, and so on.
|
||||
pub fn get(&self, pos: usize) -> Result<Word, MmrError> {
|
||||
pub fn get(&self, pos: usize) -> Result<RpoDigest, MmrError> {
|
||||
// find the target tree responsible for the MMR position
|
||||
let tree_bit =
|
||||
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?;
|
||||
let forest_target = 1usize << tree_bit;
|
||||
|
||||
// isolate the trees before the target
|
||||
let forest_before = self.forest & high_bitmask(tree_bit + 1);
|
||||
let index_offset = nodes_in_forest(forest_before);
|
||||
|
||||
// find the root
|
||||
let index = nodes_in_forest(forest_target) - 1;
|
||||
|
||||
// update the value position from global to the target tree
|
||||
let relative_pos = pos - forest_before;
|
||||
|
||||
// collect the path and the final index of the target value
|
||||
let (value, _) =
|
||||
self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset, index);
|
||||
let (value, _) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
/// Adds a new element to the MMR.
|
||||
pub fn add(&mut self, el: Word) {
|
||||
pub fn add(&mut self, el: RpoDigest) {
|
||||
// Note: every node is also a tree of size 1, adding an element to the forest creates a new
|
||||
// rooted-tree of size 1. This may temporarily break the invariant that every tree in the
|
||||
// forest has different sizes, the loop below will eagerly merge trees of same size and
|
||||
@@ -164,7 +134,7 @@ impl Mmr {
|
||||
let mut right = el;
|
||||
let mut left_tree = 1;
|
||||
while self.forest & left_tree != 0 {
|
||||
right = *Rpo256::merge(&[self.nodes[left_offset].into(), right.into()]);
|
||||
right = Rpo256::merge(&[self.nodes[left_offset], right]);
|
||||
self.nodes.push(right);
|
||||
|
||||
left_offset = left_offset.saturating_sub(nodes_in_forest(left_tree));
|
||||
@@ -174,9 +144,13 @@ impl Mmr {
|
||||
self.forest += 1;
|
||||
}
|
||||
|
||||
/// Returns an accumulator representing the current state of the MMR.
|
||||
pub fn accumulator(&self) -> MmrPeaks {
|
||||
let peaks: Vec<Word> = TrueBitPositionIterator::new(self.forest)
|
||||
/// Returns an peaks of the MMR for the version specified by `forest`.
|
||||
pub fn peaks(&self, forest: usize) -> Result<MmrPeaks, MmrError> {
|
||||
if forest > self.forest {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
}
|
||||
|
||||
let peaks: Vec<RpoDigest> = TrueBitPositionIterator::new(forest)
|
||||
.rev()
|
||||
.map(|bit| nodes_in_forest(1 << bit))
|
||||
.scan(0, |offset, el| {
|
||||
@@ -186,10 +160,84 @@ impl Mmr {
|
||||
.map(|offset| self.nodes[offset - 1])
|
||||
.collect();
|
||||
|
||||
MmrPeaks {
|
||||
num_leaves: self.forest,
|
||||
peaks,
|
||||
// Safety: the invariant is maintained by the [Mmr]
|
||||
let peaks = MmrPeaks::new(forest, peaks).unwrap();
|
||||
|
||||
Ok(peaks)
|
||||
}
|
||||
|
||||
/// Compute the required update to `original_forest`.
|
||||
///
|
||||
/// The result is a packed sequence of the authentication elements required to update the trees
|
||||
/// that have been merged together, followed by the new peaks of the [Mmr].
|
||||
pub fn get_delta(&self, from_forest: usize, to_forest: usize) -> Result<MmrDelta, MmrError> {
|
||||
if to_forest > self.forest || from_forest > to_forest {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
}
|
||||
|
||||
if from_forest == to_forest {
|
||||
return Ok(MmrDelta { forest: to_forest, data: Vec::new() });
|
||||
}
|
||||
|
||||
let mut result = Vec::new();
|
||||
|
||||
// Find the largest tree in this [Mmr] which is new to `from_forest`.
|
||||
let candidate_trees = to_forest ^ from_forest;
|
||||
let mut new_high = 1 << candidate_trees.ilog2();
|
||||
|
||||
// Collect authentication nodes used for tree merges
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// Find the trees from `from_forest` that have been merged into `new_high`.
|
||||
let mut merges = from_forest & (new_high - 1);
|
||||
|
||||
// Find the peaks that are common to `from_forest` and this [Mmr]
|
||||
let common_trees = from_forest ^ merges;
|
||||
|
||||
if merges != 0 {
|
||||
// Skip the smallest trees unknown to `from_forest`.
|
||||
let mut target = 1 << merges.trailing_zeros();
|
||||
|
||||
// Collect siblings required to computed the merged tree's peak
|
||||
while target < new_high {
|
||||
// Computes the offset to the smallest know peak
|
||||
// - common_trees: peaks unchanged in the current update, target comes after these.
|
||||
// - merges: peaks that have not been merged so far, target comes after these.
|
||||
// - target: tree from which to load the sibling. On the first iteration this is a
|
||||
// value known by the partial mmr, on subsequent iterations this value is to be
|
||||
// computed from the known peaks and provided authentication nodes.
|
||||
let known = nodes_in_forest(common_trees | merges | target);
|
||||
let sibling = nodes_in_forest(target);
|
||||
result.push(self.nodes[known + sibling - 1]);
|
||||
|
||||
// Update the target and account for tree merges
|
||||
target <<= 1;
|
||||
while merges & target != 0 {
|
||||
target <<= 1;
|
||||
}
|
||||
// Remove the merges done so far
|
||||
merges ^= merges & (target - 1);
|
||||
}
|
||||
} else {
|
||||
// The new high tree may not be the result of any merges, if it is smaller than all the
|
||||
// trees of `from_forest`.
|
||||
new_high = 0;
|
||||
}
|
||||
|
||||
// Collect the new [Mmr] peaks
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
let mut new_peaks = to_forest ^ common_trees ^ new_high;
|
||||
let old_peaks = to_forest ^ new_peaks;
|
||||
let mut offset = nodes_in_forest(old_peaks);
|
||||
while new_peaks != 0 {
|
||||
let target = 1 << new_peaks.ilog2();
|
||||
offset += nodes_in_forest(target);
|
||||
result.push(self.nodes[offset - 1]);
|
||||
new_peaks ^= target;
|
||||
}
|
||||
|
||||
Ok(MmrDelta { forest: to_forest, data: result })
|
||||
}
|
||||
|
||||
/// An iterator over inner nodes in the MMR. The order of iteration is unspecified.
|
||||
@@ -206,36 +254,52 @@ impl Mmr {
|
||||
// ============================================================================================
|
||||
|
||||
/// Internal function used to collect the Merkle path of a value.
|
||||
///
|
||||
/// The arguments are relative to the target tree. To compute the opening of the second leaf
|
||||
/// for a tree with depth 2 in the forest `0b110`:
|
||||
///
|
||||
/// - `tree_bit`: Depth of the target tree, e.g. 2 for the smallest tree.
|
||||
/// - `relative_pos`: 0-indexed leaf position in the target tree, e.g. 1 for the second leaf.
|
||||
/// - `index_offset`: Node count prior to the target tree, e.g. 7 for the tree of depth 3.
|
||||
fn collect_merkle_path_and_value(
|
||||
&self,
|
||||
tree_bit: u32,
|
||||
relative_pos: usize,
|
||||
index_offset: usize,
|
||||
mut index: usize,
|
||||
) -> (Word, Vec<Word>) {
|
||||
// collect the Merkle path
|
||||
let mut tree_depth = tree_bit as usize;
|
||||
let mut path = Vec::with_capacity(tree_depth + 1);
|
||||
while tree_depth > 0 {
|
||||
let bit = relative_pos & tree_depth;
|
||||
let right_offset = index - 1;
|
||||
let left_offset = right_offset - nodes_in_forest(tree_depth);
|
||||
) -> (RpoDigest, Vec<RpoDigest>) {
|
||||
// see documentation of `leaf_to_corresponding_tree` for details
|
||||
let tree_depth = (tree_bit + 1) as usize;
|
||||
let mut path = Vec::with_capacity(tree_depth);
|
||||
|
||||
// Elements to the right have a higher position because they were
|
||||
// added later. Therefore when the bit is true the node's path is
|
||||
// to the right, and its sibling to the left.
|
||||
let sibling = if bit != 0 {
|
||||
// The tree walk below goes from the root to the leaf, compute the root index to start
|
||||
let mut forest_target = 1usize << tree_bit;
|
||||
let mut index = nodes_in_forest(forest_target) - 1;
|
||||
|
||||
// Loop until the leaf is reached
|
||||
while forest_target > 1 {
|
||||
// Update the depth of the tree to correspond to a subtree
|
||||
forest_target >>= 1;
|
||||
|
||||
// compute the indeces of the right and left subtrees based on the post-order
|
||||
let right_offset = index - 1;
|
||||
let left_offset = right_offset - nodes_in_forest(forest_target);
|
||||
|
||||
let left_or_right = relative_pos & forest_target;
|
||||
let sibling = if left_or_right != 0 {
|
||||
// going down the right subtree, the right child becomes the new root
|
||||
index = right_offset;
|
||||
// and the left child is the authentication
|
||||
self.nodes[index_offset + left_offset]
|
||||
} else {
|
||||
index = left_offset;
|
||||
self.nodes[index_offset + right_offset]
|
||||
};
|
||||
|
||||
tree_depth >>= 1;
|
||||
path.push(sibling);
|
||||
}
|
||||
|
||||
debug_assert!(path.len() == tree_depth - 1);
|
||||
|
||||
// the rest of the codebase has the elements going from leaf to root, adjust it here for
|
||||
// easy of use/consistency sake
|
||||
path.reverse();
|
||||
@@ -245,9 +309,12 @@ impl Mmr {
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS
|
||||
// ================================================================================================
|
||||
|
||||
impl<T> From<T> for Mmr
|
||||
where
|
||||
T: IntoIterator<Item = Word>,
|
||||
T: IntoIterator<Item = RpoDigest>,
|
||||
{
|
||||
fn from(values: T) -> Self {
|
||||
let mut mmr = Mmr::new();
|
||||
@@ -339,32 +406,6 @@ impl<'a> Iterator for MmrNodes<'a> {
|
||||
// UTILITIES
|
||||
// ===============================================================================================
|
||||
|
||||
/// Given a 0-indexed leaf position and the current forest, return the tree number responsible for
|
||||
/// the position.
|
||||
///
|
||||
/// Note:
|
||||
/// The result is a tree position `p`, it has the following interpretations. $p+1$ is the depth of
|
||||
/// the tree, which corresponds to the size of a Merkle proof for that tree. $2^p$ is equal to the
|
||||
/// number of leaves in this particular tree. and $2^(p+1)-1$ corresponds to size of the tree.
|
||||
pub(crate) const fn leaf_to_corresponding_tree(pos: usize, forest: usize) -> Option<u32> {
|
||||
if pos >= forest {
|
||||
None
|
||||
} else {
|
||||
// - each bit in the forest is a unique tree and the bit position its power-of-two size
|
||||
// - each tree owns a consecutive range of positions equal to its size from left-to-right
|
||||
// - this means the first tree owns from `0` up to the `2^k_0` first positions, where `k_0`
|
||||
// is the highest true bit position, the second tree from `2^k_0 + 1` up to `2^k_1` where
|
||||
// `k_1` is the second higest bit, so on.
|
||||
// - this means the highest bits work as a category marker, and the position is owned by
|
||||
// the first tree which doesn't share a high bit with the position
|
||||
let before = forest & pos;
|
||||
let after = forest ^ before;
|
||||
let tree = after.ilog2();
|
||||
|
||||
Some(tree)
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a bitmask for the bits including and above the given position.
|
||||
pub(crate) const fn high_bitmask(bit: u32) -> usize {
|
||||
if bit > usize::BITS - 1 {
|
||||
@@ -373,17 +414,3 @@ pub(crate) const fn high_bitmask(bit: u32) -> usize {
|
||||
usize::MAX << bit
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the total number of nodes of a given forest
|
||||
///
|
||||
/// Panics:
|
||||
///
|
||||
/// This will panic if the forest has size greater than `usize::MAX / 2`
|
||||
pub(crate) const fn nodes_in_forest(forest: usize) -> usize {
|
||||
// - the size of a perfect binary tree is $2^{k+1}-1$ or $2*2^k-1$
|
||||
// - the forest represents the sum of $2^k$ so a single multiplication is necessary
|
||||
// - the number of `-1` is the same as the number of trees, which is the same as the number
|
||||
// bits set
|
||||
let tree_count = forest.count_ones() as usize;
|
||||
forest * 2 - tree_count
|
||||
}
|
||||
|
||||
164
src/merkle/mmr/inorder.rs
Normal file
164
src/merkle/mmr/inorder.rs
Normal file
@@ -0,0 +1,164 @@
|
||||
//! Index for nodes of a binary tree based on an in-order tree walk.
|
||||
//!
|
||||
//! In-order walks have the parent node index split its left and right subtrees. All the left
|
||||
//! children have indexes lower than the parent, meanwhile all the right subtree higher indexes.
|
||||
//! This property makes it is easy to compute changes to the index by adding or subtracting the
|
||||
//! leaves count.
|
||||
use core::num::NonZeroUsize;
|
||||
|
||||
// IN-ORDER INDEX
|
||||
// ================================================================================================
|
||||
|
||||
/// Index of nodes in a perfectly balanced binary tree based on an in-order tree walk.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct InOrderIndex {
|
||||
idx: usize,
|
||||
}
|
||||
|
||||
impl InOrderIndex {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [InOrderIndex] instantiated from the provided value.
|
||||
pub fn new(idx: NonZeroUsize) -> InOrderIndex {
|
||||
InOrderIndex { idx: idx.get() }
|
||||
}
|
||||
|
||||
/// Return a new [InOrderIndex] instantiated from the specified leaf position.
|
||||
///
|
||||
/// # Panics:
|
||||
/// If `leaf` is higher than or equal to `usize::MAX / 2`.
|
||||
pub fn from_leaf_pos(leaf: usize) -> InOrderIndex {
|
||||
// Convert the position from 0-indexed to 1-indexed, since the bit manipulation in this
|
||||
// implementation only works 1-indexed counting.
|
||||
let pos = leaf + 1;
|
||||
InOrderIndex { idx: pos * 2 - 1 }
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// True if the index is pointing at a leaf.
|
||||
///
|
||||
/// Every odd number represents a leaf.
|
||||
pub fn is_leaf(&self) -> bool {
|
||||
self.idx & 1 == 1
|
||||
}
|
||||
|
||||
/// Returns true if this note is a left child of its parent.
|
||||
pub fn is_left_child(&self) -> bool {
|
||||
self.parent().left_child() == *self
|
||||
}
|
||||
|
||||
/// Returns the level of the index.
|
||||
///
|
||||
/// Starts at level zero for leaves and increases by one for each parent.
|
||||
pub fn level(&self) -> u32 {
|
||||
self.idx.trailing_zeros()
|
||||
}
|
||||
|
||||
/// Returns the index of the left child.
|
||||
///
|
||||
/// # Panics:
|
||||
/// If the index corresponds to a leaf.
|
||||
pub fn left_child(&self) -> InOrderIndex {
|
||||
// The left child is itself a parent, with an index that splits its left/right subtrees. To
|
||||
// go from the parent index to its left child, it is only necessary to subtract the count
|
||||
// of elements on the child's right subtree + 1.
|
||||
let els = 1 << (self.level() - 1);
|
||||
InOrderIndex { idx: self.idx - els }
|
||||
}
|
||||
|
||||
/// Returns the index of the right child.
|
||||
///
|
||||
/// # Panics:
|
||||
/// If the index corresponds to a leaf.
|
||||
pub fn right_child(&self) -> InOrderIndex {
|
||||
// To compute the index of the parent of the right subtree it is sufficient to add the size
|
||||
// of its left subtree + 1.
|
||||
let els = 1 << (self.level() - 1);
|
||||
InOrderIndex { idx: self.idx + els }
|
||||
}
|
||||
|
||||
/// Returns the index of the parent node.
|
||||
pub fn parent(&self) -> InOrderIndex {
|
||||
// If the current index corresponds to a node in a left tree, to go up a level it is
|
||||
// required to add the number of nodes of the right sibling, analogously if the node is a
|
||||
// right child, going up requires subtracting the number of nodes in its left subtree.
|
||||
//
|
||||
// Both of the above operations can be performed by bitwise manipulation. Below the mask
|
||||
// sets the number of trailing zeros to be equal the new level of the index, and the bit
|
||||
// marks the parent.
|
||||
let target = self.level() + 1;
|
||||
let bit = 1 << target;
|
||||
let mask = bit - 1;
|
||||
let idx = self.idx ^ (self.idx & mask);
|
||||
InOrderIndex { idx: idx | bit }
|
||||
}
|
||||
|
||||
/// Returns the index of the sibling node.
|
||||
pub fn sibling(&self) -> InOrderIndex {
|
||||
let parent = self.parent();
|
||||
if *self > parent {
|
||||
parent.left_child()
|
||||
} else {
|
||||
parent.right_child()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the inner value of this [InOrderIndex].
|
||||
pub fn inner(&self) -> u64 {
|
||||
self.idx as u64
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS FROM IN-ORDER INDEX
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl From<InOrderIndex> for u64 {
|
||||
fn from(index: InOrderIndex) -> Self {
|
||||
index.idx as u64
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::InOrderIndex;
|
||||
use proptest::prelude::*;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn proptest_inorder_index_random(count in 1..1000usize) {
|
||||
let left_pos = count * 2;
|
||||
let right_pos = count * 2 + 1;
|
||||
|
||||
let left = InOrderIndex::from_leaf_pos(left_pos);
|
||||
let right = InOrderIndex::from_leaf_pos(right_pos);
|
||||
|
||||
assert!(left.is_leaf());
|
||||
assert!(right.is_leaf());
|
||||
assert_eq!(left.parent(), right.parent());
|
||||
assert_eq!(left.parent().right_child(), right);
|
||||
assert_eq!(left, right.parent().left_child());
|
||||
assert_eq!(left.sibling(), right);
|
||||
assert_eq!(left, right.sibling());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inorder_index_basic() {
|
||||
let left = InOrderIndex::from_leaf_pos(0);
|
||||
let right = InOrderIndex::from_leaf_pos(1);
|
||||
|
||||
assert!(left.is_leaf());
|
||||
assert!(right.is_leaf());
|
||||
assert_eq!(left.parent(), right.parent());
|
||||
assert_eq!(left.parent().right_child(), right);
|
||||
assert_eq!(left, right.parent().left_child());
|
||||
assert_eq!(left.sibling(), right);
|
||||
assert_eq!(left, right.sibling());
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,67 @@
|
||||
mod accumulator;
|
||||
mod bit;
|
||||
mod delta;
|
||||
mod error;
|
||||
mod full;
|
||||
mod inorder;
|
||||
mod partial;
|
||||
mod peaks;
|
||||
mod proof;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use super::{Rpo256, Word};
|
||||
use super::{Felt, Rpo256, RpoDigest, Word};
|
||||
|
||||
// REEXPORTS
|
||||
// ================================================================================================
|
||||
pub use accumulator::MmrPeaks;
|
||||
pub use delta::MmrDelta;
|
||||
pub use error::MmrError;
|
||||
pub use full::Mmr;
|
||||
pub use inorder::InOrderIndex;
|
||||
pub use partial::PartialMmr;
|
||||
pub use peaks::MmrPeaks;
|
||||
pub use proof::MmrProof;
|
||||
|
||||
// UTILITIES
|
||||
// ===============================================================================================
|
||||
|
||||
/// Given a 0-indexed leaf position and the current forest, return the tree number responsible for
|
||||
/// the position.
|
||||
///
|
||||
/// Note:
|
||||
/// The result is a tree position `p`, it has the following interpretations. $p+1$ is the depth of
|
||||
/// the tree. Because the root element is not part of the proof, $p$ is the length of the
|
||||
/// authentication path. $2^p$ is equal to the number of leaves in this particular tree. and
|
||||
/// $2^(p+1)-1$ corresponds to size of the tree.
|
||||
const fn leaf_to_corresponding_tree(pos: usize, forest: usize) -> Option<u32> {
|
||||
if pos >= forest {
|
||||
None
|
||||
} else {
|
||||
// - each bit in the forest is a unique tree and the bit position its power-of-two size
|
||||
// - each tree owns a consecutive range of positions equal to its size from left-to-right
|
||||
// - this means the first tree owns from `0` up to the `2^k_0` first positions, where `k_0`
|
||||
// is the highest true bit position, the second tree from `2^k_0 + 1` up to `2^k_1` where
|
||||
// `k_1` is the second highest bit, so on.
|
||||
// - this means the highest bits work as a category marker, and the position is owned by
|
||||
// the first tree which doesn't share a high bit with the position
|
||||
let before = forest & pos;
|
||||
let after = forest ^ before;
|
||||
let tree = after.ilog2();
|
||||
|
||||
Some(tree)
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the total number of nodes of a given forest
|
||||
///
|
||||
/// Panics:
|
||||
///
|
||||
/// This will panic if the forest has size greater than `usize::MAX / 2`
|
||||
const fn nodes_in_forest(forest: usize) -> usize {
|
||||
// - the size of a perfect binary tree is $2^{k+1}-1$ or $2*2^k-1$
|
||||
// - the forest represents the sum of $2^k$ so a single multiplication is necessary
|
||||
// - the number of `-1` is the same as the number of trees, which is the same as the number
|
||||
// bits set
|
||||
let tree_count = forest.count_ones() as usize;
|
||||
forest * 2 - tree_count
|
||||
}
|
||||
|
||||
703
src/merkle/mmr/partial.rs
Normal file
703
src/merkle/mmr/partial.rs
Normal file
@@ -0,0 +1,703 @@
|
||||
use super::{MmrDelta, MmrProof, Rpo256, RpoDigest};
|
||||
use crate::{
|
||||
merkle::{
|
||||
mmr::{leaf_to_corresponding_tree, nodes_in_forest},
|
||||
InOrderIndex, InnerNodeInfo, MerklePath, MmrError, MmrPeaks,
|
||||
},
|
||||
utils::{
|
||||
collections::{BTreeMap, BTreeSet, Vec},
|
||||
vec,
|
||||
},
|
||||
};
|
||||
|
||||
// PARTIAL MERKLE MOUNTAIN RANGE
|
||||
// ================================================================================================
|
||||
/// Partially materialized Merkle Mountain Range (MMR), used to efficiently store and update the
|
||||
/// authentication paths for a subset of the elements in a full MMR.
|
||||
///
|
||||
/// This structure store only the authentication path for a value, the value itself is stored
|
||||
/// separately.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PartialMmr {
|
||||
/// The version of the MMR.
|
||||
///
|
||||
/// This value serves the following purposes:
|
||||
///
|
||||
/// - The forest is a counter for the total number of elements in the MMR.
|
||||
/// - Since the MMR is an append-only structure, every change to it causes a change to the
|
||||
/// `forest`, so this value has a dual purpose as a version tag.
|
||||
/// - The bits in the forest also corresponds to the count and size of every perfect binary
|
||||
/// tree that composes the MMR structure, which server to compute indexes and perform
|
||||
/// validation.
|
||||
pub(crate) forest: usize,
|
||||
|
||||
/// The MMR peaks.
|
||||
///
|
||||
/// The peaks are used for two reasons:
|
||||
///
|
||||
/// 1. It authenticates the addition of an element to the [PartialMmr], ensuring only valid
|
||||
/// elements are tracked.
|
||||
/// 2. During a MMR update peaks can be merged by hashing the left and right hand sides. The
|
||||
/// peaks are used as the left hand.
|
||||
///
|
||||
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
|
||||
/// leaves, starting from the peak with most children, to the one with least.
|
||||
pub(crate) peaks: Vec<RpoDigest>,
|
||||
|
||||
/// Authentication nodes used to construct merkle paths for a subset of the MMR's leaves.
|
||||
///
|
||||
/// This does not include the MMR's peaks nor the tracked nodes, only the elements required
|
||||
/// to construct their authentication paths. This property is used to detect when elements can
|
||||
/// be safely removed from, because they are no longer required to authenticate any element in
|
||||
/// the [PartialMmr].
|
||||
///
|
||||
/// The elements in the MMR are referenced using a in-order tree index. This indexing scheme
|
||||
/// permits for easy computation of the relative nodes (left/right children, sibling, parent),
|
||||
/// which is useful for traversal. The indexing is also stable, meaning that merges to the
|
||||
/// trees in the MMR can be represented without rewrites of the indexes.
|
||||
pub(crate) nodes: BTreeMap<InOrderIndex, RpoDigest>,
|
||||
|
||||
/// Flag indicating if the odd element should be tracked.
|
||||
///
|
||||
/// This flag is necessary because the sibling of the odd doesn't exist yet, so it can not be
|
||||
/// added into `nodes` to signal the value is being tracked.
|
||||
pub(crate) track_latest: bool,
|
||||
}
|
||||
|
||||
impl PartialMmr {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Constructs a [PartialMmr] from the given [MmrPeaks].
|
||||
pub fn from_peaks(peaks: MmrPeaks) -> Self {
|
||||
let forest = peaks.num_leaves();
|
||||
let peaks = peaks.peaks().to_vec();
|
||||
let nodes = BTreeMap::new();
|
||||
let track_latest = false;
|
||||
|
||||
Self { forest, peaks, nodes, track_latest }
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the current `forest` of this [PartialMmr].
|
||||
///
|
||||
/// This value corresponds to the version of the [PartialMmr] and the number of leaves in the
|
||||
/// underlying MMR.
|
||||
pub fn forest(&self) -> usize {
|
||||
self.forest
|
||||
}
|
||||
|
||||
/// Returns the number of leaves in the underlying MMR for this [PartialMmr].
|
||||
pub fn num_leaves(&self) -> usize {
|
||||
self.forest
|
||||
}
|
||||
|
||||
/// Returns the peaks of the MMR for this [PartialMmr].
|
||||
pub fn peaks(&self) -> MmrPeaks {
|
||||
// expect() is OK here because the constructor ensures that MMR peaks can be constructed
|
||||
// correctly
|
||||
MmrPeaks::new(self.forest, self.peaks.clone()).expect("invalid MMR peaks")
|
||||
}
|
||||
|
||||
/// Given a leaf position, returns the Merkle path to its corresponding peak.
|
||||
///
|
||||
/// If the position is greater-or-equal than the tree size an error is returned. If the
|
||||
/// requested value is not tracked returns `None`.
|
||||
///
|
||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||
/// has position 0, the second position 1, and so on.
|
||||
pub fn open(&self, pos: usize) -> Result<Option<MmrProof>, MmrError> {
|
||||
let tree_bit =
|
||||
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?;
|
||||
let depth = tree_bit as usize;
|
||||
|
||||
let mut nodes = Vec::with_capacity(depth);
|
||||
let mut idx = InOrderIndex::from_leaf_pos(pos);
|
||||
|
||||
while let Some(node) = self.nodes.get(&idx.sibling()) {
|
||||
nodes.push(*node);
|
||||
idx = idx.parent();
|
||||
}
|
||||
|
||||
// If there are nodes then the path must be complete, otherwise it is a bug
|
||||
debug_assert!(nodes.is_empty() || nodes.len() == depth);
|
||||
|
||||
if nodes.len() != depth {
|
||||
// The requested `pos` is not being tracked.
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(MmrProof {
|
||||
forest: self.forest,
|
||||
position: pos,
|
||||
merkle_path: MerklePath::new(nodes),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator nodes of all authentication paths of this [PartialMmr].
|
||||
pub fn nodes(&self) -> impl Iterator<Item = (&InOrderIndex, &RpoDigest)> {
|
||||
self.nodes.iter()
|
||||
}
|
||||
|
||||
/// Returns an iterator over inner nodes of this [PartialMmr] for the specified leaves.
|
||||
///
|
||||
/// The order of iteration is not defined. If a leaf is not presented in this partial MMR it
|
||||
/// is silently ignored.
|
||||
pub fn inner_nodes<'a, I: Iterator<Item = &'a (usize, RpoDigest)> + 'a>(
|
||||
&'a self,
|
||||
mut leaves: I,
|
||||
) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
let stack = if let Some((pos, leaf)) = leaves.next() {
|
||||
let idx = InOrderIndex::from_leaf_pos(*pos);
|
||||
vec![(idx, *leaf)]
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
InnerNodeIterator {
|
||||
nodes: &self.nodes,
|
||||
leaves,
|
||||
stack,
|
||||
seen_nodes: BTreeSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Add the authentication path represented by [MerklePath] if it is valid.
|
||||
///
|
||||
/// The `index` refers to the global position of the leaf in the MMR, these are 0-indexed
|
||||
/// values assigned in a strictly monotonic fashion as elements are inserted into the MMR,
|
||||
/// this value corresponds to the values used in the MMR structure.
|
||||
///
|
||||
/// The `node` corresponds to the value at `index`, and `path` is the authentication path for
|
||||
/// that element up to its corresponding Mmr peak. The `node` is only used to compute the root
|
||||
/// from the authentication path to valid the data, only the authentication data is saved in
|
||||
/// the structure. If the value is required it should be stored out-of-band.
|
||||
pub fn add(
|
||||
&mut self,
|
||||
index: usize,
|
||||
node: RpoDigest,
|
||||
path: &MerklePath,
|
||||
) -> Result<(), MmrError> {
|
||||
// Checks there is a tree with same depth as the authentication path, if not the path is
|
||||
// invalid.
|
||||
let tree = 1 << path.depth();
|
||||
if tree & self.forest == 0 {
|
||||
return Err(MmrError::UnknownPeak);
|
||||
};
|
||||
|
||||
if index + 1 == self.forest
|
||||
&& path.depth() == 0
|
||||
&& self.peaks.last().map_or(false, |v| *v == node)
|
||||
{
|
||||
self.track_latest = true;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// ignore the trees smaller than the target (these elements are position after the current
|
||||
// target and don't affect the target index)
|
||||
let target_forest = self.forest ^ (self.forest & (tree - 1));
|
||||
let peak_pos = (target_forest.count_ones() - 1) as usize;
|
||||
|
||||
// translate from mmr index to merkle path
|
||||
let path_idx = index - (target_forest ^ tree);
|
||||
|
||||
// Compute the root of the authentication path, and check it matches the current version of
|
||||
// the PartialMmr.
|
||||
let computed = path.compute_root(path_idx as u64, node).map_err(MmrError::MerkleError)?;
|
||||
if self.peaks[peak_pos] != computed {
|
||||
return Err(MmrError::InvalidPeak);
|
||||
}
|
||||
|
||||
let mut idx = InOrderIndex::from_leaf_pos(index);
|
||||
for node in path.nodes() {
|
||||
self.nodes.insert(idx.sibling(), *node);
|
||||
idx = idx.parent();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove a leaf of the [PartialMmr] and the unused nodes from the authentication path.
|
||||
///
|
||||
/// Note: `leaf_pos` corresponds to the position in the MMR and not on an individual tree.
|
||||
pub fn remove(&mut self, leaf_pos: usize) {
|
||||
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
|
||||
|
||||
self.nodes.remove(&idx.sibling());
|
||||
|
||||
// `idx` represent the element that can be computed by the authentication path, because
|
||||
// these elements can be computed they are not saved for the authentication of the current
|
||||
// target. In other words, if the idx is present it was added for the authentication of
|
||||
// another element, and no more elements should be removed otherwise it would remove that
|
||||
// element's authentication data.
|
||||
while !self.nodes.contains_key(&idx) {
|
||||
idx = idx.parent();
|
||||
self.nodes.remove(&idx.sibling());
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies updates to this [PartialMmr] and returns a vector of new authentication nodes
|
||||
/// inserted into the partial MMR.
|
||||
pub fn apply(&mut self, delta: MmrDelta) -> Result<Vec<(InOrderIndex, RpoDigest)>, MmrError> {
|
||||
if delta.forest < self.forest {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
}
|
||||
|
||||
let mut inserted_nodes = Vec::new();
|
||||
|
||||
if delta.forest == self.forest {
|
||||
if !delta.data.is_empty() {
|
||||
return Err(MmrError::InvalidUpdate);
|
||||
}
|
||||
|
||||
return Ok(inserted_nodes);
|
||||
}
|
||||
|
||||
// find the tree merges
|
||||
let changes = self.forest ^ delta.forest;
|
||||
let largest = 1 << changes.ilog2();
|
||||
let merges = self.forest & (largest - 1);
|
||||
|
||||
debug_assert!(
|
||||
!self.track_latest || (merges & 1) == 1,
|
||||
"if there is an odd element, a merge is required"
|
||||
);
|
||||
|
||||
// count the number elements needed to produce largest from the current state
|
||||
let (merge_count, new_peaks) = if merges != 0 {
|
||||
let depth = largest.trailing_zeros();
|
||||
let skipped = merges.trailing_zeros();
|
||||
let computed = merges.count_ones() - 1;
|
||||
let merge_count = depth - skipped - computed;
|
||||
|
||||
let new_peaks = delta.forest & (largest - 1);
|
||||
|
||||
(merge_count, new_peaks)
|
||||
} else {
|
||||
(0, changes)
|
||||
};
|
||||
|
||||
// verify the delta size
|
||||
if (delta.data.len() as u32) != merge_count + new_peaks.count_ones() {
|
||||
return Err(MmrError::InvalidUpdate);
|
||||
}
|
||||
|
||||
// keeps track of how many data elements from the update have been consumed
|
||||
let mut update_count = 0;
|
||||
|
||||
if merges != 0 {
|
||||
// starts at the smallest peak and follows the merged peaks
|
||||
let mut peak_idx = forest_to_root_index(self.forest);
|
||||
|
||||
// match order of the update data while applying it
|
||||
self.peaks.reverse();
|
||||
|
||||
// set to true when the data is needed for authentication paths updates
|
||||
let mut track = self.track_latest;
|
||||
self.track_latest = false;
|
||||
|
||||
let mut peak_count = 0;
|
||||
let mut target = 1 << merges.trailing_zeros();
|
||||
let mut new = delta.data[0];
|
||||
update_count += 1;
|
||||
|
||||
while target < largest {
|
||||
// check if either the left or right subtrees have saved for authentication paths.
|
||||
// If so, turn tracking on to update those paths.
|
||||
if target != 1 && !track {
|
||||
track = self.is_tracked_node(&peak_idx);
|
||||
}
|
||||
|
||||
// update data only contains the nodes from the right subtrees, left nodes are
|
||||
// either previously known peaks or computed values
|
||||
let (left, right) = if target & merges != 0 {
|
||||
let peak = self.peaks[peak_count];
|
||||
let sibling_idx = peak_idx.sibling();
|
||||
|
||||
// if the sibling peak is tracked, add this peaks to the set of
|
||||
// authentication nodes
|
||||
if self.is_tracked_node(&sibling_idx) {
|
||||
self.nodes.insert(peak_idx, new);
|
||||
inserted_nodes.push((peak_idx, new));
|
||||
}
|
||||
peak_count += 1;
|
||||
(peak, new)
|
||||
} else {
|
||||
let update = delta.data[update_count];
|
||||
update_count += 1;
|
||||
(new, update)
|
||||
};
|
||||
|
||||
if track {
|
||||
let sibling_idx = peak_idx.sibling();
|
||||
if peak_idx.is_left_child() {
|
||||
self.nodes.insert(sibling_idx, right);
|
||||
inserted_nodes.push((sibling_idx, right));
|
||||
} else {
|
||||
self.nodes.insert(sibling_idx, left);
|
||||
inserted_nodes.push((sibling_idx, left));
|
||||
}
|
||||
}
|
||||
|
||||
peak_idx = peak_idx.parent();
|
||||
new = Rpo256::merge(&[left, right]);
|
||||
target <<= 1;
|
||||
}
|
||||
|
||||
debug_assert!(peak_count == (merges.count_ones() as usize));
|
||||
|
||||
// restore the peaks order
|
||||
self.peaks.reverse();
|
||||
// remove the merged peaks
|
||||
self.peaks.truncate(self.peaks.len() - peak_count);
|
||||
// add the newly computed peak, the result of the merges
|
||||
self.peaks.push(new);
|
||||
}
|
||||
|
||||
// The rest of the update data is composed of peaks. None of these elements can contain
|
||||
// tracked elements because the peaks were unknown, and it is not possible to add elements
|
||||
// for tacking without authenticating it to a peak.
|
||||
self.peaks.extend_from_slice(&delta.data[update_count..]);
|
||||
self.forest = delta.forest;
|
||||
|
||||
debug_assert!(self.peaks.len() == (self.forest.count_ones() as usize));
|
||||
|
||||
Ok(inserted_nodes)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if this [PartialMmr] tracks authentication path for the node at the specified
|
||||
/// index.
|
||||
fn is_tracked_node(&self, node_index: &InOrderIndex) -> bool {
|
||||
if node_index.is_leaf() {
|
||||
self.nodes.contains_key(&node_index.sibling())
|
||||
} else {
|
||||
let left_child = node_index.left_child();
|
||||
let right_child = node_index.right_child();
|
||||
self.nodes.contains_key(&left_child) | self.nodes.contains_key(&right_child)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS
|
||||
// ================================================================================================
|
||||
|
||||
impl From<MmrPeaks> for PartialMmr {
|
||||
fn from(peaks: MmrPeaks) -> Self {
|
||||
Self::from_peaks(peaks)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PartialMmr> for MmrPeaks {
|
||||
fn from(partial_mmr: PartialMmr) -> Self {
|
||||
// Safety: the [PartialMmr] maintains the constraints the number of true bits in the forest
|
||||
// matches the number of peaks, as required by the [MmrPeaks]
|
||||
MmrPeaks::new(partial_mmr.forest, partial_mmr.peaks).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&MmrPeaks> for PartialMmr {
|
||||
fn from(peaks: &MmrPeaks) -> Self {
|
||||
Self::from_peaks(peaks.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&PartialMmr> for MmrPeaks {
|
||||
fn from(partial_mmr: &PartialMmr) -> Self {
|
||||
// Safety: the [PartialMmr] maintains the constraints the number of true bits in the forest
|
||||
// matches the number of peaks, as required by the [MmrPeaks]
|
||||
MmrPeaks::new(partial_mmr.forest, partial_mmr.peaks.clone()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
|
||||
/// An iterator over every inner node of the [PartialMmr].
|
||||
pub struct InnerNodeIterator<'a, I: Iterator<Item = &'a (usize, RpoDigest)>> {
|
||||
nodes: &'a BTreeMap<InOrderIndex, RpoDigest>,
|
||||
leaves: I,
|
||||
stack: Vec<(InOrderIndex, RpoDigest)>,
|
||||
seen_nodes: BTreeSet<InOrderIndex>,
|
||||
}
|
||||
|
||||
impl<'a, I: Iterator<Item = &'a (usize, RpoDigest)>> Iterator for InnerNodeIterator<'a, I> {
|
||||
type Item = InnerNodeInfo;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while let Some((idx, node)) = self.stack.pop() {
|
||||
let parent_idx = idx.parent();
|
||||
let new_node = self.seen_nodes.insert(parent_idx);
|
||||
|
||||
// if we haven't seen this node's parent before, and the node has a sibling, return
|
||||
// the inner node defined by the parent of this node, and move up the branch
|
||||
if new_node {
|
||||
if let Some(sibling) = self.nodes.get(&idx.sibling()) {
|
||||
let (left, right) = if parent_idx.left_child() == idx {
|
||||
(node, *sibling)
|
||||
} else {
|
||||
(*sibling, node)
|
||||
};
|
||||
let parent = Rpo256::merge(&[left, right]);
|
||||
let inner_node = InnerNodeInfo { value: parent, left, right };
|
||||
|
||||
self.stack.push((parent_idx, parent));
|
||||
return Some(inner_node);
|
||||
}
|
||||
}
|
||||
|
||||
// the previous leaf has been processed, try to process the next leaf
|
||||
if let Some((pos, leaf)) = self.leaves.next() {
|
||||
let idx = InOrderIndex::from_leaf_pos(*pos);
|
||||
self.stack.push((idx, *leaf));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// UTILS
|
||||
// ================================================================================================
|
||||
|
||||
/// Given the description of a `forest`, returns the index of the root element of the smallest tree
|
||||
/// in it.
|
||||
fn forest_to_root_index(forest: usize) -> InOrderIndex {
|
||||
// Count total size of all trees in the forest.
|
||||
let nodes = nodes_in_forest(forest);
|
||||
|
||||
// Add the count for the parent nodes that separate each tree. These are allocated but
|
||||
// currently empty, and correspond to the nodes that will be used once the trees are merged.
|
||||
let open_trees = (forest.count_ones() - 1) as usize;
|
||||
|
||||
// Remove the count of the right subtree of the target tree, target tree root index comes
|
||||
// before the subtree for the in-order tree walk.
|
||||
let right_subtree_count = ((1u32 << forest.trailing_zeros()) - 1) as usize;
|
||||
|
||||
let idx = nodes + open_trees - right_subtree_count;
|
||||
|
||||
InOrderIndex::new(idx.try_into().unwrap())
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{forest_to_root_index, BTreeSet, InOrderIndex, PartialMmr, RpoDigest, Vec};
|
||||
use crate::merkle::{int_to_node, MerkleStore, Mmr, NodeIndex};
|
||||
|
||||
const LEAVES: [RpoDigest; 7] = [
|
||||
int_to_node(0),
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
int_to_node(3),
|
||||
int_to_node(4),
|
||||
int_to_node(5),
|
||||
int_to_node(6),
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn test_forest_to_root_index() {
|
||||
fn idx(pos: usize) -> InOrderIndex {
|
||||
InOrderIndex::new(pos.try_into().unwrap())
|
||||
}
|
||||
|
||||
// When there is a single tree in the forest, the index is equivalent to the number of
|
||||
// leaves in that tree, which is `2^n`.
|
||||
assert_eq!(forest_to_root_index(0b0001), idx(1));
|
||||
assert_eq!(forest_to_root_index(0b0010), idx(2));
|
||||
assert_eq!(forest_to_root_index(0b0100), idx(4));
|
||||
assert_eq!(forest_to_root_index(0b1000), idx(8));
|
||||
|
||||
assert_eq!(forest_to_root_index(0b0011), idx(5));
|
||||
assert_eq!(forest_to_root_index(0b0101), idx(9));
|
||||
assert_eq!(forest_to_root_index(0b1001), idx(17));
|
||||
assert_eq!(forest_to_root_index(0b0111), idx(13));
|
||||
assert_eq!(forest_to_root_index(0b1011), idx(21));
|
||||
assert_eq!(forest_to_root_index(0b1111), idx(29));
|
||||
|
||||
assert_eq!(forest_to_root_index(0b0110), idx(10));
|
||||
assert_eq!(forest_to_root_index(0b1010), idx(18));
|
||||
assert_eq!(forest_to_root_index(0b1100), idx(20));
|
||||
assert_eq!(forest_to_root_index(0b1110), idx(26));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_apply_delta() {
|
||||
// build an MMR with 10 nodes (2 peaks) and a partial MMR based on it
|
||||
let mut mmr = Mmr::default();
|
||||
(0..10).for_each(|i| mmr.add(int_to_node(i)));
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
|
||||
// add authentication path for position 1 and 8
|
||||
{
|
||||
let node = mmr.get(1).unwrap();
|
||||
let proof = mmr.open(1, mmr.forest()).unwrap();
|
||||
partial_mmr.add(1, node, &proof.merkle_path).unwrap();
|
||||
}
|
||||
|
||||
{
|
||||
let node = mmr.get(8).unwrap();
|
||||
let proof = mmr.open(8, mmr.forest()).unwrap();
|
||||
partial_mmr.add(8, node, &proof.merkle_path).unwrap();
|
||||
}
|
||||
|
||||
// add 2 more nodes into the MMR and validate apply_delta()
|
||||
(10..12).for_each(|i| mmr.add(int_to_node(i)));
|
||||
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||
|
||||
// add 1 more node to the MMR, validate apply_delta() and start tracking the node
|
||||
mmr.add(int_to_node(12));
|
||||
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||
{
|
||||
let node = mmr.get(12).unwrap();
|
||||
let proof = mmr.open(12, mmr.forest()).unwrap();
|
||||
partial_mmr.add(12, node, &proof.merkle_path).unwrap();
|
||||
assert!(partial_mmr.track_latest);
|
||||
}
|
||||
|
||||
// by this point we are tracking authentication paths for positions: 1, 8, and 12
|
||||
|
||||
// add 3 more nodes to the MMR (collapses to 1 peak) and validate apply_delta()
|
||||
(13..16).for_each(|i| mmr.add(int_to_node(i)));
|
||||
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||
}
|
||||
|
||||
fn validate_apply_delta(mmr: &Mmr, partial: &mut PartialMmr) {
|
||||
let tracked_leaves = partial
|
||||
.nodes
|
||||
.iter()
|
||||
.filter_map(|(index, _)| if index.is_leaf() { Some(index.sibling()) } else { None })
|
||||
.collect::<Vec<_>>();
|
||||
let nodes_before = partial.nodes.clone();
|
||||
|
||||
// compute and apply delta
|
||||
let delta = mmr.get_delta(partial.forest(), mmr.forest()).unwrap();
|
||||
let nodes_delta = partial.apply(delta).unwrap();
|
||||
|
||||
// new peaks were computed correctly
|
||||
assert_eq!(mmr.peaks(mmr.forest()).unwrap(), partial.peaks());
|
||||
|
||||
let mut expected_nodes = nodes_before;
|
||||
for (key, value) in nodes_delta {
|
||||
// nodes should not be duplicated
|
||||
assert!(expected_nodes.insert(key, value).is_none());
|
||||
}
|
||||
|
||||
// new nodes should be a combination of original nodes and delta
|
||||
assert_eq!(expected_nodes, partial.nodes);
|
||||
|
||||
// make sure tracked leaves open to the same proofs as in the underlying MMR
|
||||
for index in tracked_leaves {
|
||||
let index_value: u64 = index.into();
|
||||
let pos = index_value / 2;
|
||||
let proof1 = partial.open(pos as usize).unwrap().unwrap();
|
||||
let proof2 = mmr.open(pos as usize, mmr.forest()).unwrap();
|
||||
assert_eq!(proof1, proof2);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_inner_nodes_iterator() {
|
||||
// build the MMR
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let first_peak = mmr.peaks(mmr.forest).unwrap().peaks()[0];
|
||||
|
||||
// -- test single tree ----------------------------
|
||||
|
||||
// get path and node for position 1
|
||||
let node1 = mmr.get(1).unwrap();
|
||||
let proof1 = mmr.open(1, mmr.forest()).unwrap();
|
||||
|
||||
// create partial MMR and add authentication path to node at position 1
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
partial_mmr.add(1, node1, &proof1.merkle_path).unwrap();
|
||||
|
||||
// empty iterator should have no nodes
|
||||
assert_eq!(partial_mmr.inner_nodes([].iter()).next(), None);
|
||||
|
||||
// build Merkle store from authentication paths in partial MMR
|
||||
let mut store: MerkleStore = MerkleStore::new();
|
||||
store.extend(partial_mmr.inner_nodes([(1, node1)].iter()));
|
||||
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||
|
||||
assert_eq!(path1, proof1.merkle_path);
|
||||
|
||||
// -- test no duplicates --------------------------
|
||||
|
||||
// build the partial MMR
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
|
||||
let node0 = mmr.get(0).unwrap();
|
||||
let proof0 = mmr.open(0, mmr.forest()).unwrap();
|
||||
|
||||
let node2 = mmr.get(2).unwrap();
|
||||
let proof2 = mmr.open(2, mmr.forest()).unwrap();
|
||||
|
||||
partial_mmr.add(0, node0, &proof0.merkle_path).unwrap();
|
||||
partial_mmr.add(1, node1, &proof1.merkle_path).unwrap();
|
||||
partial_mmr.add(2, node2, &proof2.merkle_path).unwrap();
|
||||
|
||||
// make sure there are no duplicates
|
||||
let leaves = [(0, node0), (1, node1), (2, node2)];
|
||||
let mut nodes = BTreeSet::new();
|
||||
for node in partial_mmr.inner_nodes(leaves.iter()) {
|
||||
assert!(nodes.insert(node.value));
|
||||
}
|
||||
|
||||
// and also that the store is still be built correctly
|
||||
store.extend(partial_mmr.inner_nodes(leaves.iter()));
|
||||
|
||||
let index0 = NodeIndex::new(2, 0).unwrap();
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
let index2 = NodeIndex::new(2, 2).unwrap();
|
||||
|
||||
let path0 = store.get_path(first_peak, index0).unwrap().path;
|
||||
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||
let path2 = store.get_path(first_peak, index2).unwrap().path;
|
||||
|
||||
assert_eq!(path0, proof0.merkle_path);
|
||||
assert_eq!(path1, proof1.merkle_path);
|
||||
assert_eq!(path2, proof2.merkle_path);
|
||||
|
||||
// -- test multiple trees -------------------------
|
||||
|
||||
// build the partial MMR
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
|
||||
let node5 = mmr.get(5).unwrap();
|
||||
let proof5 = mmr.open(5, mmr.forest()).unwrap();
|
||||
|
||||
partial_mmr.add(1, node1, &proof1.merkle_path).unwrap();
|
||||
partial_mmr.add(5, node5, &proof5.merkle_path).unwrap();
|
||||
|
||||
// build Merkle store from authentication paths in partial MMR
|
||||
let mut store: MerkleStore = MerkleStore::new();
|
||||
store.extend(partial_mmr.inner_nodes([(1, node1), (5, node5)].iter()));
|
||||
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
let index5 = NodeIndex::new(1, 1).unwrap();
|
||||
|
||||
let second_peak = mmr.peaks(mmr.forest).unwrap().peaks()[1];
|
||||
|
||||
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||
let path5 = store.get_path(second_peak, index5).unwrap().path;
|
||||
|
||||
assert_eq!(path1, proof1.merkle_path);
|
||||
assert_eq!(path5, proof5.merkle_path);
|
||||
}
|
||||
}
|
||||
134
src/merkle/mmr/peaks.rs
Normal file
134
src/merkle/mmr/peaks.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use super::{
|
||||
super::{RpoDigest, Vec, ZERO},
|
||||
Felt, MmrError, MmrProof, Rpo256, Word,
|
||||
};
|
||||
|
||||
// MMR PEAKS
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MmrPeaks {
|
||||
/// The number of leaves is used to differentiate MMRs that have the same number of peaks. This
|
||||
/// happens because the number of peaks goes up-and-down as the structure is used causing
|
||||
/// existing trees to be merged and new ones to be created. As an example, every time the MMR
|
||||
/// has a power-of-two number of leaves there is a single peak.
|
||||
///
|
||||
/// Every tree in the MMR forest has a distinct power-of-two size, this means only the right-
|
||||
/// most tree can have an odd number of elements (e.g. `1`). Additionally this means that the
|
||||
/// bits in `num_leaves` conveniently encode the size of each individual tree.
|
||||
///
|
||||
/// Examples:
|
||||
///
|
||||
/// - With 5 leaves, the binary `0b101`. The number of set bits is equal the number
|
||||
/// of peaks, in this case there are 2 peaks. The 0-indexed least-significant position of
|
||||
/// the bit determines the number of elements of a tree, so the rightmost tree has `2**0`
|
||||
/// elements and the left most has `2**2`.
|
||||
/// - With 12 leaves, the binary is `0b1100`, this case also has 2 peaks, the
|
||||
/// leftmost tree has `2**3=8` elements, and the right most has `2**2=4` elements.
|
||||
num_leaves: usize,
|
||||
|
||||
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
|
||||
/// leaves, starting from the peak with most children, to the one with least.
|
||||
///
|
||||
/// Invariant: The length of `peaks` must be equal to the number of true bits in `num_leaves`.
|
||||
peaks: Vec<RpoDigest>,
|
||||
}
|
||||
|
||||
impl MmrPeaks {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns new [MmrPeaks] instantiated from the provided vector of peaks and the number of
|
||||
/// leaves in the underlying MMR.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the number of leaves and the number of peaks are inconsistent.
|
||||
pub fn new(num_leaves: usize, peaks: Vec<RpoDigest>) -> Result<Self, MmrError> {
|
||||
if num_leaves.count_ones() as usize != peaks.len() {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
}
|
||||
|
||||
Ok(Self { num_leaves, peaks })
|
||||
}
|
||||
|
||||
// ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a count of leaves in the underlying MMR.
|
||||
pub fn num_leaves(&self) -> usize {
|
||||
self.num_leaves
|
||||
}
|
||||
|
||||
/// Returns the number of peaks of the underlying MMR.
|
||||
pub fn num_peaks(&self) -> usize {
|
||||
self.peaks.len()
|
||||
}
|
||||
|
||||
/// Returns the list of peaks of the underlying MMR.
|
||||
pub fn peaks(&self) -> &[RpoDigest] {
|
||||
&self.peaks
|
||||
}
|
||||
|
||||
/// Converts this [MmrPeaks] into its components: number of leaves and a vector of peaks of
|
||||
/// the underlying MMR.
|
||||
pub fn into_parts(self) -> (usize, Vec<RpoDigest>) {
|
||||
(self.num_leaves, self.peaks)
|
||||
}
|
||||
|
||||
/// Hashes the peaks.
|
||||
///
|
||||
/// The procedure will:
|
||||
/// - Flatten and pad the peaks to a vector of Felts.
|
||||
/// - Hash the vector of Felts.
|
||||
pub fn hash_peaks(&self) -> RpoDigest {
|
||||
Rpo256::hash_elements(&self.flatten_and_pad_peaks())
|
||||
}
|
||||
|
||||
pub fn verify(&self, value: RpoDigest, opening: MmrProof) -> bool {
|
||||
let root = &self.peaks[opening.peak_index()];
|
||||
opening.merkle_path.verify(opening.relative_pos() as u64, value, root)
|
||||
}
|
||||
|
||||
/// Flattens and pads the peaks to make hashing inside of the Miden VM easier.
|
||||
///
|
||||
/// The procedure will:
|
||||
/// - Flatten the vector of Words into a vector of Felts.
|
||||
/// - Pad the peaks with ZERO to an even number of words, this removes the need to handle RPO
|
||||
/// padding.
|
||||
/// - Pad the peaks to a minimum length of 16 words, which reduces the constant cost of
|
||||
/// hashing.
|
||||
pub fn flatten_and_pad_peaks(&self) -> Vec<Felt> {
|
||||
let num_peaks = self.peaks.len();
|
||||
|
||||
// To achieve the padding rules above we calculate the length of the final vector.
|
||||
// This is calculated as the number of field elements. Each peak is 4 field elements.
|
||||
// The length is calculated as follows:
|
||||
// - If there are less than 16 peaks, the data is padded to 16 peaks and as such requires
|
||||
// 64 field elements.
|
||||
// - If there are more than 16 peaks and the number of peaks is odd, the data is padded to
|
||||
// an even number of peaks and as such requires `(num_peaks + 1) * 4` field elements.
|
||||
// - If there are more than 16 peaks and the number of peaks is even, the data is not padded
|
||||
// and as such requires `num_peaks * 4` field elements.
|
||||
let len = if num_peaks < 16 {
|
||||
64
|
||||
} else if num_peaks % 2 == 1 {
|
||||
(num_peaks + 1) * 4
|
||||
} else {
|
||||
num_peaks * 4
|
||||
};
|
||||
|
||||
let mut elements = Vec::with_capacity(len);
|
||||
elements.extend_from_slice(
|
||||
&self
|
||||
.peaks
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|digest| digest.into())
|
||||
.collect::<Vec<Word>>()
|
||||
.concat(),
|
||||
);
|
||||
elements.resize(len, ZERO);
|
||||
elements
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
/// The representation of a single Merkle path.
|
||||
use super::super::MerklePath;
|
||||
use super::full::{high_bitmask, leaf_to_corresponding_tree};
|
||||
use super::{full::high_bitmask, leaf_to_corresponding_tree};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MmrProof {
|
||||
/// The state of the MMR when the MmrProof was created.
|
||||
pub forest: usize,
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
use super::bit::TrueBitPositionIterator;
|
||||
use super::full::{high_bitmask, leaf_to_corresponding_tree, nodes_in_forest};
|
||||
use super::{
|
||||
super::{InnerNodeInfo, Vec, WORD_SIZE, ZERO},
|
||||
Mmr, MmrPeaks, Rpo256, Word,
|
||||
super::{InnerNodeInfo, Rpo256, RpoDigest, Vec},
|
||||
bit::TrueBitPositionIterator,
|
||||
full::high_bitmask,
|
||||
leaf_to_corresponding_tree, nodes_in_forest, Mmr, MmrPeaks, PartialMmr,
|
||||
};
|
||||
use crate::{
|
||||
merkle::{int_to_node, InOrderIndex, MerklePath, MerkleTree, MmrProof, NodeIndex},
|
||||
Felt, Word,
|
||||
};
|
||||
use crate::merkle::{int_to_node, MerklePath};
|
||||
|
||||
#[test]
|
||||
fn test_position_equal_or_higher_than_leafs_is_never_contained() {
|
||||
@@ -99,7 +102,7 @@ fn test_nodes_in_forest_single_bit() {
|
||||
}
|
||||
}
|
||||
|
||||
const LEAVES: [Word; 7] = [
|
||||
const LEAVES: [RpoDigest; 7] = [
|
||||
int_to_node(0),
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
@@ -114,14 +117,14 @@ fn test_mmr_simple() {
|
||||
let mut postorder = Vec::new();
|
||||
postorder.push(LEAVES[0]);
|
||||
postorder.push(LEAVES[1]);
|
||||
postorder.push(*Rpo256::hash_elements(&[LEAVES[0], LEAVES[1]].concat()));
|
||||
postorder.push(merge(LEAVES[0], LEAVES[1]));
|
||||
postorder.push(LEAVES[2]);
|
||||
postorder.push(LEAVES[3]);
|
||||
postorder.push(*Rpo256::hash_elements(&[LEAVES[2], LEAVES[3]].concat()));
|
||||
postorder.push(*Rpo256::hash_elements(&[postorder[2], postorder[5]].concat()));
|
||||
postorder.push(merge(LEAVES[2], LEAVES[3]));
|
||||
postorder.push(merge(postorder[2], postorder[5]));
|
||||
postorder.push(LEAVES[4]);
|
||||
postorder.push(LEAVES[5]);
|
||||
postorder.push(*Rpo256::hash_elements(&[LEAVES[4], LEAVES[5]].concat()));
|
||||
postorder.push(merge(LEAVES[4], LEAVES[5]));
|
||||
postorder.push(LEAVES[6]);
|
||||
|
||||
let mut mmr = Mmr::new();
|
||||
@@ -133,162 +136,329 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 1);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
assert_eq!(acc.num_leaves, 1);
|
||||
assert_eq!(acc.peaks, &[postorder[0]]);
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 1);
|
||||
assert_eq!(acc.peaks(), &[postorder[0]]);
|
||||
|
||||
mmr.add(LEAVES[1]);
|
||||
assert_eq!(mmr.forest(), 2);
|
||||
assert_eq!(mmr.nodes.len(), 3);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
assert_eq!(acc.num_leaves, 2);
|
||||
assert_eq!(acc.peaks, &[postorder[2]]);
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 2);
|
||||
assert_eq!(acc.peaks(), &[postorder[2]]);
|
||||
|
||||
mmr.add(LEAVES[2]);
|
||||
assert_eq!(mmr.forest(), 3);
|
||||
assert_eq!(mmr.nodes.len(), 4);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
assert_eq!(acc.num_leaves, 3);
|
||||
assert_eq!(acc.peaks, &[postorder[2], postorder[3]]);
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 3);
|
||||
assert_eq!(acc.peaks(), &[postorder[2], postorder[3]]);
|
||||
|
||||
mmr.add(LEAVES[3]);
|
||||
assert_eq!(mmr.forest(), 4);
|
||||
assert_eq!(mmr.nodes.len(), 7);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
assert_eq!(acc.num_leaves, 4);
|
||||
assert_eq!(acc.peaks, &[postorder[6]]);
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 4);
|
||||
assert_eq!(acc.peaks(), &[postorder[6]]);
|
||||
|
||||
mmr.add(LEAVES[4]);
|
||||
assert_eq!(mmr.forest(), 5);
|
||||
assert_eq!(mmr.nodes.len(), 8);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
assert_eq!(acc.num_leaves, 5);
|
||||
assert_eq!(acc.peaks, &[postorder[6], postorder[7]]);
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 5);
|
||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[7]]);
|
||||
|
||||
mmr.add(LEAVES[5]);
|
||||
assert_eq!(mmr.forest(), 6);
|
||||
assert_eq!(mmr.nodes.len(), 10);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
assert_eq!(acc.num_leaves, 6);
|
||||
assert_eq!(acc.peaks, &[postorder[6], postorder[9]]);
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 6);
|
||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[9]]);
|
||||
|
||||
mmr.add(LEAVES[6]);
|
||||
assert_eq!(mmr.forest(), 7);
|
||||
assert_eq!(mmr.nodes.len(), 11);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
assert_eq!(acc.num_leaves, 7);
|
||||
assert_eq!(acc.peaks, &[postorder[6], postorder[9], postorder[10]]);
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 7);
|
||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[9], postorder[10]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_open() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let h01: Word = Rpo256::hash_elements(&LEAVES[0..2].concat()).into();
|
||||
let h23: Word = Rpo256::hash_elements(&LEAVES[2..4].concat()).into();
|
||||
let h01 = merge(LEAVES[0], LEAVES[1]);
|
||||
let h23 = merge(LEAVES[2], LEAVES[3]);
|
||||
|
||||
// node at pos 7 is the root
|
||||
assert!(mmr.open(7).is_err(), "Element 7 is not in the tree, result should be None");
|
||||
assert!(
|
||||
mmr.open(7, mmr.forest()).is_err(),
|
||||
"Element 7 is not in the tree, result should be None"
|
||||
);
|
||||
|
||||
// node at pos 6 is the root
|
||||
let empty: MerklePath = MerklePath::new(vec![]);
|
||||
let opening = mmr
|
||||
.open(6)
|
||||
.open(6, mmr.forest())
|
||||
.expect("Element 6 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, empty);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 6);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[6], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[6], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
// nodes 4,5 are detph 1
|
||||
// nodes 4,5 are depth 1
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[4]]);
|
||||
let opening = mmr
|
||||
.open(5)
|
||||
.open(5, mmr.forest())
|
||||
.expect("Element 5 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 5);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[5], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[5], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[5]]);
|
||||
let opening = mmr
|
||||
.open(4)
|
||||
.open(4, mmr.forest())
|
||||
.expect("Element 4 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 4);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[4], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[4], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
// nodes 0,1,2,3 are detph 2
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[2], h01]);
|
||||
let opening = mmr
|
||||
.open(3)
|
||||
.open(3, mmr.forest())
|
||||
.expect("Element 3 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 3);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[3], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[3], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[3], h01]);
|
||||
let opening = mmr
|
||||
.open(2)
|
||||
.open(2, mmr.forest())
|
||||
.expect("Element 2 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 2);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[2], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[2], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[0], h23]);
|
||||
let opening = mmr
|
||||
.open(1)
|
||||
.open(1, mmr.forest())
|
||||
.expect("Element 1 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 1);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[1], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[1], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[1], h23]);
|
||||
let opening = mmr
|
||||
.open(0)
|
||||
.open(0, mmr.forest())
|
||||
.expect("Element 0 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 0);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[0], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[0], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_open_older_version() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
fn is_even(v: &usize) -> bool {
|
||||
v & 1 == 0
|
||||
}
|
||||
|
||||
// merkle path of a node is empty if there are no elements to pair with it
|
||||
for pos in (0..mmr.forest()).filter(is_even) {
|
||||
let forest = pos + 1;
|
||||
let proof = mmr.open(pos, forest).unwrap();
|
||||
assert_eq!(proof.forest, forest);
|
||||
assert_eq!(proof.merkle_path.nodes(), []);
|
||||
assert_eq!(proof.position, pos);
|
||||
}
|
||||
|
||||
// openings match that of a merkle tree
|
||||
let mtree: MerkleTree = LEAVES[..4].try_into().unwrap();
|
||||
for forest in 4..=LEAVES.len() {
|
||||
for pos in 0..4 {
|
||||
let idx = NodeIndex::new(2, pos).unwrap();
|
||||
let path = mtree.get_path(idx).unwrap();
|
||||
let proof = mmr.open(pos as usize, forest).unwrap();
|
||||
assert_eq!(path, proof.merkle_path);
|
||||
}
|
||||
}
|
||||
let mtree: MerkleTree = LEAVES[4..6].try_into().unwrap();
|
||||
for forest in 6..=LEAVES.len() {
|
||||
for pos in 0..2 {
|
||||
let idx = NodeIndex::new(1, pos).unwrap();
|
||||
let path = mtree.get_path(idx).unwrap();
|
||||
// account for the bigger tree with 4 elements
|
||||
let mmr_pos = (pos + 4) as usize;
|
||||
let proof = mmr.open(mmr_pos, forest).unwrap();
|
||||
assert_eq!(path, proof.merkle_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tests the openings of a simple Mmr with a single tree of depth 8.
|
||||
#[test]
|
||||
fn test_mmr_open_eight() {
|
||||
let leaves = [
|
||||
int_to_node(0),
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
int_to_node(3),
|
||||
int_to_node(4),
|
||||
int_to_node(5),
|
||||
int_to_node(6),
|
||||
int_to_node(7),
|
||||
];
|
||||
|
||||
let mtree: MerkleTree = leaves.as_slice().try_into().unwrap();
|
||||
let forest = leaves.len();
|
||||
let mmr: Mmr = leaves.into();
|
||||
let root = mtree.root();
|
||||
|
||||
let position = 0;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 1;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 2;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 3;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 4;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 5;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 6;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 7;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
}
|
||||
|
||||
/// Tests the openings of Mmr with a 3 trees of depths 4, 2, and 1.
|
||||
#[test]
|
||||
fn test_mmr_open_seven() {
|
||||
let mtree1: MerkleTree = LEAVES[..4].try_into().unwrap();
|
||||
let mtree2: MerkleTree = LEAVES[4..6].try_into().unwrap();
|
||||
|
||||
let forest = LEAVES.len();
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
let position = 0;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[0]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 1;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[1]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 2;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(2, LEAVES[2]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 3;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(3, LEAVES[3]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 4;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 0u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[4]).unwrap(), mtree2.root());
|
||||
|
||||
let position = 5;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 1u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[5]).unwrap(), mtree2.root());
|
||||
|
||||
let position = 6;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath = [].as_ref().into();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[6]).unwrap(), LEAVES[6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_get() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
@@ -307,15 +477,16 @@ fn test_mmr_invariants() {
|
||||
let mut mmr = Mmr::new();
|
||||
for v in 1..=1028 {
|
||||
mmr.add(int_to_node(v));
|
||||
let accumulator = mmr.accumulator();
|
||||
let accumulator = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(v as usize, mmr.forest(), "MMR leaf count must increase by one on every add");
|
||||
assert_eq!(
|
||||
v as usize, accumulator.num_leaves,
|
||||
v as usize,
|
||||
accumulator.num_leaves(),
|
||||
"MMR and its accumulator must match leaves count"
|
||||
);
|
||||
assert_eq!(
|
||||
accumulator.num_leaves.count_ones() as usize,
|
||||
accumulator.peaks.len(),
|
||||
accumulator.num_leaves().count_ones() as usize,
|
||||
accumulator.peaks().len(),
|
||||
"bits on leaves must match the number of peaks"
|
||||
);
|
||||
|
||||
@@ -361,10 +532,10 @@ fn test_mmr_inner_nodes() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let nodes: Vec<InnerNodeInfo> = mmr.inner_nodes().collect();
|
||||
|
||||
let h01 = *Rpo256::hash_elements(&[LEAVES[0], LEAVES[1]].concat());
|
||||
let h23 = *Rpo256::hash_elements(&[LEAVES[2], LEAVES[3]].concat());
|
||||
let h0123 = *Rpo256::hash_elements(&[h01, h23].concat());
|
||||
let h45 = *Rpo256::hash_elements(&[LEAVES[4], LEAVES[5]].concat());
|
||||
let h01 = Rpo256::merge(&[LEAVES[0], LEAVES[1]]);
|
||||
let h23 = Rpo256::merge(&[LEAVES[2], LEAVES[3]]);
|
||||
let h0123 = Rpo256::merge(&[h01, h23]);
|
||||
let h45 = Rpo256::merge(&[LEAVES[4], LEAVES[5]]);
|
||||
let postorder = vec![
|
||||
InnerNodeInfo {
|
||||
value: h01,
|
||||
@@ -376,11 +547,7 @@ fn test_mmr_inner_nodes() {
|
||||
left: LEAVES[2],
|
||||
right: LEAVES[3],
|
||||
},
|
||||
InnerNodeInfo {
|
||||
value: h0123,
|
||||
left: h01,
|
||||
right: h23,
|
||||
},
|
||||
InnerNodeInfo { value: h0123, left: h01, right: h23 },
|
||||
InnerNodeInfo {
|
||||
value: h45,
|
||||
left: LEAVES[4],
|
||||
@@ -391,22 +558,62 @@ fn test_mmr_inner_nodes() {
|
||||
assert_eq!(postorder, nodes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_peaks() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
let forest = 0b0001;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[0]]);
|
||||
|
||||
let forest = 0b0010;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[2]]);
|
||||
|
||||
let forest = 0b0011;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[2], mmr.nodes[3]]);
|
||||
|
||||
let forest = 0b0100;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6]]);
|
||||
|
||||
let forest = 0b0101;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[7]]);
|
||||
|
||||
let forest = 0b0110;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9]]);
|
||||
|
||||
let forest = 0b0111;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9], mmr.nodes[10]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_hash_peaks() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let peaks = mmr.accumulator();
|
||||
let peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||
|
||||
let first_peak = *Rpo256::merge(&[
|
||||
Rpo256::hash_elements(&[LEAVES[0], LEAVES[1]].concat()),
|
||||
Rpo256::hash_elements(&[LEAVES[2], LEAVES[3]].concat()),
|
||||
let first_peak = Rpo256::merge(&[
|
||||
Rpo256::merge(&[LEAVES[0], LEAVES[1]]),
|
||||
Rpo256::merge(&[LEAVES[2], LEAVES[3]]),
|
||||
]);
|
||||
let second_peak = *Rpo256::hash_elements(&[LEAVES[4], LEAVES[5]].concat());
|
||||
let second_peak = Rpo256::merge(&[LEAVES[4], LEAVES[5]]);
|
||||
let third_peak = LEAVES[6];
|
||||
|
||||
// minimum length is 16
|
||||
let mut expected_peaks = [first_peak, second_peak, third_peak].to_vec();
|
||||
expected_peaks.resize(16, [ZERO; WORD_SIZE]);
|
||||
assert_eq!(peaks.hash_peaks(), *Rpo256::hash_elements(&expected_peaks.as_slice().concat()));
|
||||
expected_peaks.resize(16, RpoDigest::default());
|
||||
assert_eq!(peaks.hash_peaks(), Rpo256::hash_elements(&digests_to_elements(&expected_peaks)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -415,39 +622,219 @@ fn test_mmr_peaks_hash_less_than_16() {
|
||||
|
||||
for i in 0..16 {
|
||||
peaks.push(int_to_node(i));
|
||||
let accumulator = MmrPeaks {
|
||||
num_leaves: (1 << peaks.len()) - 1,
|
||||
peaks: peaks.clone(),
|
||||
};
|
||||
|
||||
let num_leaves = (1 << peaks.len()) - 1;
|
||||
let accumulator = MmrPeaks::new(num_leaves, peaks.clone()).unwrap();
|
||||
|
||||
// minimum length is 16
|
||||
let mut expected_peaks = peaks.clone();
|
||||
expected_peaks.resize(16, [ZERO; WORD_SIZE]);
|
||||
expected_peaks.resize(16, RpoDigest::default());
|
||||
assert_eq!(
|
||||
accumulator.hash_peaks(),
|
||||
*Rpo256::hash_elements(&expected_peaks.as_slice().concat())
|
||||
Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_peaks_hash_odd() {
|
||||
let peaks: Vec<_> = (0..=17).map(|i| int_to_node(i)).collect();
|
||||
let peaks: Vec<_> = (0..=17).map(int_to_node).collect();
|
||||
|
||||
let accumulator = MmrPeaks {
|
||||
num_leaves: (1 << peaks.len()) - 1,
|
||||
peaks: peaks.clone(),
|
||||
};
|
||||
let num_leaves = (1 << peaks.len()) - 1;
|
||||
let accumulator = MmrPeaks::new(num_leaves, peaks.clone()).unwrap();
|
||||
|
||||
// odd length bigger than 16 is padded to the next even nubmer
|
||||
let mut expected_peaks = peaks.clone();
|
||||
expected_peaks.resize(18, [ZERO; WORD_SIZE]);
|
||||
// odd length bigger than 16 is padded to the next even number
|
||||
let mut expected_peaks = peaks;
|
||||
expected_peaks.resize(18, RpoDigest::default());
|
||||
assert_eq!(
|
||||
accumulator.hash_peaks(),
|
||||
*Rpo256::hash_elements(&expected_peaks.as_slice().concat())
|
||||
Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_delta() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
|
||||
// original_forest can't have more elements
|
||||
assert!(
|
||||
mmr.get_delta(LEAVES.len() + 1, mmr.forest()).is_err(),
|
||||
"Can not provide updates for a newer Mmr"
|
||||
);
|
||||
|
||||
// if the number of elements is the same there is no change
|
||||
assert!(
|
||||
mmr.get_delta(LEAVES.len(), mmr.forest()).unwrap().data.is_empty(),
|
||||
"There are no updates for the same Mmr version"
|
||||
);
|
||||
|
||||
// missing the last element added, which is itself a tree peak
|
||||
assert_eq!(mmr.get_delta(6, mmr.forest()).unwrap().data, vec![acc.peaks()[2]], "one peak");
|
||||
|
||||
// missing the sibling to complete the tree of depth 2, and the last element
|
||||
assert_eq!(
|
||||
mmr.get_delta(5, mmr.forest()).unwrap().data,
|
||||
vec![LEAVES[5], acc.peaks()[2]],
|
||||
"one sibling, one peak"
|
||||
);
|
||||
|
||||
// missing the whole last two trees, only send the peaks
|
||||
assert_eq!(
|
||||
mmr.get_delta(4, mmr.forest()).unwrap().data,
|
||||
vec![acc.peaks()[1], acc.peaks()[2]],
|
||||
"two peaks"
|
||||
);
|
||||
|
||||
// missing the sibling to complete the first tree, and the two last trees
|
||||
assert_eq!(
|
||||
mmr.get_delta(3, mmr.forest()).unwrap().data,
|
||||
vec![LEAVES[3], acc.peaks()[1], acc.peaks()[2]],
|
||||
"one sibling, two peaks"
|
||||
);
|
||||
|
||||
// missing half of the first tree, only send the computed element (not the leaves), and the new
|
||||
// peaks
|
||||
assert_eq!(
|
||||
mmr.get_delta(2, mmr.forest()).unwrap().data,
|
||||
vec![mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
|
||||
"one sibling, two peaks"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
mmr.get_delta(1, mmr.forest()).unwrap().data,
|
||||
vec![LEAVES[1], mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
|
||||
"one sibling, two peaks"
|
||||
);
|
||||
|
||||
assert_eq!(&mmr.get_delta(0, mmr.forest()).unwrap().data, acc.peaks(), "all peaks");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_delta_old_forest() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
// from_forest must be smaller-or-equal to to_forest
|
||||
for version in 1..=mmr.forest() {
|
||||
assert!(mmr.get_delta(version + 1, version).is_err());
|
||||
}
|
||||
|
||||
// when from_forest and to_forest are equal, there are no updates
|
||||
for version in 1..=mmr.forest() {
|
||||
let delta = mmr.get_delta(version, version).unwrap();
|
||||
assert!(delta.data.is_empty());
|
||||
assert_eq!(delta.forest, version);
|
||||
}
|
||||
|
||||
// test update which merges the odd peak to the right
|
||||
for count in 0..(mmr.forest() / 2) {
|
||||
// *2 because every iteration tests a pair
|
||||
// +1 because the Mmr is 1-indexed
|
||||
let from_forest = (count * 2) + 1;
|
||||
let to_forest = (count * 2) + 2;
|
||||
let delta = mmr.get_delta(from_forest, to_forest).unwrap();
|
||||
|
||||
// *2 because every iteration tests a pair
|
||||
// +1 because sibling is the odd element
|
||||
let sibling = (count * 2) + 1;
|
||||
assert_eq!(delta.data, [LEAVES[sibling]]);
|
||||
assert_eq!(delta.forest, to_forest);
|
||||
}
|
||||
|
||||
let version = 4;
|
||||
let delta = mmr.get_delta(1, version).unwrap();
|
||||
assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5]]);
|
||||
assert_eq!(delta.forest, version);
|
||||
|
||||
let version = 5;
|
||||
let delta = mmr.get_delta(1, version).unwrap();
|
||||
assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5], mmr.nodes[7]]);
|
||||
assert_eq!(delta.forest, version);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_simple() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||
let mut partial: PartialMmr = peaks.clone().into();
|
||||
|
||||
// check initial state of the partial mmr
|
||||
assert_eq!(partial.peaks(), peaks);
|
||||
assert_eq!(partial.forest(), peaks.num_leaves());
|
||||
assert_eq!(partial.forest(), LEAVES.len());
|
||||
assert_eq!(partial.peaks().num_peaks(), 3);
|
||||
assert_eq!(partial.nodes.len(), 0);
|
||||
|
||||
// check state after adding tracking one element
|
||||
let proof1 = mmr.open(0, mmr.forest()).unwrap();
|
||||
let el1 = mmr.get(proof1.position).unwrap();
|
||||
partial.add(proof1.position, el1, &proof1.merkle_path).unwrap();
|
||||
|
||||
// check the number of nodes increased by the number of nodes in the proof
|
||||
assert_eq!(partial.nodes.len(), proof1.merkle_path.len());
|
||||
// check the values match
|
||||
let idx = InOrderIndex::from_leaf_pos(proof1.position);
|
||||
assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[0]);
|
||||
let idx = idx.parent();
|
||||
assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[1]);
|
||||
|
||||
let proof2 = mmr.open(1, mmr.forest()).unwrap();
|
||||
let el2 = mmr.get(proof2.position).unwrap();
|
||||
partial.add(proof2.position, el2, &proof2.merkle_path).unwrap();
|
||||
|
||||
// check the number of nodes increased by a single element (the one that is not shared)
|
||||
assert_eq!(partial.nodes.len(), 3);
|
||||
// check the values match
|
||||
let idx = InOrderIndex::from_leaf_pos(proof2.position);
|
||||
assert_eq!(partial.nodes[&idx.sibling()], proof2.merkle_path[0]);
|
||||
let idx = idx.parent();
|
||||
assert_eq!(partial.nodes[&idx.sibling()], proof2.merkle_path[1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_update_single() {
|
||||
let mut full = Mmr::new();
|
||||
let zero = int_to_node(0);
|
||||
full.add(zero);
|
||||
let mut partial: PartialMmr = full.peaks(full.forest()).unwrap().into();
|
||||
|
||||
let proof = full.open(0, full.forest()).unwrap();
|
||||
partial.add(proof.position, zero, &proof.merkle_path).unwrap();
|
||||
|
||||
for i in 1..100 {
|
||||
let node = int_to_node(i);
|
||||
full.add(node);
|
||||
let delta = full.get_delta(partial.forest(), full.forest()).unwrap();
|
||||
partial.apply(delta).unwrap();
|
||||
|
||||
assert_eq!(partial.forest(), full.forest());
|
||||
assert_eq!(partial.peaks(), full.peaks(full.forest()).unwrap());
|
||||
|
||||
let proof1 = full.open(i as usize, full.forest()).unwrap();
|
||||
partial.add(proof1.position, node, &proof1.merkle_path).unwrap();
|
||||
let proof2 = partial.open(proof1.position).unwrap().unwrap();
|
||||
assert_eq!(proof1.merkle_path, proof2.merkle_path);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_add_invalid_odd_leaf() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
let mut partial: PartialMmr = acc.clone().into();
|
||||
|
||||
let empty = MerklePath::new(Vec::new());
|
||||
|
||||
// None of the other leaves should work
|
||||
for node in LEAVES.iter().cloned().rev().skip(1) {
|
||||
let result = partial.add(LEAVES.len() - 1, node, &empty);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
let result = partial.add(LEAVES.len() - 1, LEAVES[6], &empty);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
mod property_tests {
|
||||
use super::leaf_to_corresponding_tree;
|
||||
use proptest::prelude::*;
|
||||
@@ -468,11 +855,23 @@ mod property_tests {
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_contained_tree_is_always_power_of_two((leaves, pos) in any::<usize>().prop_flat_map(|v| (Just(v), 0..v))) {
|
||||
let tree = leaf_to_corresponding_tree(pos, leaves).expect("pos is smaller than leaves, there should always be a corresponding tree");
|
||||
let mask = 1usize << tree;
|
||||
let tree_bit = leaf_to_corresponding_tree(pos, leaves).expect("pos is smaller than leaves, there should always be a corresponding tree");
|
||||
let mask = 1usize << tree_bit;
|
||||
|
||||
assert!(tree < usize::BITS, "the result must be a bit in usize");
|
||||
assert!(tree_bit < usize::BITS, "the result must be a bit in usize");
|
||||
assert!(mask & leaves != 0, "the result should be a tree in leaves");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
fn digests_to_elements(digests: &[RpoDigest]) -> Vec<Felt> {
|
||||
digests.iter().flat_map(Word::from).collect()
|
||||
}
|
||||
|
||||
// short hand for the rpo hash, used to make test code more concise and easy to read
|
||||
fn merge(l: RpoDigest, r: RpoDigest) -> RpoDigest {
|
||||
Rpo256::merge(&[l, r])
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
//! Data structures related to Merkle trees based on RPO256 hash function.
|
||||
|
||||
use super::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
utils::collections::{vec, BTreeMap, Vec},
|
||||
Felt, StarkField, Word, WORD_SIZE, ZERO,
|
||||
utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, TryApplyDiff, Vec},
|
||||
Felt, StarkField, Word, EMPTY_WORD, ZERO,
|
||||
};
|
||||
use core::fmt;
|
||||
|
||||
// REEXPORTS
|
||||
// ================================================================================================
|
||||
@@ -11,6 +12,9 @@ use core::fmt;
|
||||
mod empty_roots;
|
||||
pub use empty_roots::EmptySubtreeRoots;
|
||||
|
||||
mod delta;
|
||||
pub use delta::{merkle_tree_delta, MerkleStoreDelta, MerkleTreeDelta};
|
||||
|
||||
mod index;
|
||||
pub use index::NodeIndex;
|
||||
|
||||
@@ -20,73 +24,41 @@ pub use merkle_tree::{path_to_text, tree_to_text, MerkleTree};
|
||||
mod path;
|
||||
pub use path::{MerklePath, RootPath, ValuePath};
|
||||
|
||||
mod path_set;
|
||||
pub use path_set::MerklePathSet;
|
||||
|
||||
mod simple_smt;
|
||||
pub use simple_smt::SimpleSmt;
|
||||
|
||||
mod tiered_smt;
|
||||
pub use tiered_smt::{TieredSmt, TieredSmtProof, TieredSmtProofError};
|
||||
|
||||
mod mmr;
|
||||
pub use mmr::{Mmr, MmrPeaks, MmrProof};
|
||||
pub use mmr::{InOrderIndex, Mmr, MmrDelta, MmrError, MmrPeaks, MmrProof, PartialMmr};
|
||||
|
||||
mod store;
|
||||
pub use store::MerkleStore;
|
||||
pub use store::{DefaultMerkleStore, MerkleStore, RecordingMerkleStore, StoreNode};
|
||||
|
||||
mod node;
|
||||
pub use node::InnerNodeInfo;
|
||||
|
||||
// ERRORS
|
||||
// ================================================================================================
|
||||
mod partial_mt;
|
||||
pub use partial_mt::PartialMerkleTree;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum MerkleError {
|
||||
ConflictingRoots(Vec<Word>),
|
||||
DepthTooSmall(u8),
|
||||
DepthTooBig(u64),
|
||||
NodeNotInStore(Word, NodeIndex),
|
||||
NumLeavesNotPowerOfTwo(usize),
|
||||
InvalidIndex { depth: u8, value: u64 },
|
||||
InvalidDepth { expected: u8, provided: u8 },
|
||||
InvalidPath(MerklePath),
|
||||
InvalidEntriesCount(usize, usize),
|
||||
NodeNotInSet(u64),
|
||||
RootNotInStore(Word),
|
||||
}
|
||||
|
||||
impl fmt::Display for MerkleError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use MerkleError::*;
|
||||
match self {
|
||||
ConflictingRoots(roots) => write!(f, "the merkle paths roots do not match {roots:?}"),
|
||||
DepthTooSmall(depth) => write!(f, "the provided depth {depth} is too small"),
|
||||
DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"),
|
||||
NumLeavesNotPowerOfTwo(leaves) => {
|
||||
write!(f, "the leaves count {leaves} is not a power of 2")
|
||||
}
|
||||
InvalidIndex{ depth, value} => write!(
|
||||
f,
|
||||
"the index value {value} is not valid for the depth {depth}"
|
||||
),
|
||||
InvalidDepth { expected, provided } => write!(
|
||||
f,
|
||||
"the provided depth {provided} is not valid for {expected}"
|
||||
),
|
||||
InvalidPath(_path) => write!(f, "the provided path is not valid"),
|
||||
InvalidEntriesCount(max, provided) => write!(f, "the provided number of entries is {provided}, but the maximum for the given depth is {max}"),
|
||||
NodeNotInSet(index) => write!(f, "the node indexed by {index} is not in the set"),
|
||||
NodeNotInStore(hash, index) => write!(f, "the node {:?} indexed by {} and depth {} is not in the store", hash, index.value(), index.depth(),),
|
||||
RootNotInStore(root) => write!(f, "the root {:?} is not in the store", root),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for MerkleError {}
|
||||
mod error;
|
||||
pub use error::MerkleError;
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
const fn int_to_node(value: u64) -> Word {
|
||||
const fn int_to_node(value: u64) -> RpoDigest {
|
||||
RpoDigest::new([Felt::new(value), ZERO, ZERO, ZERO])
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
const fn int_to_leaf(value: u64) -> Word {
|
||||
[Felt::new(value), ZERO, ZERO, ZERO]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn digests_to_words(digests: &[RpoDigest]) -> Vec<Word> {
|
||||
digests.iter().map(|d| d.into()).collect()
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use super::Word;
|
||||
use super::RpoDigest;
|
||||
|
||||
/// Representation of a node with two children used for iterating over containers.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct InnerNodeInfo {
|
||||
pub value: Word,
|
||||
pub left: Word,
|
||||
pub right: Word,
|
||||
pub value: RpoDigest,
|
||||
pub left: RpoDigest,
|
||||
pub right: RpoDigest,
|
||||
}
|
||||
|
||||
467
src/merkle/partial_mt/mod.rs
Normal file
467
src/merkle/partial_mt/mod.rs
Normal file
@@ -0,0 +1,467 @@
|
||||
use super::{
|
||||
BTreeMap, BTreeSet, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest,
|
||||
ValuePath, Vec, Word, EMPTY_WORD,
|
||||
};
|
||||
use crate::utils::{
|
||||
format, string::String, vec, word_to_hex, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, Serializable,
|
||||
};
|
||||
use core::fmt;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Index of the root node.
|
||||
const ROOT_INDEX: NodeIndex = NodeIndex::root();
|
||||
|
||||
/// An RpoDigest consisting of 4 ZERO elements.
|
||||
const EMPTY_DIGEST: RpoDigest = RpoDigest::new(EMPTY_WORD);
|
||||
|
||||
// PARTIAL MERKLE TREE
|
||||
// ================================================================================================
|
||||
|
||||
/// A partial Merkle tree with NodeIndex keys and 4-element RpoDigest leaf values. Partial Merkle
|
||||
/// Tree allows to create Merkle Tree by providing Merkle paths of different lengths.
|
||||
///
|
||||
/// The root of the tree is recomputed on each new leaf update.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct PartialMerkleTree {
|
||||
max_depth: u8,
|
||||
nodes: BTreeMap<NodeIndex, RpoDigest>,
|
||||
leaves: BTreeSet<NodeIndex>,
|
||||
}
|
||||
|
||||
impl Default for PartialMerkleTree {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialMerkleTree {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Minimum supported depth.
|
||||
pub const MIN_DEPTH: u8 = 1;
|
||||
|
||||
/// Maximum supported depth.
|
||||
pub const MAX_DEPTH: u8 = 64;
|
||||
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new empty [PartialMerkleTree].
|
||||
pub fn new() -> Self {
|
||||
PartialMerkleTree {
|
||||
max_depth: 0,
|
||||
nodes: BTreeMap::new(),
|
||||
leaves: BTreeSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Appends the provided paths iterator into the set.
|
||||
///
|
||||
/// Analogous to [Self::add_path].
|
||||
pub fn with_paths<I>(paths: I) -> Result<Self, MerkleError>
|
||||
where
|
||||
I: IntoIterator<Item = (u64, RpoDigest, MerklePath)>,
|
||||
{
|
||||
// create an empty tree
|
||||
let tree = PartialMerkleTree::new();
|
||||
|
||||
paths.into_iter().try_fold(tree, |mut tree, (index, value, path)| {
|
||||
tree.add_path(index, value, path)?;
|
||||
Ok(tree)
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a new [PartialMerkleTree] instantiated with leaves map as specified by the provided
|
||||
/// entries.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - If the depth is 0 or is greater than 64.
|
||||
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
|
||||
/// - The provided entries contain an insufficient set of nodes.
|
||||
pub fn with_leaves<R, I>(entries: R) -> Result<Self, MerkleError>
|
||||
where
|
||||
R: IntoIterator<IntoIter = I>,
|
||||
I: Iterator<Item = (NodeIndex, RpoDigest)> + ExactSizeIterator,
|
||||
{
|
||||
let mut layers: BTreeMap<u8, Vec<u64>> = BTreeMap::new();
|
||||
let mut leaves = BTreeSet::new();
|
||||
let mut nodes = BTreeMap::new();
|
||||
|
||||
// add data to the leaves and nodes maps and also fill layers map, where the key is the
|
||||
// depth of the node and value is its index.
|
||||
for (node_index, hash) in entries.into_iter() {
|
||||
leaves.insert(node_index);
|
||||
nodes.insert(node_index, hash);
|
||||
layers
|
||||
.entry(node_index.depth())
|
||||
.and_modify(|layer_vec| layer_vec.push(node_index.value()))
|
||||
.or_insert(vec![node_index.value()]);
|
||||
}
|
||||
|
||||
// check if the number of leaves can be accommodated by the tree's depth; we use a min
|
||||
// depth of 63 because we consider passing in a vector of size 2^64 infeasible.
|
||||
let max = 2usize.pow(63);
|
||||
if layers.len() > max {
|
||||
return Err(MerkleError::InvalidNumEntries(max));
|
||||
}
|
||||
|
||||
// Get maximum depth
|
||||
let max_depth = *layers.keys().next_back().unwrap_or(&0);
|
||||
|
||||
// fill layers without nodes with empty vector
|
||||
for depth in 0..max_depth {
|
||||
layers.entry(depth).or_default();
|
||||
}
|
||||
|
||||
let mut layer_iter = layers.into_values().rev();
|
||||
let mut parent_layer = layer_iter.next().unwrap();
|
||||
let mut current_layer;
|
||||
|
||||
for depth in (1..max_depth + 1).rev() {
|
||||
// set current_layer = parent_layer and parent_layer = layer_iter.next()
|
||||
current_layer = layer_iter.next().unwrap();
|
||||
core::mem::swap(&mut current_layer, &mut parent_layer);
|
||||
|
||||
for index_value in current_layer {
|
||||
// get the parent node index
|
||||
let parent_node = NodeIndex::new(depth - 1, index_value / 2)?;
|
||||
|
||||
// Check if the parent hash was already calculated. In about half of the cases, we
|
||||
// don't need to do anything.
|
||||
if !parent_layer.contains(&parent_node.value()) {
|
||||
// create current node index
|
||||
let index = NodeIndex::new(depth, index_value)?;
|
||||
|
||||
// get hash of the current node
|
||||
let node = nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index))?;
|
||||
// get hash of the sibling node
|
||||
let sibling = nodes
|
||||
.get(&index.sibling())
|
||||
.ok_or(MerkleError::NodeNotInSet(index.sibling()))?;
|
||||
// get parent hash
|
||||
let parent = Rpo256::merge(&index.build_node(*node, *sibling));
|
||||
|
||||
// add index value of the calculated node to the parents layer
|
||||
parent_layer.push(parent_node.value());
|
||||
// add index and hash to the nodes map
|
||||
nodes.insert(parent_node, parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(PartialMerkleTree { max_depth, nodes, leaves })
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the root of this Merkle tree.
|
||||
pub fn root(&self) -> RpoDigest {
|
||||
self.nodes.get(&ROOT_INDEX).cloned().unwrap_or(EMPTY_DIGEST)
|
||||
}
|
||||
|
||||
/// Returns the depth of this Merkle tree.
|
||||
pub fn max_depth(&self) -> u8 {
|
||||
self.max_depth
|
||||
}
|
||||
|
||||
/// Returns a node at the specified NodeIndex.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified NodeIndex is not contained in the nodes map.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
self.nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index)).map(|hash| *hash)
|
||||
}
|
||||
|
||||
/// Returns true if provided index contains in the leaves set, false otherwise.
|
||||
pub fn is_leaf(&self, index: NodeIndex) -> bool {
|
||||
self.leaves.contains(&index)
|
||||
}
|
||||
|
||||
/// Returns a vector of paths from every leaf to the root.
|
||||
pub fn to_paths(&self) -> Vec<(NodeIndex, ValuePath)> {
|
||||
let mut paths = Vec::new();
|
||||
self.leaves.iter().for_each(|&leaf| {
|
||||
paths.push((
|
||||
leaf,
|
||||
ValuePath {
|
||||
value: self.get_node(leaf).expect("Failed to get leaf node"),
|
||||
path: self.get_path(leaf).expect("Failed to get path"),
|
||||
},
|
||||
));
|
||||
});
|
||||
paths
|
||||
}
|
||||
|
||||
/// Returns a Merkle path from the node at the specified index to the root.
|
||||
///
|
||||
/// The node itself is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - the specified index has depth set to 0 or the depth is greater than the depth of this
|
||||
/// Merkle tree.
|
||||
/// - the specified index is not contained in the nodes map.
|
||||
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
} else if index.depth() > self.max_depth() {
|
||||
return Err(MerkleError::DepthTooBig(index.depth() as u64));
|
||||
}
|
||||
|
||||
if !self.nodes.contains_key(&index) {
|
||||
return Err(MerkleError::NodeNotInSet(index));
|
||||
}
|
||||
|
||||
let mut path = Vec::new();
|
||||
for _ in 0..index.depth() {
|
||||
let sibling_index = index.sibling();
|
||||
index.move_up();
|
||||
let sibling =
|
||||
self.nodes.get(&sibling_index).cloned().expect("Sibling node not in the map");
|
||||
path.push(sibling);
|
||||
}
|
||||
Ok(MerklePath::new(path))
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over the leaves of this [PartialMerkleTree].
|
||||
pub fn leaves(&self) -> impl Iterator<Item = (NodeIndex, RpoDigest)> + '_ {
|
||||
self.leaves.iter().map(|&leaf| {
|
||||
(
|
||||
leaf,
|
||||
self.get_node(leaf)
|
||||
.unwrap_or_else(|_| panic!("Leaf with {leaf} is not in the nodes map")),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns an iterator over the inner nodes of this Merkle tree.
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
let inner_nodes = self.nodes.iter().filter(|(index, _)| !self.leaves.contains(index));
|
||||
inner_nodes.map(|(index, digest)| {
|
||||
let left_hash =
|
||||
self.nodes.get(&index.left_child()).expect("Failed to get left child hash");
|
||||
let right_hash =
|
||||
self.nodes.get(&index.right_child()).expect("Failed to get right child hash");
|
||||
InnerNodeInfo {
|
||||
value: *digest,
|
||||
left: *left_hash,
|
||||
right: *right_hash,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Adds the nodes of the specified Merkle path to this [PartialMerkleTree]. The `index_value`
|
||||
/// and `value` parameters specify the leaf node at which the path starts.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The depth of the specified node_index is greater than 64 or smaller than 1.
|
||||
/// - The specified path is not consistent with other paths in the set (i.e., resolves to a
|
||||
/// different root).
|
||||
pub fn add_path(
|
||||
&mut self,
|
||||
index_value: u64,
|
||||
value: RpoDigest,
|
||||
path: MerklePath,
|
||||
) -> Result<(), MerkleError> {
|
||||
let index_value = NodeIndex::new(path.len() as u8, index_value)?;
|
||||
|
||||
Self::check_depth(index_value.depth())?;
|
||||
self.update_depth(index_value.depth());
|
||||
|
||||
// add provided node and its sibling to the leaves set
|
||||
self.leaves.insert(index_value);
|
||||
let sibling_node_index = index_value.sibling();
|
||||
self.leaves.insert(sibling_node_index);
|
||||
|
||||
// add provided node and its sibling to the nodes map
|
||||
self.nodes.insert(index_value, value);
|
||||
self.nodes.insert(sibling_node_index, path[0]);
|
||||
|
||||
// traverse to the root, updating the nodes
|
||||
let mut index_value = index_value;
|
||||
let node = Rpo256::merge(&index_value.build_node(value, path[0]));
|
||||
let root = path.iter().skip(1).copied().fold(node, |node, hash| {
|
||||
index_value.move_up();
|
||||
// insert calculated node to the nodes map
|
||||
self.nodes.insert(index_value, node);
|
||||
|
||||
// if the calculated node was a leaf, remove it from leaves set.
|
||||
self.leaves.remove(&index_value);
|
||||
|
||||
let sibling_node = index_value.sibling();
|
||||
|
||||
// Insert node from Merkle path to the nodes map. This sibling node becomes a leaf only
|
||||
// if it is a new node (it wasn't in nodes map).
|
||||
// Node can be in 3 states: internal node, leaf of the tree and not a tree node at all.
|
||||
// - Internal node can only stay in this state -- addition of a new path can't make it
|
||||
// a leaf or remove it from the tree.
|
||||
// - Leaf node can stay in the same state (remain a leaf) or can become an internal
|
||||
// node. In the first case we don't need to do anything, and the second case is handled
|
||||
// by the call of `self.leaves.remove(&index_value);`
|
||||
// - New node can be a calculated node or a "sibling" node from a Merkle Path:
|
||||
// --- Calculated node, obviously, never can be a leaf.
|
||||
// --- Sibling node can be only a leaf, because otherwise it is not a new node.
|
||||
if self.nodes.insert(sibling_node, hash).is_none() {
|
||||
self.leaves.insert(sibling_node);
|
||||
}
|
||||
|
||||
Rpo256::merge(&index_value.build_node(node, hash))
|
||||
});
|
||||
|
||||
// if the path set is empty (the root is all ZEROs), set the root to the root of the added
|
||||
// path; otherwise, the root of the added path must be identical to the current root
|
||||
if self.root() == EMPTY_DIGEST {
|
||||
self.nodes.insert(ROOT_INDEX, root);
|
||||
} else if self.root() != root {
|
||||
return Err(MerkleError::ConflictingRoots([self.root(), root].to_vec()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Updates value of the leaf at the specified index returning the old leaf value.
|
||||
/// By default the specified index is assumed to belong to the deepest layer. If the considered
|
||||
/// node does not belong to the tree, the first node on the way to the root will be changed.
|
||||
///
|
||||
/// By default the specified index is assumed to belong to the deepest layer. If the considered
|
||||
/// node does not belong to the tree, the first node on the way to the root will be changed.
|
||||
///
|
||||
/// This also recomputes all hashes between the leaf and the root, updating the root itself.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index is greater than the maximum number of nodes on the deepest layer.
|
||||
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<RpoDigest, MerkleError> {
|
||||
let mut node_index = NodeIndex::new(self.max_depth(), index)?;
|
||||
|
||||
// proceed to the leaf
|
||||
for _ in 0..node_index.depth() {
|
||||
if !self.leaves.contains(&node_index) {
|
||||
node_index.move_up();
|
||||
}
|
||||
}
|
||||
|
||||
// add node value to the nodes Map
|
||||
let old_value = self
|
||||
.nodes
|
||||
.insert(node_index, value.into())
|
||||
.ok_or(MerkleError::NodeNotInSet(node_index))?;
|
||||
|
||||
// if the old value and new value are the same, there is nothing to update
|
||||
if value == *old_value {
|
||||
return Ok(old_value);
|
||||
}
|
||||
|
||||
let mut value = value.into();
|
||||
for _ in 0..node_index.depth() {
|
||||
let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist");
|
||||
value = Rpo256::merge(&node_index.build_node(value, *sibling));
|
||||
node_index.move_up();
|
||||
self.nodes.insert(node_index, value);
|
||||
}
|
||||
|
||||
Ok(old_value)
|
||||
}
|
||||
|
||||
// UTILITY FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Utility to visualize a [PartialMerkleTree] in text.
|
||||
pub fn print(&self) -> Result<String, fmt::Error> {
|
||||
let indent = " ";
|
||||
let mut s = String::new();
|
||||
s.push_str("root: ");
|
||||
s.push_str(&word_to_hex(&self.root())?);
|
||||
s.push('\n');
|
||||
for d in 1..=self.max_depth() {
|
||||
let entries = 2u64.pow(d.into());
|
||||
for i in 0..entries {
|
||||
let index = NodeIndex::new(d, i).expect("The index must always be valid");
|
||||
let node = self.get_node(index);
|
||||
let node = match node {
|
||||
Err(_) => continue,
|
||||
Ok(node) => node,
|
||||
};
|
||||
|
||||
for _ in 0..d {
|
||||
s.push_str(indent);
|
||||
}
|
||||
s.push_str(&format!("({}, {}): ", index.depth(), index.value()));
|
||||
s.push_str(&word_to_hex(&node)?);
|
||||
s.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Updates depth value with the maximum of current and provided depth.
|
||||
fn update_depth(&mut self, new_depth: u8) {
|
||||
self.max_depth = new_depth.max(self.max_depth);
|
||||
}
|
||||
|
||||
/// Returns an error if the depth is 0 or is greater than 64.
|
||||
fn check_depth(depth: u8) -> Result<(), MerkleError> {
|
||||
// validate the range of the depth.
|
||||
if depth < Self::MIN_DEPTH {
|
||||
return Err(MerkleError::DepthTooSmall(depth));
|
||||
} else if Self::MAX_DEPTH < depth {
|
||||
return Err(MerkleError::DepthTooBig(depth as u64));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for PartialMerkleTree {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
// write leaf nodes
|
||||
target.write_u64(self.leaves.len() as u64);
|
||||
for leaf_index in self.leaves.iter() {
|
||||
leaf_index.write_into(target);
|
||||
self.get_node(*leaf_index).expect("Leaf hash not found").write_into(target);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for PartialMerkleTree {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let leaves_len = source.read_u64()? as usize;
|
||||
let mut leaf_nodes = Vec::with_capacity(leaves_len);
|
||||
|
||||
// add leaf nodes to the vector
|
||||
for _ in 0..leaves_len {
|
||||
let index = NodeIndex::read_from(source)?;
|
||||
let hash = RpoDigest::read_from(source)?;
|
||||
leaf_nodes.push((index, hash));
|
||||
}
|
||||
|
||||
let pmt = PartialMerkleTree::with_leaves(leaf_nodes).map_err(|_| {
|
||||
DeserializationError::InvalidValue("Invalid data for PartialMerkleTree creation".into())
|
||||
})?;
|
||||
|
||||
Ok(pmt)
|
||||
}
|
||||
}
|
||||
463
src/merkle/partial_mt/tests.rs
Normal file
463
src/merkle/partial_mt/tests.rs
Normal file
@@ -0,0 +1,463 @@
|
||||
use super::{
|
||||
super::{
|
||||
digests_to_words, int_to_node, BTreeMap, DefaultMerkleStore as MerkleStore, MerkleTree,
|
||||
NodeIndex, PartialMerkleTree,
|
||||
},
|
||||
Deserializable, InnerNodeInfo, RpoDigest, Serializable, ValuePath, Vec,
|
||||
};
|
||||
|
||||
// TEST DATA
|
||||
// ================================================================================================
|
||||
|
||||
const NODE10: NodeIndex = NodeIndex::new_unchecked(1, 0);
|
||||
const NODE11: NodeIndex = NodeIndex::new_unchecked(1, 1);
|
||||
|
||||
const NODE20: NodeIndex = NodeIndex::new_unchecked(2, 0);
|
||||
const NODE21: NodeIndex = NodeIndex::new_unchecked(2, 1);
|
||||
const NODE22: NodeIndex = NodeIndex::new_unchecked(2, 2);
|
||||
const NODE23: NodeIndex = NodeIndex::new_unchecked(2, 3);
|
||||
|
||||
const NODE30: NodeIndex = NodeIndex::new_unchecked(3, 0);
|
||||
const NODE31: NodeIndex = NodeIndex::new_unchecked(3, 1);
|
||||
const NODE32: NodeIndex = NodeIndex::new_unchecked(3, 2);
|
||||
const NODE33: NodeIndex = NodeIndex::new_unchecked(3, 3);
|
||||
|
||||
const VALUES8: [RpoDigest; 8] = [
|
||||
int_to_node(30),
|
||||
int_to_node(31),
|
||||
int_to_node(32),
|
||||
int_to_node(33),
|
||||
int_to_node(34),
|
||||
int_to_node(35),
|
||||
int_to_node(36),
|
||||
int_to_node(37),
|
||||
];
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
// For the Partial Merkle Tree tests we will use parts of the Merkle Tree which full form is
|
||||
// illustrated below:
|
||||
//
|
||||
// __________ root __________
|
||||
// / \
|
||||
// ____ 10 ____ ____ 11 ____
|
||||
// / \ / \
|
||||
// 20 21 22 23
|
||||
// / \ / \ / \ / \
|
||||
// (30) (31) (32) (33) (34) (35) (36) (37)
|
||||
//
|
||||
// Where node number is a concatenation of its depth and index. For example, node with
|
||||
// NodeIndex(3, 5) will be labeled as `35`. Leaves of the tree are shown as nodes with parenthesis
|
||||
// (33).
|
||||
|
||||
/// Checks that creation of the PMT with `with_leaves()` constructor is working correctly.
|
||||
#[test]
|
||||
fn with_leaves() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let leaf_nodes_vec = vec![
|
||||
(NODE20, mt.get_node(NODE20).unwrap()),
|
||||
(NODE32, mt.get_node(NODE32).unwrap()),
|
||||
(NODE33, mt.get_node(NODE33).unwrap()),
|
||||
(NODE22, mt.get_node(NODE22).unwrap()),
|
||||
(NODE23, mt.get_node(NODE23).unwrap()),
|
||||
];
|
||||
|
||||
let leaf_nodes: BTreeMap<NodeIndex, RpoDigest> = leaf_nodes_vec.into_iter().collect();
|
||||
|
||||
let pmt = PartialMerkleTree::with_leaves(leaf_nodes).unwrap();
|
||||
|
||||
assert_eq!(expected_root, pmt.root())
|
||||
}
|
||||
|
||||
/// Checks that `with_leaves()` function returns an error when using incomplete set of nodes.
|
||||
#[test]
|
||||
fn err_with_leaves() {
|
||||
// NODE22 is missing
|
||||
let leaf_nodes_vec = vec![
|
||||
(NODE20, int_to_node(20)),
|
||||
(NODE32, int_to_node(32)),
|
||||
(NODE33, int_to_node(33)),
|
||||
(NODE23, int_to_node(23)),
|
||||
];
|
||||
|
||||
let leaf_nodes: BTreeMap<NodeIndex, RpoDigest> = leaf_nodes_vec.into_iter().collect();
|
||||
|
||||
assert!(PartialMerkleTree::with_leaves(leaf_nodes).is_err());
|
||||
}
|
||||
|
||||
/// Checks that root returned by `root()` function is equal to the expected one.
|
||||
#[test]
|
||||
fn get_root() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
|
||||
let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
|
||||
assert_eq!(expected_root, pmt.root());
|
||||
}
|
||||
|
||||
/// This test checks correctness of the `add_path()` and `get_path()` functions. First it creates a
|
||||
/// PMT using `add_path()` by adding Merkle Paths from node 33 and node 22 to the empty PMT. Then
|
||||
/// it checks that paths returned by `get_path()` function are equal to the expected ones.
|
||||
#[test]
|
||||
fn add_and_get_paths() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let expected_path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
let expected_path22 = ms.get_path(expected_root, NODE22).unwrap();
|
||||
|
||||
let mut pmt = PartialMerkleTree::new();
|
||||
pmt.add_path(3, expected_path33.value, expected_path33.path.clone()).unwrap();
|
||||
pmt.add_path(2, expected_path22.value, expected_path22.path.clone()).unwrap();
|
||||
|
||||
let path33 = pmt.get_path(NODE33).unwrap();
|
||||
let path22 = pmt.get_path(NODE22).unwrap();
|
||||
let actual_root = pmt.root();
|
||||
|
||||
assert_eq!(expected_path33.path, path33);
|
||||
assert_eq!(expected_path22.path, path22);
|
||||
assert_eq!(expected_root, actual_root);
|
||||
}
|
||||
|
||||
/// Checks that function `get_node` used on nodes 10 and 32 returns expected values.
|
||||
#[test]
|
||||
fn get_node() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
|
||||
let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
|
||||
assert_eq!(ms.get_node(expected_root, NODE32).unwrap(), pmt.get_node(NODE32).unwrap());
|
||||
assert_eq!(ms.get_node(expected_root, NODE10).unwrap(), pmt.get_node(NODE10).unwrap());
|
||||
}
|
||||
|
||||
/// Updates leaves of the PMT using `update_leaf()` function and checks that new root of the tree
|
||||
/// is equal to the expected one.
|
||||
#[test]
|
||||
fn update_leaf() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let root = mt.root();
|
||||
|
||||
let mut ms = MerkleStore::from(&mt);
|
||||
let path33 = ms.get_path(root, NODE33).unwrap();
|
||||
|
||||
let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
|
||||
let new_value32 = int_to_node(132);
|
||||
let expected_root = ms.set_node(root, NODE32, new_value32).unwrap().root;
|
||||
|
||||
pmt.update_leaf(2, *new_value32).unwrap();
|
||||
let actual_root = pmt.root();
|
||||
|
||||
assert_eq!(expected_root, actual_root);
|
||||
|
||||
let new_value20 = int_to_node(120);
|
||||
let expected_root = ms.set_node(expected_root, NODE20, new_value20).unwrap().root;
|
||||
|
||||
pmt.update_leaf(0, *new_value20).unwrap();
|
||||
let actual_root = pmt.root();
|
||||
|
||||
assert_eq!(expected_root, actual_root);
|
||||
|
||||
let new_value11 = int_to_node(111);
|
||||
let expected_root = ms.set_node(expected_root, NODE11, new_value11).unwrap().root;
|
||||
|
||||
pmt.update_leaf(6, *new_value11).unwrap();
|
||||
let actual_root = pmt.root();
|
||||
|
||||
assert_eq!(expected_root, actual_root);
|
||||
}
|
||||
|
||||
/// Checks that paths of the PMT returned by `paths()` function are equal to the expected ones.
|
||||
#[test]
|
||||
fn get_paths() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
let path22 = ms.get_path(expected_root, NODE22).unwrap();
|
||||
|
||||
let mut pmt = PartialMerkleTree::new();
|
||||
pmt.add_path(3, path33.value, path33.path).unwrap();
|
||||
pmt.add_path(2, path22.value, path22.path).unwrap();
|
||||
// After PMT creation with path33 (33; 32, 20, 11) and path22 (22; 23, 10) we will have this
|
||||
// tree:
|
||||
//
|
||||
// ______root______
|
||||
// / \
|
||||
// ___10___ ___11___
|
||||
// / \ / \
|
||||
// (20) 21 (22) (23)
|
||||
// / \
|
||||
// (32) (33)
|
||||
//
|
||||
// Which have leaf nodes 20, 22, 23, 32 and 33. Hence overall we will have 5 paths -- one path
|
||||
// for each leaf.
|
||||
|
||||
let leaves = vec![NODE20, NODE22, NODE23, NODE32, NODE33];
|
||||
let expected_paths: Vec<(NodeIndex, ValuePath)> = leaves
|
||||
.iter()
|
||||
.map(|&leaf| {
|
||||
(
|
||||
leaf,
|
||||
ValuePath {
|
||||
value: mt.get_node(leaf).unwrap(),
|
||||
path: mt.get_path(leaf).unwrap(),
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let actual_paths = pmt.to_paths();
|
||||
|
||||
assert_eq!(expected_paths, actual_paths);
|
||||
}
|
||||
|
||||
// Checks correctness of leaves determination when using the `leaves()` function.
|
||||
#[test]
|
||||
fn leaves() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
let path22 = ms.get_path(expected_root, NODE22).unwrap();
|
||||
|
||||
let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
// After PMT creation with path33 (33; 32, 20, 11) we will have this tree:
|
||||
//
|
||||
// ______root______
|
||||
// / \
|
||||
// ___10___ (11)
|
||||
// / \
|
||||
// (20) 21
|
||||
// / \
|
||||
// (32) (33)
|
||||
//
|
||||
// Which have leaf nodes 11, 20, 32 and 33.
|
||||
|
||||
let value11 = mt.get_node(NODE11).unwrap();
|
||||
let value20 = mt.get_node(NODE20).unwrap();
|
||||
let value32 = mt.get_node(NODE32).unwrap();
|
||||
let value33 = mt.get_node(NODE33).unwrap();
|
||||
|
||||
let leaves = vec![(NODE11, value11), (NODE20, value20), (NODE32, value32), (NODE33, value33)];
|
||||
|
||||
let expected_leaves = leaves.iter().copied();
|
||||
assert!(expected_leaves.eq(pmt.leaves()));
|
||||
|
||||
pmt.add_path(2, path22.value, path22.path).unwrap();
|
||||
// After adding the path22 (22; 23, 10) to the existing PMT we will have this tree:
|
||||
//
|
||||
// ______root______
|
||||
// / \
|
||||
// ___10___ ___11___
|
||||
// / \ / \
|
||||
// (20) 21 (22) (23)
|
||||
// / \
|
||||
// (32) (33)
|
||||
//
|
||||
// Which have leaf nodes 20, 22, 23, 32 and 33.
|
||||
|
||||
let value20 = mt.get_node(NODE20).unwrap();
|
||||
let value22 = mt.get_node(NODE22).unwrap();
|
||||
let value23 = mt.get_node(NODE23).unwrap();
|
||||
let value32 = mt.get_node(NODE32).unwrap();
|
||||
let value33 = mt.get_node(NODE33).unwrap();
|
||||
|
||||
let leaves = vec![
|
||||
(NODE20, value20),
|
||||
(NODE22, value22),
|
||||
(NODE23, value23),
|
||||
(NODE32, value32),
|
||||
(NODE33, value33),
|
||||
];
|
||||
|
||||
let expected_leaves = leaves.iter().copied();
|
||||
assert!(expected_leaves.eq(pmt.leaves()));
|
||||
}
|
||||
|
||||
/// Checks that nodes of the PMT returned by `inner_nodes()` function are equal to the expected ones.
|
||||
#[test]
|
||||
fn test_inner_node_iterator() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
let path22 = ms.get_path(expected_root, NODE22).unwrap();
|
||||
|
||||
let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
|
||||
// get actual inner nodes
|
||||
let actual: Vec<InnerNodeInfo> = pmt.inner_nodes().collect();
|
||||
|
||||
let expected_n00 = mt.root();
|
||||
let expected_n10 = mt.get_node(NODE10).unwrap();
|
||||
let expected_n11 = mt.get_node(NODE11).unwrap();
|
||||
let expected_n20 = mt.get_node(NODE20).unwrap();
|
||||
let expected_n21 = mt.get_node(NODE21).unwrap();
|
||||
let expected_n32 = mt.get_node(NODE32).unwrap();
|
||||
let expected_n33 = mt.get_node(NODE33).unwrap();
|
||||
|
||||
// create vector of the expected inner nodes
|
||||
let mut expected = vec![
|
||||
InnerNodeInfo {
|
||||
value: expected_n00,
|
||||
left: expected_n10,
|
||||
right: expected_n11,
|
||||
},
|
||||
InnerNodeInfo {
|
||||
value: expected_n10,
|
||||
left: expected_n20,
|
||||
right: expected_n21,
|
||||
},
|
||||
InnerNodeInfo {
|
||||
value: expected_n21,
|
||||
left: expected_n32,
|
||||
right: expected_n33,
|
||||
},
|
||||
];
|
||||
|
||||
assert_eq!(actual, expected);
|
||||
|
||||
// add another path to the Partial Merkle Tree
|
||||
pmt.add_path(2, path22.value, path22.path).unwrap();
|
||||
|
||||
// get new actual inner nodes
|
||||
let actual: Vec<InnerNodeInfo> = pmt.inner_nodes().collect();
|
||||
|
||||
let expected_n22 = mt.get_node(NODE22).unwrap();
|
||||
let expected_n23 = mt.get_node(NODE23).unwrap();
|
||||
|
||||
let info_11 = InnerNodeInfo {
|
||||
value: expected_n11,
|
||||
left: expected_n22,
|
||||
right: expected_n23,
|
||||
};
|
||||
|
||||
// add new inner node to the existing vertor
|
||||
expected.insert(2, info_11);
|
||||
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
/// Checks that serialization and deserialization implementations for the PMT are working
|
||||
/// correctly.
|
||||
#[test]
|
||||
fn serialization() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
let path22 = ms.get_path(expected_root, NODE22).unwrap();
|
||||
|
||||
let pmt = PartialMerkleTree::with_paths([
|
||||
(3, path33.value, path33.path),
|
||||
(2, path22.value, path22.path),
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let serialized_pmt = pmt.to_bytes();
|
||||
let deserialized_pmt = PartialMerkleTree::read_from_bytes(&serialized_pmt).unwrap();
|
||||
|
||||
assert_eq!(deserialized_pmt, pmt);
|
||||
}
|
||||
|
||||
/// Checks that deserialization fails with incorrect data.
|
||||
#[test]
|
||||
fn err_deserialization() {
|
||||
let mut tree_bytes: Vec<u8> = vec![5];
|
||||
tree_bytes.append(&mut NODE20.to_bytes());
|
||||
tree_bytes.append(&mut int_to_node(20).to_bytes());
|
||||
|
||||
tree_bytes.append(&mut NODE21.to_bytes());
|
||||
tree_bytes.append(&mut int_to_node(21).to_bytes());
|
||||
|
||||
// node with depth 1 could have index 0 or 1, but it has 2
|
||||
tree_bytes.append(&mut vec![1, 2]);
|
||||
tree_bytes.append(&mut int_to_node(11).to_bytes());
|
||||
|
||||
assert!(PartialMerkleTree::read_from_bytes(&tree_bytes).is_err());
|
||||
}
|
||||
|
||||
/// Checks that addition of the path with different root will cause an error.
|
||||
#[test]
|
||||
fn err_add_path() {
|
||||
let path33 = vec![int_to_node(1), int_to_node(2), int_to_node(3)].into();
|
||||
let path22 = vec![int_to_node(4), int_to_node(5)].into();
|
||||
|
||||
let mut pmt = PartialMerkleTree::new();
|
||||
pmt.add_path(3, int_to_node(6), path33).unwrap();
|
||||
|
||||
assert!(pmt.add_path(2, int_to_node(7), path22).is_err());
|
||||
}
|
||||
|
||||
/// Checks that the request of the node which is not in the PMT will cause an error.
|
||||
#[test]
|
||||
fn err_get_node() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
|
||||
let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
|
||||
assert!(pmt.get_node(NODE22).is_err());
|
||||
assert!(pmt.get_node(NODE23).is_err());
|
||||
assert!(pmt.get_node(NODE30).is_err());
|
||||
assert!(pmt.get_node(NODE31).is_err());
|
||||
}
|
||||
|
||||
/// Checks that the request of the path from the leaf which is not in the PMT will cause an error.
|
||||
#[test]
|
||||
fn err_get_path() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
|
||||
let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
|
||||
assert!(pmt.get_path(NODE22).is_err());
|
||||
assert!(pmt.get_path(NODE23).is_err());
|
||||
assert!(pmt.get_path(NODE30).is_err());
|
||||
assert!(pmt.get_path(NODE31).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn err_update_leaf() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
|
||||
let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
|
||||
assert!(pmt.update_leaf(8, *int_to_node(38)).is_err());
|
||||
}
|
||||
@@ -1,13 +1,15 @@
|
||||
use super::{vec, InnerNodeInfo, MerkleError, NodeIndex, Rpo256, Vec, Word};
|
||||
use super::{vec, InnerNodeInfo, MerkleError, NodeIndex, Rpo256, RpoDigest, Vec};
|
||||
use core::ops::{Deref, DerefMut};
|
||||
use winter_utils::{ByteReader, Deserializable, DeserializationError, Serializable};
|
||||
|
||||
// MERKLE PATH
|
||||
// ================================================================================================
|
||||
|
||||
/// A merkle path container, composed of a sequence of nodes of a Merkle tree.
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MerklePath {
|
||||
nodes: Vec<Word>,
|
||||
nodes: Vec<RpoDigest>,
|
||||
}
|
||||
|
||||
impl MerklePath {
|
||||
@@ -15,7 +17,8 @@ impl MerklePath {
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Creates a new Merkle path from a list of nodes.
|
||||
pub fn new(nodes: Vec<Word>) -> Self {
|
||||
pub fn new(nodes: Vec<RpoDigest>) -> Self {
|
||||
assert!(nodes.len() <= u8::MAX.into(), "MerklePath may have at most 256 items");
|
||||
Self { nodes }
|
||||
}
|
||||
|
||||
@@ -27,14 +30,19 @@ impl MerklePath {
|
||||
self.nodes.len() as u8
|
||||
}
|
||||
|
||||
/// Returns a reference to the [MerklePath]'s nodes.
|
||||
pub fn nodes(&self) -> &[RpoDigest] {
|
||||
&self.nodes
|
||||
}
|
||||
|
||||
/// Computes the merkle root for this opening.
|
||||
pub fn compute_root(&self, index: u64, node: Word) -> Result<Word, MerkleError> {
|
||||
pub fn compute_root(&self, index: u64, node: RpoDigest) -> Result<RpoDigest, MerkleError> {
|
||||
let mut index = NodeIndex::new(self.depth(), index)?;
|
||||
let root = self.nodes.iter().copied().fold(node, |node, sibling| {
|
||||
// compute the node and move to the next iteration.
|
||||
let input = index.build_node(node.into(), sibling.into());
|
||||
let input = index.build_node(node, sibling);
|
||||
index.move_up();
|
||||
Rpo256::merge(&input).into()
|
||||
Rpo256::merge(&input)
|
||||
});
|
||||
Ok(root)
|
||||
}
|
||||
@@ -42,7 +50,7 @@ impl MerklePath {
|
||||
/// Verifies the Merkle opening proof towards the provided root.
|
||||
///
|
||||
/// Returns `true` if `node` exists at `index` in a Merkle tree with `root`.
|
||||
pub fn verify(&self, index: u64, node: Word, root: &Word) -> bool {
|
||||
pub fn verify(&self, index: u64, node: RpoDigest, root: &RpoDigest) -> bool {
|
||||
match self.compute_root(index, node) {
|
||||
Ok(computed_root) => root == &computed_root,
|
||||
Err(_) => false,
|
||||
@@ -55,7 +63,11 @@ impl MerklePath {
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified index is not valid for this path.
|
||||
pub fn inner_nodes(&self, index: u64, node: Word) -> Result<InnerNodeIterator, MerkleError> {
|
||||
pub fn inner_nodes(
|
||||
&self,
|
||||
index: u64,
|
||||
node: RpoDigest,
|
||||
) -> Result<InnerNodeIterator, MerkleError> {
|
||||
Ok(InnerNodeIterator {
|
||||
nodes: &self.nodes,
|
||||
index: NodeIndex::new(self.depth(), index)?,
|
||||
@@ -64,16 +76,31 @@ impl MerklePath {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<Word>> for MerklePath {
|
||||
fn from(path: Vec<Word>) -> Self {
|
||||
// CONVERSIONS
|
||||
// ================================================================================================
|
||||
|
||||
impl From<MerklePath> for Vec<RpoDigest> {
|
||||
fn from(path: MerklePath) -> Self {
|
||||
path.nodes
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<RpoDigest>> for MerklePath {
|
||||
fn from(path: Vec<RpoDigest>) -> Self {
|
||||
Self::new(path)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[RpoDigest]> for MerklePath {
|
||||
fn from(path: &[RpoDigest]) -> Self {
|
||||
Self::new(path.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for MerklePath {
|
||||
// we use `Vec` here instead of slice so we can call vector mutation methods directly from the
|
||||
// merkle path (example: `Vec::remove`).
|
||||
type Target = Vec<Word>;
|
||||
type Target = Vec<RpoDigest>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.nodes
|
||||
@@ -89,15 +116,15 @@ impl DerefMut for MerklePath {
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
|
||||
impl FromIterator<Word> for MerklePath {
|
||||
fn from_iter<T: IntoIterator<Item = Word>>(iter: T) -> Self {
|
||||
impl FromIterator<RpoDigest> for MerklePath {
|
||||
fn from_iter<T: IntoIterator<Item = RpoDigest>>(iter: T) -> Self {
|
||||
Self::new(iter.into_iter().collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoIterator for MerklePath {
|
||||
type Item = Word;
|
||||
type IntoIter = vec::IntoIter<Word>;
|
||||
type Item = RpoDigest;
|
||||
type IntoIter = vec::IntoIter<RpoDigest>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.nodes.into_iter()
|
||||
@@ -106,9 +133,9 @@ impl IntoIterator for MerklePath {
|
||||
|
||||
/// An iterator over internal nodes of a [MerklePath].
|
||||
pub struct InnerNodeIterator<'a> {
|
||||
nodes: &'a Vec<Word>,
|
||||
nodes: &'a Vec<RpoDigest>,
|
||||
index: NodeIndex,
|
||||
value: Word,
|
||||
value: RpoDigest,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for InnerNodeIterator<'a> {
|
||||
@@ -123,14 +150,10 @@ impl<'a> Iterator for InnerNodeIterator<'a> {
|
||||
(self.value, self.nodes[sibling_pos])
|
||||
};
|
||||
|
||||
self.value = Rpo256::merge(&[left.into(), right.into()]).into();
|
||||
self.value = Rpo256::merge(&[left, right]);
|
||||
self.index.move_up();
|
||||
|
||||
Some(InnerNodeInfo {
|
||||
value: self.value,
|
||||
left,
|
||||
right,
|
||||
})
|
||||
Some(InnerNodeInfo { value: self.value, left, right })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -144,11 +167,18 @@ impl<'a> Iterator for InnerNodeIterator<'a> {
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
||||
pub struct ValuePath {
|
||||
/// The node value opening for `path`.
|
||||
pub value: Word,
|
||||
pub value: RpoDigest,
|
||||
/// The path from `value` to `root` (exclusive).
|
||||
pub path: MerklePath,
|
||||
}
|
||||
|
||||
impl ValuePath {
|
||||
/// Returns a new [ValuePath] instantiated from the specified value and path.
|
||||
pub fn new(value: RpoDigest, path: Vec<RpoDigest>) -> Self {
|
||||
Self { value, path: MerklePath::new(path) }
|
||||
}
|
||||
}
|
||||
|
||||
/// A container for a [MerklePath] and its [Word] root.
|
||||
///
|
||||
/// This structure does not provide any guarantees regarding the correctness of the path to the
|
||||
@@ -156,11 +186,60 @@ pub struct ValuePath {
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
||||
pub struct RootPath {
|
||||
/// The node value opening for `path`.
|
||||
pub root: Word,
|
||||
pub root: RpoDigest,
|
||||
/// The path from `value` to `root` (exclusive).
|
||||
pub path: MerklePath,
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for MerklePath {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
assert!(self.nodes.len() <= u8::MAX.into(), "Length enforced in the constructor");
|
||||
target.write_u8(self.nodes.len() as u8);
|
||||
self.nodes.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for MerklePath {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let count = source.read_u8()?.into();
|
||||
let nodes = RpoDigest::read_batch_from(source, count)?;
|
||||
Ok(Self { nodes })
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for ValuePath {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
self.value.write_into(target);
|
||||
self.path.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for ValuePath {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let value = RpoDigest::read_from(source)?;
|
||||
let path = MerklePath::read_from(source)?;
|
||||
Ok(Self { value, path })
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for RootPath {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
self.root.write_into(target);
|
||||
self.path.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RootPath {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let root = RpoDigest::read_from(source)?;
|
||||
let path = MerklePath::read_from(source)?;
|
||||
Ok(Self { root, path })
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
|
||||
@@ -1,410 +0,0 @@
|
||||
use super::{BTreeMap, MerkleError, MerklePath, NodeIndex, Rpo256, ValuePath, Vec, Word, ZERO};
|
||||
|
||||
// MERKLE PATH SET
|
||||
// ================================================================================================
|
||||
|
||||
/// A set of Merkle paths.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct MerklePathSet {
|
||||
root: Word,
|
||||
total_depth: u8,
|
||||
paths: BTreeMap<u64, MerklePath>,
|
||||
}
|
||||
|
||||
impl MerklePathSet {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an empty MerklePathSet.
|
||||
pub fn new(depth: u8) -> Self {
|
||||
let root = [ZERO; 4];
|
||||
let paths = BTreeMap::new();
|
||||
|
||||
Self {
|
||||
root,
|
||||
total_depth: depth,
|
||||
paths,
|
||||
}
|
||||
}
|
||||
|
||||
/// Appends the provided paths iterator into the set.
|
||||
///
|
||||
/// Analogous to `[Self::add_path]`.
|
||||
pub fn with_paths<I>(self, paths: I) -> Result<Self, MerkleError>
|
||||
where
|
||||
I: IntoIterator<Item = (u64, Word, MerklePath)>,
|
||||
{
|
||||
paths.into_iter().try_fold(self, |mut set, (index, value, path)| {
|
||||
set.add_path(index, value, path)?;
|
||||
Ok(set)
|
||||
})
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the root to which all paths in this set resolve.
|
||||
pub const fn root(&self) -> Word {
|
||||
self.root
|
||||
}
|
||||
|
||||
/// Returns the depth of the Merkle tree implied by the paths stored in this set.
|
||||
///
|
||||
/// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc.
|
||||
pub const fn depth(&self) -> u8 {
|
||||
self.total_depth
|
||||
}
|
||||
|
||||
/// Returns a node at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// * The specified index is not valid for the depth of structure.
|
||||
/// * Requested node does not exist in the set.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
|
||||
if index.depth() != self.total_depth {
|
||||
return Err(MerkleError::InvalidDepth {
|
||||
expected: self.total_depth,
|
||||
provided: index.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
let parity = index.value() & 1;
|
||||
let path_key = index.value() - parity;
|
||||
self.paths
|
||||
.get(&path_key)
|
||||
.ok_or(MerkleError::NodeNotInSet(path_key))
|
||||
.map(|path| path[parity as usize])
|
||||
}
|
||||
|
||||
/// Returns a leaf at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// * The specified index is not valid for the depth of the structure.
|
||||
/// * Leaf with the requested path does not exist in the set.
|
||||
pub fn get_leaf(&self, index: u64) -> Result<Word, MerkleError> {
|
||||
let index = NodeIndex::new(self.depth(), index)?;
|
||||
self.get_node(index)
|
||||
}
|
||||
|
||||
/// Returns a Merkle path to the node at the specified index. The node itself is
|
||||
/// not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// * The specified index is not valid for the depth of structure.
|
||||
/// * Node of the requested path does not exist in the set.
|
||||
pub fn get_path(&self, index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
if index.depth() != self.total_depth {
|
||||
return Err(MerkleError::InvalidDepth {
|
||||
expected: self.total_depth,
|
||||
provided: index.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
let parity = index.value() & 1;
|
||||
let path_key = index.value() - parity;
|
||||
let mut path = self
|
||||
.paths
|
||||
.get(&path_key)
|
||||
.cloned()
|
||||
.ok_or(MerkleError::NodeNotInSet(index.value()))?;
|
||||
path.remove(parity as usize);
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
/// Returns all paths in this path set together with their indexes.
|
||||
pub fn to_paths(&self) -> Vec<(u64, ValuePath)> {
|
||||
let mut result = Vec::with_capacity(self.paths.len() * 2);
|
||||
|
||||
for (&index, path) in self.paths.iter() {
|
||||
// push path for the even index into the result
|
||||
let path1 = ValuePath {
|
||||
value: path[0],
|
||||
path: MerklePath::new(path[1..].to_vec()),
|
||||
};
|
||||
result.push((index, path1));
|
||||
|
||||
// push path for the odd index into the result
|
||||
let mut path2 = path.clone();
|
||||
let leaf2 = path2.remove(1);
|
||||
let path2 = ValuePath {
|
||||
value: leaf2,
|
||||
path: path2,
|
||||
};
|
||||
result.push((index + 1, path2));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Adds the specified Merkle path to this [MerklePathSet]. The `index` and `value` parameters
|
||||
/// specify the leaf node at which the path starts.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index is is not valid in the context of this Merkle path set (i.e., the
|
||||
/// index implies a greater depth than is specified for this set).
|
||||
/// - The specified path is not consistent with other paths in the set (i.e., resolves to a
|
||||
/// different root).
|
||||
pub fn add_path(
|
||||
&mut self,
|
||||
index_value: u64,
|
||||
value: Word,
|
||||
mut path: MerklePath,
|
||||
) -> Result<(), MerkleError> {
|
||||
let mut index = NodeIndex::new(path.len() as u8, index_value)?;
|
||||
if index.depth() != self.total_depth {
|
||||
return Err(MerkleError::InvalidDepth {
|
||||
expected: self.total_depth,
|
||||
provided: index.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
// update the current path
|
||||
let parity = index_value & 1;
|
||||
path.insert(parity as usize, value);
|
||||
|
||||
// traverse to the root, updating the nodes
|
||||
let root: Word = Rpo256::merge(&[path[0].into(), path[1].into()]).into();
|
||||
let root = path.iter().skip(2).copied().fold(root, |root, hash| {
|
||||
index.move_up();
|
||||
Rpo256::merge(&index.build_node(root.into(), hash.into())).into()
|
||||
});
|
||||
|
||||
// if the path set is empty (the root is all ZEROs), set the root to the root of the added
|
||||
// path; otherwise, the root of the added path must be identical to the current root
|
||||
if self.root == [ZERO; 4] {
|
||||
self.root = root;
|
||||
} else if self.root != root {
|
||||
return Err(MerkleError::ConflictingRoots([self.root, root].to_vec()));
|
||||
}
|
||||
|
||||
// finish updating the path
|
||||
let path_key = index_value - parity;
|
||||
self.paths.insert(path_key, path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Replaces the leaf at the specified index with the provided value.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// * Requested node does not exist in the set.
|
||||
pub fn update_leaf(&mut self, base_index_value: u64, value: Word) -> Result<(), MerkleError> {
|
||||
let mut index = NodeIndex::new(self.depth(), base_index_value)?;
|
||||
let parity = index.value() & 1;
|
||||
let path_key = index.value() - parity;
|
||||
let path = match self.paths.get_mut(&path_key) {
|
||||
Some(path) => path,
|
||||
None => return Err(MerkleError::NodeNotInSet(base_index_value)),
|
||||
};
|
||||
|
||||
// Fill old_hashes vector -----------------------------------------------------------------
|
||||
let mut current_index = index;
|
||||
let mut old_hashes = Vec::with_capacity(path.len().saturating_sub(2));
|
||||
let mut root: Word = Rpo256::merge(&[path[0].into(), path[1].into()]).into();
|
||||
for hash in path.iter().skip(2).copied() {
|
||||
old_hashes.push(root);
|
||||
current_index.move_up();
|
||||
let input = current_index.build_node(hash.into(), root.into());
|
||||
root = Rpo256::merge(&input).into();
|
||||
}
|
||||
|
||||
// Fill new_hashes vector -----------------------------------------------------------------
|
||||
path[index.is_value_odd() as usize] = value;
|
||||
|
||||
let mut new_hashes = Vec::with_capacity(path.len().saturating_sub(2));
|
||||
let mut new_root: Word = Rpo256::merge(&[path[0].into(), path[1].into()]).into();
|
||||
for path_hash in path.iter().skip(2).copied() {
|
||||
new_hashes.push(new_root);
|
||||
index.move_up();
|
||||
let input = current_index.build_node(path_hash.into(), new_root.into());
|
||||
new_root = Rpo256::merge(&input).into();
|
||||
}
|
||||
|
||||
self.root = new_root;
|
||||
|
||||
// update paths ---------------------------------------------------------------------------
|
||||
for path in self.paths.values_mut() {
|
||||
for i in (0..old_hashes.len()).rev() {
|
||||
if path[i + 2] == old_hashes[i] {
|
||||
path[i + 2] = new_hashes[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::merkle::int_to_node;
|
||||
|
||||
#[test]
|
||||
fn get_root() {
|
||||
let leaf0 = int_to_node(0);
|
||||
let leaf1 = int_to_node(1);
|
||||
let leaf2 = int_to_node(2);
|
||||
let leaf3 = int_to_node(3);
|
||||
|
||||
let parent0 = calculate_parent_hash(leaf0, 0, leaf1);
|
||||
let parent1 = calculate_parent_hash(leaf2, 2, leaf3);
|
||||
|
||||
let root_exp = calculate_parent_hash(parent0, 0, parent1);
|
||||
|
||||
let set = super::MerklePathSet::new(2)
|
||||
.with_paths([(0, leaf0, vec![leaf1, parent1].into())])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(set.root(), root_exp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_and_get_path() {
|
||||
let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)];
|
||||
let hash_6 = int_to_node(6);
|
||||
let index = 6_u64;
|
||||
let depth = 3_u8;
|
||||
let set = super::MerklePathSet::new(depth)
|
||||
.with_paths([(index, hash_6, path_6.clone().into())])
|
||||
.unwrap();
|
||||
let stored_path_6 = set.get_path(NodeIndex::make(depth, index)).unwrap();
|
||||
|
||||
assert_eq!(path_6, *stored_path_6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_node() {
|
||||
let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)];
|
||||
let hash_6 = int_to_node(6);
|
||||
let index = 6_u64;
|
||||
let depth = 3_u8;
|
||||
let set = MerklePathSet::new(depth).with_paths([(index, hash_6, path_6.into())]).unwrap();
|
||||
|
||||
assert_eq!(int_to_node(6u64), set.get_node(NodeIndex::make(depth, index)).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_leaf() {
|
||||
let hash_4 = int_to_node(4);
|
||||
let hash_5 = int_to_node(5);
|
||||
let hash_6 = int_to_node(6);
|
||||
let hash_7 = int_to_node(7);
|
||||
let hash_45 = calculate_parent_hash(hash_4, 12u64, hash_5);
|
||||
let hash_67 = calculate_parent_hash(hash_6, 14u64, hash_7);
|
||||
|
||||
let hash_0123 = int_to_node(123);
|
||||
|
||||
let path_6 = vec![hash_7, hash_45, hash_0123];
|
||||
let path_5 = vec![hash_4, hash_67, hash_0123];
|
||||
let path_4 = vec![hash_5, hash_67, hash_0123];
|
||||
|
||||
let index_6 = 6_u64;
|
||||
let index_5 = 5_u64;
|
||||
let index_4 = 4_u64;
|
||||
let depth = 3_u8;
|
||||
let mut set = MerklePathSet::new(depth)
|
||||
.with_paths([
|
||||
(index_6, hash_6, path_6.into()),
|
||||
(index_5, hash_5, path_5.into()),
|
||||
(index_4, hash_4, path_4.into()),
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let new_hash_6 = int_to_node(100);
|
||||
let new_hash_5 = int_to_node(55);
|
||||
|
||||
set.update_leaf(index_6, new_hash_6).unwrap();
|
||||
let new_path_4 = set.get_path(NodeIndex::make(depth, index_4)).unwrap();
|
||||
let new_hash_67 = calculate_parent_hash(new_hash_6, 14_u64, hash_7);
|
||||
assert_eq!(new_hash_67, new_path_4[1]);
|
||||
|
||||
set.update_leaf(index_5, new_hash_5).unwrap();
|
||||
let new_path_4 = set.get_path(NodeIndex::make(depth, index_4)).unwrap();
|
||||
let new_path_6 = set.get_path(NodeIndex::make(depth, index_6)).unwrap();
|
||||
let new_hash_45 = calculate_parent_hash(new_hash_5, 13_u64, hash_4);
|
||||
assert_eq!(new_hash_45, new_path_6[1]);
|
||||
assert_eq!(new_hash_5, new_path_4[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn depth_3_is_correct() {
|
||||
let a = int_to_node(1);
|
||||
let b = int_to_node(2);
|
||||
let c = int_to_node(3);
|
||||
let d = int_to_node(4);
|
||||
let e = int_to_node(5);
|
||||
let f = int_to_node(6);
|
||||
let g = int_to_node(7);
|
||||
let h = int_to_node(8);
|
||||
|
||||
let i = Rpo256::merge(&[a.into(), b.into()]);
|
||||
let j = Rpo256::merge(&[c.into(), d.into()]);
|
||||
let k = Rpo256::merge(&[e.into(), f.into()]);
|
||||
let l = Rpo256::merge(&[g.into(), h.into()]);
|
||||
|
||||
let m = Rpo256::merge(&[i.into(), j.into()]);
|
||||
let n = Rpo256::merge(&[k.into(), l.into()]);
|
||||
|
||||
let root = Rpo256::merge(&[m.into(), n.into()]);
|
||||
|
||||
let mut set = MerklePathSet::new(3);
|
||||
|
||||
let value = b;
|
||||
let index = 1;
|
||||
let path = MerklePath::new([a.into(), j.into(), n.into()].to_vec());
|
||||
set.add_path(index, value, path.clone()).unwrap();
|
||||
assert_eq!(value, set.get_leaf(index).unwrap());
|
||||
assert_eq!(Word::from(root), set.root());
|
||||
|
||||
let value = e;
|
||||
let index = 4;
|
||||
let path = MerklePath::new([f.into(), l.into(), m.into()].to_vec());
|
||||
set.add_path(index, value, path.clone()).unwrap();
|
||||
assert_eq!(value, set.get_leaf(index).unwrap());
|
||||
assert_eq!(Word::from(root), set.root());
|
||||
|
||||
let value = a;
|
||||
let index = 0;
|
||||
let path = MerklePath::new([b.into(), j.into(), n.into()].to_vec());
|
||||
set.add_path(index, value, path.clone()).unwrap();
|
||||
assert_eq!(value, set.get_leaf(index).unwrap());
|
||||
assert_eq!(Word::from(root), set.root());
|
||||
|
||||
let value = h;
|
||||
let index = 7;
|
||||
let path = MerklePath::new([g.into(), k.into(), m.into()].to_vec());
|
||||
set.add_path(index, value, path.clone()).unwrap();
|
||||
assert_eq!(value, set.get_leaf(index).unwrap());
|
||||
assert_eq!(Word::from(root), set.root());
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
const fn is_even(pos: u64) -> bool {
|
||||
pos & 1 == 0
|
||||
}
|
||||
|
||||
/// Calculates the hash of the parent node by two sibling ones
|
||||
/// - node — current node
|
||||
/// - node_pos — position of the current node
|
||||
/// - sibling — neighboring vertex in the tree
|
||||
fn calculate_parent_hash(node: Word, node_pos: u64, sibling: Word) -> Word {
|
||||
if is_even(node_pos) {
|
||||
Rpo256::merge(&[node.into(), sibling.into()]).into()
|
||||
} else {
|
||||
Rpo256::merge(&[sibling.into(), node.into()]).into()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::{
|
||||
BTreeMap, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256,
|
||||
RpoDigest, Vec, Word,
|
||||
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, MerkleTreeDelta,
|
||||
NodeIndex, Rpo256, RpoDigest, StoreNode, TryApplyDiff, Vec, Word,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -10,26 +10,15 @@ mod tests;
|
||||
// ================================================================================================
|
||||
|
||||
/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
|
||||
///
|
||||
/// The root of the tree is recomputed on each new leaf update.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct SimpleSmt {
|
||||
depth: u8,
|
||||
root: Word,
|
||||
root: RpoDigest,
|
||||
leaves: BTreeMap<u64, Word>,
|
||||
branches: BTreeMap<NodeIndex, BranchNode>,
|
||||
empty_hashes: Vec<RpoDigest>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
struct BranchNode {
|
||||
left: RpoDigest,
|
||||
right: RpoDigest,
|
||||
}
|
||||
|
||||
impl BranchNode {
|
||||
fn parent(&self) -> RpoDigest {
|
||||
Rpo256::merge(&[self.left, self.right])
|
||||
}
|
||||
}
|
||||
|
||||
impl SimpleSmt {
|
||||
@@ -42,10 +31,18 @@ impl SimpleSmt {
|
||||
/// Maximum supported depth.
|
||||
pub const MAX_DEPTH: u8 = 64;
|
||||
|
||||
/// Value of an empty leaf.
|
||||
pub const EMPTY_VALUE: Word = super::EMPTY_WORD;
|
||||
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Creates a new simple SMT with the provided depth.
|
||||
/// Returns a new [SimpleSmt] instantiated with the specified depth.
|
||||
///
|
||||
/// All leaves in the returned tree are set to [ZERO; 4].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the depth is 0 or is greater than 64.
|
||||
pub fn new(depth: u8) -> Result<Self, MerkleError> {
|
||||
// validate the range of the depth.
|
||||
if depth < Self::MIN_DEPTH {
|
||||
@@ -54,55 +51,79 @@ impl SimpleSmt {
|
||||
return Err(MerkleError::DepthTooBig(depth as u64));
|
||||
}
|
||||
|
||||
let empty_hashes = EmptySubtreeRoots::empty_hashes(depth).to_vec();
|
||||
let root = empty_hashes[0].into();
|
||||
let root = *EmptySubtreeRoots::entry(depth, 0);
|
||||
|
||||
Ok(Self {
|
||||
root,
|
||||
depth,
|
||||
empty_hashes,
|
||||
leaves: BTreeMap::new(),
|
||||
branches: BTreeMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Appends the provided entries as leaves of the tree.
|
||||
/// Returns a new [SimpleSmt] instantiated with the specified depth and with leaves
|
||||
/// set as specified by the provided entries.
|
||||
///
|
||||
/// All leaves omitted from the entries list are set to [ZERO; 4].
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// The function will fail if the provided entries count exceed the maximum tree capacity, that
|
||||
/// is `2^{depth}`.
|
||||
pub fn with_leaves<R, I>(mut self, entries: R) -> Result<Self, MerkleError>
|
||||
where
|
||||
R: IntoIterator<IntoIter = I>,
|
||||
I: Iterator<Item = (u64, Word)> + ExactSizeIterator,
|
||||
{
|
||||
// check if the leaves count will fit the depth setup
|
||||
let mut entries = entries.into_iter();
|
||||
let max = 1 << self.depth.min(63);
|
||||
if entries.len() > max {
|
||||
return Err(MerkleError::InvalidEntriesCount(max, entries.len()));
|
||||
}
|
||||
/// Returns an error if:
|
||||
/// - If the depth is 0 or is greater than 64.
|
||||
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
|
||||
/// - The provided entries contain multiple values for the same key.
|
||||
pub fn with_leaves(
|
||||
depth: u8,
|
||||
entries: impl IntoIterator<Item = (u64, Word)>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
// create an empty tree
|
||||
let mut tree = Self::new(depth)?;
|
||||
|
||||
// append leaves and return
|
||||
entries.try_for_each(|(key, leaf)| self.insert_leaf(key, leaf))?;
|
||||
Ok(self)
|
||||
// compute the max number of entries. We use an upper bound of depth 63 because we consider
|
||||
// passing in a vector of size 2^64 infeasible.
|
||||
let max_num_entries = 2_usize.pow(tree.depth.min(63).into());
|
||||
|
||||
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
|
||||
// entries with the empty value need additional tracking.
|
||||
let mut key_set_to_zero = BTreeSet::new();
|
||||
|
||||
for (idx, (key, value)) in entries.into_iter().enumerate() {
|
||||
if idx >= max_num_entries {
|
||||
return Err(MerkleError::InvalidNumEntries(max_num_entries));
|
||||
}
|
||||
|
||||
let old_value = tree.update_leaf(key, value)?;
|
||||
|
||||
if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) {
|
||||
return Err(MerkleError::DuplicateValuesForIndex(key));
|
||||
}
|
||||
|
||||
if value == Self::EMPTY_VALUE {
|
||||
key_set_to_zero.insert(key);
|
||||
};
|
||||
}
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Replaces the internal empty digests used when a given depth doesn't contain a node.
|
||||
pub fn with_empty_subtrees<I>(mut self, hashes: I) -> Self
|
||||
where
|
||||
I: IntoIterator<Item = RpoDigest>,
|
||||
{
|
||||
self.replace_empty_subtrees(hashes.into_iter().collect());
|
||||
self
|
||||
/// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
|
||||
/// starting at index 0.
|
||||
pub fn with_contiguous_leaves(
|
||||
depth: u8,
|
||||
entries: impl IntoIterator<Item = Word>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
Self::with_leaves(
|
||||
depth,
|
||||
entries
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, word)| (idx.try_into().expect("tree max depth is 2^8"), word)),
|
||||
)
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the root of this Merkle tree.
|
||||
pub const fn root(&self) -> Word {
|
||||
pub const fn root(&self) -> RpoDigest {
|
||||
self.root
|
||||
}
|
||||
|
||||
@@ -111,40 +132,46 @@ impl SimpleSmt {
|
||||
self.depth
|
||||
}
|
||||
|
||||
// PROVIDERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the set count of the keys of the leaves.
|
||||
pub fn leaves_count(&self) -> usize {
|
||||
self.leaves.len()
|
||||
}
|
||||
|
||||
/// Returns a node at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// * The specified depth is greater than the depth of the tree.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
|
||||
/// Returns an error if the specified index has depth set to 0 or the depth is greater than
|
||||
/// the depth of this Merkle tree.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
if index.is_root() {
|
||||
Err(MerkleError::DepthTooSmall(index.depth()))
|
||||
} else if index.depth() > self.depth() {
|
||||
Err(MerkleError::DepthTooBig(index.depth() as u64))
|
||||
} else if index.depth() == self.depth() {
|
||||
self.get_leaf_node(index.value())
|
||||
.or_else(|| self.empty_hashes.get(index.depth() as usize).copied().map(Word::from))
|
||||
.ok_or(MerkleError::NodeNotInSet(index.value()))
|
||||
// the lookup in empty_hashes could fail only if empty_hashes were not built correctly
|
||||
// by the constructor as we check the depth of the lookup above.
|
||||
let leaf_pos = index.value();
|
||||
let leaf = match self.get_leaf_node(leaf_pos) {
|
||||
Some(word) => word.into(),
|
||||
None => *EmptySubtreeRoots::entry(self.depth, index.depth()),
|
||||
};
|
||||
Ok(leaf)
|
||||
} else {
|
||||
let branch_node = self.get_branch_node(&index);
|
||||
Ok(Rpo256::merge(&[branch_node.left, branch_node.right]).into())
|
||||
Ok(self.get_branch_node(&index).parent())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a Merkle path from the node at the specified key to the root. The node itself is
|
||||
/// not included in the path.
|
||||
/// Returns a value of the leaf at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// * The specified depth is greater than the depth of the tree.
|
||||
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
|
||||
pub fn get_leaf(&self, index: u64) -> Result<Word, MerkleError> {
|
||||
let index = NodeIndex::new(self.depth, index)?;
|
||||
Ok(self.get_node(index)?.into())
|
||||
}
|
||||
|
||||
/// Returns a Merkle path from the node at the specified index to the root.
|
||||
///
|
||||
/// The node itself is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified index has depth set to 0 or the depth is greater than
|
||||
/// the depth of this Merkle tree.
|
||||
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
@@ -158,56 +185,135 @@ impl SimpleSmt {
|
||||
index.move_up();
|
||||
let BranchNode { left, right } = self.get_branch_node(&index);
|
||||
let value = if is_right { left } else { right };
|
||||
path.push(*value);
|
||||
path.push(value);
|
||||
}
|
||||
Ok(path.into())
|
||||
Ok(MerklePath::new(path))
|
||||
}
|
||||
|
||||
/// Return a Merkle path from the leaf at the specified key to the root. The leaf itself is not
|
||||
/// included in the path.
|
||||
/// Return a Merkle path from the leaf at the specified index to the root.
|
||||
///
|
||||
/// The leaf itself is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// * The specified key does not exist as a leaf node.
|
||||
pub fn get_leaf_path(&self, key: u64) -> Result<MerklePath, MerkleError> {
|
||||
let index = NodeIndex::new(self.depth(), key)?;
|
||||
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
|
||||
pub fn get_leaf_path(&self, index: u64) -> Result<MerklePath, MerkleError> {
|
||||
let index = NodeIndex::new(self.depth(), index)?;
|
||||
self.get_path(index)
|
||||
}
|
||||
|
||||
/// Iterator over the inner nodes of the [SimpleSmt].
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over the leaves of this [SimpleSmt].
|
||||
pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
|
||||
self.leaves.iter().map(|(i, w)| (*i, w))
|
||||
}
|
||||
|
||||
/// Returns an iterator over the inner nodes of this Merkle tree.
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.branches.values().map(|e| InnerNodeInfo {
|
||||
value: e.parent().into(),
|
||||
left: e.left.into(),
|
||||
right: e.right.into(),
|
||||
value: e.parent(),
|
||||
left: e.left,
|
||||
right: e.right,
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Replaces the leaf located at the specified key, and recomputes hashes by walking up the
|
||||
/// tree.
|
||||
/// Updates value of the leaf at the specified index returning the old leaf value.
|
||||
///
|
||||
/// This also recomputes all hashes between the leaf and the root, updating the root itself.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified key is not a valid leaf index for this tree.
|
||||
pub fn update_leaf(&mut self, key: u64, value: Word) -> Result<(), MerkleError> {
|
||||
let index = NodeIndex::new(self.depth(), key)?;
|
||||
if !self.check_leaf_node_exists(key) {
|
||||
return Err(MerkleError::NodeNotInSet(index.value()));
|
||||
}
|
||||
self.insert_leaf(key, value)?;
|
||||
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
|
||||
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<Word, MerkleError> {
|
||||
// validate the index before modifying the structure
|
||||
let idx = NodeIndex::new(self.depth(), index)?;
|
||||
|
||||
Ok(())
|
||||
let old_value = self.insert_leaf_node(index, value).unwrap_or(Self::EMPTY_VALUE);
|
||||
|
||||
// if the old value and new value are the same, there is nothing to update
|
||||
if value == old_value {
|
||||
return Ok(value);
|
||||
}
|
||||
|
||||
self.recompute_nodes_from_index_to_root(idx, RpoDigest::from(value));
|
||||
|
||||
Ok(old_value)
|
||||
}
|
||||
|
||||
/// Inserts a leaf located at the specified key, and recomputes hashes by walking up the tree
|
||||
pub fn insert_leaf(&mut self, key: u64, value: Word) -> Result<(), MerkleError> {
|
||||
self.insert_leaf_node(key, value);
|
||||
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
|
||||
/// computed as `self.depth() - subtree.depth()`.
|
||||
///
|
||||
/// Returns the new root.
|
||||
pub fn set_subtree(
|
||||
&mut self,
|
||||
subtree_insertion_index: u64,
|
||||
subtree: SimpleSmt,
|
||||
) -> Result<RpoDigest, MerkleError> {
|
||||
if subtree.depth() > self.depth() {
|
||||
return Err(MerkleError::InvalidSubtreeDepth {
|
||||
subtree_depth: subtree.depth(),
|
||||
tree_depth: self.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
// TODO consider using a map `index |-> word` instead of `index |-> (word, word)`
|
||||
let mut index = NodeIndex::new(self.depth(), key)?;
|
||||
let mut value = RpoDigest::from(value);
|
||||
// Verify that `subtree_insertion_index` is valid.
|
||||
let subtree_root_insertion_depth = self.depth() - subtree.depth();
|
||||
let subtree_root_index =
|
||||
NodeIndex::new(subtree_root_insertion_depth, subtree_insertion_index)?;
|
||||
|
||||
// add leaves
|
||||
// --------------
|
||||
|
||||
// The subtree's leaf indices live in their own context - i.e. a subtree of depth `d`. If we
|
||||
// insert the subtree at `subtree_insertion_index = 0`, then the subtree leaf indices are
|
||||
// valid as they are. However, consider what happens when we insert at
|
||||
// `subtree_insertion_index = 1`. The first leaf of our subtree now will have index `2^d`;
|
||||
// you can see it as there's a full subtree sitting on its left. In general, for
|
||||
// `subtree_insertion_index = i`, there are `i` subtrees sitting before the subtree we want
|
||||
// to insert, so we need to adjust all its leaves by `i * 2^d`.
|
||||
let leaf_index_shift: u64 = subtree_insertion_index * 2_u64.pow(subtree.depth().into());
|
||||
for (subtree_leaf_idx, leaf_value) in subtree.leaves() {
|
||||
let new_leaf_idx = leaf_index_shift + subtree_leaf_idx;
|
||||
debug_assert!(new_leaf_idx < 2_u64.pow(self.depth().into()));
|
||||
|
||||
self.insert_leaf_node(new_leaf_idx, *leaf_value);
|
||||
}
|
||||
|
||||
// add subtree's branch nodes (which includes the root)
|
||||
// --------------
|
||||
for (branch_idx, branch_node) in subtree.branches {
|
||||
let new_branch_idx = {
|
||||
let new_depth = subtree_root_insertion_depth + branch_idx.depth();
|
||||
let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
|
||||
+ branch_idx.value();
|
||||
|
||||
NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
|
||||
};
|
||||
|
||||
self.branches.insert(new_branch_idx, branch_node);
|
||||
}
|
||||
|
||||
// recompute nodes starting from subtree root
|
||||
// --------------
|
||||
self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
|
||||
|
||||
Ok(self.root)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Recomputes the branch nodes (including the root) from `index` all the way to the root.
|
||||
/// `node_hash_at_index` is the hash of the node stored at index.
|
||||
fn recompute_nodes_from_index_to_root(
|
||||
&mut self,
|
||||
mut index: NodeIndex,
|
||||
node_hash_at_index: RpoDigest,
|
||||
) {
|
||||
let mut value = node_hash_at_index;
|
||||
for _ in 0..index.depth() {
|
||||
let is_right = index.is_value_odd();
|
||||
index.move_up();
|
||||
@@ -216,36 +322,21 @@ impl SimpleSmt {
|
||||
self.insert_branch_node(index, left, right);
|
||||
value = Rpo256::merge(&[left, right]);
|
||||
}
|
||||
self.root = value.into();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn replace_empty_subtrees(&mut self, hashes: Vec<RpoDigest>) {
|
||||
self.empty_hashes = hashes;
|
||||
}
|
||||
|
||||
fn check_leaf_node_exists(&self, key: u64) -> bool {
|
||||
self.leaves.contains_key(&key)
|
||||
self.root = value;
|
||||
}
|
||||
|
||||
fn get_leaf_node(&self, key: u64) -> Option<Word> {
|
||||
self.leaves.get(&key).copied()
|
||||
}
|
||||
|
||||
fn insert_leaf_node(&mut self, key: u64, node: Word) {
|
||||
self.leaves.insert(key, node);
|
||||
fn insert_leaf_node(&mut self, key: u64, node: Word) -> Option<Word> {
|
||||
self.leaves.insert(key, node)
|
||||
}
|
||||
|
||||
fn get_branch_node(&self, index: &NodeIndex) -> BranchNode {
|
||||
self.branches.get(index).cloned().unwrap_or_else(|| {
|
||||
let node = self.empty_hashes[index.depth() as usize + 1];
|
||||
BranchNode {
|
||||
left: node,
|
||||
right: node,
|
||||
}
|
||||
let node = EmptySubtreeRoots::entry(self.depth, index.depth() + 1);
|
||||
BranchNode { left: *node, right: *node }
|
||||
})
|
||||
}
|
||||
|
||||
@@ -254,3 +345,45 @@ impl SimpleSmt {
|
||||
self.branches.insert(index, branch);
|
||||
}
|
||||
}
|
||||
|
||||
// BRANCH NODE
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
struct BranchNode {
|
||||
left: RpoDigest,
|
||||
right: RpoDigest,
|
||||
}
|
||||
|
||||
impl BranchNode {
|
||||
fn parent(&self) -> RpoDigest {
|
||||
Rpo256::merge(&[self.left, self.right])
|
||||
}
|
||||
}
|
||||
|
||||
// TRY APPLY DIFF
|
||||
// ================================================================================================
|
||||
impl TryApplyDiff<RpoDigest, StoreNode> for SimpleSmt {
|
||||
type Error = MerkleError;
|
||||
type DiffType = MerkleTreeDelta;
|
||||
|
||||
fn try_apply(&mut self, diff: MerkleTreeDelta) -> Result<(), MerkleError> {
|
||||
if diff.depth() != self.depth() {
|
||||
return Err(MerkleError::InvalidDepth {
|
||||
expected: self.depth(),
|
||||
provided: diff.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
for slot in diff.cleared_slots() {
|
||||
self.update_leaf(*slot, Self::EMPTY_VALUE)?;
|
||||
}
|
||||
|
||||
for (slot, value) in diff.updated_slots() {
|
||||
self.update_leaf(*slot, *value)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
use super::{
|
||||
super::{int_to_node, InnerNodeInfo, MerkleError, MerkleTree, RpoDigest, SimpleSmt},
|
||||
NodeIndex, Rpo256, Vec, Word,
|
||||
super::{InnerNodeInfo, MerkleError, MerkleTree, RpoDigest, SimpleSmt, EMPTY_WORD},
|
||||
NodeIndex, Rpo256, Vec,
|
||||
};
|
||||
use proptest::prelude::*;
|
||||
use rand_utils::prng_array;
|
||||
use crate::{
|
||||
merkle::{digests_to_words, int_to_leaf, int_to_node, EmptySubtreeRoots},
|
||||
Word,
|
||||
};
|
||||
|
||||
// TEST DATA
|
||||
// ================================================================================================
|
||||
|
||||
const KEYS4: [u64; 4] = [0, 1, 2, 3];
|
||||
const KEYS8: [u64; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
|
||||
|
||||
const VALUES4: [Word; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
|
||||
const VALUES4: [RpoDigest; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
|
||||
|
||||
const VALUES8: [Word; 8] = [
|
||||
const VALUES8: [RpoDigest; 8] = [
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
int_to_node(3),
|
||||
@@ -21,27 +26,19 @@ const VALUES8: [Word; 8] = [
|
||||
int_to_node(8),
|
||||
];
|
||||
|
||||
const ZERO_VALUES8: [Word; 8] = [int_to_node(0); 8];
|
||||
const ZERO_VALUES8: [Word; 8] = [int_to_leaf(0); 8];
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn build_empty_tree() {
|
||||
// tree of depth 3
|
||||
let smt = SimpleSmt::new(3).unwrap();
|
||||
let mt = MerkleTree::new(ZERO_VALUES8.to_vec()).unwrap();
|
||||
assert_eq!(mt.root(), smt.root());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_digests_are_consistent() {
|
||||
let depth = 5;
|
||||
let root = SimpleSmt::new(depth).unwrap().root();
|
||||
let computed: [RpoDigest; 2] = (0..depth).fold([Default::default(); 2], |state, _| {
|
||||
let digest = Rpo256::merge(&state);
|
||||
[digest; 2]
|
||||
});
|
||||
|
||||
assert_eq!(Word::from(computed[0]), root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_sparse_tree() {
|
||||
let mut smt = SimpleSmt::new(3).unwrap();
|
||||
@@ -49,82 +46,80 @@ fn build_sparse_tree() {
|
||||
|
||||
// insert single value
|
||||
let key = 6;
|
||||
let new_node = int_to_node(7);
|
||||
let new_node = int_to_leaf(7);
|
||||
values[key as usize] = new_node;
|
||||
smt.insert_leaf(key, new_node).expect("Failed to insert leaf");
|
||||
let old_value = smt.update_leaf(key, new_node).expect("Failed to update leaf");
|
||||
let mt2 = MerkleTree::new(values.clone()).unwrap();
|
||||
assert_eq!(mt2.root(), smt.root());
|
||||
assert_eq!(
|
||||
mt2.get_path(NodeIndex::make(3, 6)).unwrap(),
|
||||
smt.get_path(NodeIndex::make(3, 6)).unwrap()
|
||||
);
|
||||
assert_eq!(old_value, EMPTY_WORD);
|
||||
|
||||
// insert second value at distinct leaf branch
|
||||
let key = 2;
|
||||
let new_node = int_to_node(3);
|
||||
let new_node = int_to_leaf(3);
|
||||
values[key as usize] = new_node;
|
||||
smt.insert_leaf(key, new_node).expect("Failed to insert leaf");
|
||||
let old_value = smt.update_leaf(key, new_node).expect("Failed to update leaf");
|
||||
let mt3 = MerkleTree::new(values).unwrap();
|
||||
assert_eq!(mt3.root(), smt.root());
|
||||
assert_eq!(
|
||||
mt3.get_path(NodeIndex::make(3, 2)).unwrap(),
|
||||
smt.get_path(NodeIndex::make(3, 2)).unwrap()
|
||||
);
|
||||
assert_eq!(old_value, EMPTY_WORD);
|
||||
}
|
||||
|
||||
/// Tests that [`SimpleSmt::with_contiguous_leaves`] works as expected
|
||||
#[test]
|
||||
fn build_contiguous_tree() {
|
||||
let tree_with_leaves = SimpleSmt::with_leaves(
|
||||
2,
|
||||
[0, 1, 2, 3].into_iter().zip(digests_to_words(&VALUES4).into_iter()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let tree_with_contiguous_leaves =
|
||||
SimpleSmt::with_contiguous_leaves(2, digests_to_words(&VALUES4).into_iter()).unwrap();
|
||||
|
||||
assert_eq!(tree_with_leaves, tree_with_contiguous_leaves);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_full_tree() {
|
||||
let tree = SimpleSmt::new(2)
|
||||
.unwrap()
|
||||
.with_leaves(KEYS4.into_iter().zip(VALUES4.into_iter()))
|
||||
.unwrap();
|
||||
fn test_depth2_tree() {
|
||||
let tree =
|
||||
SimpleSmt::with_leaves(2, KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()))
|
||||
.unwrap();
|
||||
|
||||
// check internal structure
|
||||
let (root, node2, node3) = compute_internal_nodes();
|
||||
assert_eq!(root, tree.root());
|
||||
assert_eq!(node2, tree.get_node(NodeIndex::make(1, 0)).unwrap());
|
||||
assert_eq!(node3, tree.get_node(NodeIndex::make(1, 1)).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_values() {
|
||||
let tree = SimpleSmt::new(2)
|
||||
.unwrap()
|
||||
.with_leaves(KEYS4.into_iter().zip(VALUES4.into_iter()))
|
||||
.unwrap();
|
||||
|
||||
// check depth 2
|
||||
// check get_node()
|
||||
assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
|
||||
assert_eq!(VALUES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
|
||||
assert_eq!(VALUES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
|
||||
assert_eq!(VALUES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_path() {
|
||||
let tree = SimpleSmt::new(2)
|
||||
.unwrap()
|
||||
.with_leaves(KEYS4.into_iter().zip(VALUES4.into_iter()))
|
||||
.unwrap();
|
||||
|
||||
let (_, node2, node3) = compute_internal_nodes();
|
||||
|
||||
// check depth 2
|
||||
// check get_path(): depth 2
|
||||
assert_eq!(vec![VALUES4[1], node3], *tree.get_path(NodeIndex::make(2, 0)).unwrap());
|
||||
assert_eq!(vec![VALUES4[0], node3], *tree.get_path(NodeIndex::make(2, 1)).unwrap());
|
||||
assert_eq!(vec![VALUES4[3], node2], *tree.get_path(NodeIndex::make(2, 2)).unwrap());
|
||||
assert_eq!(vec![VALUES4[2], node2], *tree.get_path(NodeIndex::make(2, 3)).unwrap());
|
||||
|
||||
// check depth 1
|
||||
// check get_path(): depth 1
|
||||
assert_eq!(vec![node3], *tree.get_path(NodeIndex::make(1, 0)).unwrap());
|
||||
assert_eq!(vec![node2], *tree.get_path(NodeIndex::make(1, 1)).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parent_node_iterator() -> Result<(), MerkleError> {
|
||||
let tree = SimpleSmt::new(2)
|
||||
.unwrap()
|
||||
.with_leaves(KEYS4.into_iter().zip(VALUES4.into_iter()))
|
||||
.unwrap();
|
||||
fn test_inner_node_iterator() -> Result<(), MerkleError> {
|
||||
let tree =
|
||||
SimpleSmt::with_leaves(2, KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()))
|
||||
.unwrap();
|
||||
|
||||
// check depth 2
|
||||
assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
|
||||
@@ -143,21 +138,9 @@ fn test_parent_node_iterator() -> Result<(), MerkleError> {
|
||||
|
||||
let nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
|
||||
let expected = vec![
|
||||
InnerNodeInfo {
|
||||
value: root.into(),
|
||||
left: l1n0.into(),
|
||||
right: l1n1.into(),
|
||||
},
|
||||
InnerNodeInfo {
|
||||
value: l1n0.into(),
|
||||
left: l2n0.into(),
|
||||
right: l2n1.into(),
|
||||
},
|
||||
InnerNodeInfo {
|
||||
value: l1n1.into(),
|
||||
left: l2n2.into(),
|
||||
right: l2n3.into(),
|
||||
},
|
||||
InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
|
||||
InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
|
||||
InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
|
||||
];
|
||||
assert_eq!(nodes, expected);
|
||||
|
||||
@@ -166,35 +149,30 @@ fn test_parent_node_iterator() -> Result<(), MerkleError> {
|
||||
|
||||
#[test]
|
||||
fn update_leaf() {
|
||||
let mut tree = SimpleSmt::new(3)
|
||||
.unwrap()
|
||||
.with_leaves(KEYS8.into_iter().zip(VALUES8.into_iter()))
|
||||
.unwrap();
|
||||
let mut tree =
|
||||
SimpleSmt::with_leaves(3, KEYS8.into_iter().zip(digests_to_words(&VALUES8).into_iter()))
|
||||
.unwrap();
|
||||
|
||||
// update one value
|
||||
let key = 3;
|
||||
let new_node = int_to_node(9);
|
||||
let mut expected_values = VALUES8.to_vec();
|
||||
let new_node = int_to_leaf(9);
|
||||
let mut expected_values = digests_to_words(&VALUES8);
|
||||
expected_values[key] = new_node;
|
||||
let expected_tree = SimpleSmt::new(3)
|
||||
.unwrap()
|
||||
.with_leaves(KEYS8.into_iter().zip(expected_values.clone().into_iter()))
|
||||
.unwrap();
|
||||
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
|
||||
|
||||
tree.update_leaf(key as u64, new_node).unwrap();
|
||||
assert_eq!(expected_tree.root, tree.root);
|
||||
let old_leaf = tree.update_leaf(key as u64, new_node).unwrap();
|
||||
assert_eq!(expected_tree.root(), tree.root);
|
||||
assert_eq!(old_leaf, *VALUES8[key]);
|
||||
|
||||
// update another value
|
||||
let key = 6;
|
||||
let new_node = int_to_node(10);
|
||||
let new_node = int_to_leaf(10);
|
||||
expected_values[key] = new_node;
|
||||
let expected_tree = SimpleSmt::new(3)
|
||||
.unwrap()
|
||||
.with_leaves(KEYS8.into_iter().zip(expected_values.into_iter()))
|
||||
.unwrap();
|
||||
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
|
||||
|
||||
tree.update_leaf(key as u64, new_node).unwrap();
|
||||
assert_eq!(expected_tree.root, tree.root);
|
||||
let old_leaf = tree.update_leaf(key as u64, new_node).unwrap();
|
||||
assert_eq!(expected_tree.root(), tree.root);
|
||||
assert_eq!(old_leaf, *VALUES8[key]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -207,34 +185,34 @@ fn small_tree_opening_is_consistent() {
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = Word::from(RpoDigest::default());
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
let e = Word::from(Rpo256::merge(&[a.into(), b.into()]));
|
||||
let f = Word::from(Rpo256::merge(&[z.into(), z.into()]));
|
||||
let g = Word::from(Rpo256::merge(&[c.into(), z.into()]));
|
||||
let h = Word::from(Rpo256::merge(&[z.into(), d.into()]));
|
||||
let e = Rpo256::merge(&[a.into(), b.into()]);
|
||||
let f = Rpo256::merge(&[z.into(), z.into()]);
|
||||
let g = Rpo256::merge(&[c.into(), z.into()]);
|
||||
let h = Rpo256::merge(&[z.into(), d.into()]);
|
||||
|
||||
let i = Word::from(Rpo256::merge(&[e.into(), f.into()]));
|
||||
let j = Word::from(Rpo256::merge(&[g.into(), h.into()]));
|
||||
let i = Rpo256::merge(&[e, f]);
|
||||
let j = Rpo256::merge(&[g, h]);
|
||||
|
||||
let k = Word::from(Rpo256::merge(&[i.into(), j.into()]));
|
||||
let k = Rpo256::merge(&[i, j]);
|
||||
|
||||
let depth = 3;
|
||||
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
|
||||
let tree = SimpleSmt::new(depth).unwrap().with_leaves(entries).unwrap();
|
||||
let tree = SimpleSmt::with_leaves(depth, entries).unwrap();
|
||||
|
||||
assert_eq!(tree.root(), Word::from(k));
|
||||
assert_eq!(tree.root(), k);
|
||||
|
||||
let cases: Vec<(u8, u64, Vec<Word>)> = vec![
|
||||
(3, 0, vec![b, f, j]),
|
||||
(3, 1, vec![a, f, j]),
|
||||
(3, 4, vec![z, h, i]),
|
||||
(3, 7, vec![z, g, i]),
|
||||
let cases: Vec<(u8, u64, Vec<RpoDigest>)> = vec![
|
||||
(3, 0, vec![b.into(), f, j]),
|
||||
(3, 1, vec![a.into(), f, j]),
|
||||
(3, 4, vec![z.into(), h, i]),
|
||||
(3, 7, vec![z.into(), g, i]),
|
||||
(2, 0, vec![f, j]),
|
||||
(2, 1, vec![e, j]),
|
||||
(2, 2, vec![h, i]),
|
||||
@@ -250,65 +228,269 @@ fn small_tree_opening_is_consistent() {
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn arbitrary_openings_single_leaf(
|
||||
depth in SimpleSmt::MIN_DEPTH..SimpleSmt::MAX_DEPTH,
|
||||
key in prop::num::u64::ANY,
|
||||
leaf in prop::num::u64::ANY,
|
||||
) {
|
||||
let mut tree = SimpleSmt::new(depth).unwrap();
|
||||
#[test]
|
||||
fn test_simplesmt_fail_on_duplicates() {
|
||||
let values = [
|
||||
// same key, same value
|
||||
(int_to_leaf(1), int_to_leaf(1)),
|
||||
// same key, different values
|
||||
(int_to_leaf(1), int_to_leaf(2)),
|
||||
// same key, set to zero
|
||||
(EMPTY_WORD, int_to_leaf(1)),
|
||||
// same key, re-set to zero
|
||||
(int_to_leaf(1), EMPTY_WORD),
|
||||
// same key, set to zero twice
|
||||
(EMPTY_WORD, EMPTY_WORD),
|
||||
];
|
||||
|
||||
let key = key % (1 << depth as u64);
|
||||
let leaf = int_to_node(leaf);
|
||||
for (first, second) in values.iter() {
|
||||
// consecutive
|
||||
let entries = [(1, *first), (1, *second)];
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||
|
||||
tree.insert_leaf(key, leaf.into()).unwrap();
|
||||
tree.get_leaf_path(key).unwrap();
|
||||
|
||||
// traverse to root, fetching all paths
|
||||
for d in 1..depth {
|
||||
let k = key >> (depth - d);
|
||||
tree.get_path(NodeIndex::make(d, k)).unwrap();
|
||||
}
|
||||
// not consecutive
|
||||
let entries = [(1, *first), (5, int_to_leaf(5)), (1, *second)];
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn arbitrary_openings_multiple_leaves(
|
||||
depth in SimpleSmt::MIN_DEPTH..SimpleSmt::MAX_DEPTH,
|
||||
count in 2u8..10u8,
|
||||
ref seed in any::<[u8; 32]>()
|
||||
) {
|
||||
let mut tree = SimpleSmt::new(depth).unwrap();
|
||||
let mut seed = *seed;
|
||||
let leaves = (1 << depth) - 1;
|
||||
#[test]
|
||||
fn with_no_duplicates_empty_node() {
|
||||
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2))];
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert!(smt.is_ok());
|
||||
}
|
||||
|
||||
for _ in 0..count {
|
||||
seed = prng_array(seed);
|
||||
#[test]
|
||||
fn test_simplesmt_update_nonexisting_leaf_with_zero() {
|
||||
// TESTING WITH EMPTY WORD
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
let mut key = [0u8; 8];
|
||||
let mut leaf = [0u8; 8];
|
||||
// Depth 1 has 2 leaf. Position is 0-indexed, position 2 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(1).unwrap();
|
||||
let result = smt.update_leaf(2, EMPTY_WORD);
|
||||
assert!(!smt.leaves.contains_key(&2));
|
||||
assert!(result.is_err());
|
||||
|
||||
key.copy_from_slice(&seed[..8]);
|
||||
leaf.copy_from_slice(&seed[8..16]);
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(2).unwrap();
|
||||
let result = smt.update_leaf(4, EMPTY_WORD);
|
||||
assert!(!smt.leaves.contains_key(&4));
|
||||
assert!(result.is_err());
|
||||
|
||||
let key = u64::from_le_bytes(key);
|
||||
let key = key % leaves;
|
||||
let leaf = u64::from_le_bytes(leaf);
|
||||
let leaf = int_to_node(leaf);
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(3).unwrap();
|
||||
let result = smt.update_leaf(8, EMPTY_WORD);
|
||||
assert!(!smt.leaves.contains_key(&8));
|
||||
assert!(result.is_err());
|
||||
|
||||
tree.insert_leaf(key, leaf).unwrap();
|
||||
tree.get_leaf_path(key).unwrap();
|
||||
}
|
||||
}
|
||||
// TESTING WITH A VALUE
|
||||
// --------------------------------------------------------------------------------------------
|
||||
let value = int_to_node(1);
|
||||
|
||||
// Depth 1 has 2 leaves. Position is 0-indexed, position 1 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(1).unwrap();
|
||||
let result = smt.update_leaf(2, *value);
|
||||
assert!(!smt.leaves.contains_key(&2));
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 2 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(2).unwrap();
|
||||
let result = smt.update_leaf(4, *value);
|
||||
assert!(!smt.leaves.contains_key(&4));
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(3).unwrap();
|
||||
let result = smt.update_leaf(8, *value);
|
||||
assert!(!smt.leaves.contains_key(&8));
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplesmt_with_leaves_nonexisting_leaf() {
|
||||
// TESTING WITH EMPTY WORD
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
// Depth 1 has 2 leaf. Position is 0-indexed, position 2 doesn't exist.
|
||||
let leaves = [(2, EMPTY_WORD)];
|
||||
let result = SimpleSmt::with_leaves(1, leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let leaves = [(4, EMPTY_WORD)];
|
||||
let result = SimpleSmt::with_leaves(2, leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||
let leaves = [(8, EMPTY_WORD)];
|
||||
let result = SimpleSmt::with_leaves(3, leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// TESTING WITH A VALUE
|
||||
// --------------------------------------------------------------------------------------------
|
||||
let value = int_to_node(1);
|
||||
|
||||
// Depth 1 has 2 leaves. Position is 0-indexed, position 2 doesn't exist.
|
||||
let leaves = [(2, *value)];
|
||||
let result = SimpleSmt::with_leaves(1, leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let leaves = [(4, *value)];
|
||||
let result = SimpleSmt::with_leaves(2, leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||
let leaves = [(8, *value)];
|
||||
let result = SimpleSmt::with_leaves(3, leaves);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplesmt_set_subtree() {
|
||||
// Final Tree:
|
||||
//
|
||||
// ____k____
|
||||
// / \
|
||||
// _i_ _j_
|
||||
// / \ / \
|
||||
// e f g h
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
let e = Rpo256::merge(&[a.into(), b.into()]);
|
||||
let f = Rpo256::merge(&[z.into(), z.into()]);
|
||||
let g = Rpo256::merge(&[c.into(), z.into()]);
|
||||
let h = Rpo256::merge(&[z.into(), d.into()]);
|
||||
|
||||
let i = Rpo256::merge(&[e, f]);
|
||||
let j = Rpo256::merge(&[g, h]);
|
||||
|
||||
let k = Rpo256::merge(&[i, j]);
|
||||
|
||||
// subtree:
|
||||
// g
|
||||
// / \
|
||||
// c 0
|
||||
let subtree = {
|
||||
let depth = 1;
|
||||
let entries = vec![(0, c)];
|
||||
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||
};
|
||||
|
||||
// insert subtree
|
||||
let tree = {
|
||||
let depth = 3;
|
||||
let entries = vec![(0, a), (1, b), (7, d)];
|
||||
let mut tree = SimpleSmt::with_leaves(depth, entries).unwrap();
|
||||
|
||||
tree.set_subtree(2, subtree).unwrap();
|
||||
|
||||
tree
|
||||
};
|
||||
|
||||
assert_eq!(tree.root(), k);
|
||||
assert_eq!(tree.get_leaf(4).unwrap(), c);
|
||||
assert_eq!(tree.get_branch_node(&NodeIndex::new_unchecked(2, 2)).parent(), g);
|
||||
}
|
||||
|
||||
/// Ensures that an invalid input node index into `set_subtree()` incurs no mutation of the tree
|
||||
#[test]
|
||||
fn test_simplesmt_set_subtree_unchanged_for_wrong_index() {
|
||||
// Final Tree:
|
||||
//
|
||||
// ____k____
|
||||
// / \
|
||||
// _i_ _j_
|
||||
// / \ / \
|
||||
// e f g h
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
// subtree:
|
||||
// g
|
||||
// / \
|
||||
// c 0
|
||||
let subtree = {
|
||||
let depth = 1;
|
||||
let entries = vec![(0, c)];
|
||||
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||
};
|
||||
|
||||
let mut tree = {
|
||||
let depth = 3;
|
||||
let entries = vec![(0, a), (1, b), (7, d)];
|
||||
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||
};
|
||||
let tree_root_before_insertion = tree.root();
|
||||
|
||||
// insert subtree
|
||||
assert!(tree.set_subtree(500, subtree).is_err());
|
||||
|
||||
assert_eq!(tree.root(), tree_root_before_insertion);
|
||||
}
|
||||
|
||||
/// We insert an empty subtree that has the same depth as the original tree
|
||||
#[test]
|
||||
fn test_simplesmt_set_subtree_entire_tree() {
|
||||
// Initial Tree:
|
||||
//
|
||||
// ____k____
|
||||
// / \
|
||||
// _i_ _j_
|
||||
// / \ / \
|
||||
// e f g h
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
let depth = 3;
|
||||
|
||||
// subtree: E3
|
||||
let subtree = { SimpleSmt::with_leaves(depth, Vec::new()).unwrap() };
|
||||
assert_eq!(subtree.root(), *EmptySubtreeRoots::entry(depth, 0));
|
||||
|
||||
// insert subtree
|
||||
let mut tree = {
|
||||
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
|
||||
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||
};
|
||||
|
||||
tree.set_subtree(0, subtree).unwrap();
|
||||
|
||||
assert_eq!(tree.root(), *EmptySubtreeRoots::entry(depth, 0));
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn compute_internal_nodes() -> (Word, Word, Word) {
|
||||
let node2 = Rpo256::hash_elements(&[VALUES4[0], VALUES4[1]].concat());
|
||||
let node3 = Rpo256::hash_elements(&[VALUES4[2], VALUES4[3]].concat());
|
||||
fn compute_internal_nodes() -> (RpoDigest, RpoDigest, RpoDigest) {
|
||||
let node2 = Rpo256::merge(&[VALUES4[0], VALUES4[1]]);
|
||||
let node3 = Rpo256::merge(&[VALUES4[2], VALUES4[3]]);
|
||||
let root = Rpo256::merge(&[node2, node3]);
|
||||
|
||||
(root.into(), node2.into(), node3.into())
|
||||
(root, node2, node3)
|
||||
}
|
||||
|
||||
@@ -1,20 +1,31 @@
|
||||
use super::mmr::Mmr;
|
||||
use super::{
|
||||
BTreeMap, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, MerklePathSet, MerkleTree,
|
||||
NodeIndex, RootPath, Rpo256, RpoDigest, SimpleSmt, ValuePath, Vec, Word,
|
||||
mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, KvMap, MerkleError, MerklePath,
|
||||
MerkleStoreDelta, MerkleTree, NodeIndex, PartialMerkleTree, RecordingMap, RootPath, Rpo256,
|
||||
RpoDigest, SimpleSmt, TieredSmt, TryApplyDiff, ValuePath, Vec, EMPTY_WORD,
|
||||
};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
use core::borrow::Borrow;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// MERKLE STORE
|
||||
// ================================================================================================
|
||||
|
||||
/// A default [MerkleStore] which uses a simple [BTreeMap] as the backing storage.
|
||||
pub type DefaultMerkleStore = MerkleStore<BTreeMap<RpoDigest, StoreNode>>;
|
||||
|
||||
/// A [MerkleStore] with recording capabilities which uses [RecordingMap] as the backing storage.
|
||||
pub type RecordingMerkleStore = MerkleStore<RecordingMap<RpoDigest, StoreNode>>;
|
||||
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Node {
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct StoreNode {
|
||||
left: RpoDigest,
|
||||
right: RpoDigest,
|
||||
}
|
||||
|
||||
/// An in-memory data store for Merkle-lized data.
|
||||
/// An in-memory data store for Merkelized data.
|
||||
///
|
||||
/// This is a in memory data store for Merkle trees, this store allows all the nodes of multiple
|
||||
/// trees to live as long as necessary and without duplication, this allows the implementation of
|
||||
@@ -42,7 +53,7 @@ pub struct Node {
|
||||
/// # let T1 = MerkleTree::new([A, B, C, D, E, F, G, H1].to_vec()).expect("even number of leaves provided");
|
||||
/// # let ROOT0 = T0.root();
|
||||
/// # let ROOT1 = T1.root();
|
||||
/// let mut store = MerkleStore::new();
|
||||
/// let mut store: MerkleStore = MerkleStore::new();
|
||||
///
|
||||
/// // the store is initialized with the SMT empty nodes
|
||||
/// assert_eq!(store.num_internal_nodes(), 255);
|
||||
@@ -51,9 +62,8 @@ pub struct Node {
|
||||
/// let tree2 = MerkleTree::new(vec![A, B, C, D, E, F, G, H1]).unwrap();
|
||||
///
|
||||
/// // populates the store with two merkle trees, common nodes are shared
|
||||
/// store
|
||||
/// .extend(tree1.inner_nodes())
|
||||
/// .extend(tree2.inner_nodes());
|
||||
/// store.extend(tree1.inner_nodes());
|
||||
/// store.extend(tree2.inner_nodes());
|
||||
///
|
||||
/// // every leaf except the last are the same
|
||||
/// for i in 0..7 {
|
||||
@@ -78,40 +88,25 @@ pub struct Node {
|
||||
/// assert_eq!(store.num_internal_nodes() - 255, 10);
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub struct MerkleStore {
|
||||
nodes: BTreeMap<RpoDigest, Node>,
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MerkleStore<T: KvMap<RpoDigest, StoreNode> = BTreeMap<RpoDigest, StoreNode>> {
|
||||
nodes: T,
|
||||
}
|
||||
|
||||
impl Default for MerkleStore {
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> Default for MerkleStore<T> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl MerkleStore {
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Creates an empty `MerkleStore` instance.
|
||||
pub fn new() -> MerkleStore {
|
||||
pub fn new() -> MerkleStore<T> {
|
||||
// pre-populate the store with the empty hashes
|
||||
let subtrees = EmptySubtreeRoots::empty_hashes(255);
|
||||
let nodes = subtrees
|
||||
.iter()
|
||||
.rev()
|
||||
.copied()
|
||||
.zip(subtrees.iter().rev().skip(1).copied())
|
||||
.map(|(child, parent)| {
|
||||
(
|
||||
parent,
|
||||
Node {
|
||||
left: child,
|
||||
right: child,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let nodes = empty_hashes().into_iter().collect();
|
||||
MerkleStore { nodes }
|
||||
}
|
||||
|
||||
@@ -126,25 +121,24 @@ impl MerkleStore {
|
||||
/// Returns the node at `index` rooted on the tree `root`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This method can return the following errors:
|
||||
/// - `RootNotInStore` if the `root` is not present in the store.
|
||||
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the store.
|
||||
pub fn get_node(&self, root: Word, index: NodeIndex) -> Result<Word, MerkleError> {
|
||||
let mut hash: RpoDigest = root.into();
|
||||
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in
|
||||
/// the store.
|
||||
pub fn get_node(&self, root: RpoDigest, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
let mut hash = root;
|
||||
|
||||
// corner case: check the root is in the store when called with index `NodeIndex::root()`
|
||||
self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash.into()))?;
|
||||
self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash))?;
|
||||
|
||||
for i in (0..index.depth()).rev() {
|
||||
let node =
|
||||
self.nodes.get(&hash).ok_or(MerkleError::NodeNotInStore(hash.into(), index))?;
|
||||
let node = self.nodes.get(&hash).ok_or(MerkleError::NodeNotInStore(hash, index))?;
|
||||
|
||||
let bit = (index.value() >> i) & 1;
|
||||
hash = if bit == 0 { node.left } else { node.right }
|
||||
}
|
||||
|
||||
Ok(hash.into())
|
||||
Ok(hash)
|
||||
}
|
||||
|
||||
/// Returns the node at the specified `index` and its opening to the `root`.
|
||||
@@ -152,27 +146,26 @@ impl MerkleStore {
|
||||
/// The path starts at the sibling of the target leaf.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This method can return the following errors:
|
||||
/// - `RootNotInStore` if the `root` is not present in the store.
|
||||
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the store.
|
||||
pub fn get_path(&self, root: Word, index: NodeIndex) -> Result<ValuePath, MerkleError> {
|
||||
let mut hash: RpoDigest = root.into();
|
||||
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in
|
||||
/// the store.
|
||||
pub fn get_path(&self, root: RpoDigest, index: NodeIndex) -> Result<ValuePath, MerkleError> {
|
||||
let mut hash = root;
|
||||
let mut path = Vec::with_capacity(index.depth().into());
|
||||
|
||||
// corner case: check the root is in the store when called with index `NodeIndex::root()`
|
||||
self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash.into()))?;
|
||||
self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash))?;
|
||||
|
||||
for i in (0..index.depth()).rev() {
|
||||
let node =
|
||||
self.nodes.get(&hash).ok_or(MerkleError::NodeNotInStore(hash.into(), index))?;
|
||||
let node = self.nodes.get(&hash).ok_or(MerkleError::NodeNotInStore(hash, index))?;
|
||||
|
||||
let bit = (index.value() >> i) & 1;
|
||||
hash = if bit == 0 {
|
||||
path.push(node.right.into());
|
||||
path.push(node.right);
|
||||
node.left
|
||||
} else {
|
||||
path.push(node.left.into());
|
||||
path.push(node.left);
|
||||
node.right
|
||||
}
|
||||
}
|
||||
@@ -180,30 +173,27 @@ impl MerkleStore {
|
||||
// the path is computed from root to leaf, so it must be reversed
|
||||
path.reverse();
|
||||
|
||||
Ok(ValuePath {
|
||||
value: hash.into(),
|
||||
path: MerklePath::new(path),
|
||||
})
|
||||
Ok(ValuePath::new(hash, path))
|
||||
}
|
||||
|
||||
/// Reconstructs a path from the root until a leaf or empty node and returns its depth.
|
||||
// LEAF TRAVERSAL
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the depth of the first leaf or an empty node encountered while traversing the tree
|
||||
/// from the specified root down according to the provided index.
|
||||
///
|
||||
/// The `tree_depth` parameter defines up to which depth the tree will be traversed, starting
|
||||
/// from `root`. The maximum value the argument accepts is [u64::BITS].
|
||||
///
|
||||
/// The traversed path from leaf to root will start at the least significant bit of `index`,
|
||||
/// and will be executed for `tree_depth` bits.
|
||||
/// The `tree_depth` parameter specifies the depth of the tree rooted at `root`. The
|
||||
/// maximum value the argument accepts is [u64::BITS].
|
||||
///
|
||||
/// # Errors
|
||||
/// Will return an error if:
|
||||
/// - The provided root is not found.
|
||||
/// - The path from the root continues to a depth greater than `tree_depth`.
|
||||
/// - The provided `tree_depth` is greater than `64.
|
||||
/// - The provided `index` is not valid for a depth equivalent to `tree_depth`. For more
|
||||
/// information, check [NodeIndex::new].
|
||||
/// - The provided `tree_depth` is greater than 64.
|
||||
/// - The provided `index` is not valid for a depth equivalent to `tree_depth`.
|
||||
/// - No leaf or an empty node was found while traversing the tree down to `tree_depth`.
|
||||
pub fn get_leaf_depth(
|
||||
&self,
|
||||
root: Word,
|
||||
root: RpoDigest,
|
||||
tree_depth: u8,
|
||||
index: u64,
|
||||
) -> Result<u8, MerkleError> {
|
||||
@@ -213,25 +203,18 @@ impl MerkleStore {
|
||||
}
|
||||
NodeIndex::new(tree_depth, index)?;
|
||||
|
||||
// it's not illegal to have a maximum depth of `0`; we should just return the root in that
|
||||
// case. this check will simplify the implementation as we could overflow bits for depth
|
||||
// `0`.
|
||||
if tree_depth == 0 {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
// check if the root exists, providing the proper error report if it doesn't
|
||||
let empty = EmptySubtreeRoots::empty_hashes(tree_depth);
|
||||
let mut hash: RpoDigest = root.into();
|
||||
let mut hash = root;
|
||||
if !self.nodes.contains_key(&hash) {
|
||||
return Err(MerkleError::RootNotInStore(hash.into()));
|
||||
return Err(MerkleError::RootNotInStore(hash));
|
||||
}
|
||||
|
||||
// we traverse from root to leaf, so the path is reversed
|
||||
let mut path = (index << (64 - tree_depth)).reverse_bits();
|
||||
|
||||
// iterate every depth and reconstruct the path from root to leaf
|
||||
for depth in 0..tree_depth {
|
||||
for depth in 0..=tree_depth {
|
||||
// we short-circuit if an empty node has been found
|
||||
if hash == empty[depth as usize] {
|
||||
return Ok(depth);
|
||||
@@ -248,35 +231,150 @@ impl MerkleStore {
|
||||
path >>= 1;
|
||||
}
|
||||
|
||||
// at max depth assert it doesn't have sub-trees
|
||||
if self.nodes.contains_key(&hash) {
|
||||
return Err(MerkleError::DepthTooBig(tree_depth as u64 + 1));
|
||||
// return an error because we exhausted the index but didn't find either a leaf or an
|
||||
// empty node
|
||||
Err(MerkleError::DepthTooBig(tree_depth as u64 + 1))
|
||||
}
|
||||
|
||||
/// Returns index and value of a leaf node which is the only leaf node in a subtree defined by
|
||||
/// the provided root. If the subtree contains zero or more than one leaf nodes None is
|
||||
/// returned.
|
||||
///
|
||||
/// The `tree_depth` parameter specifies the depth of the parent tree such that `root` is
|
||||
/// located in this tree at `root_index`. The maximum value the argument accepts is
|
||||
/// [u64::BITS].
|
||||
///
|
||||
/// # Errors
|
||||
/// Will return an error if:
|
||||
/// - The provided root is not found.
|
||||
/// - The provided `tree_depth` is greater than 64.
|
||||
/// - The provided `root_index` has depth greater than `tree_depth`.
|
||||
/// - A lone node at depth `tree_depth` is not a leaf node.
|
||||
pub fn find_lone_leaf(
|
||||
&self,
|
||||
root: RpoDigest,
|
||||
root_index: NodeIndex,
|
||||
tree_depth: u8,
|
||||
) -> Result<Option<(NodeIndex, RpoDigest)>, MerkleError> {
|
||||
// we set max depth at u64::BITS as this is the largest meaningful value for a 64-bit index
|
||||
const MAX_DEPTH: u8 = u64::BITS as u8;
|
||||
if tree_depth > MAX_DEPTH {
|
||||
return Err(MerkleError::DepthTooBig(tree_depth as u64));
|
||||
}
|
||||
let empty = EmptySubtreeRoots::empty_hashes(MAX_DEPTH);
|
||||
|
||||
let mut node = root;
|
||||
if !self.nodes.contains_key(&node) {
|
||||
return Err(MerkleError::RootNotInStore(node));
|
||||
}
|
||||
|
||||
// depleted bits; return max depth
|
||||
Ok(tree_depth)
|
||||
let mut index = root_index;
|
||||
if index.depth() > tree_depth {
|
||||
return Err(MerkleError::DepthTooBig(index.depth() as u64));
|
||||
}
|
||||
|
||||
// traverse down following the path of single non-empty nodes; this works because if a
|
||||
// node has two empty children it cannot contain a lone leaf. similarly if a node has
|
||||
// two non-empty children it must contain at least two leaves.
|
||||
for depth in index.depth()..tree_depth {
|
||||
// if the node is a leaf, return; otherwise, examine the node's children
|
||||
let children = match self.nodes.get(&node) {
|
||||
Some(node) => node,
|
||||
None => return Ok(Some((index, node))),
|
||||
};
|
||||
|
||||
let empty_node = empty[depth as usize + 1];
|
||||
node = if children.left != empty_node && children.right == empty_node {
|
||||
index = index.left_child();
|
||||
children.left
|
||||
} else if children.left == empty_node && children.right != empty_node {
|
||||
index = index.right_child();
|
||||
children.right
|
||||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
}
|
||||
|
||||
// if we are here, we got to `tree_depth`; thus, either the current node is a leaf node,
|
||||
// and so we return it, or it is an internal node, and then we return an error
|
||||
if self.nodes.contains_key(&node) {
|
||||
Err(MerkleError::DepthTooBig(tree_depth as u64 + 1))
|
||||
} else {
|
||||
Ok(Some((index, node)))
|
||||
}
|
||||
}
|
||||
|
||||
// DATA EXTRACTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a subset of this Merkle store such that the returned Merkle store contains all
|
||||
/// nodes which are descendants of the specified roots.
|
||||
///
|
||||
/// The roots for which no descendants exist in this Merkle store are ignored.
|
||||
pub fn subset<I, R>(&self, roots: I) -> MerkleStore<T>
|
||||
where
|
||||
I: Iterator<Item = R>,
|
||||
R: Borrow<RpoDigest>,
|
||||
{
|
||||
let mut store = MerkleStore::new();
|
||||
for root in roots {
|
||||
let root = *root.borrow();
|
||||
store.clone_tree_from(root, self);
|
||||
}
|
||||
store
|
||||
}
|
||||
|
||||
/// Iterator over the inner nodes of the [MerkleStore].
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.nodes
|
||||
.iter()
|
||||
.map(|(r, n)| InnerNodeInfo { value: *r, left: n.left, right: n.right })
|
||||
}
|
||||
|
||||
/// Iterator over the non-empty leaves of the Merkle tree associated with the specified `root`
|
||||
/// and `max_depth`.
|
||||
pub fn non_empty_leaves(
|
||||
&self,
|
||||
root: RpoDigest,
|
||||
max_depth: u8,
|
||||
) -> impl Iterator<Item = (NodeIndex, RpoDigest)> + '_ {
|
||||
let empty_roots = EmptySubtreeRoots::empty_hashes(max_depth);
|
||||
let mut stack = Vec::new();
|
||||
stack.push((NodeIndex::new_unchecked(0, 0), root));
|
||||
|
||||
core::iter::from_fn(move || {
|
||||
while let Some((index, node_hash)) = stack.pop() {
|
||||
// if we are at the max depth then we have reached a leaf
|
||||
if index.depth() == max_depth {
|
||||
return Some((index, node_hash));
|
||||
}
|
||||
|
||||
// fetch the nodes children and push them onto the stack if they are not the roots
|
||||
// of empty subtrees
|
||||
if let Some(node) = self.nodes.get(&node_hash) {
|
||||
if !empty_roots.contains(&node.left) {
|
||||
stack.push((index.left_child(), node.left));
|
||||
}
|
||||
if !empty_roots.contains(&node.right) {
|
||||
stack.push((index.right_child(), node.right));
|
||||
}
|
||||
|
||||
// if the node is not in the store assume it is a leaf
|
||||
} else {
|
||||
// assert that if we have a leaf that is not at the max depth then it must be
|
||||
// at the depth of one of the tiers of an TSMT.
|
||||
debug_assert!(TieredSmt::TIER_DEPTHS[..3].contains(&index.depth()));
|
||||
return Some((index, node_hash));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Adds a sequence of nodes yielded by the provided iterator into the store.
|
||||
pub fn extend<I>(&mut self, iter: I) -> &mut MerkleStore
|
||||
where
|
||||
I: Iterator<Item = InnerNodeInfo>,
|
||||
{
|
||||
for node in iter {
|
||||
let value: RpoDigest = node.value.into();
|
||||
let left: RpoDigest = node.left.into();
|
||||
let right: RpoDigest = node.right.into();
|
||||
|
||||
debug_assert_eq!(Rpo256::merge(&[left, right]), value);
|
||||
self.nodes.insert(value, Node { left, right });
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds all the nodes of a Merkle path represented by `path`, opening to `node`. Returns the
|
||||
/// new root.
|
||||
///
|
||||
@@ -285,16 +383,16 @@ impl MerkleStore {
|
||||
pub fn add_merkle_path(
|
||||
&mut self,
|
||||
index: u64,
|
||||
node: Word,
|
||||
node: RpoDigest,
|
||||
path: MerklePath,
|
||||
) -> Result<Word, MerkleError> {
|
||||
let root = path.inner_nodes(index, node)?.fold(Word::default(), |_, node| {
|
||||
let value: RpoDigest = node.value.into();
|
||||
let left: RpoDigest = node.left.into();
|
||||
let right: RpoDigest = node.right.into();
|
||||
) -> Result<RpoDigest, MerkleError> {
|
||||
let root = path.inner_nodes(index, node)?.fold(RpoDigest::default(), |_, node| {
|
||||
let value: RpoDigest = node.value;
|
||||
let left: RpoDigest = node.left;
|
||||
let right: RpoDigest = node.right;
|
||||
|
||||
debug_assert_eq!(Rpo256::merge(&[left, right]), value);
|
||||
self.nodes.insert(value, Node { left, right });
|
||||
self.nodes.insert(value, StoreNode { left, right });
|
||||
|
||||
node.value
|
||||
});
|
||||
@@ -309,7 +407,7 @@ impl MerkleStore {
|
||||
/// For further reference, check [MerkleStore::add_merkle_path].
|
||||
pub fn add_merkle_paths<I>(&mut self, paths: I) -> Result<(), MerkleError>
|
||||
where
|
||||
I: IntoIterator<Item = (u64, Word, MerklePath)>,
|
||||
I: IntoIterator<Item = (u64, RpoDigest, MerklePath)>,
|
||||
{
|
||||
for (index_value, node, path) in paths.into_iter() {
|
||||
self.add_merkle_path(index_value, node, path)?;
|
||||
@@ -317,29 +415,18 @@ impl MerkleStore {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Appends the provided [MerklePathSet] into the store.
|
||||
///
|
||||
/// For further reference, check [MerkleStore::add_merkle_path].
|
||||
pub fn add_merkle_path_set(&mut self, path_set: &MerklePathSet) -> Result<Word, MerkleError> {
|
||||
let root = path_set.root();
|
||||
for (index, path) in path_set.to_paths() {
|
||||
self.add_merkle_path(index, path.value, path.path)?;
|
||||
}
|
||||
Ok(root)
|
||||
}
|
||||
|
||||
/// Sets a node to `value`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This method can return the following errors:
|
||||
/// - `RootNotInStore` if the `root` is not present in the store.
|
||||
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the store.
|
||||
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in
|
||||
/// the store.
|
||||
pub fn set_node(
|
||||
&mut self,
|
||||
mut root: Word,
|
||||
mut root: RpoDigest,
|
||||
index: NodeIndex,
|
||||
value: Word,
|
||||
value: RpoDigest,
|
||||
) -> Result<RootPath, MerkleError> {
|
||||
let node = value;
|
||||
let ValuePath { value, path } = self.get_path(root, index)?;
|
||||
@@ -355,80 +442,166 @@ impl MerkleStore {
|
||||
/// Merges two elements and adds the resulting node into the store.
|
||||
///
|
||||
/// Merges arbitrary values. They may be leafs, nodes, or a mixture of both.
|
||||
pub fn merge_roots(&mut self, root1: Word, root2: Word) -> Result<Word, MerkleError> {
|
||||
let left: RpoDigest = root1.into();
|
||||
let right: RpoDigest = root2.into();
|
||||
pub fn merge_roots(
|
||||
&mut self,
|
||||
left_root: RpoDigest,
|
||||
right_root: RpoDigest,
|
||||
) -> Result<RpoDigest, MerkleError> {
|
||||
let parent = Rpo256::merge(&[left_root, right_root]);
|
||||
self.nodes.insert(parent, StoreNode { left: left_root, right: right_root });
|
||||
|
||||
let parent = Rpo256::merge(&[left, right]);
|
||||
self.nodes.insert(parent, Node { left, right });
|
||||
Ok(parent)
|
||||
}
|
||||
|
||||
Ok(parent.into())
|
||||
// DESTRUCTURING
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the inner storage of this MerkleStore while consuming `self`.
|
||||
pub fn into_inner(self) -> T {
|
||||
self.nodes
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Recursively clones a tree with the specified root from the specified source into self.
|
||||
///
|
||||
/// If the source store does not contain a tree with the specified root, this is a noop.
|
||||
fn clone_tree_from(&mut self, root: RpoDigest, source: &Self) {
|
||||
// process the node only if it is in the source
|
||||
if let Some(node) = source.nodes.get(&root) {
|
||||
// if the node has already been inserted, no need to process it further as all of its
|
||||
// descendants should be already cloned from the source store
|
||||
if self.nodes.insert(root, *node).is_none() {
|
||||
self.clone_tree_from(node.left, source);
|
||||
self.clone_tree_from(node.right, source);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS
|
||||
// ================================================================================================
|
||||
|
||||
impl From<&MerkleTree> for MerkleStore {
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&MerkleTree> for MerkleStore<T> {
|
||||
fn from(value: &MerkleTree) -> Self {
|
||||
let mut store = MerkleStore::new();
|
||||
store.extend(value.inner_nodes());
|
||||
store
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&SimpleSmt> for MerkleStore {
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&SimpleSmt> for MerkleStore<T> {
|
||||
fn from(value: &SimpleSmt) -> Self {
|
||||
let mut store = MerkleStore::new();
|
||||
store.extend(value.inner_nodes());
|
||||
store
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Mmr> for MerkleStore {
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&Mmr> for MerkleStore<T> {
|
||||
fn from(value: &Mmr) -> Self {
|
||||
let mut store = MerkleStore::new();
|
||||
store.extend(value.inner_nodes());
|
||||
store
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<InnerNodeInfo> for MerkleStore {
|
||||
fn from_iter<T: IntoIterator<Item = InnerNodeInfo>>(iter: T) -> Self {
|
||||
let mut store = MerkleStore::new();
|
||||
store.extend(iter.into_iter());
|
||||
store
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&TieredSmt> for MerkleStore<T> {
|
||||
fn from(value: &TieredSmt) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&PartialMerkleTree> for MerkleStore<T> {
|
||||
fn from(value: &PartialMerkleTree) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<T> for MerkleStore<T> {
|
||||
fn from(values: T) -> Self {
|
||||
let nodes = values.into_iter().chain(empty_hashes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> FromIterator<InnerNodeInfo> for MerkleStore<T> {
|
||||
fn from_iter<I: IntoIterator<Item = InnerNodeInfo>>(iter: I) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(iter).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> FromIterator<(RpoDigest, StoreNode)> for MerkleStore<T> {
|
||||
fn from_iter<I: IntoIterator<Item = (RpoDigest, StoreNode)>>(iter: I) -> Self {
|
||||
let nodes = iter.into_iter().chain(empty_hashes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> Extend<InnerNodeInfo> for MerkleStore<T> {
|
||||
fn extend<I: IntoIterator<Item = InnerNodeInfo>>(&mut self, iter: I) {
|
||||
self.nodes.extend(
|
||||
iter.into_iter()
|
||||
.map(|info| (info.value, StoreNode { left: info.left, right: info.right })),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl Extend<InnerNodeInfo> for MerkleStore {
|
||||
fn extend<T: IntoIterator<Item = InnerNodeInfo>>(&mut self, iter: T) {
|
||||
self.extend(iter.into_iter());
|
||||
// DiffT & ApplyDiffT TRAIT IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> TryApplyDiff<RpoDigest, StoreNode> for MerkleStore<T> {
|
||||
type Error = MerkleError;
|
||||
type DiffType = MerkleStoreDelta;
|
||||
|
||||
fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), MerkleError> {
|
||||
for (root, delta) in diff.0 {
|
||||
let mut root = root;
|
||||
for cleared_slot in delta.cleared_slots() {
|
||||
root = self
|
||||
.set_node(
|
||||
root,
|
||||
NodeIndex::new(delta.depth(), *cleared_slot)?,
|
||||
EMPTY_WORD.into(),
|
||||
)?
|
||||
.root;
|
||||
}
|
||||
for (updated_slot, updated_value) in delta.updated_slots() {
|
||||
root = self
|
||||
.set_node(
|
||||
root,
|
||||
NodeIndex::new(delta.depth(), *updated_slot)?,
|
||||
(*updated_value).into(),
|
||||
)?
|
||||
.root;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for Node {
|
||||
impl Serializable for StoreNode {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
self.left.write_into(target);
|
||||
self.right.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for Node {
|
||||
impl Deserializable for StoreNode {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let left = RpoDigest::read_from(source)?;
|
||||
let right = RpoDigest::read_from(source)?;
|
||||
Ok(Node { left, right })
|
||||
Ok(StoreNode { left, right })
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for MerkleStore {
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> Serializable for MerkleStore<T> {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_u64(self.nodes.len() as u64);
|
||||
|
||||
@@ -439,17 +612,42 @@ impl Serializable for MerkleStore {
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for MerkleStore {
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> Deserializable for MerkleStore<T> {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let len = source.read_u64()?;
|
||||
let mut nodes: BTreeMap<RpoDigest, Node> = BTreeMap::new();
|
||||
let mut nodes: Vec<(RpoDigest, StoreNode)> = Vec::with_capacity(len as usize);
|
||||
|
||||
for _ in 0..len {
|
||||
let key = RpoDigest::read_from(source)?;
|
||||
let value = Node::read_from(source)?;
|
||||
nodes.insert(key, value);
|
||||
let value = StoreNode::read_from(source)?;
|
||||
nodes.push((key, value));
|
||||
}
|
||||
|
||||
Ok(MerkleStore { nodes })
|
||||
Ok(nodes.into_iter().collect())
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Creates empty hashes for all the subtrees of a tree with a max depth of 255.
|
||||
fn empty_hashes() -> impl IntoIterator<Item = (RpoDigest, StoreNode)> {
|
||||
let subtrees = EmptySubtreeRoots::empty_hashes(255);
|
||||
subtrees
|
||||
.iter()
|
||||
.rev()
|
||||
.copied()
|
||||
.zip(subtrees.iter().rev().skip(1).copied())
|
||||
.map(|(child, parent)| (parent, StoreNode { left: child, right: child }))
|
||||
}
|
||||
|
||||
/// Consumes an iterator of [InnerNodeInfo] and returns an iterator of `(value, node)` tuples
|
||||
/// which includes the nodes associate with roots of empty subtrees up to a depth of 255.
|
||||
fn combine_nodes_with_empty_hashes(
|
||||
nodes: impl IntoIterator<Item = InnerNodeInfo>,
|
||||
) -> impl Iterator<Item = (RpoDigest, StoreNode)> {
|
||||
nodes
|
||||
.into_iter()
|
||||
.map(|info| (info.value, StoreNode { left: info.left, right: info.right }))
|
||||
.chain(empty_hashes())
|
||||
}
|
||||
|
||||
@@ -1,29 +1,51 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
hash::rpo::Rpo256,
|
||||
merkle::{int_to_node, MerklePathSet, MerkleTree, SimpleSmt},
|
||||
Felt, Word, WORD_SIZE, ZERO,
|
||||
use super::{
|
||||
DefaultMerkleStore as MerkleStore, EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex,
|
||||
PartialMerkleTree, RecordingMerkleStore, Rpo256, RpoDigest,
|
||||
};
|
||||
use crate::{
|
||||
merkle::{digests_to_words, int_to_leaf, int_to_node, MerkleTree, SimpleSmt},
|
||||
Felt, Word, ONE, WORD_SIZE, ZERO,
|
||||
};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use super::{Deserializable, Serializable};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use std::error::Error;
|
||||
|
||||
// TEST DATA
|
||||
// ================================================================================================
|
||||
|
||||
const KEYS4: [u64; 4] = [0, 1, 2, 3];
|
||||
const LEAVES4: [Word; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
|
||||
const EMPTY: Word = [ZERO; WORD_SIZE];
|
||||
const VALUES4: [RpoDigest; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
|
||||
|
||||
const KEYS8: [u64; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
|
||||
const VALUES8: [RpoDigest; 8] = [
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
int_to_node(3),
|
||||
int_to_node(4),
|
||||
int_to_node(5),
|
||||
int_to_node(6),
|
||||
int_to_node(7),
|
||||
int_to_node(8),
|
||||
];
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn test_root_not_in_store() -> Result<(), MerkleError> {
|
||||
let mtree = MerkleTree::new(LEAVES4.to_vec())?;
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
|
||||
let store = MerkleStore::from(&mtree);
|
||||
assert_eq!(
|
||||
store.get_node(LEAVES4[0], NodeIndex::make(mtree.depth(), 0)),
|
||||
Err(MerkleError::RootNotInStore(LEAVES4[0])),
|
||||
store.get_node(VALUES4[0], NodeIndex::make(mtree.depth(), 0)),
|
||||
Err(MerkleError::RootNotInStore(VALUES4[0])),
|
||||
"Leaf 0 is not a root"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_path(LEAVES4[0], NodeIndex::make(mtree.depth(), 0)),
|
||||
Err(MerkleError::RootNotInStore(LEAVES4[0])),
|
||||
store.get_path(VALUES4[0], NodeIndex::make(mtree.depth(), 0)),
|
||||
Err(MerkleError::RootNotInStore(VALUES4[0])),
|
||||
"Leaf 0 is not a root"
|
||||
);
|
||||
|
||||
@@ -32,33 +54,33 @@ fn test_root_not_in_store() -> Result<(), MerkleError> {
|
||||
|
||||
#[test]
|
||||
fn test_merkle_tree() -> Result<(), MerkleError> {
|
||||
let mtree = MerkleTree::new(LEAVES4.to_vec())?;
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
|
||||
let store = MerkleStore::from(&mtree);
|
||||
|
||||
// STORE LEAVES ARE CORRECT ==============================================================
|
||||
// STORE LEAVES ARE CORRECT -------------------------------------------------------------------
|
||||
// checks the leaves in the store corresponds to the expected values
|
||||
assert_eq!(
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 0)),
|
||||
Ok(LEAVES4[0]),
|
||||
Ok(VALUES4[0]),
|
||||
"node 0 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 1)),
|
||||
Ok(LEAVES4[1]),
|
||||
Ok(VALUES4[1]),
|
||||
"node 1 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 2)),
|
||||
Ok(LEAVES4[2]),
|
||||
Ok(VALUES4[2]),
|
||||
"node 2 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3)),
|
||||
Ok(LEAVES4[3]),
|
||||
Ok(VALUES4[3]),
|
||||
"node 3 must be in the tree"
|
||||
);
|
||||
|
||||
// STORE LEAVES MATCH TREE ===============================================================
|
||||
// STORE LEAVES MATCH TREE --------------------------------------------------------------------
|
||||
// sanity check the values returned by the store and the tree
|
||||
assert_eq!(
|
||||
mtree.get_node(NodeIndex::make(mtree.depth(), 0)),
|
||||
@@ -85,7 +107,7 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
|
||||
// assert the merkle path returned by the store is the same as the one in the tree
|
||||
let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 0)).unwrap();
|
||||
assert_eq!(
|
||||
LEAVES4[0], result.value,
|
||||
VALUES4[0], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
@@ -96,7 +118,7 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
|
||||
|
||||
let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 1)).unwrap();
|
||||
assert_eq!(
|
||||
LEAVES4[1], result.value,
|
||||
VALUES4[1], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
@@ -107,7 +129,7 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
|
||||
|
||||
let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 2)).unwrap();
|
||||
assert_eq!(
|
||||
LEAVES4[2], result.value,
|
||||
VALUES4[2], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
@@ -118,7 +140,7 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
|
||||
|
||||
let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 3)).unwrap();
|
||||
assert_eq!(
|
||||
LEAVES4[3], result.value,
|
||||
VALUES4[3], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
@@ -133,12 +155,12 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
|
||||
#[test]
|
||||
fn test_empty_roots() {
|
||||
let store = MerkleStore::default();
|
||||
let mut root = RpoDigest::new(EMPTY);
|
||||
let mut root = RpoDigest::default();
|
||||
|
||||
for depth in 0..255 {
|
||||
root = Rpo256::merge(&[root; 2]);
|
||||
assert!(
|
||||
store.get_node(root.into(), NodeIndex::make(0, 0)).is_ok(),
|
||||
store.get_node(root, NodeIndex::make(0, 0)).is_ok(),
|
||||
"The root of the empty tree of depth {depth} must be registered"
|
||||
);
|
||||
}
|
||||
@@ -157,13 +179,17 @@ fn test_leaf_paths_for_empty_trees() -> Result<(), MerkleError> {
|
||||
let index = NodeIndex::make(depth, 0);
|
||||
let store_path = store.get_path(smt.root(), index)?;
|
||||
let smt_path = smt.get_path(index)?;
|
||||
assert_eq!(store_path.value, EMPTY, "the leaf of an empty tree is always ZERO");
|
||||
assert_eq!(
|
||||
store_path.value,
|
||||
RpoDigest::default(),
|
||||
"the leaf of an empty tree is always ZERO"
|
||||
);
|
||||
assert_eq!(
|
||||
store_path.path, smt_path,
|
||||
"the returned merkle path does not match the computed values"
|
||||
);
|
||||
assert_eq!(
|
||||
store_path.path.compute_root(depth.into(), EMPTY).unwrap(),
|
||||
store_path.path.compute_root(depth.into(), RpoDigest::default()).unwrap(),
|
||||
smt.root(),
|
||||
"computed root from the path must match the empty tree root"
|
||||
);
|
||||
@@ -174,7 +200,8 @@ fn test_leaf_paths_for_empty_trees() -> Result<(), MerkleError> {
|
||||
|
||||
#[test]
|
||||
fn test_get_invalid_node() {
|
||||
let mtree = MerkleTree::new(LEAVES4.to_vec()).expect("creating a merkle tree must work");
|
||||
let mtree =
|
||||
MerkleTree::new(digests_to_words(&VALUES4)).expect("creating a merkle tree must work");
|
||||
let store = MerkleStore::from(&mtree);
|
||||
let _ = store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3));
|
||||
}
|
||||
@@ -182,19 +209,16 @@ fn test_get_invalid_node() {
|
||||
#[test]
|
||||
fn test_add_sparse_merkle_tree_one_level() -> Result<(), MerkleError> {
|
||||
let keys2: [u64; 2] = [0, 1];
|
||||
let leaves2: [Word; 2] = [int_to_node(1), int_to_node(2)];
|
||||
let smt = SimpleSmt::new(1)
|
||||
.unwrap()
|
||||
.with_leaves(keys2.into_iter().zip(leaves2.into_iter()))
|
||||
.unwrap();
|
||||
let leaves2: [Word; 2] = [int_to_leaf(1), int_to_leaf(2)];
|
||||
let smt = SimpleSmt::with_leaves(1, keys2.into_iter().zip(leaves2.into_iter())).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
|
||||
let idx = NodeIndex::make(1, 0);
|
||||
assert_eq!(smt.get_node(idx).unwrap(), leaves2[0]);
|
||||
assert_eq!(smt.get_node(idx).unwrap(), leaves2[0].into());
|
||||
assert_eq!(store.get_node(smt.root(), idx).unwrap(), smt.get_node(idx).unwrap());
|
||||
|
||||
let idx = NodeIndex::make(1, 1);
|
||||
assert_eq!(smt.get_node(idx).unwrap(), leaves2[1]);
|
||||
assert_eq!(smt.get_node(idx).unwrap(), leaves2[1].into());
|
||||
assert_eq!(store.get_node(smt.root(), idx).unwrap(), smt.get_node(idx).unwrap());
|
||||
|
||||
Ok(())
|
||||
@@ -202,10 +226,11 @@ fn test_add_sparse_merkle_tree_one_level() -> Result<(), MerkleError> {
|
||||
|
||||
#[test]
|
||||
fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
|
||||
let smt = SimpleSmt::new(SimpleSmt::MAX_DEPTH)
|
||||
.unwrap()
|
||||
.with_leaves(KEYS4.into_iter().zip(LEAVES4.into_iter()))
|
||||
.unwrap();
|
||||
let smt = SimpleSmt::with_leaves(
|
||||
SimpleSmt::MAX_DEPTH,
|
||||
KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let store = MerkleStore::from(&smt);
|
||||
|
||||
@@ -213,27 +238,27 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
|
||||
// checks the leaves in the store corresponds to the expected values
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 0)),
|
||||
Ok(LEAVES4[0]),
|
||||
Ok(VALUES4[0]),
|
||||
"node 0 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 1)),
|
||||
Ok(LEAVES4[1]),
|
||||
Ok(VALUES4[1]),
|
||||
"node 1 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 2)),
|
||||
Ok(LEAVES4[2]),
|
||||
Ok(VALUES4[2]),
|
||||
"node 2 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 3)),
|
||||
Ok(LEAVES4[3]),
|
||||
Ok(VALUES4[3]),
|
||||
"node 3 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 4)),
|
||||
Ok(EMPTY),
|
||||
Ok(RpoDigest::default()),
|
||||
"unmodified node 4 must be ZERO"
|
||||
);
|
||||
|
||||
@@ -269,7 +294,7 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
|
||||
// assert the merkle path returned by the store is the same as the one in the tree
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 0)).unwrap();
|
||||
assert_eq!(
|
||||
LEAVES4[0], result.value,
|
||||
VALUES4[0], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
@@ -280,7 +305,7 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 1)).unwrap();
|
||||
assert_eq!(
|
||||
LEAVES4[1], result.value,
|
||||
VALUES4[1], result.value,
|
||||
"Value for merkle path at index 1 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
@@ -291,7 +316,7 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 2)).unwrap();
|
||||
assert_eq!(
|
||||
LEAVES4[2], result.value,
|
||||
VALUES4[2], result.value,
|
||||
"Value for merkle path at index 2 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
@@ -302,7 +327,7 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 3)).unwrap();
|
||||
assert_eq!(
|
||||
LEAVES4[3], result.value,
|
||||
VALUES4[3], result.value,
|
||||
"Value for merkle path at index 3 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
@@ -312,7 +337,11 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
|
||||
);
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 4)).unwrap();
|
||||
assert_eq!(EMPTY, result.value, "Value for merkle path at index 4 must match leaf value");
|
||||
assert_eq!(
|
||||
RpoDigest::default(),
|
||||
result.value,
|
||||
"Value for merkle path at index 4 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_path(NodeIndex::make(smt.depth(), 4)),
|
||||
Ok(result.path),
|
||||
@@ -324,7 +353,7 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
|
||||
|
||||
#[test]
|
||||
fn test_add_merkle_paths() -> Result<(), MerkleError> {
|
||||
let mtree = MerkleTree::new(LEAVES4.to_vec())?;
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
|
||||
|
||||
let i0 = 0;
|
||||
let p0 = mtree.get_path(NodeIndex::make(2, i0)).unwrap();
|
||||
@@ -339,106 +368,105 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> {
|
||||
let p3 = mtree.get_path(NodeIndex::make(2, i3)).unwrap();
|
||||
|
||||
let paths = [
|
||||
(i0, LEAVES4[i0 as usize], p0),
|
||||
(i1, LEAVES4[i1 as usize], p1),
|
||||
(i2, LEAVES4[i2 as usize], p2),
|
||||
(i3, LEAVES4[i3 as usize], p3),
|
||||
(i0, VALUES4[i0 as usize], p0),
|
||||
(i1, VALUES4[i1 as usize], p1),
|
||||
(i2, VALUES4[i2 as usize], p2),
|
||||
(i3, VALUES4[i3 as usize], p3),
|
||||
];
|
||||
|
||||
let mut store = MerkleStore::default();
|
||||
store.add_merkle_paths(paths.clone()).expect("the valid paths must work");
|
||||
|
||||
let depth = 2;
|
||||
let set = MerklePathSet::new(depth).with_paths(paths).unwrap();
|
||||
let pmt = PartialMerkleTree::with_paths(paths).unwrap();
|
||||
|
||||
// STORE LEAVES ARE CORRECT ==============================================================
|
||||
// checks the leaves in the store corresponds to the expected values
|
||||
assert_eq!(
|
||||
store.get_node(set.root(), NodeIndex::make(set.depth(), 0)),
|
||||
Ok(LEAVES4[0]),
|
||||
"node 0 must be in the set"
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)),
|
||||
Ok(VALUES4[0]),
|
||||
"node 0 must be in the pmt"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(set.root(), NodeIndex::make(set.depth(), 1)),
|
||||
Ok(LEAVES4[1]),
|
||||
"node 1 must be in the set"
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)),
|
||||
Ok(VALUES4[1]),
|
||||
"node 1 must be in the pmt"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(set.root(), NodeIndex::make(set.depth(), 2)),
|
||||
Ok(LEAVES4[2]),
|
||||
"node 2 must be in the set"
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)),
|
||||
Ok(VALUES4[2]),
|
||||
"node 2 must be in the pmt"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(set.root(), NodeIndex::make(set.depth(), 3)),
|
||||
Ok(LEAVES4[3]),
|
||||
"node 3 must be in the set"
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)),
|
||||
Ok(VALUES4[3]),
|
||||
"node 3 must be in the pmt"
|
||||
);
|
||||
|
||||
// STORE LEAVES MATCH SET ================================================================
|
||||
// sanity check the values returned by the store and the set
|
||||
// STORE LEAVES MATCH PMT ================================================================
|
||||
// sanity check the values returned by the store and the pmt
|
||||
assert_eq!(
|
||||
set.get_node(NodeIndex::make(set.depth(), 0)),
|
||||
store.get_node(set.root(), NodeIndex::make(set.depth(), 0)),
|
||||
"node 0 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 0)),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)),
|
||||
"node 0 must be the same for both PartialMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
set.get_node(NodeIndex::make(set.depth(), 1)),
|
||||
store.get_node(set.root(), NodeIndex::make(set.depth(), 1)),
|
||||
"node 1 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 1)),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)),
|
||||
"node 1 must be the same for both PartialMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
set.get_node(NodeIndex::make(set.depth(), 2)),
|
||||
store.get_node(set.root(), NodeIndex::make(set.depth(), 2)),
|
||||
"node 2 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 2)),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)),
|
||||
"node 2 must be the same for both PartialMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
set.get_node(NodeIndex::make(set.depth(), 3)),
|
||||
store.get_node(set.root(), NodeIndex::make(set.depth(), 3)),
|
||||
"node 3 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 3)),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)),
|
||||
"node 3 must be the same for both PartialMerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
// STORE MERKLE PATH MATCHS ==============================================================
|
||||
// assert the merkle path returned by the store is the same as the one in the set
|
||||
let result = store.get_path(set.root(), NodeIndex::make(set.depth(), 0)).unwrap();
|
||||
// assert the merkle path returned by the store is the same as the one in the pmt
|
||||
let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)).unwrap();
|
||||
assert_eq!(
|
||||
LEAVES4[0], result.value,
|
||||
VALUES4[0], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
set.get_path(NodeIndex::make(set.depth(), 0)),
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 0)),
|
||||
Ok(result.path),
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(set.root(), NodeIndex::make(set.depth(), 1)).unwrap();
|
||||
let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)).unwrap();
|
||||
assert_eq!(
|
||||
LEAVES4[1], result.value,
|
||||
VALUES4[1], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
set.get_path(NodeIndex::make(set.depth(), 1)),
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 1)),
|
||||
Ok(result.path),
|
||||
"merkle path for index 1 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(set.root(), NodeIndex::make(set.depth(), 2)).unwrap();
|
||||
let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)).unwrap();
|
||||
assert_eq!(
|
||||
LEAVES4[2], result.value,
|
||||
VALUES4[2], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
set.get_path(NodeIndex::make(set.depth(), 2)),
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 2)),
|
||||
Ok(result.path),
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(set.root(), NodeIndex::make(set.depth(), 3)).unwrap();
|
||||
let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)).unwrap();
|
||||
assert_eq!(
|
||||
LEAVES4[3], result.value,
|
||||
VALUES4[3], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
set.get_path(NodeIndex::make(set.depth(), 3)),
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 3)),
|
||||
Ok(result.path),
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
@@ -449,7 +477,7 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> {
|
||||
#[test]
|
||||
fn wont_open_to_different_depth_root() {
|
||||
let empty = EmptySubtreeRoots::empty_hashes(64);
|
||||
let a = [Felt::new(1); 4];
|
||||
let a = [ONE; 4];
|
||||
let b = [Felt::new(2); 4];
|
||||
|
||||
// Compute the root for a different depth. We cherry-pick this specific depth to prevent a
|
||||
@@ -459,7 +487,6 @@ fn wont_open_to_different_depth_root() {
|
||||
for depth in (1..=63).rev() {
|
||||
root = Rpo256::merge(&[root, empty[depth]]);
|
||||
}
|
||||
let root = Word::from(root);
|
||||
|
||||
// For this example, the depth of the Merkle tree is 1, as we have only two leaves. Here we
|
||||
// attempt to fetch a node on the maximum depth, and it should fail because the root shouldn't
|
||||
@@ -473,7 +500,7 @@ fn wont_open_to_different_depth_root() {
|
||||
|
||||
#[test]
|
||||
fn store_path_opens_from_leaf() {
|
||||
let a = [Felt::new(1); 4];
|
||||
let a = [ONE; 4];
|
||||
let b = [Felt::new(2); 4];
|
||||
let c = [Felt::new(3); 4];
|
||||
let d = [Felt::new(4); 4];
|
||||
@@ -487,22 +514,22 @@ fn store_path_opens_from_leaf() {
|
||||
let k = Rpo256::merge(&[e.into(), f.into()]);
|
||||
let l = Rpo256::merge(&[g.into(), h.into()]);
|
||||
|
||||
let m = Rpo256::merge(&[i.into(), j.into()]);
|
||||
let n = Rpo256::merge(&[k.into(), l.into()]);
|
||||
let m = Rpo256::merge(&[i, j]);
|
||||
let n = Rpo256::merge(&[k, l]);
|
||||
|
||||
let root = Rpo256::merge(&[m.into(), n.into()]);
|
||||
let root = Rpo256::merge(&[m, n]);
|
||||
|
||||
let mtree = MerkleTree::new(vec![a, b, c, d, e, f, g, h]).unwrap();
|
||||
let store = MerkleStore::from(&mtree);
|
||||
let path = store.get_path(root.into(), NodeIndex::make(3, 1)).unwrap().path;
|
||||
let path = store.get_path(root, NodeIndex::make(3, 1)).unwrap().path;
|
||||
|
||||
let expected = MerklePath::new([a.into(), j.into(), n.into()].to_vec());
|
||||
let expected = MerklePath::new([a.into(), j, n].to_vec());
|
||||
assert_eq!(path, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_node() -> Result<(), MerkleError> {
|
||||
let mtree = MerkleTree::new(LEAVES4.to_vec())?;
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
|
||||
let mut store = MerkleStore::from(&mtree);
|
||||
let value = int_to_node(42);
|
||||
let index = NodeIndex::make(mtree.depth(), 0);
|
||||
@@ -514,7 +541,7 @@ fn test_set_node() -> Result<(), MerkleError> {
|
||||
|
||||
#[test]
|
||||
fn test_constructors() -> Result<(), MerkleError> {
|
||||
let mtree = MerkleTree::new(LEAVES4.to_vec())?;
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
|
||||
let store = MerkleStore::from(&mtree);
|
||||
|
||||
let depth = mtree.depth();
|
||||
@@ -526,10 +553,11 @@ fn test_constructors() -> Result<(), MerkleError> {
|
||||
}
|
||||
|
||||
let depth = 32;
|
||||
let smt = SimpleSmt::new(depth)
|
||||
.unwrap()
|
||||
.with_leaves(KEYS4.into_iter().zip(LEAVES4.into_iter()))
|
||||
.unwrap();
|
||||
let smt = SimpleSmt::with_leaves(
|
||||
depth,
|
||||
KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()),
|
||||
)
|
||||
.unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let depth = smt.depth();
|
||||
|
||||
@@ -541,30 +569,30 @@ fn test_constructors() -> Result<(), MerkleError> {
|
||||
|
||||
let d = 2;
|
||||
let paths = [
|
||||
(0, LEAVES4[0], mtree.get_path(NodeIndex::make(d, 0)).unwrap()),
|
||||
(1, LEAVES4[1], mtree.get_path(NodeIndex::make(d, 1)).unwrap()),
|
||||
(2, LEAVES4[2], mtree.get_path(NodeIndex::make(d, 2)).unwrap()),
|
||||
(3, LEAVES4[3], mtree.get_path(NodeIndex::make(d, 3)).unwrap()),
|
||||
(0, VALUES4[0], mtree.get_path(NodeIndex::make(d, 0)).unwrap()),
|
||||
(1, VALUES4[1], mtree.get_path(NodeIndex::make(d, 1)).unwrap()),
|
||||
(2, VALUES4[2], mtree.get_path(NodeIndex::make(d, 2)).unwrap()),
|
||||
(3, VALUES4[3], mtree.get_path(NodeIndex::make(d, 3)).unwrap()),
|
||||
];
|
||||
|
||||
let mut store1 = MerkleStore::default();
|
||||
store1.add_merkle_paths(paths.clone())?;
|
||||
|
||||
let mut store2 = MerkleStore::default();
|
||||
store2.add_merkle_path(0, LEAVES4[0], mtree.get_path(NodeIndex::make(d, 0))?)?;
|
||||
store2.add_merkle_path(1, LEAVES4[1], mtree.get_path(NodeIndex::make(d, 1))?)?;
|
||||
store2.add_merkle_path(2, LEAVES4[2], mtree.get_path(NodeIndex::make(d, 2))?)?;
|
||||
store2.add_merkle_path(3, LEAVES4[3], mtree.get_path(NodeIndex::make(d, 3))?)?;
|
||||
let set = MerklePathSet::new(d).with_paths(paths).unwrap();
|
||||
store2.add_merkle_path(0, VALUES4[0], mtree.get_path(NodeIndex::make(d, 0))?)?;
|
||||
store2.add_merkle_path(1, VALUES4[1], mtree.get_path(NodeIndex::make(d, 1))?)?;
|
||||
store2.add_merkle_path(2, VALUES4[2], mtree.get_path(NodeIndex::make(d, 2))?)?;
|
||||
store2.add_merkle_path(3, VALUES4[3], mtree.get_path(NodeIndex::make(d, 3))?)?;
|
||||
let pmt = PartialMerkleTree::with_paths(paths).unwrap();
|
||||
|
||||
for key in [0, 1, 2, 3] {
|
||||
let index = NodeIndex::make(d, key);
|
||||
let value_path1 = store1.get_path(set.root(), index)?;
|
||||
let value_path2 = store2.get_path(set.root(), index)?;
|
||||
let value_path1 = store1.get_path(pmt.root(), index)?;
|
||||
let value_path2 = store2.get_path(pmt.root(), index)?;
|
||||
assert_eq!(value_path1, value_path2);
|
||||
|
||||
let index = NodeIndex::make(d, key);
|
||||
assert_eq!(set.get_path(index)?, value_path1.path);
|
||||
assert_eq!(pmt.get_path(index)?, value_path1.path);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -575,11 +603,11 @@ fn node_path_should_be_truncated_by_midtier_insert() {
|
||||
let key = 0b11010010_11001100_11001100_11001100_11001100_11001100_11001100_11001100_u64;
|
||||
|
||||
let mut store = MerkleStore::new();
|
||||
let root: Word = EmptySubtreeRoots::empty_hashes(64)[0].into();
|
||||
let root: RpoDigest = EmptySubtreeRoots::empty_hashes(64)[0];
|
||||
|
||||
// insert first node - works as expected
|
||||
let depth = 64;
|
||||
let node = [Felt::new(key); WORD_SIZE];
|
||||
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
|
||||
let index = NodeIndex::new(depth, key).unwrap();
|
||||
let root = store.set_node(root, index, node).unwrap().root;
|
||||
let result = store.get_node(root, index).unwrap();
|
||||
@@ -592,7 +620,7 @@ fn node_path_should_be_truncated_by_midtier_insert() {
|
||||
let key = key ^ (1 << 63);
|
||||
let key = key >> 8;
|
||||
let depth = 56;
|
||||
let node = [Felt::new(key); WORD_SIZE];
|
||||
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
|
||||
let index = NodeIndex::new(depth, key).unwrap();
|
||||
let root = store.set_node(root, index, node).unwrap().root;
|
||||
let result = store.get_node(root, index).unwrap();
|
||||
@@ -608,16 +636,19 @@ fn node_path_should_be_truncated_by_midtier_insert() {
|
||||
assert!(store.get_node(root, index).is_err());
|
||||
}
|
||||
|
||||
// LEAF TRAVERSAL
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn get_leaf_depth_works_depth_64() {
|
||||
let mut store = MerkleStore::new();
|
||||
let mut root: Word = EmptySubtreeRoots::empty_hashes(64)[0].into();
|
||||
let mut root: RpoDigest = EmptySubtreeRoots::empty_hashes(64)[0];
|
||||
let key = u64::MAX;
|
||||
|
||||
// this will create a rainbow tree and test all opening to depth 64
|
||||
for d in 0..64 {
|
||||
let k = key & (u64::MAX >> d);
|
||||
let node = [Felt::new(k); WORD_SIZE];
|
||||
let node = RpoDigest::from([Felt::new(k); WORD_SIZE]);
|
||||
let index = NodeIndex::new(64, k).unwrap();
|
||||
|
||||
// assert the leaf doesn't exist before the insert. the returned depth should always
|
||||
@@ -634,14 +665,14 @@ fn get_leaf_depth_works_depth_64() {
|
||||
#[test]
|
||||
fn get_leaf_depth_works_with_incremental_depth() {
|
||||
let mut store = MerkleStore::new();
|
||||
let mut root: Word = EmptySubtreeRoots::empty_hashes(64)[0].into();
|
||||
let mut root: RpoDigest = EmptySubtreeRoots::empty_hashes(64)[0];
|
||||
|
||||
// insert some path to the left of the root and assert it
|
||||
let key = 0b01001011_10110110_00001101_01110100_00111011_10101101_00000100_01000001_u64;
|
||||
assert_eq!(0, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
let depth = 64;
|
||||
let index = NodeIndex::new(depth, key).unwrap();
|
||||
let node = [Felt::new(key); WORD_SIZE];
|
||||
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
|
||||
root = store.set_node(root, index, node).unwrap().root;
|
||||
assert_eq!(depth, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
|
||||
@@ -650,7 +681,7 @@ fn get_leaf_depth_works_with_incremental_depth() {
|
||||
assert_eq!(1, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
let depth = 16;
|
||||
let index = NodeIndex::new(depth, key >> (64 - depth)).unwrap();
|
||||
let node = [Felt::new(key); WORD_SIZE];
|
||||
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
|
||||
root = store.set_node(root, index, node).unwrap().root;
|
||||
assert_eq!(depth, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
|
||||
@@ -658,7 +689,7 @@ fn get_leaf_depth_works_with_incremental_depth() {
|
||||
let key = 0b11001011_10110111_00000000_00000000_00000000_00000000_00000000_00000000_u64;
|
||||
assert_eq!(16, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
let index = NodeIndex::new(depth, key >> (64 - depth)).unwrap();
|
||||
let node = [Felt::new(key); WORD_SIZE];
|
||||
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
|
||||
root = store.set_node(root, index, node).unwrap().root;
|
||||
assert_eq!(depth, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
|
||||
@@ -667,7 +698,7 @@ fn get_leaf_depth_works_with_incremental_depth() {
|
||||
assert_eq!(15, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
let depth = 17;
|
||||
let index = NodeIndex::new(depth, key >> (64 - depth)).unwrap();
|
||||
let node = [Felt::new(key); WORD_SIZE];
|
||||
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
|
||||
root = store.set_node(root, index, node).unwrap().root;
|
||||
assert_eq!(depth, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
}
|
||||
@@ -675,7 +706,7 @@ fn get_leaf_depth_works_with_incremental_depth() {
|
||||
#[test]
|
||||
fn get_leaf_depth_works_with_depth_8() {
|
||||
let mut store = MerkleStore::new();
|
||||
let mut root: Word = EmptySubtreeRoots::empty_hashes(8)[0].into();
|
||||
let mut root: RpoDigest = EmptySubtreeRoots::empty_hashes(8)[0];
|
||||
|
||||
// insert some random, 8 depth keys. `a` diverges from the first bit
|
||||
let a = 0b01101001_u64;
|
||||
@@ -685,7 +716,7 @@ fn get_leaf_depth_works_with_depth_8() {
|
||||
|
||||
for k in [a, b, c, d] {
|
||||
let index = NodeIndex::new(8, k).unwrap();
|
||||
let node = [Felt::new(k); WORD_SIZE];
|
||||
let node = RpoDigest::from([Felt::new(k); WORD_SIZE]);
|
||||
root = store.set_node(root, index, node).unwrap().root;
|
||||
}
|
||||
|
||||
@@ -718,12 +749,181 @@ fn get_leaf_depth_works_with_depth_8() {
|
||||
assert_eq!(Err(MerkleError::DepthTooBig(9)), store.get_leaf_depth(root, 8, a));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_lone_leaf() {
|
||||
let mut store = MerkleStore::new();
|
||||
let empty = EmptySubtreeRoots::empty_hashes(64);
|
||||
let mut root: RpoDigest = empty[0];
|
||||
|
||||
// insert a single leaf into the store at depth 64
|
||||
let key_a = 0b01010101_10101010_00001111_01110100_00111011_10101101_00000100_01000001_u64;
|
||||
let idx_a = NodeIndex::make(64, key_a);
|
||||
let val_a = RpoDigest::from([ONE, ONE, ONE, ONE]);
|
||||
root = store.set_node(root, idx_a, val_a).unwrap().root;
|
||||
|
||||
// for every ancestor of A, A should be a long leaf
|
||||
for depth in 1..64 {
|
||||
let parent_index = NodeIndex::make(depth, key_a >> (64 - depth));
|
||||
let parent = store.get_node(root, parent_index).unwrap();
|
||||
|
||||
let res = store.find_lone_leaf(parent, parent_index, 64).unwrap();
|
||||
assert_eq!(res, Some((idx_a, val_a)));
|
||||
}
|
||||
|
||||
// insert another leaf into the store such that it has the same 8 bit prefix as A
|
||||
let key_b = 0b01010101_01111010_00001111_01110100_00111011_10101101_00000100_01000001_u64;
|
||||
let idx_b = NodeIndex::make(64, key_b);
|
||||
let val_b = RpoDigest::from([ONE, ONE, ONE, ZERO]);
|
||||
root = store.set_node(root, idx_b, val_b).unwrap().root;
|
||||
|
||||
// for any node which is common between A and B, find_lone_leaf() should return None as the
|
||||
// node has two descendants
|
||||
for depth in 1..9 {
|
||||
let parent_index = NodeIndex::make(depth, key_a >> (64 - depth));
|
||||
let parent = store.get_node(root, parent_index).unwrap();
|
||||
|
||||
let res = store.find_lone_leaf(parent, parent_index, 64).unwrap();
|
||||
assert_eq!(res, None);
|
||||
}
|
||||
|
||||
// for other ancestors of A and B, A and B should be lone leaves respectively
|
||||
for depth in 9..64 {
|
||||
let parent_index = NodeIndex::make(depth, key_a >> (64 - depth));
|
||||
let parent = store.get_node(root, parent_index).unwrap();
|
||||
|
||||
let res = store.find_lone_leaf(parent, parent_index, 64).unwrap();
|
||||
assert_eq!(res, Some((idx_a, val_a)));
|
||||
}
|
||||
|
||||
for depth in 9..64 {
|
||||
let parent_index = NodeIndex::make(depth, key_b >> (64 - depth));
|
||||
let parent = store.get_node(root, parent_index).unwrap();
|
||||
|
||||
let res = store.find_lone_leaf(parent, parent_index, 64).unwrap();
|
||||
assert_eq!(res, Some((idx_b, val_b)));
|
||||
}
|
||||
|
||||
// for any other node, find_lone_leaf() should return None as they have no leaf nodes
|
||||
let parent_index = NodeIndex::make(16, 0b01010101_11111111);
|
||||
let parent = store.get_node(root, parent_index).unwrap();
|
||||
let res = store.find_lone_leaf(parent, parent_index, 64).unwrap();
|
||||
assert_eq!(res, None);
|
||||
}
|
||||
|
||||
// SUBSET EXTRACTION
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn mstore_subset() {
|
||||
// add a Merkle tree of depth 3 to the store
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let mut store = MerkleStore::default();
|
||||
let empty_store_num_nodes = store.nodes.len();
|
||||
store.extend(mtree.inner_nodes());
|
||||
|
||||
// build 3 subtrees contained within the above Merkle tree; note that subtree2 is a subset
|
||||
// of subtree1
|
||||
let subtree1 = MerkleTree::new(digests_to_words(&VALUES8[..4])).unwrap();
|
||||
let subtree2 = MerkleTree::new(digests_to_words(&VALUES8[2..4])).unwrap();
|
||||
let subtree3 = MerkleTree::new(digests_to_words(&VALUES8[6..])).unwrap();
|
||||
|
||||
// --- extract all 3 subtrees ---------------------------------------------
|
||||
|
||||
let substore = store.subset([subtree1.root(), subtree2.root(), subtree3.root()].iter());
|
||||
|
||||
// number of nodes should increase by 4: 3 nodes form subtree1 and 1 node from subtree3
|
||||
assert_eq!(substore.nodes.len(), empty_store_num_nodes + 4);
|
||||
|
||||
// make sure paths that all subtrees are in the store
|
||||
check_mstore_subtree(&substore, &subtree1);
|
||||
check_mstore_subtree(&substore, &subtree2);
|
||||
check_mstore_subtree(&substore, &subtree3);
|
||||
|
||||
// --- extract subtrees 1 and 3 -------------------------------------------
|
||||
// this should give the same result as above as subtree2 is nested within subtree1
|
||||
|
||||
let substore = store.subset([subtree1.root(), subtree3.root()].iter());
|
||||
|
||||
// number of nodes should increase by 4: 3 nodes form subtree1 and 1 node from subtree3
|
||||
assert_eq!(substore.nodes.len(), empty_store_num_nodes + 4);
|
||||
|
||||
// make sure paths that all subtrees are in the store
|
||||
check_mstore_subtree(&substore, &subtree1);
|
||||
check_mstore_subtree(&substore, &subtree2);
|
||||
check_mstore_subtree(&substore, &subtree3);
|
||||
}
|
||||
|
||||
fn check_mstore_subtree(store: &MerkleStore, subtree: &MerkleTree) {
|
||||
for (i, value) in subtree.leaves() {
|
||||
let index = NodeIndex::new(subtree.depth(), i).unwrap();
|
||||
let path1 = store.get_path(subtree.root(), index).unwrap();
|
||||
assert_eq!(*path1.value, *value);
|
||||
|
||||
let path2 = subtree.get_path(index).unwrap();
|
||||
assert_eq!(path1.path, path2);
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn test_serialization() -> Result<(), Box<dyn Error>> {
|
||||
let mtree = MerkleTree::new(LEAVES4.to_vec())?;
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
|
||||
let store = MerkleStore::from(&mtree);
|
||||
let decoded = MerkleStore::read_from_bytes(&store.to_bytes()).expect("deserialization failed");
|
||||
assert_eq!(store, decoded);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// MERKLE RECORDER
|
||||
// ================================================================================================
|
||||
#[test]
|
||||
fn test_recorder() {
|
||||
// instantiate recorder from MerkleTree and SimpleSmt
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4)).unwrap();
|
||||
let smtree = SimpleSmt::with_leaves(
|
||||
64,
|
||||
KEYS8.into_iter().zip(VALUES8.into_iter().map(|x| x.into()).rev()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut recorder: RecordingMerkleStore =
|
||||
mtree.inner_nodes().chain(smtree.inner_nodes()).collect();
|
||||
|
||||
// get nodes from both trees and make sure they are correct
|
||||
let index_0 = NodeIndex::new(mtree.depth(), 0).unwrap();
|
||||
let node = recorder.get_node(mtree.root(), index_0).unwrap();
|
||||
assert_eq!(node, mtree.get_node(index_0).unwrap());
|
||||
|
||||
let index_1 = NodeIndex::new(smtree.depth(), 1).unwrap();
|
||||
let node = recorder.get_node(smtree.root(), index_1).unwrap();
|
||||
assert_eq!(node, smtree.get_node(index_1).unwrap());
|
||||
|
||||
// insert a value and assert that when we request it next time it is accurate
|
||||
let new_value = [ZERO, ZERO, ONE, ONE].into();
|
||||
let index_2 = NodeIndex::new(smtree.depth(), 2).unwrap();
|
||||
let root = recorder.set_node(smtree.root(), index_2, new_value).unwrap().root;
|
||||
assert_eq!(recorder.get_node(root, index_2).unwrap(), new_value);
|
||||
|
||||
// construct the proof
|
||||
let rec_map = recorder.into_inner();
|
||||
let (_, proof) = rec_map.finalize();
|
||||
let merkle_store: MerkleStore = proof.into();
|
||||
|
||||
// make sure the proof contains all nodes from both trees
|
||||
let node = merkle_store.get_node(mtree.root(), index_0).unwrap();
|
||||
assert_eq!(node, mtree.get_node(index_0).unwrap());
|
||||
|
||||
let node = merkle_store.get_node(smtree.root(), index_1).unwrap();
|
||||
assert_eq!(node, smtree.get_node(index_1).unwrap());
|
||||
|
||||
let node = merkle_store.get_node(smtree.root(), index_2).unwrap();
|
||||
assert_eq!(node, smtree.get_leaf(index_2.value()).unwrap().into());
|
||||
|
||||
// assert that is doesnt contain nodes that were not recorded
|
||||
let not_recorded_index = NodeIndex::new(smtree.depth(), 4).unwrap();
|
||||
assert!(merkle_store.get_node(smtree.root(), not_recorded_index).is_err());
|
||||
assert!(smtree.get_node(not_recorded_index).is_ok());
|
||||
}
|
||||
|
||||
48
src/merkle/tiered_smt/error.rs
Normal file
48
src/merkle/tiered_smt/error.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use core::fmt::Display;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum TieredSmtProofError {
|
||||
EntriesEmpty,
|
||||
EmptyValueNotAllowed,
|
||||
MismatchedPrefixes(u64, u64),
|
||||
MultipleEntriesOutsideLastTier,
|
||||
NotATierPath(u8),
|
||||
PathTooLong,
|
||||
}
|
||||
|
||||
impl Display for TieredSmtProofError {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
match self {
|
||||
TieredSmtProofError::EntriesEmpty => {
|
||||
write!(f, "Missing entries for tiered sparse merkle tree proof")
|
||||
}
|
||||
TieredSmtProofError::EmptyValueNotAllowed => {
|
||||
write!(
|
||||
f,
|
||||
"The empty value [0, 0, 0, 0] is not allowed inside a tiered sparse merkle tree"
|
||||
)
|
||||
}
|
||||
TieredSmtProofError::MismatchedPrefixes(first, second) => {
|
||||
write!(f, "Not all leaves have the same prefix. First {first} second {second}")
|
||||
}
|
||||
TieredSmtProofError::MultipleEntriesOutsideLastTier => {
|
||||
write!(f, "Multiple entries are only allowed for the last tier (depth 64)")
|
||||
}
|
||||
TieredSmtProofError::NotATierPath(got) => {
|
||||
write!(
|
||||
f,
|
||||
"Path length does not correspond to a tier. Got {got} Expected one of 16, 32, 48, 64"
|
||||
)
|
||||
}
|
||||
TieredSmtProofError::PathTooLong => {
|
||||
write!(
|
||||
f,
|
||||
"Path longer than maximum depth of 64 for tiered sparse merkle tree proof"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for TieredSmtProofError {}
|
||||
509
src/merkle/tiered_smt/mod.rs
Normal file
509
src/merkle/tiered_smt/mod.rs
Normal file
@@ -0,0 +1,509 @@
|
||||
use super::{
|
||||
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex,
|
||||
Rpo256, RpoDigest, StarkField, Vec, Word,
|
||||
};
|
||||
use crate::utils::vec;
|
||||
use core::{cmp, ops::Deref};
|
||||
|
||||
mod nodes;
|
||||
use nodes::NodeStore;
|
||||
|
||||
mod values;
|
||||
use values::ValueStore;
|
||||
|
||||
mod proof;
|
||||
pub use proof::TieredSmtProof;
|
||||
|
||||
mod error;
|
||||
pub use error::TieredSmtProofError;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// TIERED SPARSE MERKLE TREE
|
||||
// ================================================================================================
|
||||
|
||||
/// Tiered (compacted) Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and
|
||||
/// values are represented by 4 field elements.
|
||||
///
|
||||
/// Leaves in the tree can exist only on specific depths called "tiers". These depths are: 16, 32,
|
||||
/// 48, and 64. Initially, when a tree is empty, it is equivalent to an empty Sparse Merkle tree
|
||||
/// of depth 64 (i.e., leaves at depth 64 are set to [ZERO; 4]). As non-empty values are inserted
|
||||
/// into the tree they are added to the first available tier.
|
||||
///
|
||||
/// For example, when the first key-value pair is inserted, it will be stored in a node at depth
|
||||
/// 16 such that the 16 most significant bits of the key determine the position of the node at
|
||||
/// depth 16. If another value with a key sharing the same 16-bit prefix is inserted, both values
|
||||
/// move into the next tier (depth 32). This process is repeated until values end up at the bottom
|
||||
/// tier (depth 64). If multiple values have keys with a common 64-bit prefix, such key-value pairs
|
||||
/// are stored in a sorted list at the bottom tier.
|
||||
///
|
||||
/// To differentiate between internal and leaf nodes, node values are computed as follows:
|
||||
/// - Internal nodes: hash(left_child, right_child).
|
||||
/// - Leaf node at depths 16, 32, or 64: hash(key, value, domain=depth).
|
||||
/// - Leaf node at depth 64: hash([key_0, value_0, ..., key_n, value_n], domain=64).
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct TieredSmt {
|
||||
root: RpoDigest,
|
||||
nodes: NodeStore,
|
||||
values: ValueStore,
|
||||
}
|
||||
|
||||
impl TieredSmt {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// The number of levels between tiers.
|
||||
pub const TIER_SIZE: u8 = 16;
|
||||
|
||||
/// Depths at which leaves can exist in a tiered SMT.
|
||||
pub const TIER_DEPTHS: [u8; 4] = [16, 32, 48, 64];
|
||||
|
||||
/// Maximum node depth. This is also the bottom tier of the tree.
|
||||
pub const MAX_DEPTH: u8 = 64;
|
||||
|
||||
/// Value of an empty leaf.
|
||||
pub const EMPTY_VALUE: Word = super::EMPTY_WORD;
|
||||
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [TieredSmt] instantiated with the specified key-value pairs.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the provided entries contain multiple values for the same key.
|
||||
pub fn with_entries<R, I>(entries: R) -> Result<Self, MerkleError>
|
||||
where
|
||||
R: IntoIterator<IntoIter = I>,
|
||||
I: Iterator<Item = (RpoDigest, Word)> + ExactSizeIterator,
|
||||
{
|
||||
// create an empty tree
|
||||
let mut tree = Self::default();
|
||||
|
||||
// append leaves to the tree returning an error if a duplicate entry for the same key
|
||||
// is found
|
||||
let mut empty_entries = BTreeSet::new();
|
||||
for (key, value) in entries {
|
||||
let old_value = tree.insert(key, value);
|
||||
if old_value != Self::EMPTY_VALUE || empty_entries.contains(&key) {
|
||||
return Err(MerkleError::DuplicateValuesForKey(key));
|
||||
}
|
||||
// if we've processed an empty entry, add the key to the set of empty entry keys, and
|
||||
// if this key was already in the set, return an error
|
||||
if value == Self::EMPTY_VALUE && !empty_entries.insert(key) {
|
||||
return Err(MerkleError::DuplicateValuesForKey(key));
|
||||
}
|
||||
}
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the root of this Merkle tree.
|
||||
pub const fn root(&self) -> RpoDigest {
|
||||
self.root
|
||||
}
|
||||
|
||||
/// Returns a node at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index depth is 0 or greater than 64.
|
||||
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
||||
/// when a leaf node with the same index prefix exists at a tier higher than the requested
|
||||
/// node.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
self.nodes.get_node(index)
|
||||
}
|
||||
|
||||
/// Returns a Merkle path from the node at the specified index to the root.
|
||||
///
|
||||
/// The node itself is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index depth is 0 or greater than 64.
|
||||
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
||||
/// when a leaf node with the same index prefix exists at a tier higher than the node to
|
||||
/// which the path is requested.
|
||||
pub fn get_path(&self, index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
self.nodes.get_path(index)
|
||||
}
|
||||
|
||||
/// Returns the value associated with the specified key.
|
||||
///
|
||||
/// If nothing was inserted into this tree for the specified key, [ZERO; 4] is returned.
|
||||
pub fn get_value(&self, key: RpoDigest) -> Word {
|
||||
match self.values.get(&key) {
|
||||
Some(value) => *value,
|
||||
None => Self::EMPTY_VALUE,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a proof for a key-value pair defined by the specified key.
|
||||
///
|
||||
/// The proof can be used to attest membership of this key-value pair in a Tiered Sparse Merkle
|
||||
/// Tree defined by the same root as this tree.
|
||||
pub fn prove(&self, key: RpoDigest) -> TieredSmtProof {
|
||||
let (path, index, leaf_exists) = self.nodes.get_proof(&key);
|
||||
|
||||
let entries = if index.depth() == Self::MAX_DEPTH {
|
||||
match self.values.get_all(index.value()) {
|
||||
Some(entries) => entries,
|
||||
None => vec![(key, Self::EMPTY_VALUE)],
|
||||
}
|
||||
} else if leaf_exists {
|
||||
let entry =
|
||||
self.values.get_first(index_to_prefix(&index)).expect("leaf entry not found");
|
||||
debug_assert_eq!(entry.0, key);
|
||||
vec![*entry]
|
||||
} else {
|
||||
vec![(key, Self::EMPTY_VALUE)]
|
||||
};
|
||||
|
||||
TieredSmtProof::new(path, entries).expect("Bug detected, TSMT produced invalid proof")
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts the provided value into the tree under the specified key and returns the value
|
||||
/// previously stored under this key.
|
||||
///
|
||||
/// If the value for the specified key was not previously set, [ZERO; 4] is returned.
|
||||
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Word {
|
||||
// if an empty value is being inserted, remove the leaf node to make it look as if the
|
||||
// value was never inserted
|
||||
if value == Self::EMPTY_VALUE {
|
||||
return self.remove_leaf_node(key);
|
||||
}
|
||||
|
||||
// insert the value into the value store, and if the key was already in the store, update
|
||||
// it with the new value
|
||||
if let Some(old_value) = self.values.insert(key, value) {
|
||||
if old_value != value {
|
||||
// if the new value is different from the old value, determine the location of
|
||||
// the leaf node for this key, build the node, and update the root
|
||||
let (index, leaf_exists) = self.nodes.get_leaf_index(&key);
|
||||
debug_assert!(leaf_exists);
|
||||
let node = self.build_leaf_node(index, key, value);
|
||||
self.root = self.nodes.update_leaf_node(index, node);
|
||||
}
|
||||
return old_value;
|
||||
};
|
||||
|
||||
// determine the location for the leaf node; this index could have 3 different meanings:
|
||||
// - it points to a root of an empty subtree or an empty node at depth 64; in this case,
|
||||
// we can replace the node with the value node immediately.
|
||||
// - it points to an existing leaf at the bottom tier (i.e., depth = 64); in this case,
|
||||
// we need to process update the bottom leaf.
|
||||
// - it points to an existing leaf node for a different key with the same prefix (same
|
||||
// key case was handled above); in this case, we need to move the leaf to a lower tier
|
||||
let (index, leaf_exists) = self.nodes.get_leaf_index(&key);
|
||||
|
||||
self.root = if leaf_exists && index.depth() == Self::MAX_DEPTH {
|
||||
// returned index points to a leaf at the bottom tier
|
||||
let node = self.build_leaf_node(index, key, value);
|
||||
self.nodes.update_leaf_node(index, node)
|
||||
} else if leaf_exists {
|
||||
// returned index points to a leaf for a different key with the same prefix
|
||||
|
||||
// get the key-value pair for the key with the same prefix; since the key-value
|
||||
// pair has already been inserted into the value store, we need to filter it out
|
||||
// when looking for the other key-value pair
|
||||
let (other_key, other_value) = self
|
||||
.values
|
||||
.get_first_filtered(index_to_prefix(&index), &key)
|
||||
.expect("other key-value pair not found");
|
||||
|
||||
// determine how far down the tree should we move the leaves
|
||||
let common_prefix_len = get_common_prefix_tier_depth(&key, other_key);
|
||||
let depth = cmp::min(common_prefix_len + Self::TIER_SIZE, Self::MAX_DEPTH);
|
||||
|
||||
// compute node locations for new and existing key-value paris
|
||||
let new_index = LeafNodeIndex::from_key(&key, depth);
|
||||
let other_index = LeafNodeIndex::from_key(other_key, depth);
|
||||
|
||||
// compute node values for the new and existing key-value pairs
|
||||
let new_node = self.build_leaf_node(new_index, key, value);
|
||||
let other_node = self.build_leaf_node(other_index, *other_key, *other_value);
|
||||
|
||||
// replace the leaf located at index with a subtree containing nodes for new and
|
||||
// existing key-value paris
|
||||
self.nodes.replace_leaf_with_subtree(
|
||||
index,
|
||||
[(new_index, new_node), (other_index, other_node)],
|
||||
)
|
||||
} else {
|
||||
// returned index points to an empty subtree or an empty leaf at the bottom tier
|
||||
let node = self.build_leaf_node(index, key, value);
|
||||
self.nodes.insert_leaf_node(index, node)
|
||||
};
|
||||
|
||||
Self::EMPTY_VALUE
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over all key-value pairs in this [TieredSmt].
|
||||
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
self.values.iter()
|
||||
}
|
||||
|
||||
/// Returns an iterator over all inner nodes of this [TieredSmt] (i.e., nodes not at depths 16
|
||||
/// 32, 48, or 64).
|
||||
///
|
||||
/// The iterator order is unspecified.
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.nodes.inner_nodes()
|
||||
}
|
||||
|
||||
/// Returns an iterator over upper leaves (i.e., depth = 16, 32, or 48) for this [TieredSmt]
|
||||
/// where each yielded item is a (node, key, value) tuple.
|
||||
///
|
||||
/// The iterator order is unspecified.
|
||||
pub fn upper_leaves(&self) -> impl Iterator<Item = (RpoDigest, RpoDigest, Word)> + '_ {
|
||||
self.nodes.upper_leaves().map(|(index, node)| {
|
||||
let key_prefix = index_to_prefix(index);
|
||||
let (key, value) = self.values.get_first(key_prefix).expect("upper leaf not found");
|
||||
debug_assert_eq!(*index, LeafNodeIndex::from_key(key, index.depth()).into());
|
||||
(*node, *key, *value)
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns an iterator over upper leaves (i.e., depth = 16, 32, or 48) for this [TieredSmt]
|
||||
/// where each yielded item is a (node_index, value) tuple.
|
||||
pub fn upper_leaf_nodes(&self) -> impl Iterator<Item = (&NodeIndex, &RpoDigest)> {
|
||||
self.nodes.upper_leaves()
|
||||
}
|
||||
|
||||
/// Returns an iterator over bottom leaves (i.e., depth = 64) of this [TieredSmt].
|
||||
///
|
||||
/// Each yielded item consists of the hash of the leaf and its contents, where contents is
|
||||
/// a vector containing key-value pairs of entries storied in this leaf.
|
||||
///
|
||||
/// The iterator order is unspecified.
|
||||
pub fn bottom_leaves(&self) -> impl Iterator<Item = (RpoDigest, Vec<(RpoDigest, Word)>)> + '_ {
|
||||
self.nodes.bottom_leaves().map(|(&prefix, node)| {
|
||||
let values = self.values.get_all(prefix).expect("bottom leaf not found");
|
||||
(*node, values)
|
||||
})
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Removes the node holding the key-value pair for the specified key from this tree, and
|
||||
/// returns the value associated with the specified key.
|
||||
///
|
||||
/// If no value was associated with the specified key, [ZERO; 4] is returned.
|
||||
fn remove_leaf_node(&mut self, key: RpoDigest) -> Word {
|
||||
// remove the key-value pair from the value store; if no value was associated with the
|
||||
// specified key, return.
|
||||
let old_value = match self.values.remove(&key) {
|
||||
Some(old_value) => old_value,
|
||||
None => return Self::EMPTY_VALUE,
|
||||
};
|
||||
|
||||
// determine the location of the leaf holding the key-value pair to be removed
|
||||
let (index, leaf_exists) = self.nodes.get_leaf_index(&key);
|
||||
debug_assert!(leaf_exists);
|
||||
|
||||
// if the leaf is at the bottom tier and after removing the key-value pair from it, the
|
||||
// leaf is still not empty, we either just update it, or move it up to a higher tier (if
|
||||
// the leaf doesn't have siblings at lower tiers)
|
||||
if index.depth() == Self::MAX_DEPTH {
|
||||
if let Some(entries) = self.values.get_all(index.value()) {
|
||||
// if there is only one key-value pair left at the bottom leaf, and it can be
|
||||
// moved up to a higher tier, truncate the branch and return
|
||||
if entries.len() == 1 {
|
||||
let new_depth = self.nodes.get_last_single_child_parent_depth(index.value());
|
||||
if new_depth != Self::MAX_DEPTH {
|
||||
let node = hash_upper_leaf(entries[0].0, entries[0].1, new_depth);
|
||||
self.root = self.nodes.truncate_branch(index.value(), new_depth, node);
|
||||
return old_value;
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise just recompute the leaf hash and update the leaf node
|
||||
let node = hash_bottom_leaf(&entries);
|
||||
self.root = self.nodes.update_leaf_node(index, node);
|
||||
return old_value;
|
||||
};
|
||||
}
|
||||
|
||||
// if the removed key-value pair has a lone sibling at the current tier with a root at
|
||||
// higher tier, we need to move the sibling to a higher tier
|
||||
if let Some((sib_key, sib_val, new_sib_index)) = self.values.get_lone_sibling(index) {
|
||||
// determine the current index of the sibling node
|
||||
let sib_index = LeafNodeIndex::from_key(sib_key, index.depth());
|
||||
debug_assert!(sib_index.depth() > new_sib_index.depth());
|
||||
|
||||
// compute node value for the new location of the sibling leaf and replace the subtree
|
||||
// with this leaf node
|
||||
let node = self.build_leaf_node(new_sib_index, *sib_key, *sib_val);
|
||||
let new_sib_depth = new_sib_index.depth();
|
||||
self.root = self.nodes.replace_subtree_with_leaf(index, sib_index, new_sib_depth, node);
|
||||
} else {
|
||||
// if the removed key-value pair did not have a sibling at the current tier with a
|
||||
// root at higher tiers, just clear the leaf node
|
||||
self.root = self.nodes.clear_leaf_node(index);
|
||||
}
|
||||
|
||||
old_value
|
||||
}
|
||||
|
||||
/// Builds and returns a leaf node value for the node located as the specified index.
|
||||
///
|
||||
/// This method assumes that the key-value pair for the node has already been inserted into
|
||||
/// the value store, however, for depths 16, 32, and 48, the node is computed directly from
|
||||
/// the passed-in values (for depth 64, the value store is queried to get all the key-value
|
||||
/// pairs located at the specified index).
|
||||
fn build_leaf_node(&self, index: LeafNodeIndex, key: RpoDigest, value: Word) -> RpoDigest {
|
||||
let depth = index.depth();
|
||||
|
||||
// insert the key into index-key map and compute the new value of the node
|
||||
if index.depth() == Self::MAX_DEPTH {
|
||||
// for the bottom tier, we add the key-value pair to the existing leaf, or create a
|
||||
// new leaf with this key-value pair
|
||||
let values = self.values.get_all(index.value()).unwrap();
|
||||
hash_bottom_leaf(&values)
|
||||
} else {
|
||||
debug_assert_eq!(self.values.get_first(index_to_prefix(&index)), Some(&(key, value)));
|
||||
hash_upper_leaf(key, value, depth)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TieredSmt {
|
||||
fn default() -> Self {
|
||||
let root = EmptySubtreeRoots::empty_hashes(Self::MAX_DEPTH)[0];
|
||||
Self {
|
||||
root,
|
||||
nodes: NodeStore::new(root),
|
||||
values: ValueStore::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LEAF NODE INDEX
|
||||
// ================================================================================================
|
||||
/// A wrapper around [NodeIndex] to provide type-safe references to nodes at depths 16, 32, 48, and
|
||||
/// 64.
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
|
||||
pub struct LeafNodeIndex(NodeIndex);
|
||||
|
||||
impl LeafNodeIndex {
|
||||
/// Returns a new [LeafNodeIndex] instantiated from the provided [NodeIndex].
|
||||
///
|
||||
/// In debug mode, panics if index depth is not 16, 32, 48, or 64.
|
||||
pub fn new(index: NodeIndex) -> Self {
|
||||
// check if the depth is 16, 32, 48, or 64; this works because for a valid depth,
|
||||
// depth - 16, can be 0, 16, 32, or 48 - i.e., the value is either 0 or any of the 4th
|
||||
// or 5th bits are set. We can test for this by computing a bitwise AND with a value
|
||||
// which has all but the 4th and 5th bits set (which is !48).
|
||||
debug_assert_eq!(((index.depth() - 16) & !48), 0, "invalid tier depth {}", index.depth());
|
||||
Self(index)
|
||||
}
|
||||
|
||||
/// Returns a new [LeafNodeIndex] instantiated from the specified key inserted at the specified
|
||||
/// depth.
|
||||
///
|
||||
/// The value for the key is computed by taking n most significant bits from the most significant
|
||||
/// element of the key, where n is the specified depth.
|
||||
pub fn from_key(key: &RpoDigest, depth: u8) -> Self {
|
||||
let mse = get_key_prefix(key);
|
||||
Self::new(NodeIndex::new_unchecked(depth, mse >> (TieredSmt::MAX_DEPTH - depth)))
|
||||
}
|
||||
|
||||
/// Returns a new [LeafNodeIndex] instantiated for testing purposes.
|
||||
#[cfg(test)]
|
||||
pub fn make(depth: u8, value: u64) -> Self {
|
||||
Self::new(NodeIndex::make(depth, value))
|
||||
}
|
||||
|
||||
/// Traverses towards the root until the specified depth is reached.
|
||||
///
|
||||
/// The new depth must be a valid tier depth - i.e., 16, 32, 48, or 64.
|
||||
pub fn move_up_to(&mut self, depth: u8) {
|
||||
debug_assert_eq!(((depth - 16) & !48), 0, "invalid tier depth: {depth}");
|
||||
self.0.move_up_to(depth);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for LeafNodeIndex {
|
||||
type Target = NodeIndex;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NodeIndex> for LeafNodeIndex {
|
||||
fn from(value: NodeIndex) -> Self {
|
||||
Self::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LeafNodeIndex> for NodeIndex {
|
||||
fn from(value: LeafNodeIndex) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Returns the value representing the 64 most significant bits of the specified key.
|
||||
fn get_key_prefix(key: &RpoDigest) -> u64 {
|
||||
Word::from(key)[3].as_int()
|
||||
}
|
||||
|
||||
/// Returns the index value shifted to be in the most significant bit positions of the returned
|
||||
/// u64 value.
|
||||
fn index_to_prefix(index: &NodeIndex) -> u64 {
|
||||
index.value() << (TieredSmt::MAX_DEPTH - index.depth())
|
||||
}
|
||||
|
||||
/// Returns tiered common prefix length between the most significant elements of the provided keys.
|
||||
///
|
||||
/// Specifically:
|
||||
/// - returns 64 if the most significant elements are equal.
|
||||
/// - returns 48 if the common prefix is between 48 and 63 bits.
|
||||
/// - returns 32 if the common prefix is between 32 and 47 bits.
|
||||
/// - returns 16 if the common prefix is between 16 and 31 bits.
|
||||
/// - returns 0 if the common prefix is fewer than 16 bits.
|
||||
fn get_common_prefix_tier_depth(key1: &RpoDigest, key2: &RpoDigest) -> u8 {
|
||||
let e1 = get_key_prefix(key1);
|
||||
let e2 = get_key_prefix(key2);
|
||||
let ex = (e1 ^ e2).leading_zeros() as u8;
|
||||
(ex / 16) * 16
|
||||
}
|
||||
|
||||
/// Computes node value for leaves at tiers 16, 32, or 48.
|
||||
///
|
||||
/// Node value is computed as: hash(key || value, domain = depth).
|
||||
pub fn hash_upper_leaf(key: RpoDigest, value: Word, depth: u8) -> RpoDigest {
|
||||
const NUM_UPPER_TIERS: usize = TieredSmt::TIER_DEPTHS.len() - 1;
|
||||
debug_assert!(TieredSmt::TIER_DEPTHS[..NUM_UPPER_TIERS].contains(&depth));
|
||||
Rpo256::merge_in_domain(&[key, value.into()], depth.into())
|
||||
}
|
||||
|
||||
/// Computes node value for leaves at the bottom tier (depth 64).
|
||||
///
|
||||
/// Node value is computed as: hash([key_0, value_0, ..., key_n, value_n], domain=64).
|
||||
///
|
||||
/// TODO: when hashing in domain is implemented for `hash_elements()`, combine this function with
|
||||
/// `hash_upper_leaf()` function.
|
||||
pub fn hash_bottom_leaf(values: &[(RpoDigest, Word)]) -> RpoDigest {
|
||||
let mut elements = Vec::with_capacity(values.len() * 8);
|
||||
for (key, val) in values.iter() {
|
||||
elements.extend_from_slice(key.as_elements());
|
||||
elements.extend_from_slice(val.as_slice());
|
||||
}
|
||||
// TODO: hash in domain
|
||||
Rpo256::hash_elements(&elements)
|
||||
}
|
||||
419
src/merkle/tiered_smt/nodes.rs
Normal file
419
src/merkle/tiered_smt/nodes.rs
Normal file
@@ -0,0 +1,419 @@
|
||||
use super::{
|
||||
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, LeafNodeIndex, MerkleError, MerklePath,
|
||||
NodeIndex, Rpo256, RpoDigest, Vec,
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// The number of levels between tiers.
|
||||
const TIER_SIZE: u8 = super::TieredSmt::TIER_SIZE;
|
||||
|
||||
/// Depths at which leaves can exist in a tiered SMT.
|
||||
const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
|
||||
|
||||
/// Maximum node depth. This is also the bottom tier of the tree.
|
||||
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
|
||||
|
||||
// NODE STORE
|
||||
// ================================================================================================
|
||||
|
||||
/// A store of nodes for a Tiered Sparse Merkle tree.
|
||||
///
|
||||
/// The store contains information about all nodes as well as information about which of the nodes
|
||||
/// represent leaf nodes in a Tiered Sparse Merkle tree. In the current implementation, [BTreeSet]s
|
||||
/// are used to determine the position of the leaves in the tree.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct NodeStore {
|
||||
nodes: BTreeMap<NodeIndex, RpoDigest>,
|
||||
upper_leaves: BTreeSet<NodeIndex>,
|
||||
bottom_leaves: BTreeSet<u64>,
|
||||
}
|
||||
|
||||
impl NodeStore {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns a new instance of [NodeStore] instantiated with the specified root node.
|
||||
///
|
||||
/// Root node is assumed to be a root of an empty sparse Merkle tree.
|
||||
pub fn new(root_node: RpoDigest) -> Self {
|
||||
let mut nodes = BTreeMap::default();
|
||||
nodes.insert(NodeIndex::root(), root_node);
|
||||
|
||||
Self {
|
||||
nodes,
|
||||
upper_leaves: BTreeSet::default(),
|
||||
bottom_leaves: BTreeSet::default(),
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a node at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index depth is 0 or greater than 64.
|
||||
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
||||
/// when a leaf node with the same index prefix exists at a tier higher than the requested
|
||||
/// node.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
self.validate_node_access(index)?;
|
||||
Ok(self.get_node_unchecked(&index))
|
||||
}
|
||||
|
||||
/// Returns a Merkle path from the node at the specified index to the root.
|
||||
///
|
||||
/// The node itself is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index depth is 0 or greater than 64.
|
||||
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
||||
/// when a leaf node with the same index prefix exists at a tier higher than the node to
|
||||
/// which the path is requested.
|
||||
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
self.validate_node_access(index)?;
|
||||
|
||||
let mut path = Vec::with_capacity(index.depth() as usize);
|
||||
for _ in 0..index.depth() {
|
||||
let node = self.get_node_unchecked(&index.sibling());
|
||||
path.push(node);
|
||||
index.move_up();
|
||||
}
|
||||
|
||||
Ok(path.into())
|
||||
}
|
||||
|
||||
/// Returns a Merkle path to the node specified by the key together with a flag indicating,
|
||||
/// whether this node is a leaf at depths 16, 32, or 48.
|
||||
pub fn get_proof(&self, key: &RpoDigest) -> (MerklePath, NodeIndex, bool) {
|
||||
let (index, leaf_exists) = self.get_leaf_index(key);
|
||||
let index: NodeIndex = index.into();
|
||||
let path = self.get_path(index).expect("failed to retrieve Merkle path for a node index");
|
||||
(path, index, leaf_exists)
|
||||
}
|
||||
|
||||
/// Returns an index at which a leaf node for the specified key should be inserted.
|
||||
///
|
||||
/// The second value in the returned tuple is set to true if the node at the returned index
|
||||
/// is already a leaf node.
|
||||
pub fn get_leaf_index(&self, key: &RpoDigest) -> (LeafNodeIndex, bool) {
|
||||
// traverse the tree from the root down checking nodes at tiers 16, 32, and 48. Return if
|
||||
// a node at any of the tiers is either a leaf or a root of an empty subtree.
|
||||
const NUM_UPPER_TIERS: usize = TIER_DEPTHS.len() - 1;
|
||||
for &tier_depth in TIER_DEPTHS[..NUM_UPPER_TIERS].iter() {
|
||||
let index = LeafNodeIndex::from_key(key, tier_depth);
|
||||
if self.upper_leaves.contains(&index) {
|
||||
return (index, true);
|
||||
} else if !self.nodes.contains_key(&index) {
|
||||
return (index, false);
|
||||
}
|
||||
}
|
||||
|
||||
// if we got here, that means all of the nodes checked so far are internal nodes, and
|
||||
// the new node would need to be inserted in the bottom tier.
|
||||
let index = LeafNodeIndex::from_key(key, MAX_DEPTH);
|
||||
(index, self.bottom_leaves.contains(&index.value()))
|
||||
}
|
||||
|
||||
/// Traverses the tree up from the bottom tier starting at the specified leaf index and
|
||||
/// returns the depth of the first node which hash more than one child. The returned depth
|
||||
/// is rounded up to the next tier.
|
||||
pub fn get_last_single_child_parent_depth(&self, leaf_index: u64) -> u8 {
|
||||
let mut index = NodeIndex::new_unchecked(MAX_DEPTH, leaf_index);
|
||||
|
||||
for _ in (TIER_DEPTHS[0]..MAX_DEPTH).rev() {
|
||||
let sibling_index = index.sibling();
|
||||
if self.nodes.contains_key(&sibling_index) {
|
||||
break;
|
||||
}
|
||||
index.move_up();
|
||||
}
|
||||
|
||||
let tier = (index.depth() - 1) / TIER_SIZE;
|
||||
TIER_DEPTHS[tier as usize]
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over all inner nodes of the Tiered Sparse Merkle tree (i.e., nodes not
|
||||
/// at depths 16 32, 48, or 64).
|
||||
///
|
||||
/// The iterator order is unspecified.
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.nodes.iter().filter_map(|(index, node)| {
|
||||
if self.is_internal_node(index) {
|
||||
Some(InnerNodeInfo {
|
||||
value: *node,
|
||||
left: self.get_node_unchecked(&index.left_child()),
|
||||
right: self.get_node_unchecked(&index.right_child()),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns an iterator over the upper leaves (i.e., leaves with depths 16, 32, 48) of the
|
||||
/// Tiered Sparse Merkle tree.
|
||||
pub fn upper_leaves(&self) -> impl Iterator<Item = (&NodeIndex, &RpoDigest)> {
|
||||
self.upper_leaves.iter().map(|index| (index, &self.nodes[index]))
|
||||
}
|
||||
|
||||
/// Returns an iterator over the bottom leaves (i.e., leaves with depth 64) of the Tiered
|
||||
/// Sparse Merkle tree.
|
||||
pub fn bottom_leaves(&self) -> impl Iterator<Item = (&u64, &RpoDigest)> {
|
||||
self.bottom_leaves.iter().map(|value| {
|
||||
let index = NodeIndex::new_unchecked(MAX_DEPTH, *value);
|
||||
(value, &self.nodes[&index])
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Replaces the leaf node at the specified index with a tree consisting of two leaves located
|
||||
/// at the specified indexes. Recomputes and returns the new root.
|
||||
pub fn replace_leaf_with_subtree(
|
||||
&mut self,
|
||||
leaf_index: LeafNodeIndex,
|
||||
subtree_leaves: [(LeafNodeIndex, RpoDigest); 2],
|
||||
) -> RpoDigest {
|
||||
debug_assert!(self.is_non_empty_leaf(&leaf_index));
|
||||
debug_assert!(!is_empty_root(&subtree_leaves[0].1));
|
||||
debug_assert!(!is_empty_root(&subtree_leaves[1].1));
|
||||
debug_assert_eq!(subtree_leaves[0].0.depth(), subtree_leaves[1].0.depth());
|
||||
debug_assert!(leaf_index.depth() < subtree_leaves[0].0.depth());
|
||||
|
||||
self.upper_leaves.remove(&leaf_index);
|
||||
|
||||
if subtree_leaves[0].0 == subtree_leaves[1].0 {
|
||||
// if the subtree is for a single node at depth 64, we only need to insert one node
|
||||
debug_assert_eq!(subtree_leaves[0].0.depth(), MAX_DEPTH);
|
||||
debug_assert_eq!(subtree_leaves[0].1, subtree_leaves[1].1);
|
||||
self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1)
|
||||
} else {
|
||||
self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1);
|
||||
self.insert_leaf_node(subtree_leaves[1].0, subtree_leaves[1].1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Replaces a subtree containing the retained and the removed leaf nodes, with a leaf node
|
||||
/// containing the retained leaf.
|
||||
///
|
||||
/// This has the effect of deleting the the node at the `removed_leaf` index from the tree,
|
||||
/// moving the node at the `retained_leaf` index up to the tier specified by `new_depth`.
|
||||
pub fn replace_subtree_with_leaf(
|
||||
&mut self,
|
||||
removed_leaf: LeafNodeIndex,
|
||||
retained_leaf: LeafNodeIndex,
|
||||
new_depth: u8,
|
||||
node: RpoDigest,
|
||||
) -> RpoDigest {
|
||||
debug_assert!(!is_empty_root(&node));
|
||||
debug_assert!(self.is_non_empty_leaf(&removed_leaf));
|
||||
debug_assert!(self.is_non_empty_leaf(&retained_leaf));
|
||||
debug_assert_eq!(removed_leaf.depth(), retained_leaf.depth());
|
||||
debug_assert!(removed_leaf.depth() > new_depth);
|
||||
|
||||
// remove the branches leading up to the tier to which the retained leaf is to be moved
|
||||
self.remove_branch(removed_leaf, new_depth);
|
||||
self.remove_branch(retained_leaf, new_depth);
|
||||
|
||||
// compute the index of the common root for retained and removed leaves
|
||||
let mut new_index = retained_leaf;
|
||||
new_index.move_up_to(new_depth);
|
||||
|
||||
// insert the node at the root index
|
||||
self.insert_leaf_node(new_index, node)
|
||||
}
|
||||
|
||||
/// Inserts the specified node at the specified index; recomputes and returns the new root
|
||||
/// of the Tiered Sparse Merkle tree.
|
||||
///
|
||||
/// This method assumes that the provided node is a non-empty value, and that there is no node
|
||||
/// at the specified index.
|
||||
pub fn insert_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest {
|
||||
debug_assert!(!is_empty_root(&node));
|
||||
debug_assert_eq!(self.nodes.get(&index), None);
|
||||
|
||||
// mark the node as the leaf
|
||||
if index.depth() == MAX_DEPTH {
|
||||
self.bottom_leaves.insert(index.value());
|
||||
} else {
|
||||
self.upper_leaves.insert(index.into());
|
||||
};
|
||||
|
||||
// insert the node and update the path from the node to the root
|
||||
let mut index: NodeIndex = index.into();
|
||||
for _ in 0..index.depth() {
|
||||
self.nodes.insert(index, node);
|
||||
let sibling = self.get_node_unchecked(&index.sibling());
|
||||
node = Rpo256::merge(&index.build_node(node, sibling));
|
||||
index.move_up();
|
||||
}
|
||||
|
||||
// update the root
|
||||
self.nodes.insert(NodeIndex::root(), node);
|
||||
node
|
||||
}
|
||||
|
||||
/// Updates the node at the specified index with the specified node value; recomputes and
|
||||
/// returns the new root of the Tiered Sparse Merkle tree.
|
||||
///
|
||||
/// This method can accept `node` as either an empty or a non-empty value.
|
||||
pub fn update_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest {
|
||||
debug_assert!(self.is_non_empty_leaf(&index));
|
||||
|
||||
// if the value we are updating the node to is a root of an empty tree, clear the leaf
|
||||
// flag for this node
|
||||
if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] {
|
||||
if index.depth() == MAX_DEPTH {
|
||||
self.bottom_leaves.remove(&index.value());
|
||||
} else {
|
||||
self.upper_leaves.remove(&index);
|
||||
}
|
||||
} else {
|
||||
debug_assert!(!is_empty_root(&node));
|
||||
}
|
||||
|
||||
// update the path from the node to the root
|
||||
let mut index: NodeIndex = index.into();
|
||||
for _ in 0..index.depth() {
|
||||
if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] {
|
||||
self.nodes.remove(&index);
|
||||
} else {
|
||||
self.nodes.insert(index, node);
|
||||
}
|
||||
|
||||
let sibling = self.get_node_unchecked(&index.sibling());
|
||||
node = Rpo256::merge(&index.build_node(node, sibling));
|
||||
index.move_up();
|
||||
}
|
||||
|
||||
// update the root
|
||||
self.nodes.insert(NodeIndex::root(), node);
|
||||
node
|
||||
}
|
||||
|
||||
/// Replaces the leaf node at the specified index with a root of an empty subtree; recomputes
|
||||
/// and returns the new root of the Tiered Sparse Merkle tree.
|
||||
pub fn clear_leaf_node(&mut self, index: LeafNodeIndex) -> RpoDigest {
|
||||
debug_assert!(self.is_non_empty_leaf(&index));
|
||||
let node = EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize];
|
||||
self.update_leaf_node(index, node)
|
||||
}
|
||||
|
||||
/// Truncates a branch starting with specified leaf at the bottom tier to new depth.
|
||||
///
|
||||
/// This involves removing the part of the branch below the new depth, and then inserting a new
|
||||
/// // node at the new depth.
|
||||
pub fn truncate_branch(
|
||||
&mut self,
|
||||
leaf_index: u64,
|
||||
new_depth: u8,
|
||||
node: RpoDigest,
|
||||
) -> RpoDigest {
|
||||
debug_assert!(self.bottom_leaves.contains(&leaf_index));
|
||||
|
||||
let mut leaf_index = LeafNodeIndex::new(NodeIndex::new_unchecked(MAX_DEPTH, leaf_index));
|
||||
self.remove_branch(leaf_index, new_depth);
|
||||
|
||||
leaf_index.move_up_to(new_depth);
|
||||
self.insert_leaf_node(leaf_index, node)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if the node at the specified index is a leaf node.
|
||||
fn is_non_empty_leaf(&self, index: &LeafNodeIndex) -> bool {
|
||||
if index.depth() == MAX_DEPTH {
|
||||
self.bottom_leaves.contains(&index.value())
|
||||
} else {
|
||||
self.upper_leaves.contains(index)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the node at the specified index is an internal node - i.e., there is
|
||||
/// no leaf at that node and the node does not belong to the bottom tier.
|
||||
fn is_internal_node(&self, index: &NodeIndex) -> bool {
|
||||
if index.depth() == MAX_DEPTH {
|
||||
false
|
||||
} else {
|
||||
!self.upper_leaves.contains(index)
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if the specified index is valid in the context of this Merkle tree.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index depth is 0 or greater than 64.
|
||||
/// - The node for the specified index does not exists in the Merkle tree. This is possible
|
||||
/// when an ancestors of the specified index is a leaf node.
|
||||
fn validate_node_access(&self, index: NodeIndex) -> Result<(), MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
} else if index.depth() > MAX_DEPTH {
|
||||
return Err(MerkleError::DepthTooBig(index.depth() as u64));
|
||||
} else {
|
||||
// make sure that there are no leaf nodes in the ancestors of the index; since leaf
|
||||
// nodes can live at specific depth, we just need to check these depths.
|
||||
let tier = ((index.depth() - 1) / TIER_SIZE) as usize;
|
||||
let mut tier_index = index;
|
||||
for &depth in TIER_DEPTHS[..tier].iter().rev() {
|
||||
tier_index.move_up_to(depth);
|
||||
if self.upper_leaves.contains(&tier_index) {
|
||||
return Err(MerkleError::NodeNotInSet(index));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns a node at the specified index. If the node does not exist at this index, a root
|
||||
/// for an empty subtree at the index's depth is returned.
|
||||
///
|
||||
/// Unlike [NodeStore::get_node()] this does not perform any checks to verify that the
|
||||
/// returned node is valid in the context of this tree.
|
||||
fn get_node_unchecked(&self, index: &NodeIndex) -> RpoDigest {
|
||||
match self.nodes.get(index) {
|
||||
Some(node) => *node,
|
||||
None => EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize],
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a sequence of nodes starting at the specified index and traversing the tree up to
|
||||
/// the specified depth. The node at the `end_depth` is also removed, and the appropriate leaf
|
||||
/// flag is cleared.
|
||||
///
|
||||
/// This method does not update any other nodes and does not recompute the tree root.
|
||||
fn remove_branch(&mut self, index: LeafNodeIndex, end_depth: u8) {
|
||||
if index.depth() == MAX_DEPTH {
|
||||
self.bottom_leaves.remove(&index.value());
|
||||
} else {
|
||||
self.upper_leaves.remove(&index);
|
||||
}
|
||||
|
||||
let mut index: NodeIndex = index.into();
|
||||
assert!(index.depth() > end_depth);
|
||||
for _ in 0..(index.depth() - end_depth + 1) {
|
||||
self.nodes.remove(&index);
|
||||
index.move_up()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Returns true if the specified node is a root of an empty tree or an empty value ([ZERO; 4]).
|
||||
fn is_empty_root(node: &RpoDigest) -> bool {
|
||||
EmptySubtreeRoots::empty_hashes(MAX_DEPTH).contains(node)
|
||||
}
|
||||
170
src/merkle/tiered_smt/proof.rs
Normal file
170
src/merkle/tiered_smt/proof.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
use super::{
|
||||
get_common_prefix_tier_depth, get_key_prefix, hash_bottom_leaf, hash_upper_leaf,
|
||||
EmptySubtreeRoots, LeafNodeIndex, MerklePath, RpoDigest, TieredSmtProofError, Vec, Word,
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Maximum node depth. This is also the bottom tier of the tree.
|
||||
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
|
||||
|
||||
/// Value of an empty leaf.
|
||||
pub const EMPTY_VALUE: Word = super::TieredSmt::EMPTY_VALUE;
|
||||
|
||||
/// Depths at which leaves can exist in a tiered SMT.
|
||||
pub const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
|
||||
|
||||
// TIERED SPARSE MERKLE TREE PROOF
|
||||
// ================================================================================================
|
||||
|
||||
/// A proof which can be used to assert membership (or non-membership) of key-value pairs in a
|
||||
/// Tiered Sparse Merkle tree.
|
||||
///
|
||||
/// The proof consists of a Merkle path and one or more key-value entries which describe the node
|
||||
/// located at the base of the path. If the node at the base of the path resolves to [ZERO; 4],
|
||||
/// the entries will contain a single item with value set to [ZERO; 4].
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
pub struct TieredSmtProof {
|
||||
path: MerklePath,
|
||||
entries: Vec<(RpoDigest, Word)>,
|
||||
}
|
||||
|
||||
impl TieredSmtProof {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns a new instance of [TieredSmtProof] instantiated from the specified path and entries.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if:
|
||||
/// - The length of the path is greater than 64.
|
||||
/// - Entries is an empty vector.
|
||||
/// - Entries contains more than 1 item, but the length of the path is not 64.
|
||||
/// - Entries contains more than 1 item, and one of the items has value set to [ZERO; 4].
|
||||
/// - Entries contains multiple items with keys which don't share the same 64-bit prefix.
|
||||
pub fn new<I>(path: MerklePath, entries: I) -> Result<Self, TieredSmtProofError>
|
||||
where
|
||||
I: IntoIterator<Item = (RpoDigest, Word)>,
|
||||
{
|
||||
let entries: Vec<(RpoDigest, Word)> = entries.into_iter().collect();
|
||||
|
||||
if !TIER_DEPTHS.into_iter().any(|e| e == path.depth()) {
|
||||
return Err(TieredSmtProofError::NotATierPath(path.depth()));
|
||||
}
|
||||
|
||||
if entries.is_empty() {
|
||||
return Err(TieredSmtProofError::EntriesEmpty);
|
||||
}
|
||||
|
||||
if entries.len() > 1 {
|
||||
if path.depth() != MAX_DEPTH {
|
||||
return Err(TieredSmtProofError::MultipleEntriesOutsideLastTier);
|
||||
}
|
||||
|
||||
let prefix = get_key_prefix(&entries[0].0);
|
||||
for entry in entries.iter().skip(1) {
|
||||
if entry.1 == EMPTY_VALUE {
|
||||
return Err(TieredSmtProofError::EmptyValueNotAllowed);
|
||||
}
|
||||
let current = get_key_prefix(&entry.0);
|
||||
if prefix != current {
|
||||
return Err(TieredSmtProofError::MismatchedPrefixes(prefix, current));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self { path, entries })
|
||||
}
|
||||
|
||||
// PROOF VERIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if a Tiered Sparse Merkle tree with the specified root contains the provided
|
||||
/// key-value pair.
|
||||
///
|
||||
/// Note: this method cannot be used to assert non-membership. That is, if false is returned,
|
||||
/// it does not mean that the provided key-value pair is not in the tree.
|
||||
pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool {
|
||||
// Handles the following scenarios:
|
||||
// - the value is set
|
||||
// - empty leaf, there is an explicit entry for the key with the empty value
|
||||
// - shared 64-bit prefix, the target key is not included in the entries list, the value is implicitly the empty word
|
||||
let v = match self.entries.iter().find(|(k, _)| k == key) {
|
||||
Some((_, v)) => v,
|
||||
None => &EMPTY_VALUE,
|
||||
};
|
||||
|
||||
// The value must match for the proof to be valid
|
||||
if v != value {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If the proof is for an empty value, we can verify it against any key which has a common
|
||||
// prefix with the key storied in entries, but the prefix must be greater than the path
|
||||
// length
|
||||
if self.is_value_empty()
|
||||
&& get_common_prefix_tier_depth(key, &self.entries[0].0) < self.path.depth()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// make sure the Merkle path resolves to the correct root
|
||||
root == &self.compute_root()
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the value associated with the specific key according to this proof, or None if
|
||||
/// this proof does not contain a value for the specified key.
|
||||
///
|
||||
/// A key-value pair generated by using this method should pass the `verify_membership()` check.
|
||||
pub fn get(&self, key: &RpoDigest) -> Option<Word> {
|
||||
if self.is_value_empty() {
|
||||
let common_prefix_tier = get_common_prefix_tier_depth(key, &self.entries[0].0);
|
||||
if common_prefix_tier < self.path.depth() {
|
||||
None
|
||||
} else {
|
||||
Some(EMPTY_VALUE)
|
||||
}
|
||||
} else {
|
||||
self.entries.iter().find(|(k, _)| k == key).map(|(_, value)| *value)
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the root of a Tiered Sparse Merkle tree to which this proof resolve.
|
||||
pub fn compute_root(&self) -> RpoDigest {
|
||||
let node = self.build_node();
|
||||
let index = LeafNodeIndex::from_key(&self.entries[0].0, self.path.depth());
|
||||
self.path
|
||||
.compute_root(index.value(), node)
|
||||
.expect("failed to compute Merkle path root")
|
||||
}
|
||||
|
||||
/// Consume the proof and returns its parts.
|
||||
pub fn into_parts(self) -> (MerklePath, Vec<(RpoDigest, Word)>) {
|
||||
(self.path, self.entries)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if the proof is for an empty value.
|
||||
fn is_value_empty(&self) -> bool {
|
||||
self.entries[0].1 == EMPTY_VALUE
|
||||
}
|
||||
|
||||
/// Converts the entries contained in this proof into a node value for node at the base of the
|
||||
/// path contained in this proof.
|
||||
fn build_node(&self) -> RpoDigest {
|
||||
let depth = self.path.depth();
|
||||
if self.is_value_empty() {
|
||||
EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[depth as usize]
|
||||
} else if depth == MAX_DEPTH {
|
||||
hash_bottom_leaf(&self.entries)
|
||||
} else {
|
||||
let (key, value) = self.entries[0];
|
||||
hash_upper_leaf(key, value, depth)
|
||||
}
|
||||
}
|
||||
}
|
||||
968
src/merkle/tiered_smt/tests.rs
Normal file
968
src/merkle/tiered_smt/tests.rs
Normal file
@@ -0,0 +1,968 @@
|
||||
use super::{
|
||||
super::{super::ONE, super::WORD_SIZE, Felt, MerkleStore, EMPTY_WORD, ZERO},
|
||||
EmptySubtreeRoots, InnerNodeInfo, NodeIndex, Rpo256, RpoDigest, TieredSmt, Vec, Word,
|
||||
};
|
||||
|
||||
// INSERTION TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_insert_one() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let value = [ONE; WORD_SIZE];
|
||||
|
||||
// since the tree is empty, the first node will be inserted at depth 16 and the index will be
|
||||
// 16 most significant bits of the key
|
||||
let index = NodeIndex::make(16, raw >> 48);
|
||||
let leaf_node = build_leaf_node(key, value, 16);
|
||||
let tree_root = store.set_node(smt.root(), index, leaf_node).unwrap().root;
|
||||
|
||||
smt.insert(key, value);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
// make sure the value was inserted, and the node is at the expected index
|
||||
assert_eq!(smt.get_value(key), value);
|
||||
assert_eq!(smt.get_node(index).unwrap(), leaf_node);
|
||||
|
||||
// make sure the paths we get from the store and the tree match
|
||||
let expected_path = store.get_path(tree_root, index).unwrap();
|
||||
assert_eq!(smt.get_path(index).unwrap(), expected_path.path);
|
||||
|
||||
// make sure inner nodes match
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
assert_eq!(actual_nodes.len(), expected_nodes.len());
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
|
||||
// make sure leaves are returned correctly
|
||||
let mut leaves = smt.upper_leaves();
|
||||
assert_eq!(leaves.next(), Some((leaf_node, key, value)));
|
||||
assert_eq!(leaves.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_insert_two_16() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// --- insert the first value ---------------------------------------------
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let val_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key_a, val_a);
|
||||
|
||||
// --- insert the second value --------------------------------------------
|
||||
// the key for this value has the same 16-bit prefix as the key for the first value,
|
||||
// thus, on insertions, both values should be pushed to depth 32 tier
|
||||
let raw_b = 0b_10101010_10101010_10011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let val_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key_b, val_b);
|
||||
|
||||
// --- build Merkle store with equivalent data ----------------------------
|
||||
let mut tree_root = get_init_root();
|
||||
let index_a = NodeIndex::make(32, raw_a >> 32);
|
||||
let leaf_node_a = build_leaf_node(key_a, val_a, 32);
|
||||
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
|
||||
|
||||
let index_b = NodeIndex::make(32, raw_b >> 32);
|
||||
let leaf_node_b = build_leaf_node(key_b, val_b, 32);
|
||||
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
|
||||
|
||||
// --- verify that data is consistent between store and tree --------------
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key_a), val_a);
|
||||
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
|
||||
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
|
||||
|
||||
assert_eq!(smt.get_value(key_b), val_b);
|
||||
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
|
||||
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
|
||||
// make sure leaves are returned correctly
|
||||
let mut leaves = smt.upper_leaves();
|
||||
assert_eq!(leaves.next(), Some((leaf_node_a, key_a, val_a)));
|
||||
assert_eq!(leaves.next(), Some((leaf_node_b, key_b, val_b)));
|
||||
assert_eq!(leaves.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_insert_two_32() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// --- insert the first value ---------------------------------------------
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let val_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key_a, val_a);
|
||||
|
||||
// --- insert the second value --------------------------------------------
|
||||
// the key for this value has the same 32-bit prefix as the key for the first value,
|
||||
// thus, on insertions, both values should be pushed to depth 48 tier
|
||||
let raw_b = 0b_10101010_10101010_00011111_11111111_00010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let val_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key_b, val_b);
|
||||
|
||||
// --- build Merkle store with equivalent data ----------------------------
|
||||
let mut tree_root = get_init_root();
|
||||
let index_a = NodeIndex::make(48, raw_a >> 16);
|
||||
let leaf_node_a = build_leaf_node(key_a, val_a, 48);
|
||||
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
|
||||
|
||||
let index_b = NodeIndex::make(48, raw_b >> 16);
|
||||
let leaf_node_b = build_leaf_node(key_b, val_b, 48);
|
||||
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
|
||||
|
||||
// --- verify that data is consistent between store and tree --------------
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key_a), val_a);
|
||||
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
|
||||
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
|
||||
|
||||
assert_eq!(smt.get_value(key_b), val_b);
|
||||
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
|
||||
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_insert_three() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// --- insert the first value ---------------------------------------------
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let val_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key_a, val_a);
|
||||
|
||||
// --- insert the second value --------------------------------------------
|
||||
// the key for this value has the same 16-bit prefix as the key for the first value,
|
||||
// thus, on insertions, both values should be pushed to depth 32 tier
|
||||
let raw_b = 0b_10101010_10101010_10011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let val_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key_b, val_b);
|
||||
|
||||
// --- insert the third value ---------------------------------------------
|
||||
// the key for this value has the same 16-bit prefix as the keys for the first two,
|
||||
// values; thus, on insertions, it will be inserted into depth 32 tier, but will not
|
||||
// affect locations of the other two values
|
||||
let raw_c = 0b_10101010_10101010_11011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let val_c = [Felt::new(3); WORD_SIZE];
|
||||
smt.insert(key_c, val_c);
|
||||
|
||||
// --- build Merkle store with equivalent data ----------------------------
|
||||
let mut tree_root = get_init_root();
|
||||
let index_a = NodeIndex::make(32, raw_a >> 32);
|
||||
let leaf_node_a = build_leaf_node(key_a, val_a, 32);
|
||||
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
|
||||
|
||||
let index_b = NodeIndex::make(32, raw_b >> 32);
|
||||
let leaf_node_b = build_leaf_node(key_b, val_b, 32);
|
||||
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
|
||||
|
||||
let index_c = NodeIndex::make(32, raw_c >> 32);
|
||||
let leaf_node_c = build_leaf_node(key_c, val_c, 32);
|
||||
tree_root = store.set_node(tree_root, index_c, leaf_node_c).unwrap().root;
|
||||
|
||||
// --- verify that data is consistent between store and tree --------------
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key_a), val_a);
|
||||
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
|
||||
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
|
||||
|
||||
assert_eq!(smt.get_value(key_b), val_b);
|
||||
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
|
||||
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
|
||||
|
||||
assert_eq!(smt.get_value(key_c), val_c);
|
||||
assert_eq!(smt.get_node(index_c).unwrap(), leaf_node_c);
|
||||
let expected_path = store.get_path(tree_root, index_c).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_c).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
}
|
||||
|
||||
// UPDATE TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_update() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let value_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key, value_a);
|
||||
|
||||
// --- update the value ---------------------------------------------------
|
||||
let value_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key, value_b);
|
||||
|
||||
// --- verify consistency -------------------------------------------------
|
||||
let mut tree_root = get_init_root();
|
||||
let index = NodeIndex::make(16, raw >> 48);
|
||||
let leaf_node = build_leaf_node(key, value_b, 16);
|
||||
tree_root = store.set_node(tree_root, index, leaf_node).unwrap().root;
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key), value_b);
|
||||
assert_eq!(smt.get_node(index).unwrap(), leaf_node);
|
||||
let expected_path = store.get_path(tree_root, index).unwrap().path;
|
||||
assert_eq!(smt.get_path(index).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
}
|
||||
|
||||
// DELETION TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_16() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let smt0 = smt.clone();
|
||||
let raw_a = 0b_01010101_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert another value into the tree ---------------------------------
|
||||
let smt1 = smt.clone();
|
||||
let raw_b = 0b_01011111_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
|
||||
assert_eq!(smt, smt1);
|
||||
|
||||
// --- delete the first inserted value ------------------------------------
|
||||
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
|
||||
assert_eq!(smt, smt0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_32() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let smt0 = smt.clone();
|
||||
let raw_a = 0b_01010101_01101100_01111111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert another with the same 16-bit prefix into the tree -----------
|
||||
let smt1 = smt.clone();
|
||||
let raw_b = 0b_01010101_01101100_00111111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// --- insert the 3rd value with the same 16-bit prefix into the tree -----
|
||||
let smt2 = smt.clone();
|
||||
let raw_c = 0b_01010101_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
|
||||
assert_eq!(smt, smt2);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
|
||||
assert_eq!(smt, smt1);
|
||||
|
||||
// --- delete the first inserted value ------------------------------------
|
||||
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
|
||||
assert_eq!(smt, smt0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_48_same_32_bit_prefix() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// test the case when all values share the same 32-bit prefix
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let smt0 = smt.clone();
|
||||
let raw_a = 0b_01010101_01010101_11111111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert another with the same 32-bit prefix into the tree -----------
|
||||
let smt1 = smt.clone();
|
||||
let raw_b = 0b_01010101_01010101_11111111_11111111_11010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// --- insert the 3rd value with the same 32-bit prefix into the tree -----
|
||||
let smt2 = smt.clone();
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_11110110_10010011_11100000_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
|
||||
assert_eq!(smt, smt2);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
|
||||
assert_eq!(smt, smt1);
|
||||
|
||||
// --- delete the first inserted value ------------------------------------
|
||||
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
|
||||
assert_eq!(smt, smt0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_48_mixed_prefix() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// test the case when some values share a 32-bit prefix and others share a 16-bit prefix
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let smt0 = smt.clone();
|
||||
let raw_a = 0b_01010101_01010101_11111111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert another with the same 16-bit prefix into the tree -----------
|
||||
let smt1 = smt.clone();
|
||||
let raw_b = 0b_01010101_01010101_01111111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// --- insert a value with the same 32-bit prefix as the first value -----
|
||||
let smt2 = smt.clone();
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_11010110_10010011_11100000_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// --- insert another value with the same 32-bit prefix as the first value
|
||||
let smt3 = smt.clone();
|
||||
let raw_d = 0b_01010101_01010101_11111111_11111111_11110110_10010011_11100000_00000000_u64;
|
||||
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_d)]);
|
||||
let value_d = [ONE, ZERO, ZERO, ZERO];
|
||||
smt.insert(key_d, value_d);
|
||||
|
||||
// --- delete the inserted values one-by-one ------------------------------
|
||||
assert_eq!(smt.insert(key_d, EMPTY_WORD), value_d);
|
||||
assert_eq!(smt, smt3);
|
||||
|
||||
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
|
||||
assert_eq!(smt, smt2);
|
||||
|
||||
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
|
||||
assert_eq!(smt, smt1);
|
||||
|
||||
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
|
||||
assert_eq!(smt, smt0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_64() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// test the case when all values share the same 48-bit prefix
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let smt0 = smt.clone();
|
||||
let raw_a = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert a value with the same 48-bit prefix into the tree -----------
|
||||
let smt1 = smt.clone();
|
||||
let raw_b = 0b_01010101_01010101_11111111_11111111_10110101_10101010_10111100_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// --- insert a value with the same 32-bit prefix into the tree -----------
|
||||
let smt2 = smt.clone();
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_11111101_10101010_10111100_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
let smt3 = smt.clone();
|
||||
let raw_d = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
|
||||
let key_d = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_d)]);
|
||||
let value_d = [ONE, ZERO, ZERO, ZERO];
|
||||
smt.insert(key_d, value_d);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_d, EMPTY_WORD), value_d);
|
||||
assert_eq!(smt, smt3);
|
||||
|
||||
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
|
||||
assert_eq!(smt, smt2);
|
||||
|
||||
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
|
||||
assert_eq!(smt, smt1);
|
||||
|
||||
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
|
||||
assert_eq!(smt, smt0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_64_leaf_promotion() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// --- delete from bottom tier (no promotion to upper tiers) --------------
|
||||
|
||||
// insert a value into the tree
|
||||
let raw_a = 0b_01010101_01010101_11111111_11111111_10101010_10101010_11111111_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// insert another value with a key having the same 64-bit prefix
|
||||
let key_b = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// insert a value with a key which shared the same 48-bit prefix
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_10101010_10101010_00111111_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// delete entry A and compare to the tree which was built from B and C
|
||||
smt.insert(key_a, EMPTY_WORD);
|
||||
|
||||
let mut expected_smt = TieredSmt::default();
|
||||
expected_smt.insert(key_b, value_b);
|
||||
expected_smt.insert(key_c, value_c);
|
||||
assert_eq!(smt, expected_smt);
|
||||
|
||||
// entries B and C should stay at depth 64
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 64);
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 64);
|
||||
|
||||
// --- delete from bottom tier (promotion to depth 48) --------------------
|
||||
|
||||
let mut smt = TieredSmt::default();
|
||||
smt.insert(key_a, value_a);
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// insert a value with a key which shared the same 32-bit prefix
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_11101010_10101010_11111111_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// delete entry A and compare to the tree which was built from B and C
|
||||
smt.insert(key_a, EMPTY_WORD);
|
||||
|
||||
let mut expected_smt = TieredSmt::default();
|
||||
expected_smt.insert(key_b, value_b);
|
||||
expected_smt.insert(key_c, value_c);
|
||||
assert_eq!(smt, expected_smt);
|
||||
|
||||
// entry B moves to depth 48, entry C stays at depth 48
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 48);
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 48);
|
||||
|
||||
// --- delete from bottom tier (promotion to depth 32) --------------------
|
||||
|
||||
let mut smt = TieredSmt::default();
|
||||
smt.insert(key_a, value_a);
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// insert a value with a key which shared the same 16-bit prefix
|
||||
let raw_c = 0b_01010101_01010101_01111111_11111111_10101010_10101010_11111111_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// delete entry A and compare to the tree which was built from B and C
|
||||
smt.insert(key_a, EMPTY_WORD);
|
||||
|
||||
let mut expected_smt = TieredSmt::default();
|
||||
expected_smt.insert(key_b, value_b);
|
||||
expected_smt.insert(key_c, value_c);
|
||||
assert_eq!(smt, expected_smt);
|
||||
|
||||
// entry B moves to depth 32, entry C stays at depth 32
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 32);
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 32);
|
||||
|
||||
// --- delete from bottom tier (promotion to depth 16) --------------------
|
||||
|
||||
let mut smt = TieredSmt::default();
|
||||
smt.insert(key_a, value_a);
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// insert a value with a key which shared prefix < 16 bits
|
||||
let raw_c = 0b_01010101_01010100_11111111_11111111_10101010_10101010_11111111_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// delete entry A and compare to the tree which was built from B and C
|
||||
smt.insert(key_a, EMPTY_WORD);
|
||||
|
||||
let mut expected_smt = TieredSmt::default();
|
||||
expected_smt.insert(key_b, value_b);
|
||||
expected_smt.insert(key_c, value_c);
|
||||
assert_eq!(smt, expected_smt);
|
||||
|
||||
// entry B moves to depth 16, entry C stays at depth 16
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 16);
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_sensitivity() {
|
||||
let raw = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000001_u64;
|
||||
let value = [ONE; WORD_SIZE];
|
||||
|
||||
let key_1 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let key_2 = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw)]);
|
||||
|
||||
let mut smt_1 = TieredSmt::default();
|
||||
|
||||
smt_1.insert(key_1, value);
|
||||
smt_1.insert(key_2, value);
|
||||
smt_1.insert(key_2, EMPTY_WORD);
|
||||
|
||||
let mut smt_2 = TieredSmt::default();
|
||||
smt_2.insert(key_1, value);
|
||||
|
||||
assert_eq!(smt_1.root(), smt_2.root());
|
||||
}
|
||||
|
||||
// BOTTOM TIER TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_bottom_tier() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// common prefix for the keys
|
||||
let prefix = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
// --- insert the first value ---------------------------------------------
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(prefix)]);
|
||||
let val_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key_a, val_a);
|
||||
|
||||
// --- insert the second value --------------------------------------------
|
||||
// this key has the same 64-bit prefix and thus both values should end up in the same
|
||||
// node at depth 64
|
||||
let key_b = RpoDigest::from([ZERO, ONE, ONE, Felt::new(prefix)]);
|
||||
let val_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key_b, val_b);
|
||||
|
||||
// --- build Merkle store with equivalent data ----------------------------
|
||||
let index = NodeIndex::make(64, prefix);
|
||||
// to build bottom leaf we sort by key starting with the least significant element, thus
|
||||
// key_b is smaller than key_a.
|
||||
let leaf_node = build_bottom_leaf_node(&[key_b, key_a], &[val_b, val_a]);
|
||||
let mut tree_root = get_init_root();
|
||||
tree_root = store.set_node(tree_root, index, leaf_node).unwrap().root;
|
||||
|
||||
// --- verify that data is consistent between store and tree --------------
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key_a), val_a);
|
||||
assert_eq!(smt.get_value(key_b), val_b);
|
||||
|
||||
assert_eq!(smt.get_node(index).unwrap(), leaf_node);
|
||||
let expected_path = store.get_path(tree_root, index).unwrap().path;
|
||||
assert_eq!(smt.get_path(index).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
|
||||
// make sure leaves are returned correctly
|
||||
let smt_clone = smt.clone();
|
||||
let mut leaves = smt_clone.bottom_leaves();
|
||||
assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a)])));
|
||||
assert_eq!(leaves.next(), None);
|
||||
|
||||
// --- update a leaf at the bottom tier -------------------------------------------------------
|
||||
|
||||
let val_a2 = [Felt::new(3); WORD_SIZE];
|
||||
assert_eq!(smt.insert(key_a, val_a2), val_a);
|
||||
|
||||
let leaf_node = build_bottom_leaf_node(&[key_b, key_a], &[val_b, val_a2]);
|
||||
store.set_node(tree_root, index, leaf_node).unwrap();
|
||||
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
|
||||
let mut leaves = smt.bottom_leaves();
|
||||
assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a2)])));
|
||||
assert_eq!(leaves.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_bottom_tier_two() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// --- insert the first value ---------------------------------------------
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let val_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key_a, val_a);
|
||||
|
||||
// --- insert the second value --------------------------------------------
|
||||
// the key for this value has the same 48-bit prefix as the key for the first value,
|
||||
// thus, on insertions, both should end up in different nodes at depth 64
|
||||
let raw_b = 0b_10101010_10101010_00011111_11111111_10010110_10010011_01100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let val_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key_b, val_b);
|
||||
|
||||
// --- build Merkle store with equivalent data ----------------------------
|
||||
let mut tree_root = get_init_root();
|
||||
let index_a = NodeIndex::make(64, raw_a);
|
||||
let leaf_node_a = build_bottom_leaf_node(&[key_a], &[val_a]);
|
||||
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
|
||||
|
||||
let index_b = NodeIndex::make(64, raw_b);
|
||||
let leaf_node_b = build_bottom_leaf_node(&[key_b], &[val_b]);
|
||||
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
|
||||
|
||||
// --- verify that data is consistent between store and tree --------------
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key_a), val_a);
|
||||
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
|
||||
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
|
||||
|
||||
assert_eq!(smt.get_value(key_b), val_b);
|
||||
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
|
||||
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
|
||||
// make sure leaves are returned correctly
|
||||
let mut leaves = smt.bottom_leaves();
|
||||
assert_eq!(leaves.next(), Some((leaf_node_b, vec![(key_b, val_b)])));
|
||||
assert_eq!(leaves.next(), Some((leaf_node_a, vec![(key_a, val_a)])));
|
||||
assert_eq!(leaves.next(), None);
|
||||
}
|
||||
|
||||
// GET PROOF TESTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Tests the membership and non-membership proof for a single at depth 64
|
||||
#[test]
|
||||
fn tsmt_get_proof_single_element_64() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
let raw_a = 0b_00000000_00000001_00000000_00000001_00000000_00000001_00000000_00000001_u64;
|
||||
let key_a = [ONE, ONE, ONE, raw_a.into()].into();
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// push element `a` to depth 64, by inserting another value that shares the 48-bit prefix
|
||||
let raw_b = 0b_00000000_00000001_00000000_00000001_00000000_00000001_00000000_00000000_u64;
|
||||
let key_b = [ONE, ONE, ONE, raw_b.into()].into();
|
||||
smt.insert(key_b, [ONE, ONE, ONE, ONE]);
|
||||
|
||||
// verify the proof for element `a`
|
||||
let proof = smt.prove(key_a);
|
||||
assert!(proof.verify_membership(&key_a, &value_a, &smt.root()));
|
||||
|
||||
// check that a value that is not inserted in the tree produces a valid membership proof for the
|
||||
// empty word
|
||||
let key = [ZERO, ZERO, ZERO, ZERO].into();
|
||||
let proof = smt.prove(key);
|
||||
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
|
||||
|
||||
// check that a key that shared the 64-bit prefix with `a`, but is not inserted, also has a
|
||||
// valid membership proof for the empty word
|
||||
let key = [ONE, ONE, ZERO, raw_a.into()].into();
|
||||
let proof = smt.prove(key);
|
||||
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_get_proof() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let raw_a = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert a value with the same 48-bit prefix into the tree -----------
|
||||
let raw_b = 0b_01010101_01010101_11111111_11111111_10110101_10101010_10111100_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
let smt_alt = smt.clone();
|
||||
|
||||
// --- insert a value with the same 32-bit prefix into the tree -----------
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_11111101_10101010_10111100_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// --- insert a value with the same 64-bit prefix as A into the tree ------
|
||||
let raw_d = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
|
||||
let key_d = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_d)]);
|
||||
let value_d = [ONE, ZERO, ZERO, ZERO];
|
||||
smt.insert(key_d, value_d);
|
||||
|
||||
// at this point the tree looks as follows:
|
||||
// - A and D are located in the same node at depth 64.
|
||||
// - B is located at depth 64 and shares the same 48-bit prefix with A and D.
|
||||
// - C is located at depth 48 and shares the same 32-bit prefix with A, B, and D.
|
||||
|
||||
// --- generate proof for key A and test that it verifies correctly -------
|
||||
let proof = smt.prove(key_a);
|
||||
assert!(proof.verify_membership(&key_a, &value_a, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key_a, &value_b, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_a, &EMPTY_WORD, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_b, &value_a, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_a, &value_a, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key_a), Some(value_a));
|
||||
assert_eq!(proof.get(&key_b), None);
|
||||
|
||||
// since A and D are stored in the same node, we should be able to use the proof to verify
|
||||
// membership of D
|
||||
assert!(proof.verify_membership(&key_d, &value_d, &smt.root()));
|
||||
assert_eq!(proof.get(&key_d), Some(value_d));
|
||||
|
||||
// --- generate proof for key B and test that it verifies correctly -------
|
||||
let proof = smt.prove(key_b);
|
||||
assert!(proof.verify_membership(&key_b, &value_b, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key_b, &value_a, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_b, &EMPTY_WORD, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_a, &value_b, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_b, &value_b, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key_b), Some(value_b));
|
||||
assert_eq!(proof.get(&key_a), None);
|
||||
|
||||
// --- generate proof for key C and test that it verifies correctly -------
|
||||
let proof = smt.prove(key_c);
|
||||
assert!(proof.verify_membership(&key_c, &value_c, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key_c, &value_a, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_c, &EMPTY_WORD, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_a, &value_c, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_c, &value_c, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key_c), Some(value_c));
|
||||
assert_eq!(proof.get(&key_b), None);
|
||||
|
||||
// --- generate proof for key D and test that it verifies correctly -------
|
||||
let proof = smt.prove(key_d);
|
||||
assert!(proof.verify_membership(&key_d, &value_d, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key_d, &value_b, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_d, &EMPTY_WORD, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_b, &value_d, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_d, &value_d, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key_d), Some(value_d));
|
||||
assert_eq!(proof.get(&key_b), None);
|
||||
|
||||
// since A and D are stored in the same node, we should be able to use the proof to verify
|
||||
// membership of A
|
||||
assert!(proof.verify_membership(&key_a, &value_a, &smt.root()));
|
||||
assert_eq!(proof.get(&key_a), Some(value_a));
|
||||
|
||||
// --- generate proof for an empty key at depth 64 ------------------------
|
||||
// this key has the same 48-bit prefix as A but is different from B
|
||||
let raw = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000011_u64;
|
||||
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
|
||||
let proof = smt.prove(key);
|
||||
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key, &value_a, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key, &EMPTY_WORD, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key), Some(EMPTY_WORD));
|
||||
assert_eq!(proof.get(&key_b), None);
|
||||
|
||||
// the same proof should verify against any key with the same 64-bit prefix
|
||||
let key2 = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw)]);
|
||||
assert!(proof.verify_membership(&key2, &EMPTY_WORD, &smt.root()));
|
||||
assert_eq!(proof.get(&key2), Some(EMPTY_WORD));
|
||||
|
||||
// but verifying if against a key with the same 63-bit prefix (or smaller) should fail
|
||||
let raw3 = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000010_u64;
|
||||
let key3 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw3)]);
|
||||
assert!(!proof.verify_membership(&key3, &EMPTY_WORD, &smt.root()));
|
||||
assert_eq!(proof.get(&key3), None);
|
||||
|
||||
// --- generate proof for an empty key at depth 48 ------------------------
|
||||
// this key has the same 32-prefix as A, B, C, and D, but is different from C
|
||||
let raw = 0b_01010101_01010101_11111111_11111111_00110101_10101010_11111100_00000000_u64;
|
||||
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
|
||||
let proof = smt.prove(key);
|
||||
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key, &value_a, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key, &EMPTY_WORD, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key), Some(EMPTY_WORD));
|
||||
assert_eq!(proof.get(&key_b), None);
|
||||
|
||||
// the same proof should verify against any key with the same 48-bit prefix
|
||||
let raw2 = 0b_01010101_01010101_11111111_11111111_00110101_10101010_01111100_00000000_u64;
|
||||
let key2 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw2)]);
|
||||
assert!(proof.verify_membership(&key2, &EMPTY_WORD, &smt.root()));
|
||||
assert_eq!(proof.get(&key2), Some(EMPTY_WORD));
|
||||
|
||||
// but verifying against a key with the same 47-bit prefix (or smaller) should fail
|
||||
let raw3 = 0b_01010101_01010101_11111111_11111111_00110101_10101011_11111100_00000000_u64;
|
||||
let key3 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw3)]);
|
||||
assert!(!proof.verify_membership(&key3, &EMPTY_WORD, &smt.root()));
|
||||
assert_eq!(proof.get(&key3), None);
|
||||
}
|
||||
|
||||
// ERROR TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_node_not_available() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
let raw = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let value = [ONE; WORD_SIZE];
|
||||
|
||||
// build an index which is just below the inserted leaf node
|
||||
let index = NodeIndex::make(17, raw >> 47);
|
||||
|
||||
// since we haven't inserted the node yet, we should be able to get node and path to this index
|
||||
assert!(smt.get_node(index).is_ok());
|
||||
assert!(smt.get_path(index).is_ok());
|
||||
|
||||
smt.insert(key, value);
|
||||
|
||||
// but once the node is inserted, everything under it should be unavailable
|
||||
assert!(smt.get_node(index).is_err());
|
||||
assert!(smt.get_path(index).is_err());
|
||||
|
||||
let index = NodeIndex::make(32, raw >> 32);
|
||||
assert!(smt.get_node(index).is_err());
|
||||
assert!(smt.get_path(index).is_err());
|
||||
|
||||
let index = NodeIndex::make(34, raw >> 30);
|
||||
assert!(smt.get_node(index).is_err());
|
||||
assert!(smt.get_path(index).is_err());
|
||||
|
||||
let index = NodeIndex::make(50, raw >> 14);
|
||||
assert!(smt.get_node(index).is_err());
|
||||
assert!(smt.get_path(index).is_err());
|
||||
|
||||
let index = NodeIndex::make(64, raw);
|
||||
assert!(smt.get_node(index).is_err());
|
||||
assert!(smt.get_path(index).is_err());
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
fn get_init_root() -> RpoDigest {
|
||||
EmptySubtreeRoots::empty_hashes(64)[0]
|
||||
}
|
||||
|
||||
fn build_leaf_node(key: RpoDigest, value: Word, depth: u8) -> RpoDigest {
|
||||
Rpo256::merge_in_domain(&[key, value.into()], depth.into())
|
||||
}
|
||||
|
||||
fn build_bottom_leaf_node(keys: &[RpoDigest], values: &[Word]) -> RpoDigest {
|
||||
assert_eq!(keys.len(), values.len());
|
||||
|
||||
let mut elements = Vec::with_capacity(keys.len());
|
||||
for (key, val) in keys.iter().zip(values.iter()) {
|
||||
elements.extend_from_slice(key.as_elements());
|
||||
elements.extend_from_slice(val.as_slice());
|
||||
}
|
||||
|
||||
Rpo256::hash_elements(&elements)
|
||||
}
|
||||
|
||||
fn get_non_empty_nodes(store: &MerkleStore) -> Vec<InnerNodeInfo> {
|
||||
store
|
||||
.inner_nodes()
|
||||
.filter(|node| !is_empty_subtree(&node.value))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn is_empty_subtree(node: &RpoDigest) -> bool {
|
||||
EmptySubtreeRoots::empty_hashes(255).contains(node)
|
||||
}
|
||||
584
src/merkle/tiered_smt/values.rs
Normal file
584
src/merkle/tiered_smt/values.rs
Normal file
@@ -0,0 +1,584 @@
|
||||
use super::{get_key_prefix, BTreeMap, LeafNodeIndex, RpoDigest, StarkField, Vec, Word};
|
||||
use crate::utils::vec;
|
||||
use core::{
|
||||
cmp::{Ord, Ordering},
|
||||
ops::RangeBounds,
|
||||
};
|
||||
use winter_utils::collections::btree_map::Entry;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Depths at which leaves can exist in a tiered SMT.
|
||||
const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
|
||||
|
||||
/// Maximum node depth. This is also the bottom tier of the tree.
|
||||
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
|
||||
|
||||
// VALUE STORE
|
||||
// ================================================================================================
|
||||
/// A store for key-value pairs for a Tiered Sparse Merkle tree.
|
||||
///
|
||||
/// The store is organized in a [BTreeMap] where keys are 64 most significant bits of a key, and
|
||||
/// the values are the corresponding key-value pairs (or a list of key-value pairs if more that
|
||||
/// a single key-value pair shares the same 64-bit prefix).
|
||||
///
|
||||
/// The store supports lookup by the full key (i.e. [RpoDigest]) as well as by the 64-bit key
|
||||
/// prefix.
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct ValueStore {
|
||||
values: BTreeMap<u64, StoreEntry>,
|
||||
}
|
||||
|
||||
impl ValueStore {
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a reference to the value stored under the specified key, or None if there is no
|
||||
/// value associated with the specified key.
|
||||
pub fn get(&self, key: &RpoDigest) -> Option<&Word> {
|
||||
let prefix = get_key_prefix(key);
|
||||
self.values.get(&prefix).and_then(|entry| entry.get(key))
|
||||
}
|
||||
|
||||
/// Returns the first key-value pair such that the key prefix is greater than or equal to the
|
||||
/// specified prefix.
|
||||
pub fn get_first(&self, prefix: u64) -> Option<&(RpoDigest, Word)> {
|
||||
self.range(prefix..).next()
|
||||
}
|
||||
|
||||
/// Returns the first key-value pair such that the key prefix is greater than or equal to the
|
||||
/// specified prefix and the key value is not equal to the exclude_key value.
|
||||
pub fn get_first_filtered(
|
||||
&self,
|
||||
prefix: u64,
|
||||
exclude_key: &RpoDigest,
|
||||
) -> Option<&(RpoDigest, Word)> {
|
||||
self.range(prefix..).find(|(key, _)| key != exclude_key)
|
||||
}
|
||||
|
||||
/// Returns a vector with key-value pairs for all keys with the specified 64-bit prefix, or
|
||||
/// None if no keys with the specified prefix are present in this store.
|
||||
pub fn get_all(&self, prefix: u64) -> Option<Vec<(RpoDigest, Word)>> {
|
||||
self.values.get(&prefix).map(|entry| match entry {
|
||||
StoreEntry::Single(kv_pair) => vec![*kv_pair],
|
||||
StoreEntry::List(kv_pairs) => kv_pairs.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns information about a sibling of a leaf node with the specified index, but only if
|
||||
/// this is the only sibling the leaf has in some subtree starting at the first tier.
|
||||
///
|
||||
/// For example, if `index` is an index at depth 32, and there is a leaf node at depth 32 with
|
||||
/// the same root at depth 16 as `index`, we say that this leaf is a lone sibling.
|
||||
///
|
||||
/// The returned tuple contains: they key-value pair of the sibling as well as the index of
|
||||
/// the node for the root of the common subtree in which both nodes are leaves.
|
||||
///
|
||||
/// This method assumes that the key-value pair for the specified index has already been
|
||||
/// removed from the store.
|
||||
pub fn get_lone_sibling(
|
||||
&self,
|
||||
index: LeafNodeIndex,
|
||||
) -> Option<(&RpoDigest, &Word, LeafNodeIndex)> {
|
||||
// iterate over tiers from top to bottom, looking at the tiers which are strictly above
|
||||
// the depth of the index. This implies that only tiers at depth 32 and 48 will be
|
||||
// considered. For each tier, check if the parent of the index at the higher tier
|
||||
// contains a single node. The fist tier (depth 16) is excluded because we cannot move
|
||||
// nodes at depth 16 to a higher tier. This implies that nodes at the first tier will
|
||||
// never have "lone siblings".
|
||||
for &tier_depth in TIER_DEPTHS.iter().filter(|&t| index.depth() > *t) {
|
||||
// compute the index of the root at a higher tier
|
||||
let mut parent_index = index;
|
||||
parent_index.move_up_to(tier_depth);
|
||||
|
||||
// find the lone sibling, if any; we need to handle the "last node" at a given tier
|
||||
// separately specify the bounds for the search correctly.
|
||||
let start_prefix = parent_index.value() << (MAX_DEPTH - tier_depth);
|
||||
let sibling = if start_prefix.leading_ones() as u8 == tier_depth {
|
||||
let mut iter = self.range(start_prefix..);
|
||||
iter.next().filter(|_| iter.next().is_none())
|
||||
} else {
|
||||
let end_prefix = (parent_index.value() + 1) << (MAX_DEPTH - tier_depth);
|
||||
let mut iter = self.range(start_prefix..end_prefix);
|
||||
iter.next().filter(|_| iter.next().is_none())
|
||||
};
|
||||
|
||||
if let Some((key, value)) = sibling {
|
||||
return Some((key, value, parent_index));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns an iterator over all key-value pairs in this store.
|
||||
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
self.values.iter().flat_map(|(_, entry)| entry.iter())
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts the specified key-value pair into this store and returns the value previously
|
||||
/// associated with the specified key.
|
||||
///
|
||||
/// If no value was previously associated with the specified key, None is returned.
|
||||
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
||||
let prefix = get_key_prefix(&key);
|
||||
match self.values.entry(prefix) {
|
||||
Entry::Occupied(mut entry) => entry.get_mut().insert(key, value),
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(StoreEntry::new(key, value));
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes the key-value pair for the specified key from this store and returns the value
|
||||
/// associated with this key.
|
||||
///
|
||||
/// If no value was associated with the specified key, None is returned.
|
||||
pub fn remove(&mut self, key: &RpoDigest) -> Option<Word> {
|
||||
let prefix = get_key_prefix(key);
|
||||
match self.values.entry(prefix) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
let (value, remove_entry) = entry.get_mut().remove(key);
|
||||
if remove_entry {
|
||||
entry.remove_entry();
|
||||
}
|
||||
value
|
||||
}
|
||||
Entry::Vacant(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over all key-value pairs contained in this store such that the most
|
||||
/// significant 64 bits of the key lay within the specified bounds.
|
||||
///
|
||||
/// The order of iteration is from the smallest to the largest key.
|
||||
fn range<R: RangeBounds<u64>>(&self, bounds: R) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
self.values.range(bounds).flat_map(|(_, entry)| entry.iter())
|
||||
}
|
||||
}
|
||||
|
||||
// VALUE NODE
|
||||
// ================================================================================================
|
||||
|
||||
/// An entry in the [ValueStore].
|
||||
///
|
||||
/// An entry can contain either a single key-value pair or a vector of key-value pairs sorted by
|
||||
/// key.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub enum StoreEntry {
|
||||
Single((RpoDigest, Word)),
|
||||
List(Vec<(RpoDigest, Word)>),
|
||||
}
|
||||
|
||||
impl StoreEntry {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns a new [StoreEntry] instantiated with a single key-value pair.
|
||||
pub fn new(key: RpoDigest, value: Word) -> Self {
|
||||
Self::Single((key, value))
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the value associated with the specified key, or None if this entry does not contain
|
||||
/// a value associated with the specified key.
|
||||
pub fn get(&self, key: &RpoDigest) -> Option<&Word> {
|
||||
match self {
|
||||
StoreEntry::Single(kv_pair) => {
|
||||
if kv_pair.0 == *key {
|
||||
Some(&kv_pair.1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
StoreEntry::List(kv_pairs) => {
|
||||
match kv_pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, key)) {
|
||||
Ok(pos) => Some(&kv_pairs[pos].1),
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an iterator over all key-value pairs in this entry.
|
||||
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
EntryIterator { entry: self, pos: 0 }
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts the specified key-value pair into this entry and returns the value previously
|
||||
/// associated with the specified key, or None if no value was associated with the specified
|
||||
/// key.
|
||||
///
|
||||
/// If a new key is inserted, this will also transform a `SingleEntry` into a `ListEntry`.
|
||||
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
||||
match self {
|
||||
StoreEntry::Single(kv_pair) => {
|
||||
// if the key is already in this entry, update the value and return
|
||||
if kv_pair.0 == key {
|
||||
let old_value = kv_pair.1;
|
||||
kv_pair.1 = value;
|
||||
return Some(old_value);
|
||||
}
|
||||
|
||||
// transform the entry into a list entry, and make sure the key-value pairs
|
||||
// are sorted by key
|
||||
let mut pairs = vec![*kv_pair, (key, value)];
|
||||
pairs.sort_by(|a, b| cmp_digests(&a.0, &b.0));
|
||||
|
||||
*self = StoreEntry::List(pairs);
|
||||
None
|
||||
}
|
||||
StoreEntry::List(pairs) => {
|
||||
match pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, &key)) {
|
||||
Ok(pos) => {
|
||||
let old_value = pairs[pos].1;
|
||||
pairs[pos].1 = value;
|
||||
Some(old_value)
|
||||
}
|
||||
Err(pos) => {
|
||||
pairs.insert(pos, (key, value));
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes the key-value pair with the specified key from this entry, and returns the value
|
||||
/// of the removed pair. If the entry did not contain a key-value pair for the specified key,
|
||||
/// None is returned.
|
||||
///
|
||||
/// If the last last key-value pair was removed from the entry, the second tuple value will
|
||||
/// be set to true.
|
||||
pub fn remove(&mut self, key: &RpoDigest) -> (Option<Word>, bool) {
|
||||
match self {
|
||||
StoreEntry::Single(kv_pair) => {
|
||||
if kv_pair.0 == *key {
|
||||
(Some(kv_pair.1), true)
|
||||
} else {
|
||||
(None, false)
|
||||
}
|
||||
}
|
||||
StoreEntry::List(kv_pairs) => {
|
||||
match kv_pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, key)) {
|
||||
Ok(pos) => {
|
||||
let kv_pair = kv_pairs.remove(pos);
|
||||
if kv_pairs.len() == 1 {
|
||||
*self = StoreEntry::Single(kv_pairs[0]);
|
||||
}
|
||||
(Some(kv_pair.1), false)
|
||||
}
|
||||
Err(_) => (None, false),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A custom iterator over key-value pairs of a [StoreEntry].
|
||||
///
|
||||
/// For a `SingleEntry` this returns only one value, but for `ListEntry`, this iterates over the
|
||||
/// entire list of key-value pairs.
|
||||
pub struct EntryIterator<'a> {
|
||||
entry: &'a StoreEntry,
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for EntryIterator<'a> {
|
||||
type Item = &'a (RpoDigest, Word);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self.entry {
|
||||
StoreEntry::Single(kv_pair) => {
|
||||
if self.pos == 0 {
|
||||
self.pos = 1;
|
||||
Some(kv_pair)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
StoreEntry::List(kv_pairs) => {
|
||||
if self.pos >= kv_pairs.len() {
|
||||
None
|
||||
} else {
|
||||
let kv_pair = &kv_pairs[self.pos];
|
||||
self.pos += 1;
|
||||
Some(kv_pair)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Compares two digests element-by-element using their integer representations starting with the
|
||||
/// most significant element.
|
||||
fn cmp_digests(d1: &RpoDigest, d2: &RpoDigest) -> Ordering {
|
||||
let d1 = Word::from(d1);
|
||||
let d2 = Word::from(d2);
|
||||
|
||||
for (v1, v2) in d1.iter().zip(d2.iter()).rev() {
|
||||
let v1 = v1.as_int();
|
||||
let v2 = v2.as_int();
|
||||
if v1 != v2 {
|
||||
return v1.cmp(&v2);
|
||||
}
|
||||
}
|
||||
|
||||
Ordering::Equal
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{LeafNodeIndex, RpoDigest, StoreEntry, ValueStore};
|
||||
use crate::{Felt, ONE, WORD_SIZE, ZERO};
|
||||
|
||||
#[test]
|
||||
fn test_insert() {
|
||||
let mut store = ValueStore::default();
|
||||
|
||||
// insert the first key-value pair into the store
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE; WORD_SIZE];
|
||||
|
||||
assert!(store.insert(key_a, value_a).is_none());
|
||||
assert_eq!(store.values.len(), 1);
|
||||
|
||||
let entry = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry = StoreEntry::Single((key_a, value_a));
|
||||
assert_eq!(entry, &expected_entry);
|
||||
|
||||
// insert a key-value pair with a different key into the store; since the keys are
|
||||
// different, another entry is added to the values map
|
||||
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
|
||||
assert!(store.insert(key_b, value_b).is_none());
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 = StoreEntry::Single((key_a, value_a));
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
let expected_entry2 = StoreEntry::Single((key_b, value_b));
|
||||
assert_eq!(entry2, &expected_entry2);
|
||||
|
||||
// insert a key-value pair with the same 64-bit key prefix as the first key; this should
|
||||
// transform the first entry into a List entry
|
||||
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
|
||||
assert!(store.insert(key_c, value_c).is_none());
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 = StoreEntry::List(vec![(key_c, value_c), (key_a, value_a)]);
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
let expected_entry2 = StoreEntry::Single((key_b, value_b));
|
||||
assert_eq!(entry2, &expected_entry2);
|
||||
|
||||
// replace values for keys a and b
|
||||
let value_a2 = [ONE, ONE, ONE, ZERO];
|
||||
let value_b2 = [ZERO, ZERO, ZERO, ONE];
|
||||
|
||||
assert_eq!(store.insert(key_a, value_a2), Some(value_a));
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
assert_eq!(store.insert(key_b, value_b2), Some(value_b));
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 = StoreEntry::List(vec![(key_c, value_c), (key_a, value_a2)]);
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
let expected_entry2 = StoreEntry::Single((key_b, value_b2));
|
||||
assert_eq!(entry2, &expected_entry2);
|
||||
|
||||
// insert one more key-value pair with the same 64-bit key-prefix as the first key
|
||||
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_d = [ZERO, ONE, ZERO, ZERO];
|
||||
|
||||
assert!(store.insert(key_d, value_d).is_none());
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 =
|
||||
StoreEntry::List(vec![(key_c, value_c), (key_a, value_a2), (key_d, value_d)]);
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
let expected_entry2 = StoreEntry::Single((key_b, value_b2));
|
||||
assert_eq!(entry2, &expected_entry2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove() {
|
||||
// populate the value store
|
||||
let mut store = ValueStore::default();
|
||||
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE; WORD_SIZE];
|
||||
store.insert(key_a, value_a);
|
||||
|
||||
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
store.insert(key_b, value_b);
|
||||
|
||||
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
store.insert(key_c, value_c);
|
||||
|
||||
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_d = [ZERO, ONE, ZERO, ZERO];
|
||||
store.insert(key_d, value_d);
|
||||
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 =
|
||||
StoreEntry::List(vec![(key_c, value_c), (key_a, value_a), (key_d, value_d)]);
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
let expected_entry2 = StoreEntry::Single((key_b, value_b));
|
||||
assert_eq!(entry2, &expected_entry2);
|
||||
|
||||
// remove non-existent keys
|
||||
let key_e = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_a)]);
|
||||
assert!(store.remove(&key_e).is_none());
|
||||
|
||||
let raw_f = 0b_11111110_11111111_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_f = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_f)]);
|
||||
assert!(store.remove(&key_f).is_none());
|
||||
|
||||
// remove keys from the list entry
|
||||
assert_eq!(store.remove(&key_c).unwrap(), value_c);
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 = StoreEntry::List(vec![(key_a, value_a), (key_d, value_d)]);
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
assert_eq!(store.remove(&key_a).unwrap(), value_a);
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 = StoreEntry::Single((key_d, value_d));
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
assert_eq!(store.remove(&key_d).unwrap(), value_d);
|
||||
assert!(store.values.get(&raw_a).is_none());
|
||||
assert_eq!(store.values.len(), 1);
|
||||
|
||||
// remove a key from a single entry
|
||||
assert_eq!(store.remove(&key_b).unwrap(), value_b);
|
||||
assert!(store.values.get(&raw_b).is_none());
|
||||
assert_eq!(store.values.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range() {
|
||||
// populate the value store
|
||||
let mut store = ValueStore::default();
|
||||
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE; WORD_SIZE];
|
||||
store.insert(key_a, value_a);
|
||||
|
||||
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
store.insert(key_b, value_b);
|
||||
|
||||
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
store.insert(key_c, value_c);
|
||||
|
||||
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_d = [ZERO, ONE, ZERO, ZERO];
|
||||
store.insert(key_d, value_d);
|
||||
|
||||
let raw_e = 0b_10101000_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_e = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_e)]);
|
||||
let value_e = [ZERO, ZERO, ZERO, ONE];
|
||||
store.insert(key_e, value_e);
|
||||
|
||||
// check the entire range
|
||||
let mut iter = store.range(..u64::MAX);
|
||||
assert_eq!(iter.next(), Some(&(key_e, value_e)));
|
||||
assert_eq!(iter.next(), Some(&(key_c, value_c)));
|
||||
assert_eq!(iter.next(), Some(&(key_a, value_a)));
|
||||
assert_eq!(iter.next(), Some(&(key_d, value_d)));
|
||||
assert_eq!(iter.next(), Some(&(key_b, value_b)));
|
||||
assert_eq!(iter.next(), None);
|
||||
|
||||
// check all but e
|
||||
let mut iter = store.range(raw_a..u64::MAX);
|
||||
assert_eq!(iter.next(), Some(&(key_c, value_c)));
|
||||
assert_eq!(iter.next(), Some(&(key_a, value_a)));
|
||||
assert_eq!(iter.next(), Some(&(key_d, value_d)));
|
||||
assert_eq!(iter.next(), Some(&(key_b, value_b)));
|
||||
assert_eq!(iter.next(), None);
|
||||
|
||||
// check all but e and b
|
||||
let mut iter = store.range(raw_a..raw_b);
|
||||
assert_eq!(iter.next(), Some(&(key_c, value_c)));
|
||||
assert_eq!(iter.next(), Some(&(key_a, value_a)));
|
||||
assert_eq!(iter.next(), Some(&(key_d, value_d)));
|
||||
assert_eq!(iter.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_lone_sibling() {
|
||||
// populate the value store
|
||||
let mut store = ValueStore::default();
|
||||
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE; WORD_SIZE];
|
||||
store.insert(key_a, value_a);
|
||||
|
||||
let raw_b = 0b_11111111_11111111_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
store.insert(key_b, value_b);
|
||||
|
||||
// check sibling node for `a`
|
||||
let index = LeafNodeIndex::make(32, 0b_10101010_10101010_00011111_11111110);
|
||||
let parent_index = LeafNodeIndex::make(16, 0b_10101010_10101010);
|
||||
assert_eq!(store.get_lone_sibling(index), Some((&key_a, &value_a, parent_index)));
|
||||
|
||||
// check sibling node for `b`
|
||||
let index = LeafNodeIndex::make(32, 0b_11111111_11111111_00011111_11111111);
|
||||
let parent_index = LeafNodeIndex::make(16, 0b_11111111_11111111);
|
||||
assert_eq!(store.get_lone_sibling(index), Some((&key_b, &value_b, parent_index)));
|
||||
|
||||
// check some other sibling for some other index
|
||||
let index = LeafNodeIndex::make(32, 0b_11101010_10101010);
|
||||
assert_eq!(store.get_lone_sibling(index), None);
|
||||
}
|
||||
}
|
||||
19
src/rand/mod.rs
Normal file
19
src/rand/mod.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
//! Pseudo-random element generation.
|
||||
|
||||
pub use winter_crypto::{DefaultRandomCoin as WinterRandomCoin, RandomCoin, RandomCoinError};
|
||||
|
||||
use crate::{Felt, FieldElement, StarkField, Word, ZERO};
|
||||
|
||||
mod rpo;
|
||||
pub use rpo::RpoRandomCoin;
|
||||
|
||||
/// Pseudo-random element generator.
|
||||
///
|
||||
/// An instance can be used to draw, uniformly at random, base field elements as well as [Word]s.
|
||||
pub trait FeltRng {
|
||||
/// Draw, uniformly at random, a base field element.
|
||||
fn draw_element(&mut self) -> Felt;
|
||||
|
||||
/// Draw, uniformly at random, a [Word].
|
||||
fn draw_word(&mut self) -> Word;
|
||||
}
|
||||
267
src/rand/rpo.rs
Normal file
267
src/rand/rpo.rs
Normal file
@@ -0,0 +1,267 @@
|
||||
use super::{Felt, FeltRng, FieldElement, StarkField, Word, ZERO};
|
||||
use crate::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
utils::{
|
||||
collections::Vec, string::ToString, vec, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, Serializable,
|
||||
},
|
||||
};
|
||||
pub use winter_crypto::{RandomCoin, RandomCoinError};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
const STATE_WIDTH: usize = Rpo256::STATE_WIDTH;
|
||||
const RATE_START: usize = Rpo256::RATE_RANGE.start;
|
||||
const RATE_END: usize = Rpo256::RATE_RANGE.end;
|
||||
const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start) / 2;
|
||||
|
||||
// RPO RANDOM COIN
|
||||
// ================================================================================================
|
||||
/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
|
||||
/// described in https://eprint.iacr.org/2011/499.pdf.
|
||||
///
|
||||
/// The simplification is related to the following facts:
|
||||
/// 1. A call to the reseed method implies one and only one call to the permutation function.
|
||||
/// This is possible because in our case we never reseed with more than 4 field elements.
|
||||
/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
|
||||
/// material.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct RpoRandomCoin {
|
||||
state: [Felt; STATE_WIDTH],
|
||||
current: usize,
|
||||
}
|
||||
|
||||
impl RpoRandomCoin {
|
||||
/// Returns a new [RpoRandomCoin] initialize with the specified seed.
|
||||
pub fn new(seed: Word) -> Self {
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
for i in 0..HALF_RATE_WIDTH {
|
||||
state[RATE_START + i] += seed[i];
|
||||
}
|
||||
|
||||
// Absorb
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
|
||||
RpoRandomCoin { state, current: RATE_START }
|
||||
}
|
||||
|
||||
/// Returns an [RpoRandomCoin] instantiated from the provided components.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if `current` is smaller than 4 or greater than or equal to 12.
|
||||
pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
|
||||
assert!(
|
||||
(RATE_START..RATE_END).contains(¤t),
|
||||
"current value outside of valid range"
|
||||
);
|
||||
Self { state, current }
|
||||
}
|
||||
|
||||
/// Returns components of this random coin.
|
||||
pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
|
||||
(self.state, self.current)
|
||||
}
|
||||
|
||||
fn draw_basefield(&mut self) -> Felt {
|
||||
if self.current == RATE_END {
|
||||
Rpo256::apply_permutation(&mut self.state);
|
||||
self.current = RATE_START;
|
||||
}
|
||||
|
||||
self.current += 1;
|
||||
self.state[self.current - 1]
|
||||
}
|
||||
}
|
||||
|
||||
// RANDOM COIN IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl RandomCoin for RpoRandomCoin {
|
||||
type BaseField = Felt;
|
||||
type Hasher = Rpo256;
|
||||
|
||||
fn new(seed: &[Self::BaseField]) -> Self {
|
||||
let digest: Word = Rpo256::hash_elements(seed).into();
|
||||
Self::new(digest)
|
||||
}
|
||||
|
||||
fn reseed(&mut self, data: RpoDigest) {
|
||||
// Reset buffer
|
||||
self.current = RATE_START;
|
||||
|
||||
// Add the new seed material to the first half of the rate portion of the RPO state
|
||||
let data: Word = data.into();
|
||||
|
||||
self.state[RATE_START] += data[0];
|
||||
self.state[RATE_START + 1] += data[1];
|
||||
self.state[RATE_START + 2] += data[2];
|
||||
self.state[RATE_START + 3] += data[3];
|
||||
|
||||
// Absorb
|
||||
Rpo256::apply_permutation(&mut self.state);
|
||||
}
|
||||
|
||||
fn check_leading_zeros(&self, value: u64) -> u32 {
|
||||
let value = Felt::new(value);
|
||||
let mut state_tmp = self.state;
|
||||
|
||||
state_tmp[RATE_START] += value;
|
||||
|
||||
Rpo256::apply_permutation(&mut state_tmp);
|
||||
|
||||
let first_rate_element = state_tmp[RATE_START].as_int();
|
||||
first_rate_element.trailing_zeros()
|
||||
}
|
||||
|
||||
fn draw<E: FieldElement<BaseField = Felt>>(&mut self) -> Result<E, RandomCoinError> {
|
||||
let ext_degree = E::EXTENSION_DEGREE;
|
||||
let mut result = vec![ZERO; ext_degree];
|
||||
for r in result.iter_mut().take(ext_degree) {
|
||||
*r = self.draw_basefield();
|
||||
}
|
||||
|
||||
let result = E::slice_from_base_elements(&result);
|
||||
Ok(result[0])
|
||||
}
|
||||
|
||||
fn draw_integers(
|
||||
&mut self,
|
||||
num_values: usize,
|
||||
domain_size: usize,
|
||||
nonce: u64,
|
||||
) -> Result<Vec<usize>, RandomCoinError> {
|
||||
assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
|
||||
assert!(num_values < domain_size, "number of values must be smaller than domain size");
|
||||
|
||||
// absorb the nonce
|
||||
let nonce = Felt::new(nonce);
|
||||
self.state[RATE_START] += nonce;
|
||||
Rpo256::apply_permutation(&mut self.state);
|
||||
|
||||
// reset the buffer
|
||||
self.current = RATE_START;
|
||||
|
||||
// determine how many bits are needed to represent valid values in the domain
|
||||
let v_mask = (domain_size - 1) as u64;
|
||||
|
||||
// draw values from PRNG until we get as many unique values as specified by num_queries
|
||||
let mut values = Vec::new();
|
||||
for _ in 0..1000 {
|
||||
// get the next pseudo-random field element
|
||||
let value = self.draw_basefield().as_int();
|
||||
|
||||
// use the mask to get a value within the range
|
||||
let value = (value & v_mask) as usize;
|
||||
|
||||
values.push(value);
|
||||
if values.len() == num_values {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if values.len() < num_values {
|
||||
return Err(RandomCoinError::FailedToDrawIntegers(num_values, values.len(), 1000));
|
||||
}
|
||||
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
|
||||
// FELT RNG IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl FeltRng for RpoRandomCoin {
|
||||
fn draw_element(&mut self) -> Felt {
|
||||
self.draw_basefield()
|
||||
}
|
||||
|
||||
fn draw_word(&mut self) -> Word {
|
||||
let mut output = [ZERO; 4];
|
||||
for o in output.iter_mut() {
|
||||
*o = self.draw_basefield();
|
||||
}
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl Serializable for RpoRandomCoin {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
self.state.iter().for_each(|v| v.write_into(target));
|
||||
// casting to u8 is OK because `current` is always between 4 and 12.
|
||||
target.write_u8(self.current as u8);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RpoRandomCoin {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let state = [
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
];
|
||||
let current = source.read_u8()? as usize;
|
||||
if !(RATE_START..RATE_END).contains(¤t) {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"current value outside of valid range".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(Self { state, current })
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Deserializable, FeltRng, RpoRandomCoin, Serializable, ZERO};
|
||||
use crate::ONE;
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_felt() {
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let output = rpocoin.draw_element();
|
||||
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let expected = rpocoin.draw_basefield();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_word() {
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let output = rpocoin.draw_word();
|
||||
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let mut expected = [ZERO; 4];
|
||||
for o in expected.iter_mut() {
|
||||
*o = rpocoin.draw_basefield();
|
||||
}
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_serialization() {
|
||||
let coin1 = RpoRandomCoin::from_parts([ONE; 12], 5);
|
||||
|
||||
let bytes = coin1.to_bytes();
|
||||
let coin2 = RpoRandomCoin::read_from_bytes(&bytes).unwrap();
|
||||
assert_eq!(coin1, coin2);
|
||||
}
|
||||
}
|
||||
21
src/utils.rs
21
src/utils.rs
@@ -1,21 +0,0 @@
|
||||
use super::Word;
|
||||
use crate::utils::string::String;
|
||||
use core::fmt::{self, Write};
|
||||
|
||||
// RE-EXPORTS
|
||||
// ================================================================================================
|
||||
pub use winter_utils::{
|
||||
collections, string, uninit_vector, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, Serializable, SliceReader,
|
||||
};
|
||||
|
||||
/// Converts a [Word] into hex.
|
||||
pub fn word_to_hex(w: &Word) -> Result<String, fmt::Error> {
|
||||
let mut s = String::new();
|
||||
|
||||
for byte in w.iter().flat_map(|e| e.to_bytes()) {
|
||||
write!(s, "{byte:02x}")?;
|
||||
}
|
||||
|
||||
Ok(s)
|
||||
}
|
||||
31
src/utils/diff.rs
Normal file
31
src/utils/diff.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
/// A trait for computing the difference between two objects.
|
||||
pub trait Diff<K: Ord + Clone, V: Clone> {
|
||||
/// The type that describes the difference between two objects.
|
||||
type DiffType;
|
||||
|
||||
/// Returns a [Self::DiffType] object that represents the difference between this object and
|
||||
/// other.
|
||||
fn diff(&self, other: &Self) -> Self::DiffType;
|
||||
}
|
||||
|
||||
/// A trait for applying the difference between two objects.
|
||||
pub trait ApplyDiff<K: Ord + Clone, V: Clone> {
|
||||
/// The type that describes the difference between two objects.
|
||||
type DiffType;
|
||||
|
||||
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
|
||||
fn apply(&mut self, diff: Self::DiffType);
|
||||
}
|
||||
|
||||
/// A trait for applying the difference between two objects with the possibility of failure.
|
||||
pub trait TryApplyDiff<K: Ord + Clone, V: Clone> {
|
||||
/// The type that describes the difference between two objects.
|
||||
type DiffType;
|
||||
|
||||
/// An error type that can be returned if the changes cannot be applied.
|
||||
type Error;
|
||||
|
||||
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
|
||||
/// Returns an error if the changes cannot be applied.
|
||||
fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), Self::Error>;
|
||||
}
|
||||
504
src/utils/kv_map.rs
Normal file
504
src/utils/kv_map.rs
Normal file
@@ -0,0 +1,504 @@
|
||||
use super::{collections::ApplyDiff, diff::Diff};
|
||||
use core::cell::RefCell;
|
||||
use winter_utils::{
|
||||
collections::{btree_map::IntoIter, BTreeMap, BTreeSet},
|
||||
Box,
|
||||
};
|
||||
|
||||
// KEY-VALUE MAP TRAIT
|
||||
// ================================================================================================
|
||||
|
||||
/// A trait that defines the interface for a key-value map.
|
||||
pub trait KvMap<K: Ord + Clone, V: Clone>:
|
||||
Extend<(K, V)> + FromIterator<(K, V)> + IntoIterator<Item = (K, V)>
|
||||
{
|
||||
fn get(&self, key: &K) -> Option<&V>;
|
||||
fn contains_key(&self, key: &K) -> bool;
|
||||
fn len(&self) -> usize;
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
fn insert(&mut self, key: K, value: V) -> Option<V>;
|
||||
fn remove(&mut self, key: &K) -> Option<V>;
|
||||
|
||||
fn iter(&self) -> Box<dyn Iterator<Item = (&K, &V)> + '_>;
|
||||
}
|
||||
|
||||
// BTREE MAP `KvMap` IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
impl<K: Ord + Clone, V: Clone> KvMap<K, V> for BTreeMap<K, V> {
|
||||
fn get(&self, key: &K) -> Option<&V> {
|
||||
self.get(key)
|
||||
}
|
||||
|
||||
fn contains_key(&self, key: &K) -> bool {
|
||||
self.contains_key(key)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
|
||||
fn insert(&mut self, key: K, value: V) -> Option<V> {
|
||||
self.insert(key, value)
|
||||
}
|
||||
|
||||
fn remove(&mut self, key: &K) -> Option<V> {
|
||||
self.remove(key)
|
||||
}
|
||||
|
||||
fn iter(&self) -> Box<dyn Iterator<Item = (&K, &V)> + '_> {
|
||||
Box::new(self.iter())
|
||||
}
|
||||
}
|
||||
|
||||
// RECORDING MAP
|
||||
// ================================================================================================
|
||||
|
||||
/// A [RecordingMap] that records read requests to the underlying key-value map.
|
||||
///
|
||||
/// The data recorder is used to generate a proof for read requests.
|
||||
///
|
||||
/// The [RecordingMap] is composed of three parts:
|
||||
/// - `data`: which contains the current set of key-value pairs in the map.
|
||||
/// - `updates`: which tracks keys for which values have been changed since the map was
|
||||
/// instantiated. updates include both insertions, removals and updates of values under existing
|
||||
/// keys.
|
||||
/// - `trace`: which contains the key-value pairs from the original data which have been accesses
|
||||
/// since the map was instantiated.
|
||||
#[derive(Debug, Default, Clone, Eq, PartialEq)]
|
||||
pub struct RecordingMap<K, V> {
|
||||
data: BTreeMap<K, V>,
|
||||
updates: BTreeSet<K>,
|
||||
trace: RefCell<BTreeMap<K, V>>,
|
||||
}
|
||||
|
||||
impl<K: Ord + Clone, V: Clone> RecordingMap<K, V> {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns a new [RecordingMap] instance initialized with the provided key-value pairs.
|
||||
/// ([BTreeMap]).
|
||||
pub fn new(init: impl IntoIterator<Item = (K, V)>) -> Self {
|
||||
RecordingMap {
|
||||
data: init.into_iter().collect(),
|
||||
updates: BTreeSet::new(),
|
||||
trace: RefCell::new(BTreeMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
pub fn inner(&self) -> &BTreeMap<K, V> {
|
||||
&self.data
|
||||
}
|
||||
|
||||
// FINALIZER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Consumes the [RecordingMap] and returns a ([BTreeMap], [BTreeMap]) tuple. The first
|
||||
/// element of the tuple is a map that represents the state of the map at the time `.finalize()`
|
||||
/// is called. The second element contains the key-value pairs from the initial data set that
|
||||
/// were read during recording.
|
||||
pub fn finalize(self) -> (BTreeMap<K, V>, BTreeMap<K, V>) {
|
||||
(self.data, self.trace.take())
|
||||
}
|
||||
|
||||
// TEST HELPERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn trace_len(&self) -> usize {
|
||||
self.trace.borrow().len()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn updates_len(&self) -> usize {
|
||||
self.updates.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Ord + Clone, V: Clone> KvMap<K, V> for RecordingMap<K, V> {
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a reference to the value associated with the given key if the value exists.
|
||||
///
|
||||
/// If the key is part of the initial data set, the key access is recorded.
|
||||
fn get(&self, key: &K) -> Option<&V> {
|
||||
self.data.get(key).map(|value| {
|
||||
if !self.updates.contains(key) {
|
||||
self.trace.borrow_mut().insert(key.clone(), value.clone());
|
||||
}
|
||||
value
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a boolean to indicate whether the given key exists in the data set.
|
||||
///
|
||||
/// If the key is part of the initial data set, the key access is recorded.
|
||||
fn contains_key(&self, key: &K) -> bool {
|
||||
self.get(key).is_some()
|
||||
}
|
||||
|
||||
/// Returns the number of key-value pairs in the data set.
|
||||
fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
// MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts a key-value pair into the data set.
|
||||
///
|
||||
/// If the key already exists in the data set, the value is updated and the old value is
|
||||
/// returned.
|
||||
fn insert(&mut self, key: K, value: V) -> Option<V> {
|
||||
let new_update = self.updates.insert(key.clone());
|
||||
self.data.insert(key.clone(), value).map(|old_value| {
|
||||
if new_update {
|
||||
self.trace.borrow_mut().insert(key, old_value.clone());
|
||||
}
|
||||
old_value
|
||||
})
|
||||
}
|
||||
|
||||
/// Removes a key-value pair from the data set.
|
||||
///
|
||||
/// If the key exists in the data set, the old value is returned.
|
||||
fn remove(&mut self, key: &K) -> Option<V> {
|
||||
self.data.remove(key).map(|old_value| {
|
||||
let new_update = self.updates.insert(key.clone());
|
||||
if new_update {
|
||||
self.trace.borrow_mut().insert(key.clone(), old_value.clone());
|
||||
}
|
||||
old_value
|
||||
})
|
||||
}
|
||||
|
||||
// ITERATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over the key-value pairs in the data set.
|
||||
fn iter(&self) -> Box<dyn Iterator<Item = (&K, &V)> + '_> {
|
||||
Box::new(self.data.iter())
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Clone + Ord, V: Clone> Extend<(K, V)> for RecordingMap<K, V> {
|
||||
fn extend<T: IntoIterator<Item = (K, V)>>(&mut self, iter: T) {
|
||||
iter.into_iter().for_each(move |(k, v)| {
|
||||
self.insert(k, v);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Clone + Ord, V: Clone> FromIterator<(K, V)> for RecordingMap<K, V> {
|
||||
fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
|
||||
Self::new(iter)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Clone + Ord, V: Clone> IntoIterator for RecordingMap<K, V> {
|
||||
type Item = (K, V);
|
||||
type IntoIter = IntoIter<K, V>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.data.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
// KV MAP DIFF
|
||||
// ================================================================================================
|
||||
/// [KvMapDiff] stores the difference between two key-value maps.
|
||||
///
|
||||
/// The [KvMapDiff] is composed of two parts:
|
||||
/// - `updates` - a map of key-value pairs that were updated in the second map compared to the
|
||||
/// first map. This includes new key-value pairs.
|
||||
/// - `removed` - a set of keys that were removed from the second map compared to the first map.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KvMapDiff<K, V> {
|
||||
pub updated: BTreeMap<K, V>,
|
||||
pub removed: BTreeSet<K>,
|
||||
}
|
||||
|
||||
impl<K, V> KvMapDiff<K, V> {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Creates a new [KvMapDiff] instance.
|
||||
pub fn new() -> Self {
|
||||
KvMapDiff {
|
||||
updated: BTreeMap::new(),
|
||||
removed: BTreeSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, V> Default for KvMapDiff<K, V> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Ord + Clone, V: Clone + PartialEq, T: KvMap<K, V>> Diff<K, V> for T {
|
||||
type DiffType = KvMapDiff<K, V>;
|
||||
|
||||
fn diff(&self, other: &T) -> Self::DiffType {
|
||||
let mut diff = KvMapDiff::default();
|
||||
for (k, v) in self.iter() {
|
||||
if let Some(other_value) = other.get(k) {
|
||||
if v != other_value {
|
||||
diff.updated.insert(k.clone(), other_value.clone());
|
||||
}
|
||||
} else {
|
||||
diff.removed.insert(k.clone());
|
||||
}
|
||||
}
|
||||
for (k, v) in other.iter() {
|
||||
if self.get(k).is_none() {
|
||||
diff.updated.insert(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
diff
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Ord + Clone, V: Clone, T: KvMap<K, V>> ApplyDiff<K, V> for T {
|
||||
type DiffType = KvMapDiff<K, V>;
|
||||
|
||||
fn apply(&mut self, diff: Self::DiffType) {
|
||||
for (k, v) in diff.updated {
|
||||
self.insert(k, v);
|
||||
}
|
||||
for k in diff.removed {
|
||||
self.remove(&k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const ITEMS: [(u64, u64); 5] = [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)];
|
||||
|
||||
#[test]
|
||||
fn test_get_item() {
|
||||
// instantiate a recording map
|
||||
let map = RecordingMap::new(ITEMS.to_vec());
|
||||
|
||||
// get a few items
|
||||
let get_items = [0, 1, 2];
|
||||
for key in get_items.iter() {
|
||||
map.get(key);
|
||||
}
|
||||
|
||||
// convert the map into a proof
|
||||
let (_, proof) = map.finalize();
|
||||
|
||||
// check that the proof contains the expected values
|
||||
for (key, value) in ITEMS.iter() {
|
||||
match get_items.contains(key) {
|
||||
true => assert_eq!(proof.get(key), Some(value)),
|
||||
false => assert_eq!(proof.get(key), None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contains_key() {
|
||||
// instantiate a recording map
|
||||
let map = RecordingMap::new(ITEMS.to_vec());
|
||||
|
||||
// check if the map contains a few items
|
||||
let get_items = [0, 1, 2];
|
||||
for key in get_items.iter() {
|
||||
map.contains_key(key);
|
||||
}
|
||||
|
||||
// convert the map into a proof
|
||||
let (_, proof) = map.finalize();
|
||||
|
||||
// check that the proof contains the expected values
|
||||
for (key, _) in ITEMS.iter() {
|
||||
match get_items.contains(key) {
|
||||
true => assert!(proof.contains_key(key)),
|
||||
false => assert!(!proof.contains_key(key)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_len() {
|
||||
// instantiate a recording map
|
||||
let mut map = RecordingMap::new(ITEMS.to_vec());
|
||||
// length of the map should be equal to the number of items
|
||||
assert_eq!(map.len(), ITEMS.len());
|
||||
|
||||
// inserting entry with key that already exists should not change the length, but it does
|
||||
// add entries to the trace and update sets
|
||||
map.insert(4, 5);
|
||||
assert_eq!(map.len(), ITEMS.len());
|
||||
assert_eq!(map.trace_len(), 1);
|
||||
assert_eq!(map.updates_len(), 1);
|
||||
|
||||
// inserting entry with new key should increase the length; it should also record the key
|
||||
// as an updated key, but the trace length does not change since old values were not touched
|
||||
map.insert(5, 5);
|
||||
assert_eq!(map.len(), ITEMS.len() + 1);
|
||||
assert_eq!(map.trace_len(), 1);
|
||||
assert_eq!(map.updates_len(), 2);
|
||||
|
||||
// get some items so that they are saved in the trace; this should record original items
|
||||
// in the trace, but should not affect the set of updates
|
||||
let get_items = [0, 1, 2];
|
||||
for key in get_items.iter() {
|
||||
map.contains_key(key);
|
||||
}
|
||||
assert_eq!(map.trace_len(), 4);
|
||||
assert_eq!(map.updates_len(), 2);
|
||||
|
||||
// read the same items again, this should not have any effect on either length, trace, or
|
||||
// the set of updates
|
||||
let get_items = [0, 1, 2];
|
||||
for key in get_items.iter() {
|
||||
map.contains_key(key);
|
||||
}
|
||||
assert_eq!(map.trace_len(), 4);
|
||||
assert_eq!(map.updates_len(), 2);
|
||||
|
||||
// read a newly inserted item; this should not affect either length, trace, or the set of
|
||||
// updates
|
||||
let _val = map.get(&5).unwrap();
|
||||
assert_eq!(map.trace_len(), 4);
|
||||
assert_eq!(map.updates_len(), 2);
|
||||
|
||||
// update a newly inserted item; this should not affect either length, trace, or the set
|
||||
// of updates
|
||||
map.insert(5, 11);
|
||||
assert_eq!(map.trace_len(), 4);
|
||||
assert_eq!(map.updates_len(), 2);
|
||||
|
||||
// Note: The length reported by the proof will be different to the length originally
|
||||
// reported by the map.
|
||||
let (_, proof) = map.finalize();
|
||||
|
||||
// length of the proof should be equal to get_items + 1. The extra item is the original
|
||||
// value at key = 4u64
|
||||
assert_eq!(proof.len(), get_items.len() + 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iter() {
|
||||
let mut map = RecordingMap::new(ITEMS.to_vec());
|
||||
assert!(map.iter().all(|(x, y)| ITEMS.contains(&(*x, *y))));
|
||||
|
||||
// when inserting entry with key that already exists the iterator should return the new value
|
||||
let new_value = 5;
|
||||
map.insert(4, new_value);
|
||||
assert_eq!(map.iter().count(), ITEMS.len());
|
||||
assert!(map.iter().all(|(x, y)| if x == &4 {
|
||||
y == &new_value
|
||||
} else {
|
||||
ITEMS.contains(&(*x, *y))
|
||||
}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_empty() {
|
||||
// instantiate an empty recording map
|
||||
let empty_map: RecordingMap<u64, u64> = RecordingMap::default();
|
||||
assert!(empty_map.is_empty());
|
||||
|
||||
// instantiate a non-empty recording map
|
||||
let map = RecordingMap::new(ITEMS.to_vec());
|
||||
assert!(!map.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove() {
|
||||
let mut map = RecordingMap::new(ITEMS.to_vec());
|
||||
|
||||
// remove an item that exists
|
||||
let key = 0;
|
||||
let value = map.remove(&key).unwrap();
|
||||
assert_eq!(value, ITEMS[0].1);
|
||||
assert_eq!(map.len(), ITEMS.len() - 1);
|
||||
assert_eq!(map.trace_len(), 1);
|
||||
assert_eq!(map.updates_len(), 1);
|
||||
|
||||
// add the item back and then remove it again
|
||||
let key = 0;
|
||||
let value = 0;
|
||||
map.insert(key, value);
|
||||
let value = map.remove(&key).unwrap();
|
||||
assert_eq!(value, 0);
|
||||
assert_eq!(map.len(), ITEMS.len() - 1);
|
||||
assert_eq!(map.trace_len(), 1);
|
||||
assert_eq!(map.updates_len(), 1);
|
||||
|
||||
// remove an item that does not exist
|
||||
let key = 100;
|
||||
let value = map.remove(&key);
|
||||
assert_eq!(value, None);
|
||||
assert_eq!(map.len(), ITEMS.len() - 1);
|
||||
assert_eq!(map.trace_len(), 1);
|
||||
assert_eq!(map.updates_len(), 1);
|
||||
|
||||
// insert a new item and then remove it
|
||||
let key = 100;
|
||||
let value = 100;
|
||||
map.insert(key, value);
|
||||
let value = map.remove(&key).unwrap();
|
||||
assert_eq!(value, 100);
|
||||
assert_eq!(map.len(), ITEMS.len() - 1);
|
||||
assert_eq!(map.trace_len(), 1);
|
||||
assert_eq!(map.updates_len(), 2);
|
||||
|
||||
// convert the map into a proof
|
||||
let (_, proof) = map.finalize();
|
||||
|
||||
// check that the proof contains the expected values
|
||||
for (key, value) in ITEMS.iter() {
|
||||
match key {
|
||||
0 => assert_eq!(proof.get(key), Some(value)),
|
||||
_ => assert_eq!(proof.get(key), None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kv_map_diff() {
|
||||
let mut initial_state = ITEMS.into_iter().collect::<BTreeMap<_, _>>();
|
||||
let mut map = RecordingMap::new(initial_state.clone());
|
||||
|
||||
// remove an item that exists
|
||||
let key = 0;
|
||||
let _value = map.remove(&key).unwrap();
|
||||
|
||||
// add a new item
|
||||
let key = 100;
|
||||
let value = 100;
|
||||
map.insert(key, value);
|
||||
|
||||
// update an existing item
|
||||
let key = 1;
|
||||
let value = 100;
|
||||
map.insert(key, value);
|
||||
|
||||
// compute a diff
|
||||
let diff = initial_state.diff(map.inner());
|
||||
assert!(diff.updated.len() == 2);
|
||||
assert!(diff.updated.iter().all(|(k, v)| [(100, 100), (1, 100)].contains(&(*k, *v))));
|
||||
assert!(diff.removed.len() == 1);
|
||||
assert!(diff.removed.first() == Some(&0));
|
||||
|
||||
// apply the diff to the initial state and assert the contents are the same as the map
|
||||
initial_state.apply(diff);
|
||||
assert!(initial_state.iter().eq(map.iter()));
|
||||
}
|
||||
}
|
||||
113
src/utils/mod.rs
Normal file
113
src/utils/mod.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
//! Utilities used in this crate which can also be generally useful downstream.
|
||||
|
||||
use super::{utils::string::String, Word};
|
||||
use core::fmt::{self, Display, Write};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
pub use alloc::{format, vec};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
pub use std::{format, vec};
|
||||
|
||||
mod diff;
|
||||
mod kv_map;
|
||||
|
||||
// RE-EXPORTS
|
||||
// ================================================================================================
|
||||
pub use winter_utils::{
|
||||
string, uninit_vector, Box, ByteReader, ByteWriter, Deserializable, DeserializationError,
|
||||
Serializable, SliceReader,
|
||||
};
|
||||
|
||||
pub mod collections {
|
||||
pub use super::diff::*;
|
||||
pub use super::kv_map::*;
|
||||
pub use winter_utils::collections::*;
|
||||
}
|
||||
|
||||
// UTILITY FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Converts a [Word] into hex.
|
||||
pub fn word_to_hex(w: &Word) -> Result<String, fmt::Error> {
|
||||
let mut s = String::new();
|
||||
|
||||
for byte in w.iter().flat_map(|e| e.to_bytes()) {
|
||||
write!(s, "{byte:02x}")?;
|
||||
}
|
||||
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
/// Renders an array of bytes as hex into a String.
|
||||
pub fn bytes_to_hex_string<const N: usize>(data: [u8; N]) -> String {
|
||||
let mut s = String::with_capacity(N + 2);
|
||||
|
||||
s.push_str("0x");
|
||||
for byte in data.iter() {
|
||||
write!(s, "{byte:02x}").expect("formatting hex failed");
|
||||
}
|
||||
|
||||
s
|
||||
}
|
||||
|
||||
/// Defines errors which can occur during parsing of hexadecimal strings.
|
||||
#[derive(Debug)]
|
||||
pub enum HexParseError {
|
||||
InvalidLength { expected: usize, actual: usize },
|
||||
MissingPrefix,
|
||||
InvalidChar,
|
||||
OutOfRange,
|
||||
}
|
||||
|
||||
impl Display for HexParseError {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
match self {
|
||||
HexParseError::InvalidLength { expected, actual } => {
|
||||
write!(f, "Hex encoded RpoDigest must have length 66, including the 0x prefix. expected {expected} got {actual}")
|
||||
}
|
||||
HexParseError::MissingPrefix => {
|
||||
write!(f, "Hex encoded RpoDigest must start with 0x prefix")
|
||||
}
|
||||
HexParseError::InvalidChar => {
|
||||
write!(f, "Hex encoded RpoDigest must contain characters [a-zA-Z0-9]")
|
||||
}
|
||||
HexParseError::OutOfRange => {
|
||||
write!(f, "Hex encoded values of an RpoDigest must be inside the field modulus")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for HexParseError {}
|
||||
|
||||
/// Parses a hex string into an array of bytes of known size.
|
||||
pub fn hex_to_bytes<const N: usize>(value: &str) -> Result<[u8; N], HexParseError> {
|
||||
let expected: usize = (N * 2) + 2;
|
||||
if value.len() != expected {
|
||||
return Err(HexParseError::InvalidLength { expected, actual: value.len() });
|
||||
}
|
||||
|
||||
if !value.starts_with("0x") {
|
||||
return Err(HexParseError::MissingPrefix);
|
||||
}
|
||||
|
||||
let mut data = value.bytes().skip(2).map(|v| match v {
|
||||
b'0'..=b'9' => Ok(v - b'0'),
|
||||
b'a'..=b'f' => Ok(v - b'a' + 10),
|
||||
b'A'..=b'F' => Ok(v - b'A' + 10),
|
||||
_ => Err(HexParseError::InvalidChar),
|
||||
});
|
||||
|
||||
let mut decoded = [0u8; N];
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for pos in 0..N {
|
||||
// These `unwrap` calls are okay because the length was checked above
|
||||
let high: u8 = data.next().unwrap()?;
|
||||
let low: u8 = data.next().unwrap()?;
|
||||
decoded[pos] = (high << 4) + low;
|
||||
}
|
||||
|
||||
Ok(decoded)
|
||||
}
|
||||
Reference in New Issue
Block a user