mirror of
https://github.com/arnaucube/miden-crypto.git
synced 2026-01-11 08:31:30 +01:00
Compare commits
40 Commits
v0.8.0
...
al-gkr-bas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f06fa30a9 | ||
|
|
0b074a795d | ||
|
|
862ccf54dd | ||
|
|
88bcdfd576 | ||
|
|
290894f497 | ||
|
|
4aac00884c | ||
|
|
2ef6f79656 | ||
|
|
5142e2fd31 | ||
|
|
9fb41337ec | ||
|
|
0296e05ccd | ||
|
|
499f97046d | ||
|
|
600feafe53 | ||
|
|
9d854f1fcb | ||
|
|
af76cb10d0 | ||
|
|
4758e0672f | ||
|
|
8bb080a91d | ||
|
|
e5f3b28645 | ||
|
|
29e0d07129 | ||
|
|
81a94ecbe7 | ||
|
|
223fbf887d | ||
|
|
9e77a7c9b7 | ||
|
|
894e20fe0c | ||
|
|
7ec7b06574 | ||
|
|
2499a8a2dd | ||
|
|
800994c69b | ||
|
|
26560605bf | ||
|
|
672340d0c2 | ||
|
|
8083b02aef | ||
|
|
ecb8719d45 | ||
|
|
4144f98560 | ||
|
|
c726050957 | ||
|
|
9239340888 | ||
|
|
97ee9298a4 | ||
|
|
bfae06e128 | ||
|
|
b4e2d63c10 | ||
|
|
9679329746 | ||
|
|
2bbea37dbe | ||
|
|
83000940da | ||
|
|
f44175e7a9 | ||
|
|
4cf8eebff5 |
138
.github/workflows/ci.yml
vendored
138
.github/workflows/ci.yml
vendored
@@ -4,74 +4,46 @@ on:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, reopened, synchronize]
|
||||
types: [opened, repoened, synchronize]
|
||||
|
||||
jobs:
|
||||
rustfmt:
|
||||
name: rustfmt ${{matrix.toolchain}} on ${{matrix.os}}
|
||||
runs-on: ${{matrix.os}}-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [nightly]
|
||||
os: [ubuntu]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install minimal Rust with rustfmt
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
components: rustfmt
|
||||
override: true
|
||||
- name: fmt
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: fmt
|
||||
args: --all -- --check
|
||||
|
||||
clippy:
|
||||
name: clippy ${{matrix.toolchain}} on ${{matrix.os}}
|
||||
runs-on: ${{matrix.os}}-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [nightly]
|
||||
os: [ubuntu]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install minimal Rust with clippy
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
components: clippy
|
||||
override: true
|
||||
- name: Clippy
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: clippy
|
||||
args: --all-targets -- -D clippy::all -D warnings
|
||||
- name: Clippy all features
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: clippy
|
||||
args: --all-targets --all-features -- -D clippy::all -D warnings
|
||||
|
||||
test:
|
||||
name: test ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.features}}
|
||||
build:
|
||||
name: Build ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.args}}
|
||||
runs-on: ${{matrix.os}}-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable, nightly]
|
||||
os: [ubuntu]
|
||||
features: ["--features default,serde", --no-default-features]
|
||||
timeout-minutes: 30
|
||||
target: [wasm32-unknown-unknown]
|
||||
args: [--no-default-features --target wasm32-unknown-unknown]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@main
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
override: true
|
||||
- run: rustup target add ${{matrix.target}}
|
||||
- name: Test
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: build
|
||||
args: ${{matrix.args}}
|
||||
|
||||
test:
|
||||
name: Test ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.features}}
|
||||
runs-on: ${{matrix.os}}-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable, nightly]
|
||||
os: [ubuntu]
|
||||
features: ["--features default,std,serde", --no-default-features]
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install rust
|
||||
@@ -85,49 +57,45 @@ jobs:
|
||||
command: test
|
||||
args: ${{matrix.features}}
|
||||
|
||||
no-std:
|
||||
name: build ${{matrix.toolchain}} no-std for wasm32-unknown-unknown
|
||||
clippy:
|
||||
name: Clippy with ${{matrix.features}}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable, nightly]
|
||||
features: ["--features default,std,serde", --no-default-features]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@main
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install rust
|
||||
- name: Install minimal nightly with clippy
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
profile: minimal
|
||||
toolchain: nightly
|
||||
components: clippy
|
||||
override: true
|
||||
- run: rustup target add wasm32-unknown-unknown
|
||||
- name: Build
|
||||
- name: Clippy
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: build
|
||||
args: --no-default-features --target wasm32-unknown-unknown
|
||||
command: clippy
|
||||
args: --all ${{matrix.features}} -- -D clippy::all -D warnings
|
||||
|
||||
docs:
|
||||
name: Verify the docs on ${{matrix.toolchain}}
|
||||
rustfmt:
|
||||
name: rustfmt
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install rust
|
||||
- uses: actions/checkout@main
|
||||
- name: Install minimal stable with rustfmt
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
components: rustfmt
|
||||
override: true
|
||||
- name: Check docs
|
||||
|
||||
- name: rustfmt
|
||||
uses: actions-rs/cargo@v1
|
||||
env:
|
||||
RUSTDOCFLAGS: -D warnings
|
||||
with:
|
||||
command: doc
|
||||
args: --verbose --all-features --keep-going
|
||||
command: fmt
|
||||
args: --all -- --check
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,6 +11,3 @@ Cargo.lock
|
||||
|
||||
# Generated by cmake
|
||||
cmake-build-*
|
||||
|
||||
# VS Code
|
||||
.vscode/
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
## 0.8.0 (2024-02-14)
|
||||
## 0.8.0 (TBD)
|
||||
|
||||
* Implemented the `PartialMmr` data structure (#195).
|
||||
* Updated Winterfell dependency to v0.7 (#200)
|
||||
* Implemented RPX hash function (#201).
|
||||
* Added `FeltRng` and `RpoRandomCoin` (#237).
|
||||
* Accelerated RPO/RPX hash functions using AVX512 instructions (#234).
|
||||
* Added `inner_nodes()` method to `PartialMmr` (#238).
|
||||
* Improved `PartialMmr::apply_delta()` (#242).
|
||||
* Refactored `SimpleSmt` struct (#245).
|
||||
* Replaced `TieredSmt` struct with `Smt` struct (#254, #277).
|
||||
* Updated Winterfell dependency to v0.8 (#275).
|
||||
|
||||
## 0.7.1 (2023-10-10)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ For example, a new change to the AIR crate might have the following message: `fe
|
||||
// ================================================================================
|
||||
```
|
||||
|
||||
- [Rustfmt](https://github.com/rust-lang/rustfmt) and [Clippy](https://github.com/rust-lang/rust-clippy) linting is included in CI pipeline. Anyways it's preferable to run linting locally before push:
|
||||
- [Rustfmt](https://github.com/rust-lang/rustfmt) and [Clippy](https://github.com/rust-lang/rust-clippy) linting is included in CI pipeline. Anyways it's prefferable to run linting locally before push:
|
||||
```
|
||||
cargo fix --allow-staged --allow-dirty --all-targets --all-features; cargo fmt; cargo clippy --workspace --all-targets --all-features -- -D warnings
|
||||
```
|
||||
|
||||
34
Cargo.toml
34
Cargo.toml
@@ -10,7 +10,7 @@ documentation = "https://docs.rs/miden-crypto/0.8.0"
|
||||
categories = ["cryptography", "no-std"]
|
||||
keywords = ["miden", "crypto", "hash", "merkle"]
|
||||
edition = "2021"
|
||||
rust-version = "1.75"
|
||||
rust-version = "1.73"
|
||||
|
||||
[[bin]]
|
||||
name = "miden-crypto"
|
||||
@@ -35,30 +35,26 @@ harness = false
|
||||
default = ["std"]
|
||||
executable = ["dep:clap", "dep:rand_utils", "std"]
|
||||
serde = ["dep:serde", "serde?/alloc", "winter_math/serde"]
|
||||
std = [
|
||||
"blake3/std",
|
||||
"dep:cc",
|
||||
"dep:libc",
|
||||
"winter_crypto/std",
|
||||
"winter_math/std",
|
||||
"winter_utils/std",
|
||||
]
|
||||
std = ["blake3/std", "dep:cc", "dep:libc", "winter_crypto/std", "winter_math/std", "winter_utils/std"]
|
||||
sve = ["std"]
|
||||
|
||||
[dependencies]
|
||||
blake3 = { version = "1.5", default-features = false }
|
||||
clap = { version = "4.5", features = ["derive"], optional = true }
|
||||
libc = { version = "0.2", default-features = false, optional = true }
|
||||
rand_utils = { version = "0.8", package = "winter-rand-utils", optional = true }
|
||||
serde = { version = "1.0", features = ["derive"], default-features = false, optional = true }
|
||||
winter_crypto = { version = "0.8", package = "winter-crypto", default-features = false }
|
||||
winter_math = { version = "0.8", package = "winter-math", default-features = false }
|
||||
winter_utils = { version = "0.8", package = "winter-utils", default-features = false }
|
||||
clap = { version = "4.4", features = ["derive"], optional = true }
|
||||
libc = { version = "0.2", default-features = false, optional = true }
|
||||
rand_utils = { version = "0.7", package = "winter-rand-utils", optional = true }
|
||||
serde = { version = "1.0", features = [ "derive" ], default-features = false, optional = true }
|
||||
winter_crypto = { version = "0.7", package = "winter-crypto", default-features = false }
|
||||
winter_math = { version = "0.7", package = "winter-math", default-features = false }
|
||||
winter_utils = { version = "0.7", package = "winter-utils", default-features = false }
|
||||
rayon = "1.8.0"
|
||||
rand = "0.8.4"
|
||||
rand_core = { version = "0.5", default-features = false }
|
||||
|
||||
[dev-dependencies]
|
||||
seq-macro = { version = "0.3" }
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
proptest = "1.4"
|
||||
rand_utils = { version = "0.8", package = "winter-rand-utils" }
|
||||
proptest = "1.3"
|
||||
rand_utils = { version = "0.7", package = "winter-rand-utils" }
|
||||
|
||||
[build-dependencies]
|
||||
cc = { version = "1.0", features = ["parallel"], optional = true }
|
||||
|
||||
12
README.md
12
README.md
@@ -19,7 +19,7 @@ For performance benchmarks of these hash functions and their comparison to other
|
||||
* `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64.
|
||||
* `PartialMmr`: a partial view of a Merkle mountain range structure.
|
||||
* `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values.
|
||||
* `Smt`: a Sparse Merkle tree (with compaction at depth 64), mapping 4-element keys to 4-element values.
|
||||
* `TieredSmt`: a Sparse Merkle tree (with compaction), mapping 4-element keys to 4-element values.
|
||||
|
||||
The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state.
|
||||
|
||||
@@ -46,16 +46,10 @@ Both of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/
|
||||
|
||||
To compile with `no_std`, disable default features via `--no-default-features` flag.
|
||||
|
||||
### AVX2 acceleration
|
||||
On platforms with [AVX2](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable AVX2 acceleration, the code needs to be compiled with the `avx2` target feature enabled. For example:
|
||||
```shell
|
||||
RUSTFLAGS="-C target-feature=+avx2" cargo build --release
|
||||
```
|
||||
|
||||
### SVE acceleration
|
||||
On platforms with [SVE](https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` target feature enabled. For example:
|
||||
On platforms with [SVE](https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)) support, RPO hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` feature enabled. This feature has an effect only if the platform exposes `target-feature=sve` flag. On some platforms (e.g., Graviton 3), for this flag to be set, the compilation must be done in "native" mode. For example, to enable SVE acceleration on Graviton 3, we can execute the following:
|
||||
```shell
|
||||
RUSTFLAGS="-C target-feature=+sve" cargo build --release
|
||||
RUSTFLAGS="-C target-cpu=native" cargo build --release --features sve
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
@@ -22,7 +22,6 @@ The second scenario is that of sequential hashing where we take a sequence of le
|
||||
| Apple M2 Max | 71 ns | 233 ns | 1.3 µs | 7.9 µs | 4.6 µs | 2.4 µs |
|
||||
| Amazon Graviton 3 | 108 ns | | | | 5.3 µs | 3.1 µs |
|
||||
| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 µs | 9.1 µs | 5.5 µs | |
|
||||
| AMD EPYC 9R14 | 83 ns | | | | 4.3 µs | 2.4 µs |
|
||||
| Intel Core i5-8279U | 68 ns | 536 ns | 2.0 µs | 13.6 µs | 8.5 µs | 4.4 µs |
|
||||
| Intel Xeon 8375C | 67 ns | | | | 8.2 µs | |
|
||||
|
||||
@@ -34,13 +33,11 @@ The second scenario is that of sequential hashing where we take a sequence of le
|
||||
| Apple M2 Max | 0.9 µs | 1.5 µs | 17.4 µs | 103 µs | 60 µs | 31 µs |
|
||||
| Amazon Graviton 3 | 1.4 µs | | | | 69 µs | 41 µs |
|
||||
| AMD Ryzen 9 5950X | 0.8 µs | 1.7 µs | 15.7 µs | 120 µs | 72 µs | |
|
||||
| AMD EPYC 9R14 | 0.9 µs | | | | 56 µs | 32 µs |
|
||||
| Intel Core i5-8279U | 0.9 µs | | | | 107 µs | 56 µs |
|
||||
| Intel Xeon 8375C | 0.8 µs | | | | 110 µs | |
|
||||
|
||||
Notes:
|
||||
- On Graviton 3, RPO256 and RPX256 are run with SVE acceleration enabled.
|
||||
- On AMD EPYC 9R14, RPO256 and RPX256 are run with AVX2 acceleration enabled.
|
||||
|
||||
### Instructions
|
||||
Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks for RPO and BLAKE3, clone the current repository, and from the root directory of the repo run the following:
|
||||
|
||||
@@ -32,6 +32,7 @@ fn rpo256_2to1(c: &mut Criterion) {
|
||||
|
||||
fn rpo256_sequential(c: &mut Criterion) {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.into_iter()
|
||||
.map(Felt::new)
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
@@ -44,6 +45,7 @@ fn rpo256_sequential(c: &mut Criterion) {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.into_iter()
|
||||
.map(|_| Felt::new(rand_value()))
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
@@ -78,6 +80,7 @@ fn rpx256_2to1(c: &mut Criterion) {
|
||||
|
||||
fn rpx256_sequential(c: &mut Criterion) {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.into_iter()
|
||||
.map(Felt::new)
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
@@ -90,6 +93,7 @@ fn rpx256_sequential(c: &mut Criterion) {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.into_iter()
|
||||
.map(|_| Felt::new(rand_value()))
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
@@ -125,6 +129,7 @@ fn blake3_2to1(c: &mut Criterion) {
|
||||
|
||||
fn blake3_sequential(c: &mut Criterion) {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.into_iter()
|
||||
.map(Felt::new)
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
@@ -137,6 +142,7 @@ fn blake3_sequential(c: &mut Criterion) {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.into_iter()
|
||||
.map(|_| Felt::new(rand_value()))
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
use core::mem::swap;
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use miden_crypto::{
|
||||
merkle::{LeafIndex, SimpleSmt},
|
||||
Felt, Word,
|
||||
};
|
||||
use miden_crypto::{merkle::SimpleSmt, Felt, Word};
|
||||
use rand_utils::prng_array;
|
||||
use seq_macro::seq;
|
||||
|
||||
fn smt_rpo(c: &mut Criterion) {
|
||||
// setup trees
|
||||
|
||||
let mut seed = [0u8; 32];
|
||||
let leaf = generate_word(&mut seed);
|
||||
let mut trees = vec![];
|
||||
|
||||
seq!(DEPTH in 14..=20 {
|
||||
let leaves = ((1 << DEPTH) - 1) as u64;
|
||||
for depth in 14..=20 {
|
||||
let leaves = ((1 << depth) - 1) as u64;
|
||||
for count in [1, leaves / 2, leaves] {
|
||||
let entries: Vec<_> = (0..count)
|
||||
.map(|i| {
|
||||
@@ -22,45 +18,50 @@ fn smt_rpo(c: &mut Criterion) {
|
||||
(i, word)
|
||||
})
|
||||
.collect();
|
||||
let mut tree = SimpleSmt::<DEPTH>::with_leaves(entries).unwrap();
|
||||
|
||||
// benchmark 1
|
||||
let mut insert = c.benchmark_group("smt update_leaf".to_string());
|
||||
{
|
||||
let depth = DEPTH;
|
||||
let key = count >> 2;
|
||||
insert.bench_with_input(
|
||||
format!("simple smt(depth:{depth},count:{count})"),
|
||||
&(key, leaf),
|
||||
|b, (key, leaf)| {
|
||||
b.iter(|| {
|
||||
tree.insert(black_box(LeafIndex::<DEPTH>::new(*key).unwrap()), black_box(*leaf));
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
}
|
||||
insert.finish();
|
||||
|
||||
// benchmark 2
|
||||
let mut path = c.benchmark_group("smt get_leaf_path".to_string());
|
||||
{
|
||||
let depth = DEPTH;
|
||||
let key = count >> 2;
|
||||
path.bench_with_input(
|
||||
format!("simple smt(depth:{depth},count:{count})"),
|
||||
&key,
|
||||
|b, key| {
|
||||
b.iter(|| {
|
||||
tree.open(black_box(&LeafIndex::<DEPTH>::new(*key).unwrap()));
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
}
|
||||
path.finish();
|
||||
let tree = SimpleSmt::with_leaves(depth, entries).unwrap();
|
||||
trees.push((tree, count));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let leaf = generate_word(&mut seed);
|
||||
|
||||
// benchmarks
|
||||
|
||||
let mut insert = c.benchmark_group(format!("smt update_leaf"));
|
||||
|
||||
for (tree, count) in trees.iter_mut() {
|
||||
let depth = tree.depth();
|
||||
let key = *count >> 2;
|
||||
insert.bench_with_input(
|
||||
format!("simple smt(depth:{depth},count:{count})"),
|
||||
&(key, leaf),
|
||||
|b, (key, leaf)| {
|
||||
b.iter(|| {
|
||||
tree.update_leaf(black_box(*key), black_box(*leaf)).unwrap();
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
insert.finish();
|
||||
|
||||
let mut path = c.benchmark_group(format!("smt get_leaf_path"));
|
||||
|
||||
for (tree, count) in trees.iter_mut() {
|
||||
let depth = tree.depth();
|
||||
let key = *count >> 2;
|
||||
path.bench_with_input(
|
||||
format!("simple smt(depth:{depth},count:{count})"),
|
||||
&key,
|
||||
|b, key| {
|
||||
b.iter(|| {
|
||||
tree.get_leaf_path(black_box(*key)).unwrap();
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
path.finish();
|
||||
}
|
||||
|
||||
criterion_group!(smt_group, smt_rpo);
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
|
||||
use miden_crypto::merkle::{
|
||||
DefaultMerkleStore as MerkleStore, LeafIndex, MerkleTree, NodeIndex, SimpleSmt, SMT_MAX_DEPTH,
|
||||
};
|
||||
use miden_crypto::merkle::{DefaultMerkleStore as MerkleStore, MerkleTree, NodeIndex, SimpleSmt};
|
||||
use miden_crypto::Word;
|
||||
use miden_crypto::{hash::rpo::RpoDigest, Felt};
|
||||
use rand_utils::{rand_array, rand_value};
|
||||
@@ -17,7 +15,7 @@ fn random_rpo_digest() -> RpoDigest {
|
||||
|
||||
/// Generates a random `Word`.
|
||||
fn random_word() -> Word {
|
||||
rand_array::<Felt, 4>()
|
||||
rand_array::<Felt, 4>().into()
|
||||
}
|
||||
|
||||
/// Generates an index at the specified depth in `0..range`.
|
||||
@@ -30,26 +28,26 @@ fn random_index(range: u64, depth: u8) -> NodeIndex {
|
||||
fn get_empty_leaf_simplesmt(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("get_empty_leaf_simplesmt");
|
||||
|
||||
const DEPTH: u8 = SMT_MAX_DEPTH;
|
||||
let depth = SimpleSmt::MAX_DEPTH;
|
||||
let size = u64::MAX;
|
||||
|
||||
// both SMT and the store are pre-populated with empty hashes, accessing these values is what is
|
||||
// being benchmarked here, so no values are inserted into the backends
|
||||
let smt = SimpleSmt::<DEPTH>::new().unwrap();
|
||||
let smt = SimpleSmt::new(depth).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let root = smt.root();
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", DEPTH), |b| {
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", depth), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size, DEPTH),
|
||||
|| random_index(size, depth),
|
||||
|index| black_box(smt.get_node(index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", DEPTH), |b| {
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", depth), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size, DEPTH),
|
||||
|| random_index(size, depth),
|
||||
|index| black_box(store.get_node(root, index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
@@ -106,14 +104,15 @@ fn get_leaf_simplesmt(c: &mut Criterion) {
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
|
||||
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let depth = smt.depth();
|
||||
let root = smt.root();
|
||||
let size_u64 = size as u64;
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, SMT_MAX_DEPTH),
|
||||
|| random_index(size_u64, depth),
|
||||
|index| black_box(smt.get_node(index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
@@ -121,7 +120,7 @@ fn get_leaf_simplesmt(c: &mut Criterion) {
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, SMT_MAX_DEPTH),
|
||||
|| random_index(size_u64, depth),
|
||||
|index| black_box(store.get_node(root, index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
@@ -133,18 +132,18 @@ fn get_leaf_simplesmt(c: &mut Criterion) {
|
||||
fn get_node_of_empty_simplesmt(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("get_node_of_empty_simplesmt");
|
||||
|
||||
const DEPTH: u8 = SMT_MAX_DEPTH;
|
||||
let depth = SimpleSmt::MAX_DEPTH;
|
||||
|
||||
// both SMT and the store are pre-populated with the empty hashes, accessing the internal nodes
|
||||
// of these values is what is being benchmarked here, so no values are inserted into the
|
||||
// backends.
|
||||
let smt = SimpleSmt::<DEPTH>::new().unwrap();
|
||||
let smt = SimpleSmt::new(depth).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let root = smt.root();
|
||||
let half_depth = DEPTH / 2;
|
||||
let half_depth = depth / 2;
|
||||
let half_size = 2_u64.pow(half_depth as u32);
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", DEPTH), |b| {
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", depth), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(half_size, half_depth),
|
||||
|index| black_box(smt.get_node(index)),
|
||||
@@ -152,7 +151,7 @@ fn get_node_of_empty_simplesmt(c: &mut Criterion) {
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", DEPTH), |b| {
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", depth), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(half_size, half_depth),
|
||||
|index| black_box(store.get_node(root, index)),
|
||||
@@ -213,10 +212,10 @@ fn get_node_simplesmt(c: &mut Criterion) {
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
|
||||
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let root = smt.root();
|
||||
let half_depth = SMT_MAX_DEPTH / 2;
|
||||
let half_depth = smt.depth() / 2;
|
||||
let half_size = 2_u64.pow(half_depth as u32);
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
|
||||
@@ -287,24 +286,23 @@ fn get_leaf_path_simplesmt(c: &mut Criterion) {
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
|
||||
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let depth = smt.depth();
|
||||
let root = smt.root();
|
||||
let size_u64 = size as u64;
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, SMT_MAX_DEPTH),
|
||||
|index| {
|
||||
black_box(smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(index.value()).unwrap()))
|
||||
},
|
||||
|| random_index(size_u64, depth),
|
||||
|index| black_box(smt.get_path(index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, SMT_MAX_DEPTH),
|
||||
|| random_index(size_u64, depth),
|
||||
|index| black_box(store.get_path(root, index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
@@ -354,7 +352,7 @@ fn new(c: &mut Criterion) {
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>()
|
||||
},
|
||||
|l| black_box(SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(l)),
|
||||
|l| black_box(SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, l)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
@@ -369,7 +367,7 @@ fn new(c: &mut Criterion) {
|
||||
.collect::<Vec<(u64, Word)>>()
|
||||
},
|
||||
|l| {
|
||||
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(l).unwrap();
|
||||
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, l).unwrap();
|
||||
black_box(MerkleStore::from(&smt));
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
@@ -435,17 +433,16 @@ fn update_leaf_simplesmt(c: &mut Criterion) {
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let mut smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
|
||||
let mut smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
|
||||
let mut store = MerkleStore::from(&smt);
|
||||
let depth = smt.depth();
|
||||
let root = smt.root();
|
||||
let size_u64 = size as u64;
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSMT", size), |b| {
|
||||
b.iter_batched(
|
||||
|| (rand_value::<u64>() % size_u64, random_word()),
|
||||
|(index, value)| {
|
||||
black_box(smt.insert(LeafIndex::<SMT_MAX_DEPTH>::new(index).unwrap(), value))
|
||||
},
|
||||
|(index, value)| black_box(smt.update_leaf(index, value)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
@@ -453,7 +450,7 @@ fn update_leaf_simplesmt(c: &mut Criterion) {
|
||||
let mut store_root = root;
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
|
||||
b.iter_batched(
|
||||
|| (random_index(size_u64, SMT_MAX_DEPTH), random_word()),
|
||||
|| (random_index(size_u64, depth), random_word()),
|
||||
|(index, value)| {
|
||||
// The MerkleTree automatically updates its internal root, the Store maintains
|
||||
// the old root and adds the new one. Here we update the root to have a fair
|
||||
|
||||
4
build.rs
4
build.rs
@@ -2,7 +2,7 @@ fn main() {
|
||||
#[cfg(feature = "std")]
|
||||
compile_rpo_falcon();
|
||||
|
||||
#[cfg(target_feature = "sve")]
|
||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
compile_arch_arm64_sve();
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ fn compile_rpo_falcon() {
|
||||
.compile("rpo_falcon512");
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "sve")]
|
||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
fn compile_arch_arm64_sve() {
|
||||
const RPO_SVE_PATH: &str = "arch/arm64-sve/rpo";
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ use super::{
|
||||
};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use super::{ffi, NonceBytes, NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
|
||||
use super::{ffi, NonceBytes, StarkField, NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
|
||||
|
||||
// PUBLIC KEY
|
||||
// ================================================================================================
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::{
|
||||
collections::Vec, ByteReader, ByteWriter, Deserializable, DeserializationError,
|
||||
Serializable,
|
||||
},
|
||||
Felt, Word, ZERO,
|
||||
Felt, StarkField, Word, ZERO,
|
||||
};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
@@ -39,10 +39,10 @@ const NONCE_LEN: usize = 40;
|
||||
const NONCE_ELEMENTS: usize = 8;
|
||||
|
||||
/// Public key length as a u8 vector.
|
||||
pub const PK_LEN: usize = 897;
|
||||
const PK_LEN: usize = 897;
|
||||
|
||||
/// Secret key length as a u8 vector.
|
||||
pub const SK_LEN: usize = 1281;
|
||||
const SK_LEN: usize = 1281;
|
||||
|
||||
/// Signature length as a u8 vector.
|
||||
const SIG_LEN: usize = 626;
|
||||
|
||||
@@ -4,7 +4,7 @@ use core::ops::{Add, Mul, Sub};
|
||||
// FALCON POLYNOMIAL
|
||||
// ================================================================================================
|
||||
|
||||
/// A polynomial over Z_p\[x\]/(phi) where phi := x^512 + 1
|
||||
/// A polynomial over Z_p[x]/(phi) where phi := x^512 + 1
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub struct Polynomial([u16; N]);
|
||||
|
||||
@@ -24,7 +24,7 @@ impl Polynomial {
|
||||
Self(data)
|
||||
}
|
||||
|
||||
/// Decodes raw bytes representing a public key into a polynomial in Z_p\[x\]/(phi).
|
||||
/// Decodes raw bytes representing a public key into a polynomial in Z_p[x]/(phi).
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
@@ -69,14 +69,14 @@ impl Polynomial {
|
||||
}
|
||||
}
|
||||
|
||||
/// Decodes the signature into the coefficients of a polynomial in Z_p\[x\]/(phi). It assumes
|
||||
/// Decodes the signature into the coefficients of a polynomial in Z_p[x]/(phi). It assumes
|
||||
/// that the signature has been encoded using the uncompressed format.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The signature has been encoded using a different algorithm than the reference compressed
|
||||
/// encoding algorithm.
|
||||
/// - The encoded signature polynomial is in Z_p\[x\]/(phi') where phi' = x^N' + 1 and N' != 512.
|
||||
/// - The encoded signature polynomial is in Z_p[x]/(phi') where phi' = x^N' + 1 and N' != 512.
|
||||
/// - While decoding the high bits of a coefficient, the current accumulated value of its
|
||||
/// high bits is larger than 2048.
|
||||
/// - The decoded coefficient is -0.
|
||||
@@ -149,12 +149,12 @@ impl Polynomial {
|
||||
// POLYNOMIAL OPERATIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Multiplies two polynomials over Z_p\[x\] without reducing modulo p. Given that the degrees
|
||||
/// Multiplies two polynomials over Z_p[x] without reducing modulo p. Given that the degrees
|
||||
/// of the input polynomials are less than 512 and their coefficients are less than the modulus
|
||||
/// q equal to 12289, the resulting product polynomial is guaranteed to have coefficients less
|
||||
/// than the Miden prime.
|
||||
///
|
||||
/// Note that this multiplication is not over Z_p\[x\]/(phi).
|
||||
/// Note that this multiplication is not over Z_p[x]/(phi).
|
||||
pub fn mul_modulo_p(a: &Self, b: &Self) -> [u64; 1024] {
|
||||
let mut c = [0; 2 * N];
|
||||
for i in 0..N {
|
||||
@@ -166,8 +166,8 @@ impl Polynomial {
|
||||
c
|
||||
}
|
||||
|
||||
/// Reduces a polynomial, that is the product of two polynomials over Z_p\[x\], modulo
|
||||
/// the irreducible polynomial phi. This results in an element in Z_p\[x\]/(phi).
|
||||
/// Reduces a polynomial, that is the product of two polynomials over Z_p[x], modulo
|
||||
/// the irreducible polynomial phi. This results in an element in Z_p[x]/(phi).
|
||||
pub fn reduce_negacyclic(a: &[u64; 1024]) -> Self {
|
||||
let mut c = [0; N];
|
||||
for i in 0..N {
|
||||
@@ -181,7 +181,7 @@ impl Polynomial {
|
||||
Self(c)
|
||||
}
|
||||
|
||||
/// Computes the norm squared of a polynomial in Z_p\[x\]/(phi) after normalizing its
|
||||
/// Computes the norm squared of a polynomial in Z_p[x]/(phi) after normalizing its
|
||||
/// coefficients to be in the interval (-p/2, p/2].
|
||||
pub fn sq_norm(&self) -> u64 {
|
||||
let mut res = 0;
|
||||
@@ -203,7 +203,7 @@ impl Default for Polynomial {
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiplication over Z_p\[x\]/(phi)
|
||||
/// Multiplication over Z_p[x]/(phi)
|
||||
impl Mul for Polynomial {
|
||||
type Output = Self;
|
||||
|
||||
@@ -227,7 +227,7 @@ impl Mul for Polynomial {
|
||||
}
|
||||
}
|
||||
|
||||
/// Addition over Z_p\[x\]/(phi)
|
||||
/// Addition over Z_p[x]/(phi)
|
||||
impl Add for Polynomial {
|
||||
type Output = Self;
|
||||
|
||||
@@ -239,7 +239,7 @@ impl Add for Polynomial {
|
||||
}
|
||||
}
|
||||
|
||||
/// Subtraction over Z_p\[x\]/(phi)
|
||||
/// Subtraction over Z_p[x]/(phi)
|
||||
impl Sub for Polynomial {
|
||||
type Output = Self;
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::{
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, Felt, NonceBytes, NonceElements,
|
||||
Polynomial, PublicKeyBytes, Rpo256, Serializable, SignatureBytes, Word, MODULUS, N,
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, NonceBytes, NonceElements,
|
||||
Polynomial, PublicKeyBytes, Rpo256, Serializable, SignatureBytes, StarkField, Word, MODULUS, N,
|
||||
SIG_L2_BOUND, ZERO,
|
||||
};
|
||||
use crate::utils::string::ToString;
|
||||
@@ -11,7 +11,7 @@ use core::cell::OnceCell;
|
||||
|
||||
/// An RPO Falcon512 signature over a message.
|
||||
///
|
||||
/// The signature is a pair of polynomials (s1, s2) in (Z_p\[x\]/(phi))^2, where:
|
||||
/// The signature is a pair of polynomials (s1, s2) in (Z_p[x]/(phi))^2, where:
|
||||
/// - p := 12289
|
||||
/// - phi := x^512 + 1
|
||||
/// - s1 = c - s2 * h
|
||||
@@ -41,7 +41,6 @@ use core::cell::OnceCell;
|
||||
/// 3. 625 bytes encoding the `s2` polynomial above.
|
||||
///
|
||||
/// The total size of the signature (including the extended public key) is 1563 bytes.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Signature {
|
||||
pub(super) pk: PublicKeyBytes,
|
||||
pub(super) sig: SignatureBytes,
|
||||
@@ -87,7 +86,7 @@ impl Signature {
|
||||
// HASH-TO-POINT
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a polynomial in Z_p\[x\]/(phi) representing the hash of the provided message.
|
||||
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message.
|
||||
pub fn hash_to_point(&self, message: Word) -> Polynomial {
|
||||
hash_to_point(message, &self.nonce())
|
||||
}
|
||||
@@ -134,7 +133,7 @@ impl Deserializable for Signature {
|
||||
let pk_polynomial = Polynomial::from_pub_key(&pk)
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?
|
||||
.into();
|
||||
let sig_polynomial = Polynomial::from_signature(&sig)
|
||||
let sig_polynomial = Polynomial::from_signature(&sig[41..])
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?
|
||||
.into();
|
||||
|
||||
@@ -182,9 +181,7 @@ fn decode_nonce(nonce: &NonceBytes) -> NonceElements {
|
||||
let mut result = [ZERO; 8];
|
||||
for (i, bytes) in nonce.chunks(5).enumerate() {
|
||||
buffer[..5].copy_from_slice(bytes);
|
||||
// we can safely (without overflow) create a new Felt from u64 value here since this value
|
||||
// contains at most 5 bytes
|
||||
result[i] = Felt::new(u64::from_le_bytes(buffer));
|
||||
result[i] = u64::from_le_bytes(buffer).into();
|
||||
}
|
||||
|
||||
result
|
||||
@@ -196,7 +193,7 @@ fn decode_nonce(nonce: &NonceBytes) -> NonceElements {
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use super::{
|
||||
super::{ffi::*, Felt, KeyPair},
|
||||
super::{ffi::*, Felt},
|
||||
*,
|
||||
};
|
||||
use libc::c_void;
|
||||
@@ -271,14 +268,4 @@ mod tests {
|
||||
let nonce = decode_nonce(&nonce);
|
||||
assert_eq!(res, hash_to_point(msg_felts, &nonce).inner());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialization_round_trip() {
|
||||
let key = KeyPair::new().unwrap();
|
||||
let signature = key.sign(Word::default()).unwrap();
|
||||
let serialized = signature.to_bytes();
|
||||
let deserialized = Signature::read_from_bytes(&serialized).unwrap();
|
||||
assert_eq!(signature.sig_poly(), deserialized.sig_poly());
|
||||
assert_eq!(signature.pub_key_poly(), deserialized.pub_key_poly());
|
||||
}
|
||||
}
|
||||
|
||||
977
src/gkr/circuit/mod.rs
Normal file
977
src/gkr/circuit/mod.rs
Normal file
@@ -0,0 +1,977 @@
|
||||
use alloc::sync::Arc;
|
||||
use winter_crypto::{ElementHasher, RandomCoin};
|
||||
use winter_math::fields::f64::BaseElement;
|
||||
use winter_math::FieldElement;
|
||||
|
||||
use crate::gkr::multivariate::{
|
||||
ComposedMultiLinearsOracle, EqPolynomial, GkrCompositionVanilla, MultiLinearOracle,
|
||||
};
|
||||
use crate::gkr::sumcheck::{sum_check_verify, Claim};
|
||||
|
||||
use super::multivariate::{
|
||||
gen_plain_gkr_oracle, gkr_composition_from_composition_polys, ComposedMultiLinears,
|
||||
CompositionPolynomial, MultiLinear,
|
||||
};
|
||||
use super::sumcheck::{
|
||||
sum_check_prove, sum_check_verify_and_reduce, FinalEvaluationClaim,
|
||||
PartialProof as SumcheckInstanceProof, RoundProof as SumCheckRoundProof, Witness,
|
||||
};
|
||||
|
||||
/// Layered circuit for computing a sum of fractions.
|
||||
///
|
||||
/// The circuit computes a sum of fractions based on the formula a / c + b / d = (a * d + b * c) / (c * d)
|
||||
/// which defines a "gate" ((a, b), (c, d)) --> (a * d + b * c, c * d) upon which the `FractionalSumCircuit`
|
||||
/// is built.
|
||||
/// TODO: Swap 1 and 0
|
||||
#[derive(Debug)]
|
||||
pub struct FractionalSumCircuit<E: FieldElement> {
|
||||
p_1_vec: Vec<MultiLinear<E>>,
|
||||
p_0_vec: Vec<MultiLinear<E>>,
|
||||
q_1_vec: Vec<MultiLinear<E>>,
|
||||
q_0_vec: Vec<MultiLinear<E>>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> FractionalSumCircuit<E> {
|
||||
/// Computes The values of the gates outputs for each of the layers of the fractional sum circuit.
|
||||
pub fn new_(num_den: &Vec<MultiLinear<E>>) -> Self {
|
||||
let mut p_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut p_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut q_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut q_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
|
||||
let num_layers = num_den[0].len().ilog2() as usize;
|
||||
|
||||
p_1_vec.push(num_den[0].to_owned());
|
||||
p_0_vec.push(num_den[1].to_owned());
|
||||
q_1_vec.push(num_den[2].to_owned());
|
||||
q_0_vec.push(num_den[3].to_owned());
|
||||
|
||||
for i in 0..num_layers {
|
||||
let (output_p_1, output_p_0, output_q_1, output_q_0) =
|
||||
FractionalSumCircuit::compute_layer(
|
||||
&p_1_vec[i],
|
||||
&p_0_vec[i],
|
||||
&q_1_vec[i],
|
||||
&q_0_vec[i],
|
||||
);
|
||||
p_1_vec.push(output_p_1);
|
||||
p_0_vec.push(output_p_0);
|
||||
q_1_vec.push(output_q_1);
|
||||
q_0_vec.push(output_q_0);
|
||||
}
|
||||
|
||||
FractionalSumCircuit { p_1_vec, p_0_vec, q_1_vec, q_0_vec }
|
||||
}
|
||||
|
||||
/// Compute the output values of the layer given a set of input values
|
||||
fn compute_layer(
|
||||
inp_p_1: &MultiLinear<E>,
|
||||
inp_p_0: &MultiLinear<E>,
|
||||
inp_q_1: &MultiLinear<E>,
|
||||
inp_q_0: &MultiLinear<E>,
|
||||
) -> (MultiLinear<E>, MultiLinear<E>, MultiLinear<E>, MultiLinear<E>) {
|
||||
let len = inp_q_1.len();
|
||||
let outp_p_1 = (0..len / 2)
|
||||
.map(|i| inp_p_1[i] * inp_q_0[i] + inp_p_0[i] * inp_q_1[i])
|
||||
.collect::<Vec<E>>();
|
||||
let outp_p_0 = (len / 2..len)
|
||||
.map(|i| inp_p_1[i] * inp_q_0[i] + inp_p_0[i] * inp_q_1[i])
|
||||
.collect::<Vec<E>>();
|
||||
let outp_q_1 = (0..len / 2).map(|i| inp_q_1[i] * inp_q_0[i]).collect::<Vec<E>>();
|
||||
let outp_q_0 = (len / 2..len).map(|i| inp_q_1[i] * inp_q_0[i]).collect::<Vec<E>>();
|
||||
|
||||
(
|
||||
MultiLinear::new(outp_p_1),
|
||||
MultiLinear::new(outp_p_0),
|
||||
MultiLinear::new(outp_q_1),
|
||||
MultiLinear::new(outp_q_0),
|
||||
)
|
||||
}
|
||||
|
||||
/// Computes The values of the gates outputs for each of the layers of the fractional sum circuit.
|
||||
pub fn new(poly: &MultiLinear<E>) -> Self {
|
||||
let mut p_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut p_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut q_1_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
let mut q_0_vec: Vec<MultiLinear<E>> = Vec::new();
|
||||
|
||||
let num_layers = poly.len().ilog2() as usize - 1;
|
||||
let (output_p, output_q) = poly.split(poly.len() / 2);
|
||||
let (output_p_1, output_p_0) = output_p.split(output_p.len() / 2);
|
||||
let (output_q_1, output_q_0) = output_q.split(output_q.len() / 2);
|
||||
|
||||
p_1_vec.push(output_p_1);
|
||||
p_0_vec.push(output_p_0);
|
||||
q_1_vec.push(output_q_1);
|
||||
q_0_vec.push(output_q_0);
|
||||
|
||||
for i in 0..num_layers - 1 {
|
||||
let (output_p_1, output_p_0, output_q_1, output_q_0) =
|
||||
FractionalSumCircuit::compute_layer(
|
||||
&p_1_vec[i],
|
||||
&p_0_vec[i],
|
||||
&q_1_vec[i],
|
||||
&q_0_vec[i],
|
||||
);
|
||||
p_1_vec.push(output_p_1);
|
||||
p_0_vec.push(output_p_0);
|
||||
q_1_vec.push(output_q_1);
|
||||
q_0_vec.push(output_q_0);
|
||||
}
|
||||
|
||||
FractionalSumCircuit { p_1_vec, p_0_vec, q_1_vec, q_0_vec }
|
||||
}
|
||||
|
||||
/// Given a value r, computes the evaluation of the last layer at r when interpreted as (two)
|
||||
/// multilinear polynomials.
|
||||
pub fn evaluate(&self, r: E) -> (E, E) {
|
||||
let len = self.p_1_vec.len();
|
||||
assert_eq!(self.p_1_vec[len - 1].num_variables(), 0);
|
||||
assert_eq!(self.p_0_vec[len - 1].num_variables(), 0);
|
||||
assert_eq!(self.q_1_vec[len - 1].num_variables(), 0);
|
||||
assert_eq!(self.q_0_vec[len - 1].num_variables(), 0);
|
||||
|
||||
let mut p_1 = self.p_1_vec[len - 1].clone();
|
||||
p_1.extend(&self.p_0_vec[len - 1]);
|
||||
let mut q_1 = self.q_1_vec[len - 1].clone();
|
||||
q_1.extend(&self.q_0_vec[len - 1]);
|
||||
|
||||
(p_1.evaluate(&[r]), q_1.evaluate(&[r]))
|
||||
}
|
||||
}
|
||||
|
||||
/// A proof for reducing a claim on the correctness of the output of a layer to that of:
|
||||
///
|
||||
/// 1. Correctness of a sumcheck proof on the claimed output.
|
||||
/// 2. Correctness of the evaluation of the input (to the said layer) at a random point when
|
||||
/// interpreted as multilinear polynomial.
|
||||
///
|
||||
/// The verifier will then have to work backward and:
|
||||
///
|
||||
/// 1. Verify that the sumcheck proof is valid.
|
||||
/// 2. Recurse on the (claimed evaluations) using the same approach as above.
|
||||
///
|
||||
/// Note that the following struct batches proofs for many circuits of the same type that
|
||||
/// are independent i.e., parallel.
|
||||
#[derive(Debug)]
|
||||
pub struct LayerProof<E: FieldElement> {
|
||||
pub proof: SumcheckInstanceProof<E>,
|
||||
pub claims_sum_p1: E,
|
||||
pub claims_sum_p0: E,
|
||||
pub claims_sum_q1: E,
|
||||
pub claims_sum_q0: E,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl<E: FieldElement<BaseField = BaseElement> + 'static> LayerProof<E> {
|
||||
/// Checks the validity of a `LayerProof`.
|
||||
///
|
||||
/// It first reduces the 2 claims to 1 claim using randomness and then checks that the sumcheck
|
||||
/// protocol was correctly executed.
|
||||
///
|
||||
/// The method outputs:
|
||||
///
|
||||
/// 1. A vector containing the randomness sent by the verifier throughout the course of the
|
||||
/// sum-check protocol.
|
||||
/// 2. The (claimed) evaluation of the inner polynomial (i.e., the one being summed) at the this random vector.
|
||||
/// 3. The random value used in the 2-to-1 reduction of the 2 sumchecks.
|
||||
pub fn verify_sum_check_before_last<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
&self,
|
||||
claim: (E, E),
|
||||
num_rounds: usize,
|
||||
transcript: &mut C,
|
||||
) -> ((E, Vec<E>), E) {
|
||||
// Absorb the claims
|
||||
let data = vec![claim.0, claim.1];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Squeeze challenge to reduce two sumchecks to one
|
||||
let r_sum_check: E = transcript.draw().unwrap();
|
||||
|
||||
// Run the sumcheck protocol
|
||||
|
||||
// Given r_sum_check and claim, we create a Claim with the GKR composer and then call the generic sum-check verifier
|
||||
let reduced_claim = claim.0 + claim.1 * r_sum_check;
|
||||
|
||||
// Create vanilla oracle
|
||||
let oracle = gen_plain_gkr_oracle(num_rounds, r_sum_check);
|
||||
|
||||
// Create sum-check claim
|
||||
let transformed_claim = Claim {
|
||||
sum_value: reduced_claim,
|
||||
polynomial: oracle,
|
||||
};
|
||||
|
||||
let reduced_gkr_claim =
|
||||
sum_check_verify_and_reduce(&transformed_claim, self.proof.clone(), transcript);
|
||||
|
||||
(reduced_gkr_claim, r_sum_check)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GkrClaim<E: FieldElement + 'static> {
|
||||
evaluation_point: Vec<E>,
|
||||
claimed_evaluation: (E, E),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CircuitProof<E: FieldElement + 'static> {
|
||||
pub proof: Vec<LayerProof<E>>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement<BaseField = BaseElement> + 'static> CircuitProof<E> {
|
||||
pub fn prove<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
circuit: &mut FractionalSumCircuit<E>,
|
||||
transcript: &mut C,
|
||||
) -> (Self, Vec<E>, Vec<Vec<E>>) {
|
||||
let mut proof_layers: Vec<LayerProof<E>> = Vec::new();
|
||||
let num_layers = circuit.p_0_vec.len();
|
||||
|
||||
let data = vec![
|
||||
circuit.p_1_vec[num_layers - 1][0],
|
||||
circuit.p_0_vec[num_layers - 1][0],
|
||||
circuit.q_1_vec[num_layers - 1][0],
|
||||
circuit.q_0_vec[num_layers - 1][0],
|
||||
];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Challenge to reduce p1, p0, q1, q0 to pr, qr
|
||||
let r_cord = transcript.draw().unwrap();
|
||||
|
||||
// Compute the (2-to-1 folded) claim
|
||||
let mut claim = circuit.evaluate(r_cord);
|
||||
let mut all_rand = Vec::new();
|
||||
|
||||
let mut rand = Vec::new();
|
||||
rand.push(r_cord);
|
||||
for layer_id in (0..num_layers - 1).rev() {
|
||||
let len = circuit.p_0_vec[layer_id].len();
|
||||
|
||||
// Construct the Lagrange kernel evaluated at previous GKR round randomness.
|
||||
// TODO: Treat the direction of doing sum-check more robustly.
|
||||
let mut rand_reversed = rand.clone();
|
||||
rand_reversed.reverse();
|
||||
let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations();
|
||||
let mut poly_x = MultiLinear::from_values(&eq_evals);
|
||||
assert_eq!(poly_x.len(), len);
|
||||
|
||||
let num_rounds = poly_x.len().ilog2() as usize;
|
||||
|
||||
// 1. A is a polynomial containing the evaluations `p_1`.
|
||||
// 2. B is a polynomial containing the evaluations `p_0`.
|
||||
// 3. C is a polynomial containing the evaluations `q_1`.
|
||||
// 4. D is a polynomial containing the evaluations `q_0`.
|
||||
let poly_a: &mut MultiLinear<E>;
|
||||
let poly_b: &mut MultiLinear<E>;
|
||||
let poly_c: &mut MultiLinear<E>;
|
||||
let poly_d: &mut MultiLinear<E>;
|
||||
poly_a = &mut circuit.p_1_vec[layer_id];
|
||||
poly_b = &mut circuit.p_0_vec[layer_id];
|
||||
poly_c = &mut circuit.q_1_vec[layer_id];
|
||||
poly_d = &mut circuit.q_0_vec[layer_id];
|
||||
|
||||
let poly_vec_par = (poly_a, poly_b, poly_c, poly_d, &mut poly_x);
|
||||
|
||||
// The (non-linear) polynomial combining the multilinear polynomials
|
||||
let comb_func = |a: &E, b: &E, c: &E, d: &E, x: &E, rho: &E| -> E {
|
||||
(*a * *d + *b * *c + *rho * *c * *d) * *x
|
||||
};
|
||||
|
||||
// Run the sumcheck protocol
|
||||
let (proof, rand_sumcheck, claims_sum) = sum_check_prover_gkr_before_last::<E, _, _>(
|
||||
claim,
|
||||
num_rounds,
|
||||
poly_vec_par,
|
||||
comb_func,
|
||||
transcript,
|
||||
);
|
||||
|
||||
let (claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0, _claims_eq) =
|
||||
claims_sum;
|
||||
|
||||
let data = vec![claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Produce a random challenge to condense claims into a single claim
|
||||
let r_layer = transcript.draw().unwrap();
|
||||
|
||||
claim = (
|
||||
claims_sum_p1 + r_layer * (claims_sum_p0 - claims_sum_p1),
|
||||
claims_sum_q1 + r_layer * (claims_sum_q0 - claims_sum_q1),
|
||||
);
|
||||
|
||||
// Collect the randomness used for the current layer in order to construct the random
|
||||
// point where the input multilinear polynomials were evaluated.
|
||||
let mut ext = rand_sumcheck;
|
||||
ext.push(r_layer);
|
||||
all_rand.push(rand);
|
||||
rand = ext;
|
||||
|
||||
proof_layers.push(LayerProof {
|
||||
proof,
|
||||
claims_sum_p1,
|
||||
claims_sum_p0,
|
||||
claims_sum_q1,
|
||||
claims_sum_q0,
|
||||
});
|
||||
}
|
||||
|
||||
(CircuitProof { proof: proof_layers }, rand, all_rand)
|
||||
}
|
||||
|
||||
pub fn prove_virtual_bus<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
composition_polys: Vec<Vec<Arc<dyn CompositionPolynomial<E>>>>,
|
||||
mls: &mut Vec<MultiLinear<E>>,
|
||||
transcript: &mut C,
|
||||
) -> (Vec<E>, Self, super::sumcheck::FullProof<E>) {
|
||||
let num_evaluations = 1 << mls[0].num_variables();
|
||||
|
||||
// I) Evaluate the numerators and denominators over the boolean hyper-cube
|
||||
let mut num_den: Vec<Vec<E>> = vec![vec![]; 4];
|
||||
for i in 0..num_evaluations {
|
||||
for j in 0..4 {
|
||||
let query: Vec<E> = mls.iter().map(|ml| ml[i]).collect();
|
||||
|
||||
composition_polys[j].iter().for_each(|c| {
|
||||
let evaluation = c.as_ref().evaluate(&query);
|
||||
num_den[j].push(evaluation);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// II) Evaluate the GKR fractional sum circuit
|
||||
let input: Vec<MultiLinear<E>> =
|
||||
(0..4).map(|i| MultiLinear::from_values(&num_den[i])).collect();
|
||||
let mut circuit = FractionalSumCircuit::new_(&input);
|
||||
|
||||
// III) Run the GKR prover for all layers except the last one
|
||||
let (gkr_proofs, GkrClaim { evaluation_point, claimed_evaluation }) =
|
||||
CircuitProof::prove_before_final(&mut circuit, transcript);
|
||||
|
||||
// IV) Run the sum-check prover for the last GKR layer counting backwards i.e., first layer
|
||||
// in the circuit.
|
||||
|
||||
// 1) Build the EQ polynomial (Lagrange kernel) at the randomness sampled during the previous
|
||||
// sum-check protocol run
|
||||
let mut rand_reversed = evaluation_point.clone();
|
||||
rand_reversed.reverse();
|
||||
let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations();
|
||||
let poly_x = MultiLinear::from_values(&eq_evals);
|
||||
|
||||
// 2) Add the Lagrange kernel to the list of MLs
|
||||
mls.push(poly_x);
|
||||
|
||||
// 3) Absorb the final sum-check claims and generate randomness for 2-to-1 sum-check reduction
|
||||
let data = vec![claimed_evaluation.0, claimed_evaluation.1];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
// Squeeze challenge to reduce two sumchecks to one
|
||||
let r_sum_check = transcript.draw().unwrap();
|
||||
let reduced_claim = claimed_evaluation.0 + claimed_evaluation.1 * r_sum_check;
|
||||
|
||||
// 4) Create the composed ML representing the numerators and denominators of the topmost GKR layer
|
||||
let gkr_final_composed_ml = gkr_composition_from_composition_polys(
|
||||
&composition_polys,
|
||||
r_sum_check,
|
||||
1 << mls[0].num_variables,
|
||||
);
|
||||
let composed_ml =
|
||||
ComposedMultiLinears::new(Arc::new(gkr_final_composed_ml.clone()), mls.to_vec());
|
||||
|
||||
// 5) Create the composed ML oracle. This will be used for verifying the FinalEvaluationClaim downstream
|
||||
// TODO: This should be an input to the current function.
|
||||
// TODO: Make MultiLinearOracle a variant in an enum so that it is possible to capture other types of oracles.
|
||||
// For example, shifts of polynomials, Lagrange kernels at a random point or periodic (transparent) polynomials.
|
||||
let left_num_oracle = MultiLinearOracle { id: 0 };
|
||||
let right_num_oracle = MultiLinearOracle { id: 1 };
|
||||
let left_denom_oracle = MultiLinearOracle { id: 2 };
|
||||
let right_denom_oracle = MultiLinearOracle { id: 3 };
|
||||
let eq_oracle = MultiLinearOracle { id: 4 };
|
||||
let composed_ml_oracle = ComposedMultiLinearsOracle {
|
||||
composer: (Arc::new(gkr_final_composed_ml.clone())),
|
||||
multi_linears: vec![
|
||||
eq_oracle,
|
||||
left_num_oracle,
|
||||
right_num_oracle,
|
||||
left_denom_oracle,
|
||||
right_denom_oracle,
|
||||
],
|
||||
};
|
||||
|
||||
// 6) Create the claim for the final sum-check protocol.
|
||||
let claim = Claim {
|
||||
sum_value: reduced_claim,
|
||||
polynomial: composed_ml_oracle.clone(),
|
||||
};
|
||||
|
||||
// 7) Create the witness for the sum-check claim.
|
||||
let witness = Witness { polynomial: composed_ml };
|
||||
let output = sum_check_prove(&claim, composed_ml_oracle, witness, transcript);
|
||||
|
||||
// 8) Create the claimed output of the circuit.
|
||||
let circuit_outputs = vec![
|
||||
circuit.p_1_vec.last().unwrap()[0],
|
||||
circuit.p_0_vec.last().unwrap()[0],
|
||||
circuit.q_1_vec.last().unwrap()[0],
|
||||
circuit.q_0_vec.last().unwrap()[0],
|
||||
];
|
||||
|
||||
// 9) Return:
|
||||
// 1. The claimed circuit outputs.
|
||||
// 2. GKR proofs of all circuit layers except the initial layer.
|
||||
// 3. Output of the final sum-check protocol.
|
||||
(circuit_outputs, gkr_proofs, output)
|
||||
}
|
||||
|
||||
pub fn prove_before_final<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
sum_circuits: &mut FractionalSumCircuit<E>,
|
||||
transcript: &mut C,
|
||||
) -> (Self, GkrClaim<E>) {
|
||||
let mut proof_layers: Vec<LayerProof<E>> = Vec::new();
|
||||
let num_layers = sum_circuits.p_0_vec.len();
|
||||
|
||||
let data = vec![
|
||||
sum_circuits.p_1_vec[num_layers - 1][0],
|
||||
sum_circuits.p_0_vec[num_layers - 1][0],
|
||||
sum_circuits.q_1_vec[num_layers - 1][0],
|
||||
sum_circuits.q_0_vec[num_layers - 1][0],
|
||||
];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Challenge to reduce p1, p0, q1, q0 to pr, qr
|
||||
let r_cord = transcript.draw().unwrap();
|
||||
|
||||
// Compute the (2-to-1 folded) claim
|
||||
let mut claims_to_verify = sum_circuits.evaluate(r_cord);
|
||||
let mut all_rand = Vec::new();
|
||||
|
||||
let mut rand = Vec::new();
|
||||
rand.push(r_cord);
|
||||
for layer_id in (1..num_layers - 1).rev() {
|
||||
let len = sum_circuits.p_0_vec[layer_id].len();
|
||||
|
||||
// Construct the Lagrange kernel evaluated at previous GKR round randomness.
|
||||
// TODO: Treat the direction of doing sum-check more robustly.
|
||||
let mut rand_reversed = rand.clone();
|
||||
rand_reversed.reverse();
|
||||
let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations();
|
||||
let mut poly_x = MultiLinear::from_values(&eq_evals);
|
||||
assert_eq!(poly_x.len(), len);
|
||||
|
||||
let num_rounds = poly_x.len().ilog2() as usize;
|
||||
|
||||
// 1. A is a polynomial containing the evaluations `p_1`.
|
||||
// 2. B is a polynomial containing the evaluations `p_0`.
|
||||
// 3. C is a polynomial containing the evaluations `q_1`.
|
||||
// 4. D is a polynomial containing the evaluations `q_0`.
|
||||
let poly_a: &mut MultiLinear<E>;
|
||||
let poly_b: &mut MultiLinear<E>;
|
||||
let poly_c: &mut MultiLinear<E>;
|
||||
let poly_d: &mut MultiLinear<E>;
|
||||
poly_a = &mut sum_circuits.p_1_vec[layer_id];
|
||||
poly_b = &mut sum_circuits.p_0_vec[layer_id];
|
||||
poly_c = &mut sum_circuits.q_1_vec[layer_id];
|
||||
poly_d = &mut sum_circuits.q_0_vec[layer_id];
|
||||
|
||||
let poly_vec = (poly_a, poly_b, poly_c, poly_d, &mut poly_x);
|
||||
|
||||
let claim = claims_to_verify;
|
||||
|
||||
// The (non-linear) polynomial combining the multilinear polynomials
|
||||
let comb_func = |a: &E, b: &E, c: &E, d: &E, x: &E, rho: &E| -> E {
|
||||
(*a * *d + *b * *c + *rho * *c * *d) * *x
|
||||
};
|
||||
|
||||
// Run the sumcheck protocol
|
||||
let (proof, rand_sumcheck, claims_sum) = sum_check_prover_gkr_before_last::<E, _, _>(
|
||||
claim, num_rounds, poly_vec, comb_func, transcript,
|
||||
);
|
||||
|
||||
let (claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0, _claims_eq) =
|
||||
claims_sum;
|
||||
|
||||
let data = vec![claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Produce a random challenge to condense claims into a single claim
|
||||
let r_layer = transcript.draw().unwrap();
|
||||
|
||||
claims_to_verify = (
|
||||
claims_sum_p1 + r_layer * (claims_sum_p0 - claims_sum_p1),
|
||||
claims_sum_q1 + r_layer * (claims_sum_q0 - claims_sum_q1),
|
||||
);
|
||||
|
||||
// Collect the randomness used for the current layer in order to construct the random
|
||||
// point where the input multilinear polynomials were evaluated.
|
||||
let mut ext = rand_sumcheck;
|
||||
ext.push(r_layer);
|
||||
all_rand.push(rand);
|
||||
rand = ext;
|
||||
|
||||
proof_layers.push(LayerProof {
|
||||
proof,
|
||||
claims_sum_p1,
|
||||
claims_sum_p0,
|
||||
claims_sum_q1,
|
||||
claims_sum_q0,
|
||||
});
|
||||
}
|
||||
let gkr_claim = GkrClaim {
|
||||
evaluation_point: rand.clone(),
|
||||
claimed_evaluation: claims_to_verify,
|
||||
};
|
||||
|
||||
(CircuitProof { proof: proof_layers }, gkr_claim)
|
||||
}
|
||||
|
||||
pub fn verify<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
&self,
|
||||
claims_sum_vec: &[E],
|
||||
transcript: &mut C,
|
||||
) -> ((E, E), Vec<E>) {
|
||||
let num_layers = self.proof.len() as usize - 1;
|
||||
let mut rand: Vec<E> = Vec::new();
|
||||
|
||||
let data = claims_sum_vec;
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
let r_cord = transcript.draw().unwrap();
|
||||
|
||||
let p_poly_coef = vec![claims_sum_vec[0], claims_sum_vec[1]];
|
||||
let q_poly_coef = vec![claims_sum_vec[2], claims_sum_vec[3]];
|
||||
|
||||
let p_poly = MultiLinear::new(p_poly_coef);
|
||||
let q_poly = MultiLinear::new(q_poly_coef);
|
||||
let p_eval = p_poly.evaluate(&[r_cord]);
|
||||
let q_eval = q_poly.evaluate(&[r_cord]);
|
||||
|
||||
let mut reduced_claim = (p_eval, q_eval);
|
||||
|
||||
rand.push(r_cord);
|
||||
for (num_rounds, i) in (0..num_layers).enumerate() {
|
||||
let ((claim_last, rand_sumcheck), r_two_sumchecks) = self.proof[i]
|
||||
.verify_sum_check_before_last::<_, _>(reduced_claim, num_rounds + 1, transcript);
|
||||
|
||||
let claims_sum_p1 = &self.proof[i].claims_sum_p1;
|
||||
let claims_sum_p0 = &self.proof[i].claims_sum_p0;
|
||||
let claims_sum_q1 = &self.proof[i].claims_sum_q1;
|
||||
let claims_sum_q0 = &self.proof[i].claims_sum_q0;
|
||||
|
||||
let data = vec![
|
||||
claims_sum_p1.clone(),
|
||||
claims_sum_p0.clone(),
|
||||
claims_sum_q1.clone(),
|
||||
claims_sum_q0.clone(),
|
||||
];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
assert_eq!(rand.len(), rand_sumcheck.len());
|
||||
|
||||
let eq: E = (0..rand.len())
|
||||
.map(|i| {
|
||||
rand[i] * rand_sumcheck[i] + (E::ONE - rand[i]) * (E::ONE - rand_sumcheck[i])
|
||||
})
|
||||
.fold(E::ONE, |acc, term| acc * term);
|
||||
|
||||
let claim_expected: E = (*claims_sum_p1 * *claims_sum_q0
|
||||
+ *claims_sum_p0 * *claims_sum_q1
|
||||
+ r_two_sumchecks * *claims_sum_q1 * *claims_sum_q0)
|
||||
* eq;
|
||||
|
||||
assert_eq!(claim_expected, claim_last);
|
||||
|
||||
// Produce a random challenge to condense claims into a single claim
|
||||
let r_layer = transcript.draw().unwrap();
|
||||
|
||||
reduced_claim = (
|
||||
*claims_sum_p1 + r_layer * (*claims_sum_p0 - *claims_sum_p1),
|
||||
*claims_sum_q1 + r_layer * (*claims_sum_q0 - *claims_sum_q1),
|
||||
);
|
||||
|
||||
// Collect the randomness' used for the current layer in order to construct the random
|
||||
// point where the input multilinear polynomials were evaluated.
|
||||
let mut ext = rand_sumcheck;
|
||||
ext.push(r_layer);
|
||||
rand = ext;
|
||||
}
|
||||
(reduced_claim, rand)
|
||||
}
|
||||
|
||||
pub fn verify_virtual_bus<
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
&self,
|
||||
composition_polys: Vec<Vec<Arc<dyn CompositionPolynomial<E>>>>,
|
||||
final_layer_proof: super::sumcheck::FullProof<E>,
|
||||
claims_sum_vec: &[E],
|
||||
transcript: &mut C,
|
||||
) -> (FinalEvaluationClaim<E>, Vec<E>) {
|
||||
let num_layers = self.proof.len() as usize;
|
||||
let mut rand: Vec<E> = Vec::new();
|
||||
|
||||
// Check that a/b + d/e is equal to 0
|
||||
assert_ne!(claims_sum_vec[2], E::ZERO);
|
||||
assert_ne!(claims_sum_vec[3], E::ZERO);
|
||||
assert_eq!(
|
||||
claims_sum_vec[0] * claims_sum_vec[3] + claims_sum_vec[1] * claims_sum_vec[2],
|
||||
E::ZERO
|
||||
);
|
||||
|
||||
let data = claims_sum_vec;
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
let r_cord = transcript.draw().unwrap();
|
||||
|
||||
let p_poly_coef = vec![claims_sum_vec[0], claims_sum_vec[1]];
|
||||
let q_poly_coef = vec![claims_sum_vec[2], claims_sum_vec[3]];
|
||||
|
||||
let p_poly = MultiLinear::new(p_poly_coef);
|
||||
let q_poly = MultiLinear::new(q_poly_coef);
|
||||
let p_eval = p_poly.evaluate(&[r_cord]);
|
||||
let q_eval = q_poly.evaluate(&[r_cord]);
|
||||
|
||||
let mut reduced_claim = (p_eval, q_eval);
|
||||
|
||||
// I) Verify all GKR layers but for the last one counting backwards.
|
||||
rand.push(r_cord);
|
||||
for (num_rounds, i) in (0..num_layers).enumerate() {
|
||||
let ((claim_last, rand_sumcheck), r_two_sumchecks) = self.proof[i]
|
||||
.verify_sum_check_before_last::<_, _>(reduced_claim, num_rounds + 1, transcript);
|
||||
|
||||
let claims_sum_p1 = &self.proof[i].claims_sum_p1;
|
||||
let claims_sum_p0 = &self.proof[i].claims_sum_p0;
|
||||
let claims_sum_q1 = &self.proof[i].claims_sum_q1;
|
||||
let claims_sum_q0 = &self.proof[i].claims_sum_q0;
|
||||
|
||||
let data = vec![
|
||||
claims_sum_p1.clone(),
|
||||
claims_sum_p0.clone(),
|
||||
claims_sum_q1.clone(),
|
||||
claims_sum_q0.clone(),
|
||||
];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
assert_eq!(rand.len(), rand_sumcheck.len());
|
||||
|
||||
let eq: E = (0..rand.len())
|
||||
.map(|i| {
|
||||
rand[i] * rand_sumcheck[i] + (E::ONE - rand[i]) * (E::ONE - rand_sumcheck[i])
|
||||
})
|
||||
.fold(E::ONE, |acc, term| acc * term);
|
||||
|
||||
let claim_expected: E = (*claims_sum_p1 * *claims_sum_q0
|
||||
+ *claims_sum_p0 * *claims_sum_q1
|
||||
+ r_two_sumchecks * *claims_sum_q1 * *claims_sum_q0)
|
||||
* eq;
|
||||
|
||||
assert_eq!(claim_expected, claim_last);
|
||||
|
||||
// Produce a random challenge to condense claims into a single claim
|
||||
let r_layer = transcript.draw().unwrap();
|
||||
|
||||
reduced_claim = (
|
||||
*claims_sum_p1 + r_layer * (*claims_sum_p0 - *claims_sum_p1),
|
||||
*claims_sum_q1 + r_layer * (*claims_sum_q0 - *claims_sum_q1),
|
||||
);
|
||||
|
||||
let mut ext = rand_sumcheck;
|
||||
ext.push(r_layer);
|
||||
rand = ext;
|
||||
}
|
||||
|
||||
// II) Verify the final GKR layer counting backwards.
|
||||
|
||||
// Absorb the claims
|
||||
let data = vec![reduced_claim.0, reduced_claim.1];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Squeeze challenge to reduce two sumchecks to one
|
||||
let r_sum_check = transcript.draw().unwrap();
|
||||
let reduced_claim = reduced_claim.0 + reduced_claim.1 * r_sum_check;
|
||||
|
||||
let gkr_final_composed_ml = gkr_composition_from_composition_polys(
|
||||
&composition_polys,
|
||||
r_sum_check,
|
||||
1 << (num_layers + 1),
|
||||
);
|
||||
|
||||
// TODO: refactor
|
||||
let composed_ml_oracle = {
|
||||
let left_num_oracle = MultiLinearOracle { id: 0 };
|
||||
let right_num_oracle = MultiLinearOracle { id: 1 };
|
||||
let left_denom_oracle = MultiLinearOracle { id: 2 };
|
||||
let right_denom_oracle = MultiLinearOracle { id: 3 };
|
||||
let eq_oracle = MultiLinearOracle { id: 4 };
|
||||
ComposedMultiLinearsOracle {
|
||||
composer: (Arc::new(gkr_final_composed_ml.clone())),
|
||||
multi_linears: vec![
|
||||
eq_oracle,
|
||||
left_num_oracle,
|
||||
right_num_oracle,
|
||||
left_denom_oracle,
|
||||
right_denom_oracle,
|
||||
],
|
||||
}
|
||||
};
|
||||
|
||||
let claim = Claim {
|
||||
sum_value: reduced_claim,
|
||||
polynomial: composed_ml_oracle.clone(),
|
||||
};
|
||||
|
||||
let final_eval_claim = sum_check_verify(&claim, final_layer_proof, transcript);
|
||||
|
||||
(final_eval_claim, rand)
|
||||
}
|
||||
}
|
||||
|
||||
fn sum_check_prover_gkr_before_last<
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
claim: (E, E),
|
||||
num_rounds: usize,
|
||||
ml_polys: (
|
||||
&mut MultiLinear<E>,
|
||||
&mut MultiLinear<E>,
|
||||
&mut MultiLinear<E>,
|
||||
&mut MultiLinear<E>,
|
||||
&mut MultiLinear<E>,
|
||||
),
|
||||
comb_func: impl Fn(&E, &E, &E, &E, &E, &E) -> E,
|
||||
transcript: &mut C,
|
||||
) -> (SumcheckInstanceProof<E>, Vec<E>, (E, E, E, E, E)) {
|
||||
// Absorb the claims
|
||||
let data = vec![claim.0, claim.1];
|
||||
transcript.reseed(H::hash_elements(&data));
|
||||
|
||||
// Squeeze challenge to reduce two sumchecks to one
|
||||
let r_sum_check = transcript.draw().unwrap();
|
||||
|
||||
let (poly_a, poly_b, poly_c, poly_d, poly_x) = ml_polys;
|
||||
|
||||
let mut e = claim.0 + claim.1 * r_sum_check;
|
||||
|
||||
let mut r: Vec<E> = Vec::new();
|
||||
let mut round_proofs: Vec<SumCheckRoundProof<E>> = Vec::new();
|
||||
|
||||
for _j in 0..num_rounds {
|
||||
let evals: (E, E, E) = {
|
||||
let mut eval_point_0 = E::ZERO;
|
||||
let mut eval_point_2 = E::ZERO;
|
||||
let mut eval_point_3 = E::ZERO;
|
||||
|
||||
let len = poly_a.len() / 2;
|
||||
for i in 0..len {
|
||||
// The interpolation formula for a linear function is:
|
||||
// z * A(x) + (1 - z) * A (y)
|
||||
// z * A(1) + (1 - z) * A(0)
|
||||
|
||||
// eval at z = 0: A(1)
|
||||
eval_point_0 += comb_func(
|
||||
&poly_a[i << 1],
|
||||
&poly_b[i << 1],
|
||||
&poly_c[i << 1],
|
||||
&poly_d[i << 1],
|
||||
&poly_x[i << 1],
|
||||
&r_sum_check,
|
||||
);
|
||||
|
||||
let poly_a_u = poly_a[(i << 1) + 1];
|
||||
let poly_a_v = poly_a[i << 1];
|
||||
let poly_b_u = poly_b[(i << 1) + 1];
|
||||
let poly_b_v = poly_b[i << 1];
|
||||
let poly_c_u = poly_c[(i << 1) + 1];
|
||||
let poly_c_v = poly_c[i << 1];
|
||||
let poly_d_u = poly_d[(i << 1) + 1];
|
||||
let poly_d_v = poly_d[i << 1];
|
||||
let poly_x_u = poly_x[(i << 1) + 1];
|
||||
let poly_x_v = poly_x[i << 1];
|
||||
|
||||
// eval at z = 2: 2 * A(1) - A(0)
|
||||
let poly_a_extrapolated_point = poly_a_u + poly_a_u - poly_a_v;
|
||||
let poly_b_extrapolated_point = poly_b_u + poly_b_u - poly_b_v;
|
||||
let poly_c_extrapolated_point = poly_c_u + poly_c_u - poly_c_v;
|
||||
let poly_d_extrapolated_point = poly_d_u + poly_d_u - poly_d_v;
|
||||
let poly_x_extrapolated_point = poly_x_u + poly_x_u - poly_x_v;
|
||||
eval_point_2 += comb_func(
|
||||
&poly_a_extrapolated_point,
|
||||
&poly_b_extrapolated_point,
|
||||
&poly_c_extrapolated_point,
|
||||
&poly_d_extrapolated_point,
|
||||
&poly_x_extrapolated_point,
|
||||
&r_sum_check,
|
||||
);
|
||||
|
||||
// eval at z = 3: 3 * A(1) - 2 * A(0) = 2 * A(1) - A(0) + A(1) - A(0)
|
||||
// hence we can compute the evaluation at z + 1 from that of z for z > 1
|
||||
let poly_a_extrapolated_point = poly_a_extrapolated_point + poly_a_u - poly_a_v;
|
||||
let poly_b_extrapolated_point = poly_b_extrapolated_point + poly_b_u - poly_b_v;
|
||||
let poly_c_extrapolated_point = poly_c_extrapolated_point + poly_c_u - poly_c_v;
|
||||
let poly_d_extrapolated_point = poly_d_extrapolated_point + poly_d_u - poly_d_v;
|
||||
let poly_x_extrapolated_point = poly_x_extrapolated_point + poly_x_u - poly_x_v;
|
||||
|
||||
eval_point_3 += comb_func(
|
||||
&poly_a_extrapolated_point,
|
||||
&poly_b_extrapolated_point,
|
||||
&poly_c_extrapolated_point,
|
||||
&poly_d_extrapolated_point,
|
||||
&poly_x_extrapolated_point,
|
||||
&r_sum_check,
|
||||
);
|
||||
}
|
||||
|
||||
(eval_point_0, eval_point_2, eval_point_3)
|
||||
};
|
||||
|
||||
let eval_0 = evals.0;
|
||||
let eval_2 = evals.1;
|
||||
let eval_3 = evals.2;
|
||||
|
||||
let evals = vec![e - eval_0, eval_2, eval_3];
|
||||
let compressed_poly = SumCheckRoundProof { poly_evals: evals };
|
||||
|
||||
// append the prover's message to the transcript
|
||||
transcript.reseed(H::hash_elements(&compressed_poly.poly_evals));
|
||||
|
||||
// derive the verifier's challenge for the next round
|
||||
let r_j = transcript.draw().unwrap();
|
||||
r.push(r_j);
|
||||
|
||||
poly_a.bind_assign(r_j);
|
||||
poly_b.bind_assign(r_j);
|
||||
poly_c.bind_assign(r_j);
|
||||
poly_d.bind_assign(r_j);
|
||||
|
||||
poly_x.bind_assign(r_j);
|
||||
|
||||
e = compressed_poly.evaluate(e, r_j);
|
||||
|
||||
round_proofs.push(compressed_poly);
|
||||
}
|
||||
let claims_sum = (poly_a[0], poly_b[0], poly_c[0], poly_d[0], poly_x[0]);
|
||||
|
||||
(SumcheckInstanceProof { round_proofs }, r, claims_sum)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod sum_circuit_tests {
|
||||
use crate::rand::RpoRandomCoin;
|
||||
|
||||
use super::*;
|
||||
use rand::Rng;
|
||||
use rand_utils::rand_value;
|
||||
use BaseElement as Felt;
|
||||
|
||||
/// The following tests the fractional sum circuit to check that \sum_{i = 0}^{log(m)-1} m / 2^{i} = 2 * (m - 1)
|
||||
#[test]
|
||||
fn sum_circuit_example() {
|
||||
let n = 4; // n := log(m)
|
||||
let mut inp: Vec<Felt> = (0..n).map(|_| Felt::from(1_u64 << n)).collect();
|
||||
let inp_: Vec<Felt> = (0..n).map(|i| Felt::from(1_u64 << i)).collect();
|
||||
inp.extend(inp_.iter());
|
||||
|
||||
let summation = MultiLinear::new(inp);
|
||||
|
||||
let expected_output = Felt::from(2 * ((1_u64 << n) - 1));
|
||||
|
||||
let mut circuit = FractionalSumCircuit::new(&summation);
|
||||
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
|
||||
let (proof, _evals, _) = CircuitProof::prove(&mut circuit, &mut transcript);
|
||||
|
||||
let (p1, q1) = circuit.evaluate(Felt::from(1_u8));
|
||||
let (p0, q0) = circuit.evaluate(Felt::from(0_u8));
|
||||
assert_eq!(expected_output, (p1 * q0 + q1 * p0) / (q1 * q0));
|
||||
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
let claims = vec![p0, p1, q0, q1];
|
||||
proof.verify(&claims, &mut transcript);
|
||||
}
|
||||
|
||||
// Test the fractional sum GKR in the context of LogUp.
|
||||
#[test]
|
||||
fn log_up() {
|
||||
use rand::distributions::Slice;
|
||||
|
||||
let n: usize = 16;
|
||||
let num_w: usize = 31; // This should be of the form 2^k - 1
|
||||
let rng = rand::thread_rng();
|
||||
|
||||
let t_table: Vec<u32> = (0..(1 << n)).collect();
|
||||
let mut m_table: Vec<u32> = (0..(1 << n)).map(|_| 0).collect();
|
||||
|
||||
let t_table_slice = Slice::new(&t_table).unwrap();
|
||||
|
||||
// Construct the witness columns. Uses sampling with replacement in order to have multiplicities
|
||||
// different from 1.
|
||||
let mut w_tables = Vec::new();
|
||||
for _ in 0..num_w {
|
||||
let wi_table: Vec<u32> =
|
||||
rng.clone().sample_iter(&t_table_slice).cloned().take(1 << n).collect();
|
||||
|
||||
// Construct the multiplicities
|
||||
wi_table.iter().for_each(|w| {
|
||||
m_table[*w as usize] += 1;
|
||||
});
|
||||
w_tables.push(wi_table)
|
||||
}
|
||||
|
||||
// The numerators
|
||||
let mut p: Vec<Felt> = m_table.iter().map(|m| Felt::from(*m as u32)).collect();
|
||||
p.extend((0..(num_w * (1 << n))).map(|_| Felt::from(1_u32)).collect::<Vec<Felt>>());
|
||||
|
||||
// Sample the challenge alpha to construct the denominators.
|
||||
let alpha = rand_value();
|
||||
|
||||
// Construct the denominators
|
||||
let mut q: Vec<Felt> = t_table.iter().map(|t| Felt::from(*t) - alpha).collect();
|
||||
for w_table in w_tables {
|
||||
q.extend(w_table.iter().map(|w| alpha - Felt::from(*w)).collect::<Vec<Felt>>());
|
||||
}
|
||||
|
||||
// Build the input to the fractional sum GKR circuit
|
||||
p.extend(q);
|
||||
let input = p;
|
||||
|
||||
let summation = MultiLinear::new(input);
|
||||
|
||||
let expected_output = Felt::from(0_u8);
|
||||
|
||||
let mut circuit = FractionalSumCircuit::new(&summation);
|
||||
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
|
||||
let (proof, _evals, _) = CircuitProof::prove(&mut circuit, &mut transcript);
|
||||
|
||||
let (p1, q1) = circuit.evaluate(Felt::from(1_u8));
|
||||
let (p0, q0) = circuit.evaluate(Felt::from(0_u8));
|
||||
assert_eq!(expected_output, (p1 * q0 + q1 * p0) / (q1 * q0)); // This check should be part of verification
|
||||
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
let claims = vec![p0, p1, q0, q1];
|
||||
proof.verify(&claims, &mut transcript);
|
||||
}
|
||||
}
|
||||
7
src/gkr/mod.rs
Normal file
7
src/gkr/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
#![allow(unused_imports)]
|
||||
#![allow(dead_code)]
|
||||
|
||||
mod sumcheck;
|
||||
mod multivariate;
|
||||
mod utils;
|
||||
mod circuit;
|
||||
34
src/gkr/multivariate/eq_poly.rs
Normal file
34
src/gkr/multivariate/eq_poly.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use super::FieldElement;
|
||||
|
||||
pub struct EqPolynomial<E> {
|
||||
r: Vec<E>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> EqPolynomial<E> {
|
||||
pub fn new(r: Vec<E>) -> Self {
|
||||
EqPolynomial { r }
|
||||
}
|
||||
|
||||
pub fn evaluate(&self, rho: &[E]) -> E {
|
||||
assert_eq!(self.r.len(), rho.len());
|
||||
(0..rho.len())
|
||||
.map(|i| self.r[i] * rho[i] + (E::ONE - self.r[i]) * (E::ONE - rho[i]))
|
||||
.fold(E::ONE, |acc, term| acc * term)
|
||||
}
|
||||
|
||||
pub fn evaluations(&self) -> Vec<E> {
|
||||
let nu = self.r.len();
|
||||
|
||||
let mut evals: Vec<E> = vec![E::ONE; 1 << nu];
|
||||
let mut size = 1;
|
||||
for j in 0..nu {
|
||||
size *= 2;
|
||||
for i in (0..size).rev().step_by(2) {
|
||||
let scalar = evals[i / 2];
|
||||
evals[i] = scalar * self.r[j];
|
||||
evals[i - 1] = scalar - evals[i];
|
||||
}
|
||||
}
|
||||
evals
|
||||
}
|
||||
}
|
||||
543
src/gkr/multivariate/mod.rs
Normal file
543
src/gkr/multivariate/mod.rs
Normal file
@@ -0,0 +1,543 @@
|
||||
use core::ops::Index;
|
||||
|
||||
use alloc::sync::Arc;
|
||||
use winter_math::{fields::f64::BaseElement, log2, FieldElement, StarkField};
|
||||
|
||||
mod eq_poly;
|
||||
pub use eq_poly::EqPolynomial;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MultiLinear<E: FieldElement> {
|
||||
pub num_variables: usize,
|
||||
pub evaluations: Vec<E>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> MultiLinear<E> {
|
||||
pub fn new(values: Vec<E>) -> Self {
|
||||
Self {
|
||||
num_variables: log2(values.len()) as usize,
|
||||
evaluations: values,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_values(values: &[E]) -> Self {
|
||||
Self {
|
||||
num_variables: log2(values.len()) as usize,
|
||||
evaluations: values.to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_variables(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
pub fn evaluations(&self) -> &[E] {
|
||||
&self.evaluations
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.evaluations.len()
|
||||
}
|
||||
|
||||
pub fn evaluate(&self, query: &[E]) -> E {
|
||||
let tensored_query = tensorize(query);
|
||||
inner_product(&self.evaluations, &tensored_query)
|
||||
}
|
||||
|
||||
pub fn bind(&self, round_challenge: E) -> Self {
|
||||
let mut result = vec![E::ZERO; 1 << (self.num_variables() - 1)];
|
||||
for i in 0..(1 << (self.num_variables() - 1)) {
|
||||
result[i] = self.evaluations[i << 1]
|
||||
+ round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]);
|
||||
}
|
||||
Self::from_values(&result)
|
||||
}
|
||||
|
||||
pub fn bind_assign(&mut self, round_challenge: E) {
|
||||
let mut result = vec![E::ZERO; 1 << (self.num_variables() - 1)];
|
||||
for i in 0..(1 << (self.num_variables() - 1)) {
|
||||
result[i] = self.evaluations[i << 1]
|
||||
+ round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]);
|
||||
}
|
||||
*self = Self::from_values(&result);
|
||||
}
|
||||
|
||||
pub fn split(&self, at: usize) -> (Self, Self) {
|
||||
assert!(at < self.len());
|
||||
(
|
||||
Self::new(self.evaluations[..at].to_vec()),
|
||||
Self::new(self.evaluations[at..2 * at].to_vec()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn extend(&mut self, other: &MultiLinear<E>) {
|
||||
let other_vec = other.evaluations.to_vec();
|
||||
assert_eq!(other_vec.len(), self.len());
|
||||
self.evaluations.extend(other_vec);
|
||||
self.num_variables += 1;
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: FieldElement> Index<usize> for MultiLinear<E> {
|
||||
type Output = E;
|
||||
|
||||
fn index(&self, index: usize) -> &E {
|
||||
&(self.evaluations[index])
|
||||
}
|
||||
}
|
||||
|
||||
/// A multi-variate polynomial for composing individual multi-linear polynomials
|
||||
pub trait CompositionPolynomial<E: FieldElement>: Sync + Send {
|
||||
/// The number of variables when interpreted as a multi-variate polynomial.
|
||||
fn num_variables(&self) -> usize;
|
||||
|
||||
/// Maximum degree in all variables.
|
||||
fn max_degree(&self) -> usize;
|
||||
|
||||
/// Given a query, of length equal the number of variables, evaluate [Self] at this query.
|
||||
fn evaluate(&self, query: &[E]) -> E;
|
||||
}
|
||||
|
||||
pub struct ComposedMultiLinears<E: FieldElement> {
|
||||
pub composer: Arc<dyn CompositionPolynomial<E>>,
|
||||
pub multi_linears: Vec<MultiLinear<E>>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> ComposedMultiLinears<E> {
|
||||
pub fn new(
|
||||
composer: Arc<dyn CompositionPolynomial<E>>,
|
||||
multi_linears: Vec<MultiLinear<E>>,
|
||||
) -> Self {
|
||||
Self { composer, multi_linears }
|
||||
}
|
||||
|
||||
pub fn num_ml(&self) -> usize {
|
||||
self.multi_linears.len()
|
||||
}
|
||||
|
||||
pub fn num_variables(&self) -> usize {
|
||||
self.composer.num_variables()
|
||||
}
|
||||
|
||||
pub fn num_variables_ml(&self) -> usize {
|
||||
self.multi_linears[0].num_variables
|
||||
}
|
||||
|
||||
pub fn degree(&self) -> usize {
|
||||
self.composer.max_degree()
|
||||
}
|
||||
|
||||
pub fn bind(&self, round_challenge: E) -> ComposedMultiLinears<E> {
|
||||
let result: Vec<MultiLinear<E>> =
|
||||
self.multi_linears.iter().map(|f| f.bind(round_challenge)).collect();
|
||||
|
||||
Self {
|
||||
composer: self.composer.clone(),
|
||||
multi_linears: result,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ComposedMultiLinearsOracle<E: FieldElement> {
|
||||
pub composer: Arc<dyn CompositionPolynomial<E>>,
|
||||
pub multi_linears: Vec<MultiLinearOracle>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultiLinearOracle {
|
||||
pub id: usize,
|
||||
}
|
||||
|
||||
// Composition polynomials
|
||||
|
||||
pub struct IdentityComposition {
|
||||
num_variables: usize,
|
||||
}
|
||||
|
||||
impl IdentityComposition {
|
||||
pub fn new() -> Self {
|
||||
Self { num_variables: 1 }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for IdentityComposition
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
assert_eq!(query.len(), 1);
|
||||
query[0]
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ProjectionComposition {
|
||||
coordinate: usize,
|
||||
}
|
||||
|
||||
impl ProjectionComposition {
|
||||
pub fn new(coordinate: usize) -> Self {
|
||||
Self { coordinate }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for ProjectionComposition
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
query[self.coordinate]
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LogUpDenominatorTableComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
projection_coordinate: usize,
|
||||
alpha: E,
|
||||
}
|
||||
|
||||
impl<E> LogUpDenominatorTableComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
pub fn new(projection_coordinate: usize, alpha: E) -> Self {
|
||||
Self { projection_coordinate, alpha }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for LogUpDenominatorTableComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
query[self.projection_coordinate] + self.alpha
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LogUpDenominatorWitnessComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
projection_coordinate: usize,
|
||||
alpha: E,
|
||||
}
|
||||
|
||||
impl<E> LogUpDenominatorWitnessComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
pub fn new(projection_coordinate: usize, alpha: E) -> Self {
|
||||
Self { projection_coordinate, alpha }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for LogUpDenominatorWitnessComposition<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
-(query[self.projection_coordinate] + self.alpha)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ProductComposition {
|
||||
num_variables: usize,
|
||||
}
|
||||
|
||||
impl ProductComposition {
|
||||
pub fn new(num_variables: usize) -> Self {
|
||||
Self { num_variables }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for ProductComposition
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
query.iter().fold(E::ONE, |acc, x| acc * *x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SumComposition {
|
||||
num_variables: usize,
|
||||
}
|
||||
|
||||
impl SumComposition {
|
||||
pub fn new(num_variables: usize) -> Self {
|
||||
Self { num_variables }
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for SumComposition
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
self.num_variables
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
query.iter().fold(E::ZERO, |acc, x| acc + *x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GkrCompositionVanilla<E: 'static>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
num_variables_ml: usize,
|
||||
num_variables_merge: usize,
|
||||
combining_randomness: E,
|
||||
gkr_randomness: Vec<E>,
|
||||
}
|
||||
|
||||
impl<E> GkrCompositionVanilla<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
pub fn new(
|
||||
num_variables_ml: usize,
|
||||
num_variables_merge: usize,
|
||||
combining_randomness: E,
|
||||
gkr_randomness: Vec<E>,
|
||||
) -> Self {
|
||||
Self {
|
||||
num_variables_ml,
|
||||
num_variables_merge,
|
||||
combining_randomness,
|
||||
gkr_randomness,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for GkrCompositionVanilla<E>
|
||||
where
|
||||
E: FieldElement,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
self.num_variables_ml // + TODO
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
self.num_variables_ml //TODO
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
let eval_left_numerator = query[0];
|
||||
let eval_right_numerator = query[1];
|
||||
let eval_left_denominator = query[2];
|
||||
let eval_right_denominator = query[3];
|
||||
let eq_eval = query[4];
|
||||
|
||||
eq_eval
|
||||
* ((eval_left_numerator * eval_right_denominator
|
||||
+ eval_right_numerator * eval_left_denominator)
|
||||
+ eval_left_denominator * eval_right_denominator * self.combining_randomness)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct GkrComposition<E>
|
||||
where
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
{
|
||||
pub num_variables_ml: usize,
|
||||
pub combining_randomness: E,
|
||||
|
||||
eq_composer: Arc<dyn CompositionPolynomial<E>>,
|
||||
right_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
left_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
right_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
left_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
}
|
||||
|
||||
impl<E> GkrComposition<E>
|
||||
where
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
{
|
||||
pub fn new(
|
||||
num_variables_ml: usize,
|
||||
combining_randomness: E,
|
||||
eq_composer: Arc<dyn CompositionPolynomial<E>>,
|
||||
right_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
left_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
right_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
left_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
num_variables_ml,
|
||||
combining_randomness,
|
||||
eq_composer,
|
||||
right_numerator_composer,
|
||||
left_numerator_composer,
|
||||
right_denominator_composer,
|
||||
left_denominator_composer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> CompositionPolynomial<E> for GkrComposition<E>
|
||||
where
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
{
|
||||
fn num_variables(&self) -> usize {
|
||||
self.num_variables_ml // + TODO
|
||||
}
|
||||
|
||||
fn max_degree(&self) -> usize {
|
||||
3 // TODO
|
||||
}
|
||||
|
||||
fn evaluate(&self, query: &[E]) -> E {
|
||||
let eval_right_numerator = self.right_numerator_composer[0].evaluate(query);
|
||||
let eval_left_numerator = self.left_numerator_composer[0].evaluate(query);
|
||||
let eval_right_denominator = self.right_denominator_composer[0].evaluate(query);
|
||||
let eval_left_denominator = self.left_denominator_composer[0].evaluate(query);
|
||||
let eq_eval = self.eq_composer.evaluate(query);
|
||||
|
||||
let res = eq_eval
|
||||
* ((eval_left_numerator * eval_right_denominator
|
||||
+ eval_right_numerator * eval_left_denominator)
|
||||
+ eval_left_denominator * eval_right_denominator * self.combining_randomness);
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a composed ML polynomial for the initial GKR layer from a vector of composition
|
||||
/// polynomials.
|
||||
/// The composition polynomials are divided into LeftNumerator, RightNumerator, LeftDenominator
|
||||
/// and RightDenominator.
|
||||
/// TODO: Generalize this to the case where each numerator/denominator contains more than one
|
||||
/// composition polynomial i.e., a merged composed ML polynomial.
|
||||
pub fn gkr_composition_from_composition_polys<
|
||||
E: FieldElement<BaseField = BaseElement> + 'static,
|
||||
>(
|
||||
composition_polys: &Vec<Vec<Arc<dyn CompositionPolynomial<E>>>>,
|
||||
combining_randomness: E,
|
||||
num_variables: usize,
|
||||
) -> GkrComposition<E> {
|
||||
let eq_composer = Arc::new(ProjectionComposition::new(4));
|
||||
let left_numerator = composition_polys[0].to_owned();
|
||||
let right_numerator = composition_polys[1].to_owned();
|
||||
let left_denominator = composition_polys[2].to_owned();
|
||||
let right_denominator = composition_polys[3].to_owned();
|
||||
GkrComposition::new(
|
||||
num_variables,
|
||||
combining_randomness,
|
||||
eq_composer,
|
||||
right_numerator,
|
||||
left_numerator,
|
||||
right_denominator,
|
||||
left_denominator,
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates a plain oracle for the sum-check protocol except the final one.
|
||||
pub fn gen_plain_gkr_oracle<E: FieldElement<BaseField = BaseElement> + 'static>(
|
||||
num_rounds: usize,
|
||||
r_sum_check: E,
|
||||
) -> ComposedMultiLinearsOracle<E> {
|
||||
let gkr_composer = Arc::new(GkrCompositionVanilla::new(num_rounds, 0, r_sum_check, vec![]));
|
||||
|
||||
let ml_oracles = vec![
|
||||
MultiLinearOracle { id: 0 },
|
||||
MultiLinearOracle { id: 1 },
|
||||
MultiLinearOracle { id: 2 },
|
||||
MultiLinearOracle { id: 3 },
|
||||
MultiLinearOracle { id: 4 },
|
||||
];
|
||||
|
||||
let oracle = ComposedMultiLinearsOracle {
|
||||
composer: gkr_composer,
|
||||
multi_linears: ml_oracles,
|
||||
};
|
||||
oracle
|
||||
}
|
||||
|
||||
fn to_index<E: FieldElement<BaseField = BaseElement>>(index: &[E]) -> usize {
|
||||
let res = index.iter().fold(E::ZERO, |acc, term| acc * E::ONE.double() + (*term));
|
||||
let res = res.base_element(0);
|
||||
res.as_int() as usize
|
||||
}
|
||||
|
||||
fn inner_product<E: FieldElement>(evaluations: &[E], tensored_query: &[E]) -> E {
|
||||
assert_eq!(evaluations.len(), tensored_query.len());
|
||||
evaluations
|
||||
.iter()
|
||||
.zip(tensored_query.iter())
|
||||
.fold(E::ZERO, |acc, (x_i, y_i)| acc + *x_i * *y_i)
|
||||
}
|
||||
|
||||
pub fn tensorize<E: FieldElement>(query: &[E]) -> Vec<E> {
|
||||
let nu = query.len();
|
||||
let n = 1 << nu;
|
||||
|
||||
(0..n).map(|i| lagrange_basis_eval(query, i)).collect()
|
||||
}
|
||||
|
||||
fn lagrange_basis_eval<E: FieldElement>(query: &[E], i: usize) -> E {
|
||||
query
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(j, x_j)| if i & (1 << j) == 0 { E::ONE - *x_j } else { *x_j })
|
||||
.fold(E::ONE, |acc, v| acc * v)
|
||||
}
|
||||
|
||||
pub fn compute_claim<E: FieldElement>(poly: &ComposedMultiLinears<E>) -> E {
|
||||
let cube_size = 1 << poly.num_variables_ml();
|
||||
let mut res = E::ZERO;
|
||||
|
||||
for i in 0..cube_size {
|
||||
let eval_point: Vec<E> =
|
||||
poly.multi_linears.iter().map(|poly| poly.evaluations[i]).collect();
|
||||
res += poly.composer.evaluate(&eval_point);
|
||||
}
|
||||
res
|
||||
}
|
||||
108
src/gkr/sumcheck/mod.rs
Normal file
108
src/gkr/sumcheck/mod.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use super::{
|
||||
multivariate::{ComposedMultiLinears, ComposedMultiLinearsOracle},
|
||||
utils::{barycentric_weights, evaluate_barycentric},
|
||||
};
|
||||
use winter_math::FieldElement;
|
||||
|
||||
mod prover;
|
||||
pub use prover::sum_check_prove;
|
||||
mod verifier;
|
||||
pub use verifier::{sum_check_verify, sum_check_verify_and_reduce};
|
||||
mod tests;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RoundProof<E> {
|
||||
pub poly_evals: Vec<E>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> RoundProof<E> {
|
||||
pub fn to_evals(&self, claim: E) -> Vec<E> {
|
||||
let mut result = vec![];
|
||||
|
||||
// s(0) + s(1) = claim
|
||||
let c0 = claim - self.poly_evals[0];
|
||||
|
||||
result.push(c0);
|
||||
result.extend_from_slice(&self.poly_evals);
|
||||
result
|
||||
}
|
||||
|
||||
// TODO: refactor once we move to coefficient form
|
||||
pub(crate) fn evaluate(&self, claim: E, r: E) -> E {
|
||||
let poly_evals = self.to_evals(claim);
|
||||
|
||||
let points: Vec<E> = (0..poly_evals.len()).map(|i| E::from(i as u8)).collect();
|
||||
let evalss: Vec<(E, E)> =
|
||||
points.iter().zip(poly_evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||
let weights = barycentric_weights(&evalss);
|
||||
let new_claim = evaluate_barycentric(&evalss, r, &weights);
|
||||
new_claim
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PartialProof<E> {
|
||||
pub round_proofs: Vec<RoundProof<E>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FinalEvaluationClaim<E: FieldElement> {
|
||||
pub evaluation_point: Vec<E>,
|
||||
pub claimed_evaluation: E,
|
||||
pub polynomial: ComposedMultiLinearsOracle<E>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FullProof<E: FieldElement> {
|
||||
pub sum_check_proof: PartialProof<E>,
|
||||
pub final_evaluation_claim: FinalEvaluationClaim<E>,
|
||||
}
|
||||
|
||||
pub struct Claim<E: FieldElement> {
|
||||
pub sum_value: E,
|
||||
pub polynomial: ComposedMultiLinearsOracle<E>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RoundClaim<E: FieldElement> {
|
||||
pub partial_eval_point: Vec<E>,
|
||||
pub current_claim: E,
|
||||
}
|
||||
|
||||
pub struct RoundOutput<E: FieldElement> {
|
||||
proof: PartialProof<E>,
|
||||
witness: Witness<E>,
|
||||
}
|
||||
|
||||
impl<E: FieldElement> From<Claim<E>> for RoundClaim<E> {
|
||||
fn from(value: Claim<E>) -> Self {
|
||||
Self {
|
||||
partial_eval_point: vec![],
|
||||
current_claim: value.sum_value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Witness<E: FieldElement> {
|
||||
pub(crate) polynomial: ComposedMultiLinears<E>,
|
||||
}
|
||||
|
||||
pub fn reduce_claim<E: FieldElement>(
|
||||
current_poly: RoundProof<E>,
|
||||
current_round_claim: RoundClaim<E>,
|
||||
round_challenge: E,
|
||||
) -> RoundClaim<E> {
|
||||
let poly_evals = current_poly.to_evals(current_round_claim.current_claim);
|
||||
let points: Vec<E> = (0..poly_evals.len()).map(|i| E::from(i as u8)).collect();
|
||||
let evalss: Vec<(E, E)> = points.iter().zip(poly_evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||
let weights = barycentric_weights(&evalss);
|
||||
let new_claim = evaluate_barycentric(&evalss, round_challenge, &weights);
|
||||
|
||||
let mut new_partial_eval_point = current_round_claim.partial_eval_point;
|
||||
new_partial_eval_point.push(round_challenge);
|
||||
|
||||
RoundClaim {
|
||||
partial_eval_point: new_partial_eval_point,
|
||||
current_claim: new_claim,
|
||||
}
|
||||
}
|
||||
109
src/gkr/sumcheck/prover.rs
Normal file
109
src/gkr/sumcheck/prover.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
use super::{Claim, FullProof, RoundProof, Witness};
|
||||
use crate::gkr::{
|
||||
multivariate::{ComposedMultiLinears, ComposedMultiLinearsOracle},
|
||||
sumcheck::{reduce_claim, FinalEvaluationClaim, PartialProof, RoundClaim, RoundOutput},
|
||||
};
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
||||
use winter_crypto::{ElementHasher, RandomCoin};
|
||||
use winter_math::{fields::f64::BaseElement, FieldElement};
|
||||
|
||||
pub fn sum_check_prove<
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
claim: &Claim<E>,
|
||||
oracle: ComposedMultiLinearsOracle<E>,
|
||||
witness: Witness<E>,
|
||||
coin: &mut C,
|
||||
) -> FullProof<E> {
|
||||
// Setup first round
|
||||
let mut prev_claim = RoundClaim {
|
||||
partial_eval_point: vec![],
|
||||
current_claim: claim.sum_value.clone(),
|
||||
};
|
||||
let prev_proof = PartialProof { round_proofs: vec![] };
|
||||
let num_vars = witness.polynomial.num_variables_ml();
|
||||
let prev_output = RoundOutput { proof: prev_proof, witness };
|
||||
|
||||
let mut output = sumcheck_round(prev_output);
|
||||
let poly_evals = &output.proof.round_proofs[0].poly_evals;
|
||||
coin.reseed(H::hash_elements(&poly_evals));
|
||||
|
||||
for i in 1..num_vars {
|
||||
let round_challenge = coin.draw().unwrap();
|
||||
let new_claim = reduce_claim(
|
||||
output.proof.round_proofs.last().unwrap().clone(),
|
||||
prev_claim,
|
||||
round_challenge,
|
||||
);
|
||||
output.witness.polynomial = output.witness.polynomial.bind(round_challenge);
|
||||
|
||||
output = sumcheck_round(output);
|
||||
prev_claim = new_claim;
|
||||
|
||||
let poly_evals = &output.proof.round_proofs[i].poly_evals;
|
||||
coin.reseed(H::hash_elements(&poly_evals));
|
||||
}
|
||||
|
||||
let round_challenge = coin.draw().unwrap();
|
||||
let RoundClaim { partial_eval_point, current_claim } = reduce_claim(
|
||||
output.proof.round_proofs.last().unwrap().clone(),
|
||||
prev_claim,
|
||||
round_challenge,
|
||||
);
|
||||
let final_eval_claim = FinalEvaluationClaim {
|
||||
evaluation_point: partial_eval_point,
|
||||
claimed_evaluation: current_claim,
|
||||
polynomial: oracle,
|
||||
};
|
||||
|
||||
FullProof {
|
||||
sum_check_proof: output.proof,
|
||||
final_evaluation_claim: final_eval_claim,
|
||||
}
|
||||
}
|
||||
|
||||
fn sumcheck_round<E: FieldElement>(prev_proof: RoundOutput<E>) -> RoundOutput<E> {
|
||||
let RoundOutput { mut proof, witness } = prev_proof;
|
||||
|
||||
let polynomial = witness.polynomial;
|
||||
let num_ml = polynomial.num_ml();
|
||||
let num_vars = polynomial.num_variables_ml();
|
||||
let num_rounds = num_vars - 1;
|
||||
|
||||
let mut evals_zero = vec![E::ZERO; num_ml];
|
||||
let mut evals_one = vec![E::ZERO; num_ml];
|
||||
let mut deltas = vec![E::ZERO; num_ml];
|
||||
let mut evals_x = vec![E::ZERO; num_ml];
|
||||
|
||||
let total_evals = (0..1 << num_rounds).into_iter().map(|i| {
|
||||
for (j, ml) in polynomial.multi_linears.iter().enumerate() {
|
||||
evals_zero[j] = ml.evaluations[(i << 1) as usize];
|
||||
evals_one[j] = ml.evaluations[(i << 1) + 1];
|
||||
}
|
||||
let mut total_evals = vec![E::ZERO; polynomial.degree()];
|
||||
total_evals[0] = polynomial.composer.evaluate(&evals_one);
|
||||
evals_zero
|
||||
.iter()
|
||||
.zip(evals_one.iter().zip(deltas.iter_mut().zip(evals_x.iter_mut())))
|
||||
.for_each(|(a0, (a1, (delta, evx)))| {
|
||||
*delta = *a1 - *a0;
|
||||
*evx = *a1;
|
||||
});
|
||||
total_evals.iter_mut().skip(1).for_each(|e| {
|
||||
evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| {
|
||||
*evx += *delta;
|
||||
});
|
||||
*e = polynomial.composer.evaluate(&evals_x);
|
||||
});
|
||||
total_evals
|
||||
});
|
||||
let evaluations = total_evals.fold(vec![E::ZERO; polynomial.degree()], |mut acc, evals| {
|
||||
acc.iter_mut().zip(evals.iter()).for_each(|(a, ev)| *a += *ev);
|
||||
acc
|
||||
});
|
||||
let proof_update = RoundProof { poly_evals: evaluations };
|
||||
proof.round_proofs.push(proof_update);
|
||||
RoundOutput { proof, witness: Witness { polynomial } }
|
||||
}
|
||||
199
src/gkr/sumcheck/tests.rs
Normal file
199
src/gkr/sumcheck/tests.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
use alloc::sync::Arc;
|
||||
use rand::{distributions::Uniform, SeedableRng};
|
||||
use winter_crypto::RandomCoin;
|
||||
use winter_math::{fields::f64::BaseElement, FieldElement};
|
||||
|
||||
use crate::{
|
||||
gkr::{
|
||||
circuit::{CircuitProof, FractionalSumCircuit},
|
||||
multivariate::{
|
||||
compute_claim, gkr_composition_from_composition_polys, ComposedMultiLinears,
|
||||
ComposedMultiLinearsOracle, CompositionPolynomial, EqPolynomial, GkrComposition,
|
||||
GkrCompositionVanilla, LogUpDenominatorTableComposition,
|
||||
LogUpDenominatorWitnessComposition, MultiLinear, MultiLinearOracle,
|
||||
ProjectionComposition, SumComposition,
|
||||
},
|
||||
sumcheck::{
|
||||
prover::sum_check_prove, verifier::sum_check_verify, Claim, FinalEvaluationClaim,
|
||||
FullProof, Witness,
|
||||
},
|
||||
},
|
||||
hash::rpo::Rpo256,
|
||||
rand::RpoRandomCoin,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn gkr_workflow() {
|
||||
// generate the data witness for the LogUp argument
|
||||
let mut mls = generate_logup_witness::<BaseElement>(3);
|
||||
|
||||
// the is sampled after receiving the main trace commitment
|
||||
let alpha = rand_utils::rand_value();
|
||||
|
||||
// the composition polynomials defining the numerators/denominators
|
||||
let composition_polys: Vec<Vec<Arc<dyn CompositionPolynomial<BaseElement>>>> = vec![
|
||||
// left num
|
||||
vec![Arc::new(ProjectionComposition::new(0))],
|
||||
// right num
|
||||
vec![Arc::new(ProjectionComposition::new(1))],
|
||||
// left den
|
||||
vec![Arc::new(LogUpDenominatorTableComposition::new(2, alpha))],
|
||||
// right den
|
||||
vec![Arc::new(LogUpDenominatorWitnessComposition::new(3, alpha))],
|
||||
];
|
||||
|
||||
// run the GKR prover to obtain:
|
||||
// 1. The fractional sum circuit output.
|
||||
// 2. GKR proofs up to the last circuit layer counting backwards.
|
||||
// 3. GKR proof (i.e., a sum-check proof) for the last circuit layer counting backwards.
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
let (circuit_outputs, gkr_before_last_proof, final_layer_proof) =
|
||||
CircuitProof::prove_virtual_bus(composition_polys.clone(), &mut mls, &mut transcript);
|
||||
|
||||
let seed = [BaseElement::ZERO; 4];
|
||||
let mut transcript = RpoRandomCoin::new(seed.into());
|
||||
|
||||
// run the GKR verifier to obtain:
|
||||
// 1. A final evaluation claim.
|
||||
// 2. Randomness defining the Lagrange kernel in the final sum-check protocol. Note that this
|
||||
// Lagrange kernel is different from the one used by the STARK (outer) prover to open the MLs
|
||||
// at the evaluation point.
|
||||
let (final_eval_claim, gkr_lagrange_kernel_rand) = gkr_before_last_proof.verify_virtual_bus(
|
||||
composition_polys.clone(),
|
||||
final_layer_proof,
|
||||
&circuit_outputs,
|
||||
&mut transcript,
|
||||
);
|
||||
|
||||
// the final verification step is composed of:
|
||||
// 1. Querying the oracles for the openings at the evaluation point. This will be done by the
|
||||
// (outer) STARK prover using:
|
||||
// a. The Lagrange kernel (auxiliary) column at the evaluation point.
|
||||
// b. An extra (auxiliary) column to compute an inner product between two vectors. The first
|
||||
// being the Lagrange kernel and the second being (\sum_{j=0}^3 mls[j][i] * \lambda_i)_{i\in\{0,..,n\}}
|
||||
// 2. Evaluating the composition polynomial at the previous openings and checking equality with
|
||||
// the claimed evaluation.
|
||||
|
||||
// 1. Querying the oracles
|
||||
|
||||
let FinalEvaluationClaim {
|
||||
evaluation_point,
|
||||
claimed_evaluation,
|
||||
polynomial,
|
||||
} = final_eval_claim;
|
||||
|
||||
// The evaluation of the EQ polynomial can be done by the verifier directly
|
||||
let eq = (0..gkr_lagrange_kernel_rand.len())
|
||||
.map(|i| {
|
||||
gkr_lagrange_kernel_rand[i] * evaluation_point[i]
|
||||
+ (BaseElement::ONE - gkr_lagrange_kernel_rand[i])
|
||||
* (BaseElement::ONE - evaluation_point[i])
|
||||
})
|
||||
.fold(BaseElement::ONE, |acc, term| acc * term);
|
||||
|
||||
// These are the queries to the oracles.
|
||||
// They should be provided by the prover non-deterministically
|
||||
let left_num_eval = mls[0].evaluate(&evaluation_point);
|
||||
let right_num_eval = mls[1].evaluate(&evaluation_point);
|
||||
let left_den_eval = mls[2].evaluate(&evaluation_point);
|
||||
let right_den_eval = mls[3].evaluate(&evaluation_point);
|
||||
|
||||
// The verifier absorbs the claimed openings and generates batching randomness lambda
|
||||
let mut query = vec![left_num_eval, right_num_eval, left_den_eval, right_den_eval];
|
||||
transcript.reseed(Rpo256::hash_elements(&query));
|
||||
let lambdas: Vec<BaseElement> = vec![
|
||||
transcript.draw().unwrap(),
|
||||
transcript.draw().unwrap(),
|
||||
transcript.draw().unwrap(),
|
||||
];
|
||||
let batched_query =
|
||||
query[0] + query[1] * lambdas[0] + query[2] * lambdas[1] + query[3] * lambdas[2];
|
||||
|
||||
// The prover generates the Lagrange kernel as an auxiliary column
|
||||
let mut rev_evaluation_point = evaluation_point;
|
||||
rev_evaluation_point.reverse();
|
||||
let lagrange_kernel = EqPolynomial::new(rev_evaluation_point).evaluations();
|
||||
|
||||
// The prover generates the additional auxiliary column for the inner product
|
||||
let tmp_col: Vec<BaseElement> = (0..mls[0].len())
|
||||
.map(|i| {
|
||||
mls[0][i] + mls[1][i] * lambdas[0] + mls[2][i] * lambdas[1] + mls[3][i] * lambdas[2]
|
||||
})
|
||||
.collect();
|
||||
let mut running_sum_col = vec![BaseElement::ZERO; tmp_col.len() + 1];
|
||||
running_sum_col[0] = BaseElement::ZERO;
|
||||
for i in 1..(tmp_col.len() + 1) {
|
||||
running_sum_col[i] = running_sum_col[i - 1] + tmp_col[i - 1] * lagrange_kernel[i - 1];
|
||||
}
|
||||
|
||||
// Boundary constraint to check correctness of openings
|
||||
assert_eq!(batched_query, *running_sum_col.last().unwrap());
|
||||
|
||||
// 2) Final evaluation and check
|
||||
query.push(eq);
|
||||
let verifier_computed = polynomial.composer.evaluate(&query);
|
||||
|
||||
assert_eq!(verifier_computed, claimed_evaluation);
|
||||
}
|
||||
|
||||
pub fn generate_logup_witness<E: FieldElement>(trace_len: usize) -> Vec<MultiLinear<E>> {
|
||||
let num_variables_ml = trace_len;
|
||||
let num_evaluations = 1 << num_variables_ml;
|
||||
let num_witnesses = 1;
|
||||
let (p, q) = generate_logup_data::<E>(num_variables_ml, num_witnesses);
|
||||
let numerators: Vec<Vec<E>> = p.chunks(num_evaluations).map(|x| x.into()).collect();
|
||||
let denominators: Vec<Vec<E>> = q.chunks(num_evaluations).map(|x| x.into()).collect();
|
||||
|
||||
let mut mls = vec![];
|
||||
for i in 0..2 {
|
||||
let ml = MultiLinear::from_values(&numerators[i]);
|
||||
mls.push(ml);
|
||||
}
|
||||
for i in 0..2 {
|
||||
let ml = MultiLinear::from_values(&denominators[i]);
|
||||
mls.push(ml);
|
||||
}
|
||||
mls
|
||||
}
|
||||
|
||||
pub fn generate_logup_data<E: FieldElement>(
|
||||
trace_len: usize,
|
||||
num_witnesses: usize,
|
||||
) -> (Vec<E>, Vec<E>) {
|
||||
use rand::distributions::Slice;
|
||||
use rand::Rng;
|
||||
let n: usize = trace_len;
|
||||
let num_w: usize = num_witnesses; // This should be of the form 2^k - 1
|
||||
let rng = rand::rngs::StdRng::seed_from_u64(0);
|
||||
|
||||
let t_table: Vec<u32> = (0..(1 << n)).collect();
|
||||
let mut m_table: Vec<u32> = (0..(1 << n)).map(|_| 0).collect();
|
||||
|
||||
let t_table_slice = Slice::new(&t_table).unwrap();
|
||||
|
||||
// Construct the witness columns. Uses sampling with replacement in order to have multiplicities
|
||||
// different from 1.
|
||||
let mut w_tables = Vec::new();
|
||||
for _ in 0..num_w {
|
||||
let wi_table: Vec<u32> =
|
||||
rng.clone().sample_iter(&t_table_slice).cloned().take(1 << n).collect();
|
||||
|
||||
// Construct the multiplicities
|
||||
wi_table.iter().for_each(|w| {
|
||||
m_table[*w as usize] += 1;
|
||||
});
|
||||
w_tables.push(wi_table)
|
||||
}
|
||||
|
||||
// The numerators
|
||||
let mut p: Vec<E> = m_table.iter().map(|m| E::from(*m as u32)).collect();
|
||||
p.extend((0..(num_w * (1 << n))).map(|_| E::from(1_u32)).collect::<Vec<E>>());
|
||||
|
||||
// Construct the denominators
|
||||
let mut q: Vec<E> = t_table.iter().map(|t| E::from(*t)).collect();
|
||||
for w_table in w_tables {
|
||||
q.extend(w_table.iter().map(|w| E::from(*w)).collect::<Vec<E>>());
|
||||
}
|
||||
(p, q)
|
||||
}
|
||||
71
src/gkr/sumcheck/verifier.rs
Normal file
71
src/gkr/sumcheck/verifier.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use winter_crypto::{ElementHasher, RandomCoin};
|
||||
use winter_math::{fields::f64::BaseElement, FieldElement};
|
||||
|
||||
use crate::gkr::utils::{barycentric_weights, evaluate_barycentric};
|
||||
|
||||
use super::{Claim, FinalEvaluationClaim, FullProof, PartialProof};
|
||||
|
||||
pub fn sum_check_verify_and_reduce<
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
claim: &Claim<E>,
|
||||
proofs: PartialProof<E>,
|
||||
coin: &mut C,
|
||||
) -> (E, Vec<E>) {
|
||||
let degree = 3;
|
||||
let points: Vec<E> = (0..degree + 1).map(|x| E::from(x as u8)).collect();
|
||||
let mut sum_value = claim.sum_value.clone();
|
||||
let mut randomness = vec![];
|
||||
|
||||
for proof in proofs.round_proofs {
|
||||
let partial_evals = proof.poly_evals.clone();
|
||||
coin.reseed(H::hash_elements(&partial_evals));
|
||||
|
||||
// get r
|
||||
let r: E = coin.draw().unwrap();
|
||||
randomness.push(r);
|
||||
let evals = proof.to_evals(sum_value);
|
||||
|
||||
let point_evals: Vec<_> = points.iter().zip(evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||
let weights = barycentric_weights(&point_evals);
|
||||
sum_value = evaluate_barycentric(&point_evals, r, &weights);
|
||||
}
|
||||
(sum_value, randomness)
|
||||
}
|
||||
|
||||
pub fn sum_check_verify<
|
||||
E: FieldElement<BaseField = BaseElement>,
|
||||
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
|
||||
H: ElementHasher<BaseField = BaseElement>,
|
||||
>(
|
||||
claim: &Claim<E>,
|
||||
proofs: FullProof<E>,
|
||||
coin: &mut C,
|
||||
) -> FinalEvaluationClaim<E> {
|
||||
let FullProof {
|
||||
sum_check_proof: proofs,
|
||||
final_evaluation_claim,
|
||||
} = proofs;
|
||||
let Claim { mut sum_value, polynomial } = claim;
|
||||
let degree = polynomial.composer.max_degree();
|
||||
let points: Vec<E> = (0..degree + 1).map(|x| E::from(x as u8)).collect();
|
||||
|
||||
for proof in proofs.round_proofs {
|
||||
let partial_evals = proof.poly_evals.clone();
|
||||
coin.reseed(H::hash_elements(&partial_evals));
|
||||
|
||||
// get r
|
||||
let r: E = coin.draw().unwrap();
|
||||
let evals = proof.to_evals(sum_value);
|
||||
|
||||
let point_evals: Vec<_> = points.iter().zip(evals.iter()).map(|(x, y)| (*x, *y)).collect();
|
||||
let weights = barycentric_weights(&point_evals);
|
||||
sum_value = evaluate_barycentric(&point_evals, r, &weights);
|
||||
}
|
||||
|
||||
assert_eq!(final_evaluation_claim.claimed_evaluation, sum_value);
|
||||
|
||||
final_evaluation_claim
|
||||
}
|
||||
33
src/gkr/utils/mod.rs
Normal file
33
src/gkr/utils/mod.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use winter_math::{FieldElement, batch_inversion};
|
||||
|
||||
|
||||
pub fn barycentric_weights<E: FieldElement>(points: &[(E, E)]) -> Vec<E> {
|
||||
let n = points.len();
|
||||
let tmp = (0..n)
|
||||
.map(|i| (0..n).filter(|&j| j != i).fold(E::ONE, |acc, j| acc * (points[i].0 - points[j].0)))
|
||||
.collect::<Vec<_>>();
|
||||
batch_inversion(&tmp)
|
||||
}
|
||||
|
||||
pub fn evaluate_barycentric<E: FieldElement>(
|
||||
points: &[(E, E)],
|
||||
x: E,
|
||||
barycentric_weights: &[E],
|
||||
) -> E {
|
||||
for &(x_i, y_i) in points {
|
||||
if x_i == x {
|
||||
return y_i;
|
||||
}
|
||||
}
|
||||
|
||||
let l_x: E = points.iter().fold(E::ONE, |acc, &(x_i, _y_i)| acc * (x - x_i));
|
||||
|
||||
let sum = (0..points.len()).fold(E::ZERO, |acc, i| {
|
||||
let x_i = points[i].0;
|
||||
let y_i = points[i].1;
|
||||
let w_i = barycentric_weights[i];
|
||||
acc + (w_i / (x - x_i) * y_i)
|
||||
});
|
||||
|
||||
l_x * sum
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher};
|
||||
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField};
|
||||
use crate::utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#[cfg(target_feature = "sve")]
|
||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
pub mod optimized {
|
||||
use crate::hash::rescue::STATE_WIDTH;
|
||||
use crate::Felt;
|
||||
@@ -78,7 +78,7 @@ pub mod optimized {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(target_feature = "avx2", target_feature = "sve")))]
|
||||
#[cfg(not(any(target_feature = "avx2", all(target_feature = "sve", feature = "sve"))))]
|
||||
pub mod optimized {
|
||||
use crate::hash::rescue::STATE_WIDTH;
|
||||
use crate::Felt;
|
||||
|
||||
@@ -33,11 +33,6 @@ impl RpoDigest {
|
||||
{
|
||||
digests.flat_map(|d| d.0.iter())
|
||||
}
|
||||
|
||||
/// Returns hexadecimal representation of this digest prefixed with `0x`.
|
||||
pub fn to_hex(&self) -> String {
|
||||
bytes_to_hex_string(self.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl Digest for RpoDigest {
|
||||
@@ -163,7 +158,7 @@ impl From<RpoDigest> for [u8; DIGEST_BYTES] {
|
||||
impl From<RpoDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
value.to_hex()
|
||||
bytes_to_hex_string(value.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,12 +229,15 @@ impl TryFrom<[u64; DIGEST_SIZE]> for RpoDigest {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
|
||||
Ok(Self([
|
||||
value[0].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
|
||||
value[1].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
|
||||
value[2].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
|
||||
value[3].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
|
||||
]))
|
||||
if value[0] >= Felt::MODULUS
|
||||
|| value[1] >= Felt::MODULUS
|
||||
|| value[2] >= Felt::MODULUS
|
||||
|| value[3] >= Felt::MODULUS
|
||||
{
|
||||
return Err(RpoDigestError::InvalidInteger);
|
||||
}
|
||||
|
||||
Ok(Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()]))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ mod tests;
|
||||
/// * Number of founds: 7.
|
||||
/// * S-Box degree: 7.
|
||||
///
|
||||
/// The above parameters target a 128-bit security level. The digest consists of four field elements
|
||||
/// The above parameters target 128-bit security level. The digest consists of four field elements
|
||||
/// and it can be serialized into 32 bytes (256 bits).
|
||||
///
|
||||
/// ## Hash output consistency
|
||||
@@ -55,7 +55,13 @@ mod tests;
|
||||
pub struct Rpo256();
|
||||
|
||||
impl Hasher for Rpo256 {
|
||||
/// Rpo256 collision resistance is 128-bits.
|
||||
/// Rpo256 collision resistance is the same as the security level, that is 128-bits.
|
||||
///
|
||||
/// #### Collision resistance
|
||||
///
|
||||
/// However, our setup of the capacity registers might drop it to 126.
|
||||
///
|
||||
/// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69)
|
||||
const COLLISION_RESISTANCE: u32 = 128;
|
||||
|
||||
type Digest = RpoDigest;
|
||||
|
||||
@@ -33,11 +33,6 @@ impl RpxDigest {
|
||||
{
|
||||
digests.flat_map(|d| d.0.iter())
|
||||
}
|
||||
|
||||
/// Returns hexadecimal representation of this digest prefixed with `0x`.
|
||||
pub fn to_hex(&self) -> String {
|
||||
bytes_to_hex_string(self.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl Digest for RpxDigest {
|
||||
@@ -163,7 +158,7 @@ impl From<RpxDigest> for [u8; DIGEST_BYTES] {
|
||||
impl From<RpxDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
value.to_hex()
|
||||
bytes_to_hex_string(value.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,12 +229,15 @@ impl TryFrom<[u64; DIGEST_SIZE]> for RpxDigest {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
|
||||
Ok(Self([
|
||||
value[0].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
|
||||
value[1].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
|
||||
value[2].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
|
||||
value[3].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
|
||||
]))
|
||||
if value[0] >= Felt::MODULUS
|
||||
|| value[1] >= Felt::MODULUS
|
||||
|| value[2] >= Felt::MODULUS
|
||||
|| value[3] >= Felt::MODULUS
|
||||
{
|
||||
return Err(RpxDigestError::InvalidInteger);
|
||||
}
|
||||
|
||||
Ok(Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()]))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ use super::{
|
||||
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
|
||||
apply_mds, apply_sbox, CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher,
|
||||
StarkField, ARK1, ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE,
|
||||
DIGEST_SIZE, INPUT1_RANGE, INPUT2_RANGE, MDS, NUM_ROUNDS, RATE_RANGE, RATE_WIDTH, STATE_WIDTH,
|
||||
ZERO,
|
||||
DIGEST_SIZE, INPUT1_RANGE, INPUT2_RANGE, MDS, NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH,
|
||||
STATE_WIDTH, ZERO,
|
||||
};
|
||||
use core::{convert::TryInto, ops::Range};
|
||||
|
||||
@@ -30,7 +30,7 @@ pub type CubicExtElement = CubeExtension<Felt>;
|
||||
/// - (M): `apply_mds` → `add_constants`.
|
||||
/// * Permutation: (FB) (E) (FB) (E) (FB) (E) (M).
|
||||
///
|
||||
/// The above parameters target a 128-bit security level. The digest consists of four field elements
|
||||
/// The above parameters target 128-bit security level. The digest consists of four field elements
|
||||
/// and it can be serialized into 32 bytes (256 bits).
|
||||
///
|
||||
/// ## Hash output consistency
|
||||
@@ -58,7 +58,13 @@ pub type CubicExtElement = CubeExtension<Felt>;
|
||||
pub struct Rpx256();
|
||||
|
||||
impl Hasher for Rpx256 {
|
||||
/// Rpx256 collision resistance is 128-bits.
|
||||
/// Rpx256 collision resistance is the same as the security level, that is 128-bits.
|
||||
///
|
||||
/// #### Collision resistance
|
||||
///
|
||||
/// However, our setup of the capacity registers might drop it to 126.
|
||||
///
|
||||
/// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69)
|
||||
const COLLISION_RESISTANCE: u32 = 128;
|
||||
|
||||
type Digest = RpxDigest;
|
||||
@@ -67,16 +73,14 @@ impl Hasher for Rpx256 {
|
||||
// initialize the state with zeroes
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
// determine the number of field elements needed to encode `bytes` when each field element
|
||||
// represents at most 7 bytes.
|
||||
let num_field_elem = bytes.len().div_ceil(BINARY_CHUNK_SIZE);
|
||||
|
||||
// set the first capacity element to `RATE_WIDTH + (num_field_elem % RATE_WIDTH)`. We do
|
||||
// this to achieve:
|
||||
// 1. Domain separating hashing of `[u8]` from hashing of `[Felt]`.
|
||||
// 2. Avoiding collisions at the `[Felt]` representation of the encoded bytes.
|
||||
state[CAPACITY_RANGE.start] =
|
||||
Felt::from((RATE_WIDTH + (num_field_elem % RATE_WIDTH)) as u8);
|
||||
// set the capacity (first element) to a flag on whether or not the input length is evenly
|
||||
// divided by the rate. this will prevent collisions between padded and non-padded inputs,
|
||||
// and will rule out the need to perform an extra permutation in case of evenly divided
|
||||
// inputs.
|
||||
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0;
|
||||
if !is_rate_multiple {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// initialize a buffer to receive the little-endian elements.
|
||||
let mut buf = [0_u8; 8];
|
||||
@@ -88,12 +92,12 @@ impl Hasher for Rpx256 {
|
||||
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
|
||||
// additional permutation must be performed.
|
||||
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
|
||||
// copy the chunk into the buffer
|
||||
if i != num_field_elem - 1 {
|
||||
// the last element of the iteration may or may not be a full chunk. if it's not, then
|
||||
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`.
|
||||
// this will avoid collisions.
|
||||
if chunk.len() == BINARY_CHUNK_SIZE {
|
||||
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||
} else {
|
||||
// on the last iteration, we pad `buf` with a 1 followed by as many 0's as are
|
||||
// needed to fill it
|
||||
buf.fill(0);
|
||||
buf[..chunk.len()].copy_from_slice(chunk);
|
||||
buf[chunk.len()] = 1;
|
||||
@@ -116,10 +120,10 @@ impl Hasher for Rpx256 {
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation. we
|
||||
// don't need to apply any extra padding because the first capacity element contains a
|
||||
// flag indicating the number of field elements constituting the last block when the latter
|
||||
// is not divisible by `RATE_WIDTH`.
|
||||
// flag indicating whether the input is evenly divisible by the rate.
|
||||
if i != 0 {
|
||||
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
@@ -144,20 +148,25 @@ impl Hasher for Rpx256 {
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
// initialize the state as follows:
|
||||
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element and
|
||||
// set the first capacity element to 5.
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element
|
||||
// and set the sixth rate element to 1.
|
||||
// - if the value doesn't fit into a single field element, split it into two field
|
||||
// elements, copy them into rate elements 5 and 6 and set the first capacity element to 6.
|
||||
// elements, copy them into rate elements 5 and 6, and set the seventh rate element
|
||||
// to 1.
|
||||
// - set the first capacity element to 1
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||
if value < Felt::MODULUS {
|
||||
state[CAPACITY_RANGE.start] = Felt::from(5_u8);
|
||||
state[INPUT2_RANGE.start + 1] = ONE;
|
||||
} else {
|
||||
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
||||
state[CAPACITY_RANGE.start] = Felt::from(6_u8);
|
||||
state[INPUT2_RANGE.start + 2] = ONE;
|
||||
}
|
||||
|
||||
// common padding for both cases
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
|
||||
// apply the RPX permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
@@ -172,9 +181,11 @@ impl ElementHasher for Rpx256 {
|
||||
let elements = E::slice_as_base_elements(elements);
|
||||
|
||||
// initialize state to all zeros, except for the first element of the capacity part, which
|
||||
// is set to `elements.len() % RATE_WIDTH`.
|
||||
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[CAPACITY_RANGE.start] = Self::BaseField::from((elements.len() % RATE_WIDTH) as u8);
|
||||
if elements.len() % RATE_WIDTH != 0 {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// absorb elements into the state one by one until the rate portion of the state is filled
|
||||
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
||||
@@ -191,8 +202,11 @@ impl ElementHasher for Rpx256 {
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation after
|
||||
// padding by as many 0 as necessary to make the input length a multiple of the RATE_WIDTH.
|
||||
// padding by appending a 1 followed by as many 0 as necessary to make the input length a
|
||||
// multiple of the RATE_WIDTH.
|
||||
if i > 0 {
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
i += 1;
|
||||
while i != RATE_WIDTH {
|
||||
state[RATE_RANGE.start + i] = ZERO;
|
||||
i += 1;
|
||||
@@ -340,7 +354,7 @@ impl Rpx256 {
|
||||
add_constants(state, &ARK1[round]);
|
||||
}
|
||||
|
||||
/// Computes an exponentiation to the power 7 in cubic extension field.
|
||||
/// Computes an exponentiation to the power 7 in cubic extension field
|
||||
#[inline(always)]
|
||||
pub fn exp7(x: CubeExtension<Felt>) -> CubeExtension<Felt> {
|
||||
let x2 = x.square();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[cfg_attr(test, macro_use)]
|
||||
//#[cfg(not(feature = "std"))]
|
||||
//#[cfg_attr(test, macro_use)]
|
||||
extern crate alloc;
|
||||
|
||||
pub mod dsa;
|
||||
@@ -9,6 +9,7 @@ pub mod hash;
|
||||
pub mod merkle;
|
||||
pub mod rand;
|
||||
pub mod utils;
|
||||
pub mod gkr;
|
||||
|
||||
// RE-EXPORTS
|
||||
// ================================================================================================
|
||||
|
||||
52
src/main.rs
52
src/main.rs
@@ -1,14 +1,19 @@
|
||||
use clap::Parser;
|
||||
use miden_crypto::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
merkle::{MerkleError, Smt},
|
||||
merkle::{MerkleError, TieredSmt},
|
||||
Felt, Word, ONE,
|
||||
};
|
||||
use rand_utils::rand_value;
|
||||
use std::time::Instant;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(name = "Benchmark", about = "SMT benchmark", version, rename_all = "kebab-case")]
|
||||
#[clap(
|
||||
name = "Benchmark",
|
||||
about = "Tiered SMT benchmark",
|
||||
version,
|
||||
rename_all = "kebab-case"
|
||||
)]
|
||||
pub struct BenchmarkCmd {
|
||||
/// Size of the tree
|
||||
#[clap(short = 's', long = "size")]
|
||||
@@ -16,11 +21,11 @@ pub struct BenchmarkCmd {
|
||||
}
|
||||
|
||||
fn main() {
|
||||
benchmark_smt();
|
||||
benchmark_tsmt();
|
||||
}
|
||||
|
||||
/// Run a benchmark for [`Smt`].
|
||||
pub fn benchmark_smt() {
|
||||
/// Run a benchmark for the Tiered SMT.
|
||||
pub fn benchmark_tsmt() {
|
||||
let args = BenchmarkCmd::parse();
|
||||
let tree_size = args.size;
|
||||
|
||||
@@ -37,25 +42,38 @@ pub fn benchmark_smt() {
|
||||
proof_generation(&mut tree, tree_size).unwrap();
|
||||
}
|
||||
|
||||
/// Runs the construction benchmark for [`Smt`], returning the constructed tree.
|
||||
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result<Smt, MerkleError> {
|
||||
/// Runs the construction benchmark for the Tiered SMT, returning the constructed tree.
|
||||
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result<TieredSmt, MerkleError> {
|
||||
println!("Running a construction benchmark:");
|
||||
let now = Instant::now();
|
||||
let tree = Smt::with_entries(entries)?;
|
||||
let tree = TieredSmt::with_entries(entries)?;
|
||||
let elapsed = now.elapsed();
|
||||
println!(
|
||||
"Constructed a SMT with {} key-value pairs in {:.3} seconds",
|
||||
"Constructed a TSMT with {} key-value pairs in {:.3} seconds",
|
||||
size,
|
||||
elapsed.as_secs_f32(),
|
||||
);
|
||||
|
||||
println!("Number of leaf nodes: {}\n", tree.leaves().count());
|
||||
// Count how many nodes end up at each tier
|
||||
let mut nodes_num_16_32_48 = (0, 0, 0);
|
||||
|
||||
tree.upper_leaf_nodes().for_each(|(index, _)| match index.depth() {
|
||||
16 => nodes_num_16_32_48.0 += 1,
|
||||
32 => nodes_num_16_32_48.1 += 1,
|
||||
48 => nodes_num_16_32_48.2 += 1,
|
||||
_ => unreachable!(),
|
||||
});
|
||||
|
||||
println!("Number of nodes on depth 16: {}", nodes_num_16_32_48.0);
|
||||
println!("Number of nodes on depth 32: {}", nodes_num_16_32_48.1);
|
||||
println!("Number of nodes on depth 48: {}", nodes_num_16_32_48.2);
|
||||
println!("Number of nodes on depth 64: {}\n", tree.bottom_leaves().count());
|
||||
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Runs the insertion benchmark for the [`Smt`].
|
||||
pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
|
||||
/// Runs the insertion benchmark for the Tiered SMT.
|
||||
pub fn insertion(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
|
||||
println!("Running an insertion benchmark:");
|
||||
|
||||
let mut insertion_times = Vec::new();
|
||||
@@ -71,7 +89,7 @@ pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
|
||||
}
|
||||
|
||||
println!(
|
||||
"An average insertion time measured by 20 inserts into a SMT with {} key-value pairs is {:.3} milliseconds\n",
|
||||
"An average insertion time measured by 20 inserts into a TSMT with {} key-value pairs is {:.3} milliseconds\n",
|
||||
size,
|
||||
// calculate the average by dividing by 20 and convert to milliseconds by multiplying by
|
||||
// 1000. As a result, we can only multiply by 50
|
||||
@@ -81,8 +99,8 @@ pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Runs the proof generation benchmark for the [`Smt`].
|
||||
pub fn proof_generation(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
|
||||
/// Runs the proof generation benchmark for the Tiered SMT.
|
||||
pub fn proof_generation(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
|
||||
println!("Running a proof generation benchmark:");
|
||||
|
||||
let mut insertion_times = Vec::new();
|
||||
@@ -93,13 +111,13 @@ pub fn proof_generation(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
|
||||
tree.insert(test_key, test_value);
|
||||
|
||||
let now = Instant::now();
|
||||
let _proof = tree.open(&test_key);
|
||||
let _proof = tree.prove(test_key);
|
||||
let elapsed = now.elapsed();
|
||||
insertion_times.push(elapsed.as_secs_f32());
|
||||
}
|
||||
|
||||
println!(
|
||||
"An average proving time measured by 20 value proofs in a SMT with {} key-value pairs in {:.3} microseconds",
|
||||
"An average proving time measured by 20 value proofs in a TSMT with {} key-value pairs in {:.3} microseconds",
|
||||
size,
|
||||
// calculate the average by dividing by 20 and convert to microseconds by multiplying by
|
||||
// 1000000. As a result, we can only multiply by 50000
|
||||
|
||||
156
src/merkle/delta.rs
Normal file
156
src/merkle/delta.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use super::{
|
||||
BTreeMap, KvMap, MerkleError, MerkleStore, NodeIndex, RpoDigest, StoreNode, Vec, Word,
|
||||
};
|
||||
use crate::utils::collections::Diff;
|
||||
|
||||
#[cfg(test)]
|
||||
use super::{super::ONE, Felt, SimpleSmt, EMPTY_WORD, ZERO};
|
||||
|
||||
// MERKLE STORE DELTA
|
||||
// ================================================================================================
|
||||
|
||||
/// [MerkleStoreDelta] stores a vector of ([RpoDigest], [MerkleTreeDelta]) tuples where the
|
||||
/// [RpoDigest] represents the root of the Merkle tree and [MerkleTreeDelta] represents the
|
||||
/// differences between the initial and final Merkle tree states.
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>);
|
||||
|
||||
// MERKLE TREE DELTA
|
||||
// ================================================================================================
|
||||
|
||||
/// [MerkleDelta] stores the differences between the initial and final Merkle tree states.
|
||||
///
|
||||
/// The differences are represented as follows:
|
||||
/// - depth: the depth of the merkle tree.
|
||||
/// - cleared_slots: indexes of slots where values were set to [ZERO; 4].
|
||||
/// - updated_slots: index-value pairs of slots where values were set to non [ZERO; 4] values.
|
||||
#[cfg(not(test))]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MerkleTreeDelta {
|
||||
depth: u8,
|
||||
cleared_slots: Vec<u64>,
|
||||
updated_slots: Vec<(u64, Word)>,
|
||||
}
|
||||
|
||||
impl MerkleTreeDelta {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
pub fn new(depth: u8) -> Self {
|
||||
Self {
|
||||
depth,
|
||||
cleared_slots: Vec::new(),
|
||||
updated_slots: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns the depth of the Merkle tree the [MerkleDelta] is associated with.
|
||||
pub fn depth(&self) -> u8 {
|
||||
self.depth
|
||||
}
|
||||
|
||||
/// Returns the indexes of slots where values were set to [ZERO; 4].
|
||||
pub fn cleared_slots(&self) -> &[u64] {
|
||||
&self.cleared_slots
|
||||
}
|
||||
|
||||
/// Returns the index-value pairs of slots where values were set to non [ZERO; 4] values.
|
||||
pub fn updated_slots(&self) -> &[(u64, Word)] {
|
||||
&self.updated_slots
|
||||
}
|
||||
|
||||
// MODIFIERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Adds a slot index to the list of cleared slots.
|
||||
pub fn add_cleared_slot(&mut self, index: u64) {
|
||||
self.cleared_slots.push(index);
|
||||
}
|
||||
|
||||
/// Adds a slot index and a value to the list of updated slots.
|
||||
pub fn add_updated_slot(&mut self, index: u64, value: Word) {
|
||||
self.updated_slots.push((index, value));
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts a [MerkleDelta] object by comparing the leaves of two Merkle trees specifies by
|
||||
/// their roots and depth.
|
||||
pub fn merkle_tree_delta<T: KvMap<RpoDigest, StoreNode>>(
|
||||
tree_root_1: RpoDigest,
|
||||
tree_root_2: RpoDigest,
|
||||
depth: u8,
|
||||
merkle_store: &MerkleStore<T>,
|
||||
) -> Result<MerkleTreeDelta, MerkleError> {
|
||||
if tree_root_1 == tree_root_2 {
|
||||
return Ok(MerkleTreeDelta::new(depth));
|
||||
}
|
||||
|
||||
let tree_1_leaves: BTreeMap<NodeIndex, RpoDigest> =
|
||||
merkle_store.non_empty_leaves(tree_root_1, depth).collect();
|
||||
let tree_2_leaves: BTreeMap<NodeIndex, RpoDigest> =
|
||||
merkle_store.non_empty_leaves(tree_root_2, depth).collect();
|
||||
let diff = tree_1_leaves.diff(&tree_2_leaves);
|
||||
|
||||
// TODO: Refactor this diff implementation to prevent allocation of both BTree and Vec.
|
||||
Ok(MerkleTreeDelta {
|
||||
depth,
|
||||
cleared_slots: diff.removed.into_iter().map(|index| index.value()).collect(),
|
||||
updated_slots: diff
|
||||
.updated
|
||||
.into_iter()
|
||||
.map(|(index, leaf)| (index.value(), *leaf))
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
// INTERNALS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
#[cfg(test)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MerkleTreeDelta {
|
||||
pub depth: u8,
|
||||
pub cleared_slots: Vec<u64>,
|
||||
pub updated_slots: Vec<(u64, Word)>,
|
||||
}
|
||||
|
||||
// MERKLE DELTA
|
||||
// ================================================================================================
|
||||
#[test]
|
||||
fn test_compute_merkle_delta() {
|
||||
let entries = vec![
|
||||
(10, [ZERO, ONE, Felt::new(2), Felt::new(3)]),
|
||||
(15, [Felt::new(4), Felt::new(5), Felt::new(6), Felt::new(7)]),
|
||||
(20, [Felt::new(8), Felt::new(9), Felt::new(10), Felt::new(11)]),
|
||||
(31, [Felt::new(12), Felt::new(13), Felt::new(14), Felt::new(15)]),
|
||||
];
|
||||
let simple_smt = SimpleSmt::with_leaves(30, entries.clone()).unwrap();
|
||||
let mut store: MerkleStore = (&simple_smt).into();
|
||||
let root = simple_smt.root();
|
||||
|
||||
// add a new node
|
||||
let new_value = [Felt::new(16), Felt::new(17), Felt::new(18), Felt::new(19)];
|
||||
let new_index = NodeIndex::new(simple_smt.depth(), 32).unwrap();
|
||||
let root = store.set_node(root, new_index, new_value.into()).unwrap().root;
|
||||
|
||||
// update an existing node
|
||||
let update_value = [Felt::new(20), Felt::new(21), Felt::new(22), Felt::new(23)];
|
||||
let update_idx = NodeIndex::new(simple_smt.depth(), entries[0].0).unwrap();
|
||||
let root = store.set_node(root, update_idx, update_value.into()).unwrap().root;
|
||||
|
||||
// remove a node
|
||||
let remove_idx = NodeIndex::new(simple_smt.depth(), entries[1].0).unwrap();
|
||||
let root = store.set_node(root, remove_idx, EMPTY_WORD.into()).unwrap().root;
|
||||
|
||||
let merkle_delta =
|
||||
merkle_tree_delta(simple_smt.root(), root, simple_smt.depth(), &store).unwrap();
|
||||
let expected_merkle_delta = MerkleTreeDelta {
|
||||
depth: simple_smt.depth(),
|
||||
cleared_slots: vec![remove_idx.value()],
|
||||
updated_slots: vec![(update_idx.value(), update_value), (new_index.value(), new_value)],
|
||||
};
|
||||
|
||||
assert_eq!(merkle_delta, expected_merkle_delta);
|
||||
}
|
||||
@@ -4,8 +4,6 @@ use crate::{
|
||||
};
|
||||
use core::fmt;
|
||||
|
||||
use super::smt::SmtLeafError;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum MerkleError {
|
||||
ConflictingRoots(Vec<RpoDigest>),
|
||||
@@ -22,7 +20,6 @@ pub enum MerkleError {
|
||||
NodeNotInStore(RpoDigest, NodeIndex),
|
||||
NumLeavesNotPowerOfTwo(usize),
|
||||
RootNotInStore(RpoDigest),
|
||||
SmtLeaf(SmtLeafError),
|
||||
}
|
||||
|
||||
impl fmt::Display for MerkleError {
|
||||
@@ -53,16 +50,9 @@ impl fmt::Display for MerkleError {
|
||||
write!(f, "the leaves count {leaves} is not a power of 2")
|
||||
}
|
||||
RootNotInStore(root) => write!(f, "the root {:?} is not in the store", root),
|
||||
SmtLeaf(smt_leaf_error) => write!(f, "smt leaf error: {smt_leaf_error}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for MerkleError {}
|
||||
|
||||
impl From<SmtLeafError> for MerkleError {
|
||||
fn from(value: SmtLeafError) -> Self {
|
||||
Self::SmtLeaf(value)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::{Felt, MerkleError, RpoDigest};
|
||||
use super::{Felt, MerkleError, RpoDigest, StarkField};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
use core::fmt::Display;
|
||||
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
use super::super::{RpoDigest, Vec};
|
||||
|
||||
/// Container for the update data of a [super::PartialMmr]
|
||||
/// Container for the update data of a [PartialMmr]
|
||||
#[derive(Debug)]
|
||||
pub struct MmrDelta {
|
||||
/// The new version of the [super::Mmr]
|
||||
/// The new version of the [Mmr]
|
||||
pub forest: usize,
|
||||
|
||||
/// Update data.
|
||||
///
|
||||
/// The data is packed as follows:
|
||||
/// 1. All the elements needed to perform authentication path updates. These are the right
|
||||
/// siblings required to perform tree merges on the [super::PartialMmr].
|
||||
/// siblings required to perform tree merges on the [PartialMmr].
|
||||
/// 2. The new peaks.
|
||||
pub data: Vec<RpoDigest>,
|
||||
}
|
||||
|
||||
@@ -280,7 +280,7 @@ impl Mmr {
|
||||
// Update the depth of the tree to correspond to a subtree
|
||||
forest_target >>= 1;
|
||||
|
||||
// compute the indices of the right and left subtrees based on the post-order
|
||||
// compute the indeces of the right and left subtrees based on the post-order
|
||||
let right_offset = index - 1;
|
||||
let left_offset = right_offset - nodes_in_forest(forest_target);
|
||||
|
||||
|
||||
@@ -10,11 +10,6 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
// TYPE ALIASES
|
||||
// ================================================================================================
|
||||
|
||||
type NodeMap = BTreeMap<InOrderIndex, RpoDigest>;
|
||||
|
||||
// PARTIAL MERKLE MOUNTAIN RANGE
|
||||
// ================================================================================================
|
||||
/// Partially materialized Merkle Mountain Range (MMR), used to efficiently store and update the
|
||||
@@ -51,16 +46,16 @@ pub struct PartialMmr {
|
||||
|
||||
/// Authentication nodes used to construct merkle paths for a subset of the MMR's leaves.
|
||||
///
|
||||
/// This does not include the MMR's peaks nor the tracked nodes, only the elements required to
|
||||
/// construct their authentication paths. This property is used to detect when elements can be
|
||||
/// safely removed, because they are no longer required to authenticate any element in the
|
||||
/// [PartialMmr].
|
||||
/// This does not include the MMR's peaks nor the tracked nodes, only the elements required
|
||||
/// to construct their authentication paths. This property is used to detect when elements can
|
||||
/// be safely removed from, because they are no longer required to authenticate any element in
|
||||
/// the [PartialMmr].
|
||||
///
|
||||
/// The elements in the MMR are referenced using a in-order tree index. This indexing scheme
|
||||
/// permits for easy computation of the relative nodes (left/right children, sibling, parent),
|
||||
/// which is useful for traversal. The indexing is also stable, meaning that merges to the
|
||||
/// trees in the MMR can be represented without rewrites of the indexes.
|
||||
pub(crate) nodes: NodeMap,
|
||||
pub(crate) nodes: BTreeMap<InOrderIndex, RpoDigest>,
|
||||
|
||||
/// Flag indicating if the odd element should be tracked.
|
||||
///
|
||||
@@ -73,27 +68,16 @@ impl PartialMmr {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [PartialMmr] instantiated from the specified peaks.
|
||||
/// Constructs a [PartialMmr] from the given [MmrPeaks].
|
||||
pub fn from_peaks(peaks: MmrPeaks) -> Self {
|
||||
let forest = peaks.num_leaves();
|
||||
let peaks = peaks.into();
|
||||
let peaks = peaks.peaks().to_vec();
|
||||
let nodes = BTreeMap::new();
|
||||
let track_latest = false;
|
||||
|
||||
Self { forest, peaks, nodes, track_latest }
|
||||
}
|
||||
|
||||
/// Returns a new [PartialMmr] instantiated from the specified components.
|
||||
///
|
||||
/// This constructor does not check the consistency between peaks and nodes. If the specified
|
||||
/// peaks are nodes are inconsistent, the returned partial MMR may exhibit undefined behavior.
|
||||
pub fn from_parts(peaks: MmrPeaks, nodes: NodeMap, track_latest: bool) -> Self {
|
||||
let forest = peaks.num_leaves();
|
||||
let peaks = peaks.into();
|
||||
|
||||
Self { forest, peaks, nodes, track_latest }
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
@@ -117,31 +101,14 @@ impl PartialMmr {
|
||||
MmrPeaks::new(self.forest, self.peaks.clone()).expect("invalid MMR peaks")
|
||||
}
|
||||
|
||||
/// Returns true if this partial MMR tracks an authentication path for the leaf at the
|
||||
/// specified position.
|
||||
pub fn is_tracked(&self, pos: usize) -> bool {
|
||||
if pos >= self.forest {
|
||||
return false;
|
||||
} else if pos == self.forest - 1 && self.forest & 1 != 0 {
|
||||
// if the number of leaves in the MMR is odd and the position is for the last leaf
|
||||
// whether the leaf is tracked is defined by the `track_latest` flag
|
||||
return self.track_latest;
|
||||
}
|
||||
|
||||
let leaf_index = InOrderIndex::from_leaf_pos(pos);
|
||||
self.is_tracked_node(&leaf_index)
|
||||
}
|
||||
|
||||
/// Given a leaf position, returns the Merkle path to its corresponding peak, or None if this
|
||||
/// partial MMR does not track an authentication paths for the specified leaf.
|
||||
/// Given a leaf position, returns the Merkle path to its corresponding peak.
|
||||
///
|
||||
/// If the position is greater-or-equal than the tree size an error is returned. If the
|
||||
/// requested value is not tracked returns `None`.
|
||||
///
|
||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||
/// has position 0, the second position 1, and so on.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified position is greater-or-equal than the number of leaves
|
||||
/// in the underlying MMR.
|
||||
pub fn open(&self, pos: usize) -> Result<Option<MmrProof>, MmrError> {
|
||||
let tree_bit =
|
||||
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?;
|
||||
@@ -182,13 +149,13 @@ impl PartialMmr {
|
||||
///
|
||||
/// The order of iteration is not defined. If a leaf is not presented in this partial MMR it
|
||||
/// is silently ignored.
|
||||
pub fn inner_nodes<'a, I: Iterator<Item = (usize, RpoDigest)> + 'a>(
|
||||
pub fn inner_nodes<'a, I: Iterator<Item = &'a (usize, RpoDigest)> + 'a>(
|
||||
&'a self,
|
||||
mut leaves: I,
|
||||
) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
let stack = if let Some((pos, leaf)) = leaves.next() {
|
||||
let idx = InOrderIndex::from_leaf_pos(pos);
|
||||
vec![(idx, leaf)]
|
||||
let idx = InOrderIndex::from_leaf_pos(*pos);
|
||||
vec![(idx, *leaf)]
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
@@ -204,93 +171,20 @@ impl PartialMmr {
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Adds a new peak and optionally track it. Returns a vector of the authentication nodes
|
||||
/// inserted into this [PartialMmr] as a result of this operation.
|
||||
/// Add the authentication path represented by [MerklePath] if it is valid.
|
||||
///
|
||||
/// When `track` is `true` the new leaf is tracked.
|
||||
pub fn add(&mut self, leaf: RpoDigest, track: bool) -> Vec<(InOrderIndex, RpoDigest)> {
|
||||
self.forest += 1;
|
||||
let merges = self.forest.trailing_zeros() as usize;
|
||||
let mut new_nodes = Vec::with_capacity(merges);
|
||||
|
||||
let peak = if merges == 0 {
|
||||
self.track_latest = track;
|
||||
leaf
|
||||
} else {
|
||||
let mut track_right = track;
|
||||
let mut track_left = self.track_latest;
|
||||
|
||||
let mut right = leaf;
|
||||
let mut right_idx = forest_to_rightmost_index(self.forest);
|
||||
|
||||
for _ in 0..merges {
|
||||
let left = self.peaks.pop().expect("Missing peak");
|
||||
let left_idx = right_idx.sibling();
|
||||
|
||||
if track_right {
|
||||
let old = self.nodes.insert(left_idx, left);
|
||||
new_nodes.push((left_idx, left));
|
||||
|
||||
debug_assert!(
|
||||
old.is_none(),
|
||||
"Idx {:?} already contained an element {:?}",
|
||||
left_idx,
|
||||
old
|
||||
);
|
||||
};
|
||||
if track_left {
|
||||
let old = self.nodes.insert(right_idx, right);
|
||||
new_nodes.push((right_idx, right));
|
||||
|
||||
debug_assert!(
|
||||
old.is_none(),
|
||||
"Idx {:?} already contained an element {:?}",
|
||||
right_idx,
|
||||
old
|
||||
);
|
||||
};
|
||||
|
||||
// Update state for the next iteration.
|
||||
// --------------------------------------------------------------------------------
|
||||
|
||||
// This layer is merged, go up one layer.
|
||||
right_idx = right_idx.parent();
|
||||
|
||||
// Merge the current layer. The result is either the right element of the next
|
||||
// merge, or a new peak.
|
||||
right = Rpo256::merge(&[left, right]);
|
||||
|
||||
// This iteration merged the left and right nodes, the new value is always used as
|
||||
// the next iteration's right node. Therefore the tracking flags of this iteration
|
||||
// have to be merged into the right side only.
|
||||
track_right = track_right || track_left;
|
||||
|
||||
// On the next iteration, a peak will be merged. If any of its children are tracked,
|
||||
// then we have to track the left side
|
||||
track_left = self.is_tracked_node(&right_idx.sibling());
|
||||
}
|
||||
right
|
||||
};
|
||||
|
||||
self.peaks.push(peak);
|
||||
|
||||
new_nodes
|
||||
}
|
||||
|
||||
/// Adds the authentication path represented by [MerklePath] if it is valid.
|
||||
///
|
||||
/// The `leaf_pos` refers to the global position of the leaf in the MMR, these are 0-indexed
|
||||
/// The `index` refers to the global position of the leaf in the MMR, these are 0-indexed
|
||||
/// values assigned in a strictly monotonic fashion as elements are inserted into the MMR,
|
||||
/// this value corresponds to the values used in the MMR structure.
|
||||
///
|
||||
/// The `leaf` corresponds to the value at `leaf_pos`, and `path` is the authentication path for
|
||||
/// that element up to its corresponding Mmr peak. The `leaf` is only used to compute the root
|
||||
/// The `node` corresponds to the value at `index`, and `path` is the authentication path for
|
||||
/// that element up to its corresponding Mmr peak. The `node` is only used to compute the root
|
||||
/// from the authentication path to valid the data, only the authentication data is saved in
|
||||
/// the structure. If the value is required it should be stored out-of-band.
|
||||
pub fn track(
|
||||
pub fn add(
|
||||
&mut self,
|
||||
leaf_pos: usize,
|
||||
leaf: RpoDigest,
|
||||
index: usize,
|
||||
node: RpoDigest,
|
||||
path: &MerklePath,
|
||||
) -> Result<(), MmrError> {
|
||||
// Checks there is a tree with same depth as the authentication path, if not the path is
|
||||
@@ -300,42 +194,42 @@ impl PartialMmr {
|
||||
return Err(MmrError::UnknownPeak);
|
||||
};
|
||||
|
||||
if leaf_pos + 1 == self.forest
|
||||
if index + 1 == self.forest
|
||||
&& path.depth() == 0
|
||||
&& self.peaks.last().map_or(false, |v| *v == leaf)
|
||||
&& self.peaks.last().map_or(false, |v| *v == node)
|
||||
{
|
||||
self.track_latest = true;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// ignore the trees smaller than the target (these elements are position after the current
|
||||
// target and don't affect the target leaf_pos)
|
||||
// target and don't affect the target index)
|
||||
let target_forest = self.forest ^ (self.forest & (tree - 1));
|
||||
let peak_pos = (target_forest.count_ones() - 1) as usize;
|
||||
|
||||
// translate from mmr leaf_pos to merkle path
|
||||
let path_idx = leaf_pos - (target_forest ^ tree);
|
||||
// translate from mmr index to merkle path
|
||||
let path_idx = index - (target_forest ^ tree);
|
||||
|
||||
// Compute the root of the authentication path, and check it matches the current version of
|
||||
// the PartialMmr.
|
||||
let computed = path.compute_root(path_idx as u64, leaf).map_err(MmrError::MerkleError)?;
|
||||
let computed = path.compute_root(path_idx as u64, node).map_err(MmrError::MerkleError)?;
|
||||
if self.peaks[peak_pos] != computed {
|
||||
return Err(MmrError::InvalidPeak);
|
||||
}
|
||||
|
||||
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
|
||||
for leaf in path.nodes() {
|
||||
self.nodes.insert(idx.sibling(), *leaf);
|
||||
let mut idx = InOrderIndex::from_leaf_pos(index);
|
||||
for node in path.nodes() {
|
||||
self.nodes.insert(idx.sibling(), *node);
|
||||
idx = idx.parent();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes a leaf of the [PartialMmr] and the unused nodes from the authentication path.
|
||||
/// Remove a leaf of the [PartialMmr] and the unused nodes from the authentication path.
|
||||
///
|
||||
/// Note: `leaf_pos` corresponds to the position in the MMR and not on an individual tree.
|
||||
pub fn untrack(&mut self, leaf_pos: usize) {
|
||||
pub fn remove(&mut self, leaf_pos: usize) {
|
||||
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
|
||||
|
||||
self.nodes.remove(&idx.sibling());
|
||||
@@ -531,14 +425,14 @@ impl From<&PartialMmr> for MmrPeaks {
|
||||
// ================================================================================================
|
||||
|
||||
/// An iterator over every inner node of the [PartialMmr].
|
||||
pub struct InnerNodeIterator<'a, I: Iterator<Item = (usize, RpoDigest)>> {
|
||||
nodes: &'a NodeMap,
|
||||
pub struct InnerNodeIterator<'a, I: Iterator<Item = &'a (usize, RpoDigest)>> {
|
||||
nodes: &'a BTreeMap<InOrderIndex, RpoDigest>,
|
||||
leaves: I,
|
||||
stack: Vec<(InOrderIndex, RpoDigest)>,
|
||||
seen_nodes: BTreeSet<InOrderIndex>,
|
||||
}
|
||||
|
||||
impl<'a, I: Iterator<Item = (usize, RpoDigest)>> Iterator for InnerNodeIterator<'a, I> {
|
||||
impl<'a, I: Iterator<Item = &'a (usize, RpoDigest)>> Iterator for InnerNodeIterator<'a, I> {
|
||||
type Item = InnerNodeInfo;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
@@ -565,8 +459,8 @@ impl<'a, I: Iterator<Item = (usize, RpoDigest)>> Iterator for InnerNodeIterator<
|
||||
|
||||
// the previous leaf has been processed, try to process the next leaf
|
||||
if let Some((pos, leaf)) = self.leaves.next() {
|
||||
let idx = InOrderIndex::from_leaf_pos(pos);
|
||||
self.stack.push((idx, leaf));
|
||||
let idx = InOrderIndex::from_leaf_pos(*pos);
|
||||
self.stack.push((idx, *leaf));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -596,29 +490,12 @@ fn forest_to_root_index(forest: usize) -> InOrderIndex {
|
||||
InOrderIndex::new(idx.try_into().unwrap())
|
||||
}
|
||||
|
||||
/// Given the description of a `forest`, returns the index of the right most element.
|
||||
fn forest_to_rightmost_index(forest: usize) -> InOrderIndex {
|
||||
// Count total size of all trees in the forest.
|
||||
let nodes = nodes_in_forest(forest);
|
||||
|
||||
// Add the count for the parent nodes that separate each tree. These are allocated but
|
||||
// currently empty, and correspond to the nodes that will be used once the trees are merged.
|
||||
let open_trees = (forest.count_ones() - 1) as usize;
|
||||
|
||||
let idx = nodes + open_trees;
|
||||
|
||||
InOrderIndex::new(idx.try_into().unwrap())
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
forest_to_rightmost_index, forest_to_root_index, BTreeSet, InOrderIndex, MmrPeaks,
|
||||
PartialMmr, RpoDigest, Vec,
|
||||
};
|
||||
use super::{forest_to_root_index, BTreeSet, InOrderIndex, PartialMmr, RpoDigest, Vec};
|
||||
use crate::merkle::{int_to_node, MerkleStore, Mmr, NodeIndex};
|
||||
|
||||
const LEAVES: [RpoDigest; 7] = [
|
||||
@@ -657,33 +534,6 @@ mod tests {
|
||||
assert_eq!(forest_to_root_index(0b1110), idx(26));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forest_to_rightmost_index() {
|
||||
fn idx(pos: usize) -> InOrderIndex {
|
||||
InOrderIndex::new(pos.try_into().unwrap())
|
||||
}
|
||||
|
||||
for forest in 1..256 {
|
||||
assert!(forest_to_rightmost_index(forest).inner() % 2 == 1, "Leaves are always odd");
|
||||
}
|
||||
|
||||
assert_eq!(forest_to_rightmost_index(0b0001), idx(1));
|
||||
assert_eq!(forest_to_rightmost_index(0b0010), idx(3));
|
||||
assert_eq!(forest_to_rightmost_index(0b0011), idx(5));
|
||||
assert_eq!(forest_to_rightmost_index(0b0100), idx(7));
|
||||
assert_eq!(forest_to_rightmost_index(0b0101), idx(9));
|
||||
assert_eq!(forest_to_rightmost_index(0b0110), idx(11));
|
||||
assert_eq!(forest_to_rightmost_index(0b0111), idx(13));
|
||||
assert_eq!(forest_to_rightmost_index(0b1000), idx(15));
|
||||
assert_eq!(forest_to_rightmost_index(0b1001), idx(17));
|
||||
assert_eq!(forest_to_rightmost_index(0b1010), idx(19));
|
||||
assert_eq!(forest_to_rightmost_index(0b1011), idx(21));
|
||||
assert_eq!(forest_to_rightmost_index(0b1100), idx(23));
|
||||
assert_eq!(forest_to_rightmost_index(0b1101), idx(25));
|
||||
assert_eq!(forest_to_rightmost_index(0b1110), idx(27));
|
||||
assert_eq!(forest_to_rightmost_index(0b1111), idx(29));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_apply_delta() {
|
||||
// build an MMR with 10 nodes (2 peaks) and a partial MMR based on it
|
||||
@@ -695,13 +545,13 @@ mod tests {
|
||||
{
|
||||
let node = mmr.get(1).unwrap();
|
||||
let proof = mmr.open(1, mmr.forest()).unwrap();
|
||||
partial_mmr.track(1, node, &proof.merkle_path).unwrap();
|
||||
partial_mmr.add(1, node, &proof.merkle_path).unwrap();
|
||||
}
|
||||
|
||||
{
|
||||
let node = mmr.get(8).unwrap();
|
||||
let proof = mmr.open(8, mmr.forest()).unwrap();
|
||||
partial_mmr.track(8, node, &proof.merkle_path).unwrap();
|
||||
partial_mmr.add(8, node, &proof.merkle_path).unwrap();
|
||||
}
|
||||
|
||||
// add 2 more nodes into the MMR and validate apply_delta()
|
||||
@@ -714,7 +564,7 @@ mod tests {
|
||||
{
|
||||
let node = mmr.get(12).unwrap();
|
||||
let proof = mmr.open(12, mmr.forest()).unwrap();
|
||||
partial_mmr.track(12, node, &proof.merkle_path).unwrap();
|
||||
partial_mmr.add(12, node, &proof.merkle_path).unwrap();
|
||||
assert!(partial_mmr.track_latest);
|
||||
}
|
||||
|
||||
@@ -773,14 +623,14 @@ mod tests {
|
||||
|
||||
// create partial MMR and add authentication path to node at position 1
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
|
||||
partial_mmr.add(1, node1, &proof1.merkle_path).unwrap();
|
||||
|
||||
// empty iterator should have no nodes
|
||||
assert_eq!(partial_mmr.inner_nodes([].iter().cloned()).next(), None);
|
||||
assert_eq!(partial_mmr.inner_nodes([].iter()).next(), None);
|
||||
|
||||
// build Merkle store from authentication paths in partial MMR
|
||||
let mut store: MerkleStore = MerkleStore::new();
|
||||
store.extend(partial_mmr.inner_nodes([(1, node1)].iter().cloned()));
|
||||
store.extend(partial_mmr.inner_nodes([(1, node1)].iter()));
|
||||
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||
@@ -798,19 +648,19 @@ mod tests {
|
||||
let node2 = mmr.get(2).unwrap();
|
||||
let proof2 = mmr.open(2, mmr.forest()).unwrap();
|
||||
|
||||
partial_mmr.track(0, node0, &proof0.merkle_path).unwrap();
|
||||
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
|
||||
partial_mmr.track(2, node2, &proof2.merkle_path).unwrap();
|
||||
partial_mmr.add(0, node0, &proof0.merkle_path).unwrap();
|
||||
partial_mmr.add(1, node1, &proof1.merkle_path).unwrap();
|
||||
partial_mmr.add(2, node2, &proof2.merkle_path).unwrap();
|
||||
|
||||
// make sure there are no duplicates
|
||||
let leaves = [(0, node0), (1, node1), (2, node2)];
|
||||
let mut nodes = BTreeSet::new();
|
||||
for node in partial_mmr.inner_nodes(leaves.iter().cloned()) {
|
||||
for node in partial_mmr.inner_nodes(leaves.iter()) {
|
||||
assert!(nodes.insert(node.value));
|
||||
}
|
||||
|
||||
// and also that the store is still be built correctly
|
||||
store.extend(partial_mmr.inner_nodes(leaves.iter().cloned()));
|
||||
store.extend(partial_mmr.inner_nodes(leaves.iter()));
|
||||
|
||||
let index0 = NodeIndex::new(2, 0).unwrap();
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
@@ -832,12 +682,12 @@ mod tests {
|
||||
let node5 = mmr.get(5).unwrap();
|
||||
let proof5 = mmr.open(5, mmr.forest()).unwrap();
|
||||
|
||||
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
|
||||
partial_mmr.track(5, node5, &proof5.merkle_path).unwrap();
|
||||
partial_mmr.add(1, node1, &proof1.merkle_path).unwrap();
|
||||
partial_mmr.add(5, node5, &proof5.merkle_path).unwrap();
|
||||
|
||||
// build Merkle store from authentication paths in partial MMR
|
||||
let mut store: MerkleStore = MerkleStore::new();
|
||||
store.extend(partial_mmr.inner_nodes([(1, node1), (5, node5)].iter().cloned()));
|
||||
store.extend(partial_mmr.inner_nodes([(1, node1), (5, node5)].iter()));
|
||||
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
let index5 = NodeIndex::new(1, 1).unwrap();
|
||||
@@ -850,62 +700,4 @@ mod tests {
|
||||
assert_eq!(path1, proof1.merkle_path);
|
||||
assert_eq!(path5, proof5.merkle_path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_add_without_track() {
|
||||
let mut mmr = Mmr::default();
|
||||
let empty_peaks = MmrPeaks::new(0, vec![]).unwrap();
|
||||
let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
|
||||
|
||||
for el in (0..256).map(int_to_node) {
|
||||
mmr.add(el);
|
||||
partial_mmr.add(el, false);
|
||||
|
||||
let mmr_peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(mmr_peaks, partial_mmr.peaks());
|
||||
assert_eq!(mmr.forest(), partial_mmr.forest());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_add_with_track() {
|
||||
let mut mmr = Mmr::default();
|
||||
let empty_peaks = MmrPeaks::new(0, vec![]).unwrap();
|
||||
let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
|
||||
|
||||
for i in 0..256 {
|
||||
let el = int_to_node(i);
|
||||
mmr.add(el);
|
||||
partial_mmr.add(el, true);
|
||||
|
||||
let mmr_peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(mmr_peaks, partial_mmr.peaks());
|
||||
assert_eq!(mmr.forest(), partial_mmr.forest());
|
||||
|
||||
for pos in 0..i {
|
||||
let mmr_proof = mmr.open(pos as usize, mmr.forest()).unwrap();
|
||||
let partialmmr_proof = partial_mmr.open(pos as usize).unwrap().unwrap();
|
||||
assert_eq!(mmr_proof, partialmmr_proof);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_add_existing_track() {
|
||||
let mut mmr = Mmr::from((0..7).map(int_to_node));
|
||||
|
||||
// derive a partial Mmr from it which tracks authentication path to leaf 5
|
||||
let mut partial_mmr = PartialMmr::from_peaks(mmr.peaks(mmr.forest()).unwrap());
|
||||
let path_to_5 = mmr.open(5, mmr.forest()).unwrap().merkle_path;
|
||||
let leaf_at_5 = mmr.get(5).unwrap();
|
||||
partial_mmr.track(5, leaf_at_5, &path_to_5).unwrap();
|
||||
|
||||
// add a new leaf to both Mmr and partial Mmr
|
||||
let leaf_at_7 = int_to_node(7);
|
||||
mmr.add(leaf_at_7);
|
||||
partial_mmr.add(leaf_at_7, false);
|
||||
|
||||
// the openings should be the same
|
||||
assert_eq!(mmr.open(5, mmr.forest()).unwrap(), partial_mmr.open(5).unwrap().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,9 +132,3 @@ impl MmrPeaks {
|
||||
elements
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MmrPeaks> for Vec<RpoDigest> {
|
||||
fn from(peaks: MmrPeaks) -> Self {
|
||||
peaks.peaks
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,6 @@
|
||||
use super::super::MerklePath;
|
||||
use super::{full::high_bitmask, leaf_to_corresponding_tree};
|
||||
|
||||
// MMR PROOF
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MmrProof {
|
||||
@@ -29,78 +26,9 @@ impl MmrProof {
|
||||
self.position - forest_before
|
||||
}
|
||||
|
||||
/// Returns index of the MMR peak against which the Merkle path in this proof can be verified.
|
||||
pub fn peak_index(&self) -> usize {
|
||||
let root = leaf_to_corresponding_tree(self.position, self.forest)
|
||||
.expect("position must be part of the forest");
|
||||
let smaller_peak_mask = 2_usize.pow(root) as usize - 1;
|
||||
let num_smaller_peaks = (self.forest & smaller_peak_mask).count_ones();
|
||||
(self.forest.count_ones() - num_smaller_peaks - 1) as usize
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{MerklePath, MmrProof};
|
||||
|
||||
#[test]
|
||||
fn test_peak_index() {
|
||||
// --- single peak forest ---------------------------------------------
|
||||
let forest = 11;
|
||||
|
||||
// the first 4 leaves belong to peak 0
|
||||
for position in 0..8 {
|
||||
let proof = make_dummy_proof(forest, position);
|
||||
assert_eq!(proof.peak_index(), 0);
|
||||
}
|
||||
|
||||
// --- forest with non-consecutive peaks ------------------------------
|
||||
let forest = 11;
|
||||
|
||||
// the first 8 leaves belong to peak 0
|
||||
for position in 0..8 {
|
||||
let proof = make_dummy_proof(forest, position);
|
||||
assert_eq!(proof.peak_index(), 0);
|
||||
}
|
||||
|
||||
// the next 2 leaves belong to peak 1
|
||||
for position in 8..10 {
|
||||
let proof = make_dummy_proof(forest, position);
|
||||
assert_eq!(proof.peak_index(), 1);
|
||||
}
|
||||
|
||||
// the last leaf is the peak 2
|
||||
let proof = make_dummy_proof(forest, 10);
|
||||
assert_eq!(proof.peak_index(), 2);
|
||||
|
||||
// --- forest with consecutive peaks ----------------------------------
|
||||
let forest = 7;
|
||||
|
||||
// the first 4 leaves belong to peak 0
|
||||
for position in 0..4 {
|
||||
let proof = make_dummy_proof(forest, position);
|
||||
assert_eq!(proof.peak_index(), 0);
|
||||
}
|
||||
|
||||
// the next 2 leaves belong to peak 1
|
||||
for position in 4..6 {
|
||||
let proof = make_dummy_proof(forest, position);
|
||||
assert_eq!(proof.peak_index(), 1);
|
||||
}
|
||||
|
||||
// the last leaf is the peak 2
|
||||
let proof = make_dummy_proof(forest, 6);
|
||||
assert_eq!(proof.peak_index(), 2);
|
||||
}
|
||||
|
||||
fn make_dummy_proof(forest: usize, position: usize) -> MmrProof {
|
||||
MmrProof {
|
||||
forest,
|
||||
position,
|
||||
merkle_path: MerklePath::default(),
|
||||
}
|
||||
(self.forest.count_ones() - root - 1) as usize
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,14 +114,13 @@ const LEAVES: [RpoDigest; 7] = [
|
||||
|
||||
#[test]
|
||||
fn test_mmr_simple() {
|
||||
let mut postorder = vec![
|
||||
LEAVES[0],
|
||||
LEAVES[1],
|
||||
merge(LEAVES[0], LEAVES[1]),
|
||||
LEAVES[2],
|
||||
LEAVES[3],
|
||||
merge(LEAVES[2], LEAVES[3]),
|
||||
];
|
||||
let mut postorder = Vec::new();
|
||||
postorder.push(LEAVES[0]);
|
||||
postorder.push(LEAVES[1]);
|
||||
postorder.push(merge(LEAVES[0], LEAVES[1]));
|
||||
postorder.push(LEAVES[2]);
|
||||
postorder.push(LEAVES[3]);
|
||||
postorder.push(merge(LEAVES[2], LEAVES[3]));
|
||||
postorder.push(merge(postorder[2], postorder[5]));
|
||||
postorder.push(LEAVES[4]);
|
||||
postorder.push(LEAVES[5]);
|
||||
@@ -769,7 +768,7 @@ fn test_partial_mmr_simple() {
|
||||
// check state after adding tracking one element
|
||||
let proof1 = mmr.open(0, mmr.forest()).unwrap();
|
||||
let el1 = mmr.get(proof1.position).unwrap();
|
||||
partial.track(proof1.position, el1, &proof1.merkle_path).unwrap();
|
||||
partial.add(proof1.position, el1, &proof1.merkle_path).unwrap();
|
||||
|
||||
// check the number of nodes increased by the number of nodes in the proof
|
||||
assert_eq!(partial.nodes.len(), proof1.merkle_path.len());
|
||||
@@ -781,7 +780,7 @@ fn test_partial_mmr_simple() {
|
||||
|
||||
let proof2 = mmr.open(1, mmr.forest()).unwrap();
|
||||
let el2 = mmr.get(proof2.position).unwrap();
|
||||
partial.track(proof2.position, el2, &proof2.merkle_path).unwrap();
|
||||
partial.add(proof2.position, el2, &proof2.merkle_path).unwrap();
|
||||
|
||||
// check the number of nodes increased by a single element (the one that is not shared)
|
||||
assert_eq!(partial.nodes.len(), 3);
|
||||
@@ -800,7 +799,7 @@ fn test_partial_mmr_update_single() {
|
||||
let mut partial: PartialMmr = full.peaks(full.forest()).unwrap().into();
|
||||
|
||||
let proof = full.open(0, full.forest()).unwrap();
|
||||
partial.track(proof.position, zero, &proof.merkle_path).unwrap();
|
||||
partial.add(proof.position, zero, &proof.merkle_path).unwrap();
|
||||
|
||||
for i in 1..100 {
|
||||
let node = int_to_node(i);
|
||||
@@ -812,7 +811,7 @@ fn test_partial_mmr_update_single() {
|
||||
assert_eq!(partial.peaks(), full.peaks(full.forest()).unwrap());
|
||||
|
||||
let proof1 = full.open(i as usize, full.forest()).unwrap();
|
||||
partial.track(proof1.position, node, &proof1.merkle_path).unwrap();
|
||||
partial.add(proof1.position, node, &proof1.merkle_path).unwrap();
|
||||
let proof2 = partial.open(proof1.position).unwrap().unwrap();
|
||||
assert_eq!(proof1.merkle_path, proof2.merkle_path);
|
||||
}
|
||||
@@ -828,11 +827,11 @@ fn test_mmr_add_invalid_odd_leaf() {
|
||||
|
||||
// None of the other leaves should work
|
||||
for node in LEAVES.iter().cloned().rev().skip(1) {
|
||||
let result = partial.track(LEAVES.len() - 1, node, &empty);
|
||||
let result = partial.add(LEAVES.len() - 1, node, &empty);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
let result = partial.track(LEAVES.len() - 1, LEAVES[6], &empty);
|
||||
let result = partial.add(LEAVES.len() - 1, LEAVES[6], &empty);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
use super::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, Vec},
|
||||
Felt, Word, EMPTY_WORD, ZERO,
|
||||
utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, TryApplyDiff, Vec},
|
||||
Felt, StarkField, Word, EMPTY_WORD, ZERO,
|
||||
};
|
||||
|
||||
// REEXPORTS
|
||||
@@ -12,6 +12,9 @@ use super::{
|
||||
mod empty_roots;
|
||||
pub use empty_roots::EmptySubtreeRoots;
|
||||
|
||||
mod delta;
|
||||
pub use delta::{merkle_tree_delta, MerkleStoreDelta, MerkleTreeDelta};
|
||||
|
||||
mod index;
|
||||
pub use index::NodeIndex;
|
||||
|
||||
@@ -21,11 +24,11 @@ pub use merkle_tree::{path_to_text, tree_to_text, MerkleTree};
|
||||
mod path;
|
||||
pub use path::{MerklePath, RootPath, ValuePath};
|
||||
|
||||
mod smt;
|
||||
pub use smt::{
|
||||
LeafIndex, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH,
|
||||
SMT_MAX_DEPTH, SMT_MIN_DEPTH,
|
||||
};
|
||||
mod simple_smt;
|
||||
pub use simple_smt::SimpleSmt;
|
||||
|
||||
mod tiered_smt;
|
||||
pub use tiered_smt::{TieredSmt, TieredSmtProof, TieredSmtProofError};
|
||||
|
||||
mod mmr;
|
||||
pub use mmr::{InOrderIndex, Mmr, MmrDelta, MmrError, MmrPeaks, MmrProof, PartialMmr};
|
||||
|
||||
@@ -179,7 +179,7 @@ impl PartialMerkleTree {
|
||||
/// # Errors
|
||||
/// Returns an error if the specified NodeIndex is not contained in the nodes map.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
self.nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index)).copied()
|
||||
self.nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index)).map(|hash| *hash)
|
||||
}
|
||||
|
||||
/// Returns true if provided index contains in the leaves set, false otherwise.
|
||||
|
||||
@@ -209,7 +209,7 @@ fn get_paths() {
|
||||
// Which have leaf nodes 20, 22, 23, 32 and 33. Hence overall we will have 5 paths -- one path
|
||||
// for each leaf.
|
||||
|
||||
let leaves = [NODE20, NODE22, NODE23, NODE32, NODE33];
|
||||
let leaves = vec![NODE20, NODE22, NODE23, NODE32, NODE33];
|
||||
let expected_paths: Vec<(NodeIndex, ValuePath)> = leaves
|
||||
.iter()
|
||||
.map(|&leaf| {
|
||||
@@ -257,7 +257,7 @@ fn leaves() {
|
||||
let value32 = mt.get_node(NODE32).unwrap();
|
||||
let value33 = mt.get_node(NODE33).unwrap();
|
||||
|
||||
let leaves = [(NODE11, value11), (NODE20, value20), (NODE32, value32), (NODE33, value33)];
|
||||
let leaves = vec![(NODE11, value11), (NODE20, value20), (NODE32, value32), (NODE33, value33)];
|
||||
|
||||
let expected_leaves = leaves.iter().copied();
|
||||
assert!(expected_leaves.eq(pmt.leaves()));
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use crate::Word;
|
||||
|
||||
use super::{vec, InnerNodeInfo, MerkleError, NodeIndex, Rpo256, RpoDigest, Vec};
|
||||
use core::ops::{Deref, DerefMut};
|
||||
use winter_utils::{ByteReader, Deserializable, DeserializationError, Serializable};
|
||||
@@ -165,7 +163,7 @@ impl<'a> Iterator for InnerNodeIterator<'a> {
|
||||
// MERKLE PATH CONTAINERS
|
||||
// ================================================================================================
|
||||
|
||||
/// A container for a [crate::Word] value and its [MerklePath] opening.
|
||||
/// A container for a [Word] value and its [MerklePath] opening.
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
||||
pub struct ValuePath {
|
||||
/// The node value opening for `path`.
|
||||
@@ -176,18 +174,12 @@ pub struct ValuePath {
|
||||
|
||||
impl ValuePath {
|
||||
/// Returns a new [ValuePath] instantiated from the specified value and path.
|
||||
pub fn new(value: RpoDigest, path: MerklePath) -> Self {
|
||||
Self { value, path }
|
||||
pub fn new(value: RpoDigest, path: Vec<RpoDigest>) -> Self {
|
||||
Self { value, path: MerklePath::new(path) }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(MerklePath, Word)> for ValuePath {
|
||||
fn from((path, value): (MerklePath, Word)) -> Self {
|
||||
ValuePath::new(value.into(), path)
|
||||
}
|
||||
}
|
||||
|
||||
/// A container for a [MerklePath] and its [crate::Word] root.
|
||||
/// A container for a [MerklePath] and its [Word] root.
|
||||
///
|
||||
/// This structure does not provide any guarantees regarding the correctness of the path to the
|
||||
/// root. For more information, check [MerklePath::verify].
|
||||
@@ -206,14 +198,14 @@ impl Serializable for MerklePath {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
assert!(self.nodes.len() <= u8::MAX.into(), "Length enforced in the constructor");
|
||||
target.write_u8(self.nodes.len() as u8);
|
||||
target.write_many(&self.nodes);
|
||||
self.nodes.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for MerklePath {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let count = source.read_u8()?.into();
|
||||
let nodes = source.read_many::<RpoDigest>(count)?;
|
||||
let nodes = RpoDigest::read_batch_from(source, count)?;
|
||||
Ok(Self { nodes })
|
||||
}
|
||||
}
|
||||
|
||||
389
src/merkle/simple_smt/mod.rs
Normal file
389
src/merkle/simple_smt/mod.rs
Normal file
@@ -0,0 +1,389 @@
|
||||
use super::{
|
||||
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, MerkleTreeDelta,
|
||||
NodeIndex, Rpo256, RpoDigest, StoreNode, TryApplyDiff, Vec, Word,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// SPARSE MERKLE TREE
|
||||
// ================================================================================================
|
||||
|
||||
/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
|
||||
///
|
||||
/// The root of the tree is recomputed on each new leaf update.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct SimpleSmt {
|
||||
depth: u8,
|
||||
root: RpoDigest,
|
||||
leaves: BTreeMap<u64, Word>,
|
||||
branches: BTreeMap<NodeIndex, BranchNode>,
|
||||
}
|
||||
|
||||
impl SimpleSmt {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Minimum supported depth.
|
||||
pub const MIN_DEPTH: u8 = 1;
|
||||
|
||||
/// Maximum supported depth.
|
||||
pub const MAX_DEPTH: u8 = 64;
|
||||
|
||||
/// Value of an empty leaf.
|
||||
pub const EMPTY_VALUE: Word = super::EMPTY_WORD;
|
||||
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [SimpleSmt] instantiated with the specified depth.
|
||||
///
|
||||
/// All leaves in the returned tree are set to [ZERO; 4].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the depth is 0 or is greater than 64.
|
||||
pub fn new(depth: u8) -> Result<Self, MerkleError> {
|
||||
// validate the range of the depth.
|
||||
if depth < Self::MIN_DEPTH {
|
||||
return Err(MerkleError::DepthTooSmall(depth));
|
||||
} else if Self::MAX_DEPTH < depth {
|
||||
return Err(MerkleError::DepthTooBig(depth as u64));
|
||||
}
|
||||
|
||||
let root = *EmptySubtreeRoots::entry(depth, 0);
|
||||
|
||||
Ok(Self {
|
||||
root,
|
||||
depth,
|
||||
leaves: BTreeMap::new(),
|
||||
branches: BTreeMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a new [SimpleSmt] instantiated with the specified depth and with leaves
|
||||
/// set as specified by the provided entries.
|
||||
///
|
||||
/// All leaves omitted from the entries list are set to [ZERO; 4].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - If the depth is 0 or is greater than 64.
|
||||
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
|
||||
/// - The provided entries contain multiple values for the same key.
|
||||
pub fn with_leaves(
|
||||
depth: u8,
|
||||
entries: impl IntoIterator<Item = (u64, Word)>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
// create an empty tree
|
||||
let mut tree = Self::new(depth)?;
|
||||
|
||||
// compute the max number of entries. We use an upper bound of depth 63 because we consider
|
||||
// passing in a vector of size 2^64 infeasible.
|
||||
let max_num_entries = 2_usize.pow(tree.depth.min(63).into());
|
||||
|
||||
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
|
||||
// entries with the empty value need additional tracking.
|
||||
let mut key_set_to_zero = BTreeSet::new();
|
||||
|
||||
for (idx, (key, value)) in entries.into_iter().enumerate() {
|
||||
if idx >= max_num_entries {
|
||||
return Err(MerkleError::InvalidNumEntries(max_num_entries));
|
||||
}
|
||||
|
||||
let old_value = tree.update_leaf(key, value)?;
|
||||
|
||||
if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) {
|
||||
return Err(MerkleError::DuplicateValuesForIndex(key));
|
||||
}
|
||||
|
||||
if value == Self::EMPTY_VALUE {
|
||||
key_set_to_zero.insert(key);
|
||||
};
|
||||
}
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
|
||||
/// starting at index 0.
|
||||
pub fn with_contiguous_leaves(
|
||||
depth: u8,
|
||||
entries: impl IntoIterator<Item = Word>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
Self::with_leaves(
|
||||
depth,
|
||||
entries
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, word)| (idx.try_into().expect("tree max depth is 2^8"), word)),
|
||||
)
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the root of this Merkle tree.
|
||||
pub const fn root(&self) -> RpoDigest {
|
||||
self.root
|
||||
}
|
||||
|
||||
/// Returns the depth of this Merkle tree.
|
||||
pub const fn depth(&self) -> u8 {
|
||||
self.depth
|
||||
}
|
||||
|
||||
/// Returns a node at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified index has depth set to 0 or the depth is greater than
|
||||
/// the depth of this Merkle tree.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
if index.is_root() {
|
||||
Err(MerkleError::DepthTooSmall(index.depth()))
|
||||
} else if index.depth() > self.depth() {
|
||||
Err(MerkleError::DepthTooBig(index.depth() as u64))
|
||||
} else if index.depth() == self.depth() {
|
||||
// the lookup in empty_hashes could fail only if empty_hashes were not built correctly
|
||||
// by the constructor as we check the depth of the lookup above.
|
||||
let leaf_pos = index.value();
|
||||
let leaf = match self.get_leaf_node(leaf_pos) {
|
||||
Some(word) => word.into(),
|
||||
None => *EmptySubtreeRoots::entry(self.depth, index.depth()),
|
||||
};
|
||||
Ok(leaf)
|
||||
} else {
|
||||
Ok(self.get_branch_node(&index).parent())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a value of the leaf at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
|
||||
pub fn get_leaf(&self, index: u64) -> Result<Word, MerkleError> {
|
||||
let index = NodeIndex::new(self.depth, index)?;
|
||||
Ok(self.get_node(index)?.into())
|
||||
}
|
||||
|
||||
/// Returns a Merkle path from the node at the specified index to the root.
|
||||
///
|
||||
/// The node itself is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified index has depth set to 0 or the depth is greater than
|
||||
/// the depth of this Merkle tree.
|
||||
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
} else if index.depth() > self.depth() {
|
||||
return Err(MerkleError::DepthTooBig(index.depth() as u64));
|
||||
}
|
||||
|
||||
let mut path = Vec::with_capacity(index.depth() as usize);
|
||||
for _ in 0..index.depth() {
|
||||
let is_right = index.is_value_odd();
|
||||
index.move_up();
|
||||
let BranchNode { left, right } = self.get_branch_node(&index);
|
||||
let value = if is_right { left } else { right };
|
||||
path.push(value);
|
||||
}
|
||||
Ok(MerklePath::new(path))
|
||||
}
|
||||
|
||||
/// Return a Merkle path from the leaf at the specified index to the root.
|
||||
///
|
||||
/// The leaf itself is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
|
||||
pub fn get_leaf_path(&self, index: u64) -> Result<MerklePath, MerkleError> {
|
||||
let index = NodeIndex::new(self.depth(), index)?;
|
||||
self.get_path(index)
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over the leaves of this [SimpleSmt].
|
||||
pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
|
||||
self.leaves.iter().map(|(i, w)| (*i, w))
|
||||
}
|
||||
|
||||
/// Returns an iterator over the inner nodes of this Merkle tree.
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.branches.values().map(|e| InnerNodeInfo {
|
||||
value: e.parent(),
|
||||
left: e.left,
|
||||
right: e.right,
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Updates value of the leaf at the specified index returning the old leaf value.
|
||||
///
|
||||
/// This also recomputes all hashes between the leaf and the root, updating the root itself.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
|
||||
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<Word, MerkleError> {
|
||||
// validate the index before modifying the structure
|
||||
let idx = NodeIndex::new(self.depth(), index)?;
|
||||
|
||||
let old_value = self.insert_leaf_node(index, value).unwrap_or(Self::EMPTY_VALUE);
|
||||
|
||||
// if the old value and new value are the same, there is nothing to update
|
||||
if value == old_value {
|
||||
return Ok(value);
|
||||
}
|
||||
|
||||
self.recompute_nodes_from_index_to_root(idx, RpoDigest::from(value));
|
||||
|
||||
Ok(old_value)
|
||||
}
|
||||
|
||||
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
|
||||
/// computed as `self.depth() - subtree.depth()`.
|
||||
///
|
||||
/// Returns the new root.
|
||||
pub fn set_subtree(
|
||||
&mut self,
|
||||
subtree_insertion_index: u64,
|
||||
subtree: SimpleSmt,
|
||||
) -> Result<RpoDigest, MerkleError> {
|
||||
if subtree.depth() > self.depth() {
|
||||
return Err(MerkleError::InvalidSubtreeDepth {
|
||||
subtree_depth: subtree.depth(),
|
||||
tree_depth: self.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
// Verify that `subtree_insertion_index` is valid.
|
||||
let subtree_root_insertion_depth = self.depth() - subtree.depth();
|
||||
let subtree_root_index =
|
||||
NodeIndex::new(subtree_root_insertion_depth, subtree_insertion_index)?;
|
||||
|
||||
// add leaves
|
||||
// --------------
|
||||
|
||||
// The subtree's leaf indices live in their own context - i.e. a subtree of depth `d`. If we
|
||||
// insert the subtree at `subtree_insertion_index = 0`, then the subtree leaf indices are
|
||||
// valid as they are. However, consider what happens when we insert at
|
||||
// `subtree_insertion_index = 1`. The first leaf of our subtree now will have index `2^d`;
|
||||
// you can see it as there's a full subtree sitting on its left. In general, for
|
||||
// `subtree_insertion_index = i`, there are `i` subtrees sitting before the subtree we want
|
||||
// to insert, so we need to adjust all its leaves by `i * 2^d`.
|
||||
let leaf_index_shift: u64 = subtree_insertion_index * 2_u64.pow(subtree.depth().into());
|
||||
for (subtree_leaf_idx, leaf_value) in subtree.leaves() {
|
||||
let new_leaf_idx = leaf_index_shift + subtree_leaf_idx;
|
||||
debug_assert!(new_leaf_idx < 2_u64.pow(self.depth().into()));
|
||||
|
||||
self.insert_leaf_node(new_leaf_idx, *leaf_value);
|
||||
}
|
||||
|
||||
// add subtree's branch nodes (which includes the root)
|
||||
// --------------
|
||||
for (branch_idx, branch_node) in subtree.branches {
|
||||
let new_branch_idx = {
|
||||
let new_depth = subtree_root_insertion_depth + branch_idx.depth();
|
||||
let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
|
||||
+ branch_idx.value();
|
||||
|
||||
NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
|
||||
};
|
||||
|
||||
self.branches.insert(new_branch_idx, branch_node);
|
||||
}
|
||||
|
||||
// recompute nodes starting from subtree root
|
||||
// --------------
|
||||
self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
|
||||
|
||||
Ok(self.root)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Recomputes the branch nodes (including the root) from `index` all the way to the root.
|
||||
/// `node_hash_at_index` is the hash of the node stored at index.
|
||||
fn recompute_nodes_from_index_to_root(
|
||||
&mut self,
|
||||
mut index: NodeIndex,
|
||||
node_hash_at_index: RpoDigest,
|
||||
) {
|
||||
let mut value = node_hash_at_index;
|
||||
for _ in 0..index.depth() {
|
||||
let is_right = index.is_value_odd();
|
||||
index.move_up();
|
||||
let BranchNode { left, right } = self.get_branch_node(&index);
|
||||
let (left, right) = if is_right { (left, value) } else { (value, right) };
|
||||
self.insert_branch_node(index, left, right);
|
||||
value = Rpo256::merge(&[left, right]);
|
||||
}
|
||||
self.root = value;
|
||||
}
|
||||
|
||||
fn get_leaf_node(&self, key: u64) -> Option<Word> {
|
||||
self.leaves.get(&key).copied()
|
||||
}
|
||||
|
||||
fn insert_leaf_node(&mut self, key: u64, node: Word) -> Option<Word> {
|
||||
self.leaves.insert(key, node)
|
||||
}
|
||||
|
||||
fn get_branch_node(&self, index: &NodeIndex) -> BranchNode {
|
||||
self.branches.get(index).cloned().unwrap_or_else(|| {
|
||||
let node = EmptySubtreeRoots::entry(self.depth, index.depth() + 1);
|
||||
BranchNode { left: *node, right: *node }
|
||||
})
|
||||
}
|
||||
|
||||
fn insert_branch_node(&mut self, index: NodeIndex, left: RpoDigest, right: RpoDigest) {
|
||||
let branch = BranchNode { left, right };
|
||||
self.branches.insert(index, branch);
|
||||
}
|
||||
}
|
||||
|
||||
// BRANCH NODE
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
struct BranchNode {
|
||||
left: RpoDigest,
|
||||
right: RpoDigest,
|
||||
}
|
||||
|
||||
impl BranchNode {
|
||||
fn parent(&self) -> RpoDigest {
|
||||
Rpo256::merge(&[self.left, self.right])
|
||||
}
|
||||
}
|
||||
|
||||
// TRY APPLY DIFF
|
||||
// ================================================================================================
|
||||
impl TryApplyDiff<RpoDigest, StoreNode> for SimpleSmt {
|
||||
type Error = MerkleError;
|
||||
type DiffType = MerkleTreeDelta;
|
||||
|
||||
fn try_apply(&mut self, diff: MerkleTreeDelta) -> Result<(), MerkleError> {
|
||||
if diff.depth() != self.depth() {
|
||||
return Err(MerkleError::InvalidDepth {
|
||||
expected: self.depth(),
|
||||
provided: diff.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
for slot in diff.cleared_slots() {
|
||||
self.update_leaf(*slot, Self::EMPTY_VALUE)?;
|
||||
}
|
||||
|
||||
for (slot, value) in diff.updated_slots() {
|
||||
self.update_leaf(*slot, *value)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,10 @@
|
||||
use super::{
|
||||
super::{MerkleError, RpoDigest, SimpleSmt},
|
||||
NodeIndex,
|
||||
super::{InnerNodeInfo, MerkleError, MerkleTree, RpoDigest, SimpleSmt, EMPTY_WORD},
|
||||
NodeIndex, Rpo256, Vec,
|
||||
};
|
||||
use crate::{
|
||||
hash::rpo::Rpo256,
|
||||
merkle::{
|
||||
digests_to_words, int_to_leaf, int_to_node, smt::SparseMerkleTree, EmptySubtreeRoots,
|
||||
InnerNodeInfo, LeafIndex, MerkleTree,
|
||||
},
|
||||
utils::collections::Vec,
|
||||
Word, EMPTY_WORD,
|
||||
merkle::{digests_to_words, int_to_leaf, int_to_node, EmptySubtreeRoots},
|
||||
Word,
|
||||
};
|
||||
|
||||
// TEST DATA
|
||||
@@ -39,27 +34,26 @@ const ZERO_VALUES8: [Word; 8] = [int_to_leaf(0); 8];
|
||||
#[test]
|
||||
fn build_empty_tree() {
|
||||
// tree of depth 3
|
||||
let smt = SimpleSmt::<3>::new().unwrap();
|
||||
let mt = MerkleTree::new(ZERO_VALUES8).unwrap();
|
||||
let smt = SimpleSmt::new(3).unwrap();
|
||||
let mt = MerkleTree::new(ZERO_VALUES8.to_vec()).unwrap();
|
||||
assert_eq!(mt.root(), smt.root());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_sparse_tree() {
|
||||
const DEPTH: u8 = 3;
|
||||
let mut smt = SimpleSmt::<DEPTH>::new().unwrap();
|
||||
let mut smt = SimpleSmt::new(3).unwrap();
|
||||
let mut values = ZERO_VALUES8.to_vec();
|
||||
|
||||
// insert single value
|
||||
let key = 6;
|
||||
let new_node = int_to_leaf(7);
|
||||
values[key as usize] = new_node;
|
||||
let old_value = smt.insert(LeafIndex::<DEPTH>::new(key).unwrap(), new_node);
|
||||
let old_value = smt.update_leaf(key, new_node).expect("Failed to update leaf");
|
||||
let mt2 = MerkleTree::new(values.clone()).unwrap();
|
||||
assert_eq!(mt2.root(), smt.root());
|
||||
assert_eq!(
|
||||
mt2.get_path(NodeIndex::make(3, 6)).unwrap(),
|
||||
smt.open(&LeafIndex::<3>::new(6).unwrap()).path
|
||||
smt.get_path(NodeIndex::make(3, 6)).unwrap()
|
||||
);
|
||||
assert_eq!(old_value, EMPTY_WORD);
|
||||
|
||||
@@ -67,12 +61,12 @@ fn build_sparse_tree() {
|
||||
let key = 2;
|
||||
let new_node = int_to_leaf(3);
|
||||
values[key as usize] = new_node;
|
||||
let old_value = smt.insert(LeafIndex::<DEPTH>::new(key).unwrap(), new_node);
|
||||
let old_value = smt.update_leaf(key, new_node).expect("Failed to update leaf");
|
||||
let mt3 = MerkleTree::new(values).unwrap();
|
||||
assert_eq!(mt3.root(), smt.root());
|
||||
assert_eq!(
|
||||
mt3.get_path(NodeIndex::make(3, 2)).unwrap(),
|
||||
smt.open(&LeafIndex::<3>::new(2).unwrap()).path
|
||||
smt.get_path(NodeIndex::make(3, 2)).unwrap()
|
||||
);
|
||||
assert_eq!(old_value, EMPTY_WORD);
|
||||
}
|
||||
@@ -80,12 +74,14 @@ fn build_sparse_tree() {
|
||||
/// Tests that [`SimpleSmt::with_contiguous_leaves`] works as expected
|
||||
#[test]
|
||||
fn build_contiguous_tree() {
|
||||
let tree_with_leaves =
|
||||
SimpleSmt::<2>::with_leaves([0, 1, 2, 3].into_iter().zip(digests_to_words(&VALUES4)))
|
||||
.unwrap();
|
||||
let tree_with_leaves = SimpleSmt::with_leaves(
|
||||
2,
|
||||
[0, 1, 2, 3].into_iter().zip(digests_to_words(&VALUES4).into_iter()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let tree_with_contiguous_leaves =
|
||||
SimpleSmt::<2>::with_contiguous_leaves(digests_to_words(&VALUES4)).unwrap();
|
||||
SimpleSmt::with_contiguous_leaves(2, digests_to_words(&VALUES4).into_iter()).unwrap();
|
||||
|
||||
assert_eq!(tree_with_leaves, tree_with_contiguous_leaves);
|
||||
}
|
||||
@@ -93,7 +89,8 @@ fn build_contiguous_tree() {
|
||||
#[test]
|
||||
fn test_depth2_tree() {
|
||||
let tree =
|
||||
SimpleSmt::<2>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4))).unwrap();
|
||||
SimpleSmt::with_leaves(2, KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()))
|
||||
.unwrap();
|
||||
|
||||
// check internal structure
|
||||
let (root, node2, node3) = compute_internal_nodes();
|
||||
@@ -108,16 +105,21 @@ fn test_depth2_tree() {
|
||||
assert_eq!(VALUES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
|
||||
|
||||
// check get_path(): depth 2
|
||||
assert_eq!(vec![VALUES4[1], node3], *tree.open(&LeafIndex::<2>::new(0).unwrap()).path);
|
||||
assert_eq!(vec![VALUES4[0], node3], *tree.open(&LeafIndex::<2>::new(1).unwrap()).path);
|
||||
assert_eq!(vec![VALUES4[3], node2], *tree.open(&LeafIndex::<2>::new(2).unwrap()).path);
|
||||
assert_eq!(vec![VALUES4[2], node2], *tree.open(&LeafIndex::<2>::new(3).unwrap()).path);
|
||||
assert_eq!(vec![VALUES4[1], node3], *tree.get_path(NodeIndex::make(2, 0)).unwrap());
|
||||
assert_eq!(vec![VALUES4[0], node3], *tree.get_path(NodeIndex::make(2, 1)).unwrap());
|
||||
assert_eq!(vec![VALUES4[3], node2], *tree.get_path(NodeIndex::make(2, 2)).unwrap());
|
||||
assert_eq!(vec![VALUES4[2], node2], *tree.get_path(NodeIndex::make(2, 3)).unwrap());
|
||||
|
||||
// check get_path(): depth 1
|
||||
assert_eq!(vec![node3], *tree.get_path(NodeIndex::make(1, 0)).unwrap());
|
||||
assert_eq!(vec![node2], *tree.get_path(NodeIndex::make(1, 1)).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inner_node_iterator() -> Result<(), MerkleError> {
|
||||
let tree =
|
||||
SimpleSmt::<2>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4))).unwrap();
|
||||
SimpleSmt::with_leaves(2, KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()))
|
||||
.unwrap();
|
||||
|
||||
// check depth 2
|
||||
assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
|
||||
@@ -147,9 +149,9 @@ fn test_inner_node_iterator() -> Result<(), MerkleError> {
|
||||
|
||||
#[test]
|
||||
fn update_leaf() {
|
||||
const DEPTH: u8 = 3;
|
||||
let mut tree =
|
||||
SimpleSmt::<DEPTH>::with_leaves(KEYS8.into_iter().zip(digests_to_words(&VALUES8))).unwrap();
|
||||
SimpleSmt::with_leaves(3, KEYS8.into_iter().zip(digests_to_words(&VALUES8).into_iter()))
|
||||
.unwrap();
|
||||
|
||||
// update one value
|
||||
let key = 3;
|
||||
@@ -158,7 +160,7 @@ fn update_leaf() {
|
||||
expected_values[key] = new_node;
|
||||
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
|
||||
|
||||
let old_leaf = tree.insert(LeafIndex::<DEPTH>::new(key as u64).unwrap(), new_node);
|
||||
let old_leaf = tree.update_leaf(key as u64, new_node).unwrap();
|
||||
assert_eq!(expected_tree.root(), tree.root);
|
||||
assert_eq!(old_leaf, *VALUES8[key]);
|
||||
|
||||
@@ -168,7 +170,7 @@ fn update_leaf() {
|
||||
expected_values[key] = new_node;
|
||||
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
|
||||
|
||||
let old_leaf = tree.insert(LeafIndex::<DEPTH>::new(key as u64).unwrap(), new_node);
|
||||
let old_leaf = tree.update_leaf(key as u64, new_node).unwrap();
|
||||
assert_eq!(expected_tree.root(), tree.root);
|
||||
assert_eq!(old_leaf, *VALUES8[key]);
|
||||
}
|
||||
@@ -200,22 +202,29 @@ fn small_tree_opening_is_consistent() {
|
||||
|
||||
let k = Rpo256::merge(&[i, j]);
|
||||
|
||||
let depth = 3;
|
||||
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
|
||||
let tree = SimpleSmt::<3>::with_leaves(entries).unwrap();
|
||||
let tree = SimpleSmt::with_leaves(depth, entries).unwrap();
|
||||
|
||||
assert_eq!(tree.root(), k);
|
||||
|
||||
let cases: Vec<(u64, Vec<RpoDigest>)> = vec![
|
||||
(0, vec![b.into(), f, j]),
|
||||
(1, vec![a.into(), f, j]),
|
||||
(4, vec![z.into(), h, i]),
|
||||
(7, vec![z.into(), g, i]),
|
||||
let cases: Vec<(u8, u64, Vec<RpoDigest>)> = vec![
|
||||
(3, 0, vec![b.into(), f, j]),
|
||||
(3, 1, vec![a.into(), f, j]),
|
||||
(3, 4, vec![z.into(), h, i]),
|
||||
(3, 7, vec![z.into(), g, i]),
|
||||
(2, 0, vec![f, j]),
|
||||
(2, 1, vec![e, j]),
|
||||
(2, 2, vec![h, i]),
|
||||
(2, 3, vec![g, i]),
|
||||
(1, 0, vec![j]),
|
||||
(1, 1, vec![i]),
|
||||
];
|
||||
|
||||
for (key, path) in cases {
|
||||
let opening = tree.open(&LeafIndex::<3>::new(key).unwrap());
|
||||
for (depth, key, path) in cases {
|
||||
let opening = tree.get_path(NodeIndex::make(depth, key)).unwrap();
|
||||
|
||||
assert_eq!(path, *opening.path);
|
||||
assert_eq!(path, *opening);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,12 +246,12 @@ fn test_simplesmt_fail_on_duplicates() {
|
||||
for (first, second) in values.iter() {
|
||||
// consecutive
|
||||
let entries = [(1, *first), (1, *second)];
|
||||
let smt = SimpleSmt::<64>::with_leaves(entries);
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||
|
||||
// not consecutive
|
||||
let entries = [(1, *first), (5, int_to_leaf(5)), (1, *second)];
|
||||
let smt = SimpleSmt::<64>::with_leaves(entries);
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||
}
|
||||
}
|
||||
@@ -250,10 +259,56 @@ fn test_simplesmt_fail_on_duplicates() {
|
||||
#[test]
|
||||
fn with_no_duplicates_empty_node() {
|
||||
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2))];
|
||||
let smt = SimpleSmt::<64>::with_leaves(entries);
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert!(smt.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplesmt_update_nonexisting_leaf_with_zero() {
|
||||
// TESTING WITH EMPTY WORD
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
// Depth 1 has 2 leaf. Position is 0-indexed, position 2 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(1).unwrap();
|
||||
let result = smt.update_leaf(2, EMPTY_WORD);
|
||||
assert!(!smt.leaves.contains_key(&2));
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(2).unwrap();
|
||||
let result = smt.update_leaf(4, EMPTY_WORD);
|
||||
assert!(!smt.leaves.contains_key(&4));
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(3).unwrap();
|
||||
let result = smt.update_leaf(8, EMPTY_WORD);
|
||||
assert!(!smt.leaves.contains_key(&8));
|
||||
assert!(result.is_err());
|
||||
|
||||
// TESTING WITH A VALUE
|
||||
// --------------------------------------------------------------------------------------------
|
||||
let value = int_to_node(1);
|
||||
|
||||
// Depth 1 has 2 leaves. Position is 0-indexed, position 1 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(1).unwrap();
|
||||
let result = smt.update_leaf(2, *value);
|
||||
assert!(!smt.leaves.contains_key(&2));
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 2 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(2).unwrap();
|
||||
let result = smt.update_leaf(4, *value);
|
||||
assert!(!smt.leaves.contains_key(&4));
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let mut smt = SimpleSmt::new(3).unwrap();
|
||||
let result = smt.update_leaf(8, *value);
|
||||
assert!(!smt.leaves.contains_key(&8));
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplesmt_with_leaves_nonexisting_leaf() {
|
||||
// TESTING WITH EMPTY WORD
|
||||
@@ -261,17 +316,17 @@ fn test_simplesmt_with_leaves_nonexisting_leaf() {
|
||||
|
||||
// Depth 1 has 2 leaf. Position is 0-indexed, position 2 doesn't exist.
|
||||
let leaves = [(2, EMPTY_WORD)];
|
||||
let result = SimpleSmt::<1>::with_leaves(leaves);
|
||||
let result = SimpleSmt::with_leaves(1, leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let leaves = [(4, EMPTY_WORD)];
|
||||
let result = SimpleSmt::<2>::with_leaves(leaves);
|
||||
let result = SimpleSmt::with_leaves(2, leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||
let leaves = [(8, EMPTY_WORD)];
|
||||
let result = SimpleSmt::<3>::with_leaves(leaves);
|
||||
let result = SimpleSmt::with_leaves(3, leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// TESTING WITH A VALUE
|
||||
@@ -280,17 +335,17 @@ fn test_simplesmt_with_leaves_nonexisting_leaf() {
|
||||
|
||||
// Depth 1 has 2 leaves. Position is 0-indexed, position 2 doesn't exist.
|
||||
let leaves = [(2, *value)];
|
||||
let result = SimpleSmt::<1>::with_leaves(leaves);
|
||||
let result = SimpleSmt::with_leaves(1, leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let leaves = [(4, *value)];
|
||||
let result = SimpleSmt::<2>::with_leaves(leaves);
|
||||
let result = SimpleSmt::with_leaves(2, leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||
let leaves = [(8, *value)];
|
||||
let result = SimpleSmt::<3>::with_leaves(leaves);
|
||||
let result = SimpleSmt::with_leaves(3, leaves);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@@ -328,15 +383,16 @@ fn test_simplesmt_set_subtree() {
|
||||
// / \
|
||||
// c 0
|
||||
let subtree = {
|
||||
let depth = 1;
|
||||
let entries = vec![(0, c)];
|
||||
SimpleSmt::<1>::with_leaves(entries).unwrap()
|
||||
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||
};
|
||||
|
||||
// insert subtree
|
||||
const TREE_DEPTH: u8 = 3;
|
||||
let tree = {
|
||||
let depth = 3;
|
||||
let entries = vec![(0, a), (1, b), (7, d)];
|
||||
let mut tree = SimpleSmt::<TREE_DEPTH>::with_leaves(entries).unwrap();
|
||||
let mut tree = SimpleSmt::with_leaves(depth, entries).unwrap();
|
||||
|
||||
tree.set_subtree(2, subtree).unwrap();
|
||||
|
||||
@@ -344,8 +400,8 @@ fn test_simplesmt_set_subtree() {
|
||||
};
|
||||
|
||||
assert_eq!(tree.root(), k);
|
||||
assert_eq!(tree.get_leaf(&LeafIndex::<TREE_DEPTH>::new(4).unwrap()), c);
|
||||
assert_eq!(tree.get_inner_node(NodeIndex::new_unchecked(2, 2)).hash(), g);
|
||||
assert_eq!(tree.get_leaf(4).unwrap(), c);
|
||||
assert_eq!(tree.get_branch_node(&NodeIndex::new_unchecked(2, 2)).parent(), g);
|
||||
}
|
||||
|
||||
/// Ensures that an invalid input node index into `set_subtree()` incurs no mutation of the tree
|
||||
@@ -373,13 +429,15 @@ fn test_simplesmt_set_subtree_unchanged_for_wrong_index() {
|
||||
// / \
|
||||
// c 0
|
||||
let subtree = {
|
||||
let depth = 1;
|
||||
let entries = vec![(0, c)];
|
||||
SimpleSmt::<1>::with_leaves(entries).unwrap()
|
||||
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||
};
|
||||
|
||||
let mut tree = {
|
||||
let depth = 3;
|
||||
let entries = vec![(0, a), (1, b), (7, d)];
|
||||
SimpleSmt::<3>::with_leaves(entries).unwrap()
|
||||
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||
};
|
||||
let tree_root_before_insertion = tree.root();
|
||||
|
||||
@@ -409,20 +467,21 @@ fn test_simplesmt_set_subtree_entire_tree() {
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
let depth = 3;
|
||||
|
||||
// subtree: E3
|
||||
const DEPTH: u8 = 3;
|
||||
let subtree = { SimpleSmt::<DEPTH>::with_leaves(Vec::new()).unwrap() };
|
||||
assert_eq!(subtree.root(), *EmptySubtreeRoots::entry(DEPTH, 0));
|
||||
let subtree = { SimpleSmt::with_leaves(depth, Vec::new()).unwrap() };
|
||||
assert_eq!(subtree.root(), *EmptySubtreeRoots::entry(depth, 0));
|
||||
|
||||
// insert subtree
|
||||
let mut tree = {
|
||||
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
|
||||
SimpleSmt::<3>::with_leaves(entries).unwrap()
|
||||
SimpleSmt::with_leaves(depth, entries).unwrap()
|
||||
};
|
||||
|
||||
tree.set_subtree(0, subtree).unwrap();
|
||||
|
||||
assert_eq!(tree.root(), *EmptySubtreeRoots::entry(DEPTH, 0));
|
||||
assert_eq!(tree.root(), *EmptySubtreeRoots::entry(depth, 0));
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
@@ -1,86 +0,0 @@
|
||||
use core::fmt;
|
||||
|
||||
use crate::{
|
||||
hash::rpo::RpoDigest,
|
||||
merkle::{LeafIndex, SMT_DEPTH},
|
||||
utils::collections::Vec,
|
||||
Word,
|
||||
};
|
||||
|
||||
// SMT LEAF ERROR
|
||||
// =================================================================================================
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum SmtLeafError {
|
||||
InconsistentKeys {
|
||||
entries: Vec<(RpoDigest, Word)>,
|
||||
key_1: RpoDigest,
|
||||
key_2: RpoDigest,
|
||||
},
|
||||
InvalidNumEntriesForMultiple(usize),
|
||||
SingleKeyInconsistentWithLeafIndex {
|
||||
key: RpoDigest,
|
||||
leaf_index: LeafIndex<SMT_DEPTH>,
|
||||
},
|
||||
MultipleKeysInconsistentWithLeafIndex {
|
||||
leaf_index_from_keys: LeafIndex<SMT_DEPTH>,
|
||||
leaf_index_supplied: LeafIndex<SMT_DEPTH>,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for SmtLeafError {}
|
||||
|
||||
impl fmt::Display for SmtLeafError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use SmtLeafError::*;
|
||||
match self {
|
||||
InvalidNumEntriesForMultiple(num_entries) => {
|
||||
write!(f, "Multiple leaf requires 2 or more entries. Got: {num_entries}")
|
||||
}
|
||||
InconsistentKeys { entries, key_1, key_2 } => {
|
||||
write!(f, "Multiple leaf requires all keys to map to the same leaf index. Offending keys: {key_1} and {key_2}. Entries: {entries:?}.")
|
||||
}
|
||||
SingleKeyInconsistentWithLeafIndex { key, leaf_index } => {
|
||||
write!(
|
||||
f,
|
||||
"Single key in leaf inconsistent with leaf index. Key: {key}, leaf index: {}",
|
||||
leaf_index.value()
|
||||
)
|
||||
}
|
||||
MultipleKeysInconsistentWithLeafIndex {
|
||||
leaf_index_from_keys,
|
||||
leaf_index_supplied,
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"Keys in entries map to leaf index {}, but leaf index {} was supplied",
|
||||
leaf_index_from_keys.value(),
|
||||
leaf_index_supplied.value()
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SMT PROOF ERROR
|
||||
// =================================================================================================
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum SmtProofError {
|
||||
InvalidPathLength(usize),
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for SmtProofError {}
|
||||
|
||||
impl fmt::Display for SmtProofError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use SmtProofError::*;
|
||||
match self {
|
||||
InvalidPathLength(path_length) => {
|
||||
write!(f, "Invalid Merkle path length. Expected {SMT_DEPTH}, got {path_length}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,372 +0,0 @@
|
||||
use core::cmp::Ordering;
|
||||
|
||||
use crate::utils::{collections::Vec, string::ToString, vec};
|
||||
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
|
||||
use super::{Felt, LeafIndex, Rpo256, RpoDigest, SmtLeafError, Word, EMPTY_WORD, SMT_DEPTH};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub enum SmtLeaf {
|
||||
Empty(LeafIndex<SMT_DEPTH>),
|
||||
Single((RpoDigest, Word)),
|
||||
Multiple(Vec<(RpoDigest, Word)>),
|
||||
}
|
||||
|
||||
impl SmtLeaf {
|
||||
// CONSTRUCTORS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new leaf with the specified entries
|
||||
///
|
||||
/// # Errors
|
||||
/// - Returns an error if 2 keys in `entries` map to a different leaf index
|
||||
/// - Returns an error if 1 or more keys in `entries` map to a leaf index
|
||||
/// different from `leaf_index`
|
||||
pub fn new(
|
||||
entries: Vec<(RpoDigest, Word)>,
|
||||
leaf_index: LeafIndex<SMT_DEPTH>,
|
||||
) -> Result<Self, SmtLeafError> {
|
||||
match entries.len() {
|
||||
0 => Ok(Self::new_empty(leaf_index)),
|
||||
1 => {
|
||||
let (key, value) = entries[0];
|
||||
|
||||
if LeafIndex::<SMT_DEPTH>::from(key) != leaf_index {
|
||||
return Err(SmtLeafError::SingleKeyInconsistentWithLeafIndex {
|
||||
key,
|
||||
leaf_index,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Self::new_single(key, value))
|
||||
}
|
||||
_ => {
|
||||
let leaf = Self::new_multiple(entries)?;
|
||||
|
||||
// `new_multiple()` checked that all keys map to the same leaf index. We still need
|
||||
// to ensure that that leaf index is `leaf_index`.
|
||||
if leaf.index() != leaf_index {
|
||||
Err(SmtLeafError::MultipleKeysInconsistentWithLeafIndex {
|
||||
leaf_index_from_keys: leaf.index(),
|
||||
leaf_index_supplied: leaf_index,
|
||||
})
|
||||
} else {
|
||||
Ok(leaf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new empty leaf with the specified leaf index
|
||||
pub fn new_empty(leaf_index: LeafIndex<SMT_DEPTH>) -> Self {
|
||||
Self::Empty(leaf_index)
|
||||
}
|
||||
|
||||
/// Returns a new single leaf with the specified entry. The leaf index is derived from the
|
||||
/// entry's key.
|
||||
pub fn new_single(key: RpoDigest, value: Word) -> Self {
|
||||
Self::Single((key, value))
|
||||
}
|
||||
|
||||
/// Returns a new single leaf with the specified entry. The leaf index is derived from the
|
||||
/// entries' keys.
|
||||
///
|
||||
/// # Errors
|
||||
/// - Returns an error if 2 keys in `entries` map to a different leaf index
|
||||
pub fn new_multiple(entries: Vec<(RpoDigest, Word)>) -> Result<Self, SmtLeafError> {
|
||||
if entries.len() < 2 {
|
||||
return Err(SmtLeafError::InvalidNumEntriesForMultiple(entries.len()));
|
||||
}
|
||||
|
||||
// Check that all keys map to the same leaf index
|
||||
{
|
||||
let mut keys = entries.iter().map(|(key, _)| key);
|
||||
|
||||
let first_key = *keys.next().expect("ensured at least 2 entries");
|
||||
let first_leaf_index: LeafIndex<SMT_DEPTH> = first_key.into();
|
||||
|
||||
for &next_key in keys {
|
||||
let next_leaf_index: LeafIndex<SMT_DEPTH> = next_key.into();
|
||||
|
||||
if next_leaf_index != first_leaf_index {
|
||||
return Err(SmtLeafError::InconsistentKeys {
|
||||
entries,
|
||||
key_1: first_key,
|
||||
key_2: next_key,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self::Multiple(entries))
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if the leaf is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
matches!(self, Self::Empty(_))
|
||||
}
|
||||
|
||||
/// Returns the leaf's index in the [`super::Smt`]
|
||||
pub fn index(&self) -> LeafIndex<SMT_DEPTH> {
|
||||
match self {
|
||||
SmtLeaf::Empty(leaf_index) => *leaf_index,
|
||||
SmtLeaf::Single((key, _)) => key.into(),
|
||||
SmtLeaf::Multiple(entries) => {
|
||||
// Note: All keys are guaranteed to have the same leaf index
|
||||
let (first_key, _) = entries[0];
|
||||
first_key.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of entries stored in the leaf
|
||||
pub fn num_entries(&self) -> u64 {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => 0,
|
||||
SmtLeaf::Single(_) => 1,
|
||||
SmtLeaf::Multiple(entries) => {
|
||||
entries.len().try_into().expect("shouldn't have more than 2^64 entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the hash of the leaf
|
||||
pub fn hash(&self) -> RpoDigest {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => EMPTY_WORD.into(),
|
||||
SmtLeaf::Single((key, value)) => Rpo256::merge(&[*key, value.into()]),
|
||||
SmtLeaf::Multiple(kvs) => {
|
||||
let elements: Vec<Felt> = kvs.iter().copied().flat_map(kv_to_elements).collect();
|
||||
Rpo256::hash_elements(&elements)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the key-value pairs in the leaf
|
||||
pub fn entries(&self) -> Vec<&(RpoDigest, Word)> {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => Vec::new(),
|
||||
SmtLeaf::Single(kv_pair) => vec![kv_pair],
|
||||
SmtLeaf::Multiple(kv_pairs) => kv_pairs.iter().collect(),
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Converts a leaf to a list of field elements
|
||||
pub fn to_elements(&self) -> Vec<Felt> {
|
||||
self.clone().into_elements()
|
||||
}
|
||||
|
||||
/// Converts a leaf to a list of field elements
|
||||
pub fn into_elements(self) -> Vec<Felt> {
|
||||
self.into_entries().into_iter().flat_map(kv_to_elements).collect()
|
||||
}
|
||||
|
||||
/// Converts a leaf the key-value pairs in the leaf
|
||||
pub fn into_entries(self) -> Vec<(RpoDigest, Word)> {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => Vec::new(),
|
||||
SmtLeaf::Single(kv_pair) => vec![kv_pair],
|
||||
SmtLeaf::Multiple(kv_pairs) => kv_pairs,
|
||||
}
|
||||
}
|
||||
|
||||
// HELPERS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the value associated with `key` in the leaf, or `None` if `key` maps to another leaf.
|
||||
pub(super) fn get_value(&self, key: &RpoDigest) -> Option<Word> {
|
||||
// Ensure that `key` maps to this leaf
|
||||
if self.index() != key.into() {
|
||||
return None;
|
||||
}
|
||||
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => Some(EMPTY_WORD),
|
||||
SmtLeaf::Single((key_in_leaf, value_in_leaf)) => {
|
||||
if key == key_in_leaf {
|
||||
Some(*value_in_leaf)
|
||||
} else {
|
||||
Some(EMPTY_WORD)
|
||||
}
|
||||
}
|
||||
SmtLeaf::Multiple(kv_pairs) => {
|
||||
for (key_in_leaf, value_in_leaf) in kv_pairs {
|
||||
if key == key_in_leaf {
|
||||
return Some(*value_in_leaf);
|
||||
}
|
||||
}
|
||||
|
||||
Some(EMPTY_WORD)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Inserts key-value pair into the leaf; returns the previous value associated with `key`, if
|
||||
/// any.
|
||||
///
|
||||
/// The caller needs to ensure that `key` has the same leaf index as all other keys in the leaf
|
||||
pub(super) fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => {
|
||||
*self = SmtLeaf::new_single(key, value);
|
||||
None
|
||||
}
|
||||
SmtLeaf::Single(kv_pair) => {
|
||||
if kv_pair.0 == key {
|
||||
// the key is already in this leaf. Update the value and return the previous
|
||||
// value
|
||||
let old_value = kv_pair.1;
|
||||
kv_pair.1 = value;
|
||||
Some(old_value)
|
||||
} else {
|
||||
// Another entry is present in this leaf. Transform the entry into a list
|
||||
// entry, and make sure the key-value pairs are sorted by key
|
||||
let mut pairs = vec![*kv_pair, (key, value)];
|
||||
pairs.sort_by(|(key_1, _), (key_2, _)| cmp_keys(*key_1, *key_2));
|
||||
|
||||
*self = SmtLeaf::Multiple(pairs);
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
SmtLeaf::Multiple(kv_pairs) => {
|
||||
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
|
||||
Ok(pos) => {
|
||||
let old_value = kv_pairs[pos].1;
|
||||
kv_pairs[pos].1 = value;
|
||||
|
||||
Some(old_value)
|
||||
}
|
||||
Err(pos) => {
|
||||
kv_pairs.insert(pos, (key, value));
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes key-value pair from the leaf stored at key; returns the previous value associated
|
||||
/// with `key`, if any. Also returns an `is_empty` flag, indicating whether the leaf became
|
||||
/// empty, and must be removed from the data structure it is contained in.
|
||||
pub(super) fn remove(&mut self, key: RpoDigest) -> (Option<Word>, bool) {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => (None, false),
|
||||
SmtLeaf::Single((key_at_leaf, value_at_leaf)) => {
|
||||
if *key_at_leaf == key {
|
||||
// our key was indeed stored in the leaf, so we return the value that was stored
|
||||
// in it, and indicate that the leaf should be removed
|
||||
let old_value = *value_at_leaf;
|
||||
|
||||
// Note: this is not strictly needed, since the caller is expected to drop this
|
||||
// `SmtLeaf` object.
|
||||
*self = SmtLeaf::new_empty(key.into());
|
||||
|
||||
(Some(old_value), true)
|
||||
} else {
|
||||
// another key is stored at leaf; nothing to update
|
||||
(None, false)
|
||||
}
|
||||
}
|
||||
SmtLeaf::Multiple(kv_pairs) => {
|
||||
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
|
||||
Ok(pos) => {
|
||||
let old_value = kv_pairs[pos].1;
|
||||
|
||||
kv_pairs.remove(pos);
|
||||
debug_assert!(!kv_pairs.is_empty());
|
||||
|
||||
if kv_pairs.len() == 1 {
|
||||
// convert the leaf into `Single`
|
||||
*self = SmtLeaf::Single(kv_pairs[0]);
|
||||
}
|
||||
|
||||
(Some(old_value), false)
|
||||
}
|
||||
Err(_) => {
|
||||
// other keys are stored at leaf; nothing to update
|
||||
(None, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for SmtLeaf {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
// Write: num entries
|
||||
self.num_entries().write_into(target);
|
||||
|
||||
// Write: leaf index
|
||||
let leaf_index: u64 = self.index().value();
|
||||
leaf_index.write_into(target);
|
||||
|
||||
// Write: entries
|
||||
for (key, value) in self.entries() {
|
||||
key.write_into(target);
|
||||
value.write_into(target);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SmtLeaf {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
// Read: num entries
|
||||
let num_entries = source.read_u64()?;
|
||||
|
||||
// Read: leaf index
|
||||
let leaf_index: LeafIndex<SMT_DEPTH> = {
|
||||
let value = source.read_u64()?;
|
||||
LeafIndex::new_max_depth(value)
|
||||
};
|
||||
|
||||
// Read: entries
|
||||
let mut entries: Vec<(RpoDigest, Word)> = Vec::new();
|
||||
for _ in 0..num_entries {
|
||||
let key: RpoDigest = source.read()?;
|
||||
let value: Word = source.read()?;
|
||||
|
||||
entries.push((key, value));
|
||||
}
|
||||
|
||||
Self::new(entries, leaf_index)
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Converts a key-value tuple to an iterator of `Felt`s
|
||||
fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt> {
|
||||
let key_elements = key.into_iter();
|
||||
let value_elements = value.into_iter();
|
||||
|
||||
key_elements.chain(value_elements)
|
||||
}
|
||||
|
||||
/// Compares two keys, compared element-by-element using their integer representations starting with
|
||||
/// the most significant element.
|
||||
fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering {
|
||||
for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() {
|
||||
let v1 = v1.as_int();
|
||||
let v2 = v2.as_int();
|
||||
if v1 != v2 {
|
||||
return v1.cmp(&v2);
|
||||
}
|
||||
}
|
||||
|
||||
Ordering::Equal
|
||||
}
|
||||
@@ -1,299 +0,0 @@
|
||||
use crate::hash::rpo::Rpo256;
|
||||
use crate::merkle::{EmptySubtreeRoots, InnerNodeInfo};
|
||||
use crate::utils::collections::{BTreeMap, BTreeSet};
|
||||
use crate::{Felt, EMPTY_WORD};
|
||||
|
||||
use super::{
|
||||
InnerNode, LeafIndex, MerkleError, MerklePath, NodeIndex, RpoDigest, SparseMerkleTree, Word,
|
||||
};
|
||||
|
||||
mod error;
|
||||
pub use error::{SmtLeafError, SmtProofError};
|
||||
|
||||
mod leaf;
|
||||
pub use leaf::SmtLeaf;
|
||||
|
||||
mod proof;
|
||||
pub use proof::SmtProof;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
pub const SMT_DEPTH: u8 = 64;
|
||||
|
||||
// SMT
|
||||
// ================================================================================================
|
||||
|
||||
/// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented
|
||||
/// by 4 field elements.
|
||||
///
|
||||
/// All leaves sit at depth 64. The most significant element of the key is used to identify the leaf to
|
||||
/// which the key maps.
|
||||
///
|
||||
/// A leaf is either empty, or holds one or more key-value pairs. An empty leaf hashes to the empty
|
||||
/// word. Otherwise, a leaf hashes to the hash of its key-value pairs, ordered by key first, value
|
||||
/// second.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct Smt {
|
||||
root: RpoDigest,
|
||||
leaves: BTreeMap<u64, SmtLeaf>,
|
||||
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
|
||||
}
|
||||
|
||||
impl Smt {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// The default value used to compute the hash of empty leaves
|
||||
pub const EMPTY_VALUE: Word = <Self as SparseMerkleTree<SMT_DEPTH>>::EMPTY_VALUE;
|
||||
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [Smt].
|
||||
///
|
||||
/// All leaves in the returned tree are set to [Self::EMPTY_VALUE].
|
||||
pub fn new() -> Self {
|
||||
let root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
|
||||
|
||||
Self {
|
||||
root,
|
||||
leaves: BTreeMap::new(),
|
||||
inner_nodes: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new [Smt] instantiated with leaves set as specified by the provided entries.
|
||||
///
|
||||
/// All leaves omitted from the entries list are set to [Self::EMPTY_VALUE].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the provided entries contain multiple values for the same key.
|
||||
pub fn with_entries(
|
||||
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
// create an empty tree
|
||||
let mut tree = Self::new();
|
||||
|
||||
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
|
||||
// entries with the empty value need additional tracking.
|
||||
let mut key_set_to_zero = BTreeSet::new();
|
||||
|
||||
for (key, value) in entries {
|
||||
let old_value = tree.insert(key, value);
|
||||
|
||||
if old_value != EMPTY_WORD || key_set_to_zero.contains(&key) {
|
||||
return Err(MerkleError::DuplicateValuesForIndex(
|
||||
LeafIndex::<SMT_DEPTH>::from(key).value(),
|
||||
));
|
||||
}
|
||||
|
||||
if value == EMPTY_WORD {
|
||||
key_set_to_zero.insert(key);
|
||||
};
|
||||
}
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the depth of the tree
|
||||
pub const fn depth(&self) -> u8 {
|
||||
SMT_DEPTH
|
||||
}
|
||||
|
||||
/// Returns the root of the tree
|
||||
pub fn root(&self) -> RpoDigest {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::root(self)
|
||||
}
|
||||
|
||||
/// Returns the leaf to which `key` maps
|
||||
pub fn get_leaf(&self, key: &RpoDigest) -> SmtLeaf {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::get_leaf(self, key)
|
||||
}
|
||||
|
||||
/// Returns the value associated with `key`
|
||||
pub fn get_value(&self, key: &RpoDigest) -> Word {
|
||||
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
|
||||
|
||||
match self.leaves.get(&leaf_pos) {
|
||||
Some(leaf) => leaf.get_value(key).unwrap_or_default(),
|
||||
None => EMPTY_WORD,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
|
||||
/// path to the leaf, as well as the leaf itself.
|
||||
pub fn open(&self, key: &RpoDigest) -> SmtProof {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::open(self, key)
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over the leaves of this [Smt].
|
||||
pub fn leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
|
||||
self.leaves
|
||||
.iter()
|
||||
.map(|(leaf_index, leaf)| (LeafIndex::new_max_depth(*leaf_index), leaf))
|
||||
}
|
||||
|
||||
/// Returns an iterator over the key-value pairs of this [Smt].
|
||||
pub fn entries(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
self.leaves().flat_map(|(_, leaf)| leaf.entries())
|
||||
}
|
||||
|
||||
/// Returns an iterator over the inner nodes of this [Smt].
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.inner_nodes.values().map(|e| InnerNodeInfo {
|
||||
value: e.hash(),
|
||||
left: e.left,
|
||||
right: e.right,
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts a value at the specified key, returning the previous value associated with that key.
|
||||
/// Recall that by definition, any key that hasn't been updated is associated with
|
||||
/// [`Self::EMPTY_VALUE`].
|
||||
///
|
||||
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
|
||||
/// updating the root itself.
|
||||
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Word {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::insert(self, key, value)
|
||||
}
|
||||
|
||||
// HELPERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts `value` at leaf index pointed to by `key`. `value` is guaranteed to not be the empty
|
||||
/// value, such that this is indeed an insertion.
|
||||
fn perform_insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
||||
debug_assert_ne!(value, Self::EMPTY_VALUE);
|
||||
|
||||
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
|
||||
|
||||
match self.leaves.get_mut(&leaf_index.value()) {
|
||||
Some(leaf) => leaf.insert(key, value),
|
||||
None => {
|
||||
self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes key-value pair at leaf index pointed to by `key` if it exists.
|
||||
fn perform_remove(&mut self, key: RpoDigest) -> Option<Word> {
|
||||
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
|
||||
|
||||
if let Some(leaf) = self.leaves.get_mut(&leaf_index.value()) {
|
||||
let (old_value, is_empty) = leaf.remove(key);
|
||||
if is_empty {
|
||||
self.leaves.remove(&leaf_index.value());
|
||||
}
|
||||
old_value
|
||||
} else {
|
||||
// there's nothing stored at the leaf; nothing to update
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SparseMerkleTree<SMT_DEPTH> for Smt {
|
||||
type Key = RpoDigest;
|
||||
type Value = Word;
|
||||
type Leaf = SmtLeaf;
|
||||
type Opening = SmtProof;
|
||||
|
||||
const EMPTY_VALUE: Self::Value = EMPTY_WORD;
|
||||
|
||||
fn root(&self) -> RpoDigest {
|
||||
self.root
|
||||
}
|
||||
|
||||
fn set_root(&mut self, root: RpoDigest) {
|
||||
self.root = root;
|
||||
}
|
||||
|
||||
fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
|
||||
self.inner_nodes.get(&index).cloned().unwrap_or_else(|| {
|
||||
let node = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth() + 1);
|
||||
|
||||
InnerNode { left: *node, right: *node }
|
||||
})
|
||||
}
|
||||
|
||||
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
|
||||
self.inner_nodes.insert(index, inner_node);
|
||||
}
|
||||
|
||||
fn remove_inner_node(&mut self, index: NodeIndex) {
|
||||
let _ = self.inner_nodes.remove(&index);
|
||||
}
|
||||
|
||||
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> {
|
||||
// inserting an `EMPTY_VALUE` is equivalent to removing any value associated with `key`
|
||||
if value != Self::EMPTY_VALUE {
|
||||
self.perform_insert(key, value)
|
||||
} else {
|
||||
self.perform_remove(key)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_leaf(&self, key: &RpoDigest) -> Self::Leaf {
|
||||
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
|
||||
|
||||
match self.leaves.get(&leaf_pos) {
|
||||
Some(leaf) => leaf.clone(),
|
||||
None => SmtLeaf::new_empty(key.into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest {
|
||||
leaf.hash()
|
||||
}
|
||||
|
||||
fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex<SMT_DEPTH> {
|
||||
let most_significant_felt = key[3];
|
||||
LeafIndex::new_max_depth(most_significant_felt.as_int())
|
||||
}
|
||||
|
||||
fn path_and_leaf_to_opening(path: MerklePath, leaf: SmtLeaf) -> SmtProof {
|
||||
SmtProof::new_unchecked(path, leaf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Smt {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS
|
||||
// ================================================================================================
|
||||
|
||||
impl From<Word> for LeafIndex<SMT_DEPTH> {
|
||||
fn from(value: Word) -> Self {
|
||||
// We use the most significant `Felt` of a `Word` as the leaf index.
|
||||
Self::new_max_depth(value[3].as_int())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for LeafIndex<SMT_DEPTH> {
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
Word::from(value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for LeafIndex<SMT_DEPTH> {
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
Word::from(value).into()
|
||||
}
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
use crate::utils::string::ToString;
|
||||
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
|
||||
use super::{MerklePath, RpoDigest, SmtLeaf, SmtProofError, Word, SMT_DEPTH};
|
||||
|
||||
/// A proof which can be used to assert membership (or non-membership) of key-value pairs in a
|
||||
/// [`super::Smt`].
|
||||
///
|
||||
/// The proof consists of a Merkle path and leaf which describes the node located at the base of the
|
||||
/// path.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct SmtProof {
|
||||
path: MerklePath,
|
||||
leaf: SmtLeaf,
|
||||
}
|
||||
|
||||
impl SmtProof {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new instance of [`SmtProof`] instantiated from the specified path and leaf.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the path length is not [`SMT_DEPTH`].
|
||||
pub fn new(path: MerklePath, leaf: SmtLeaf) -> Result<Self, SmtProofError> {
|
||||
if path.len() != SMT_DEPTH.into() {
|
||||
return Err(SmtProofError::InvalidPathLength(path.len()));
|
||||
}
|
||||
|
||||
Ok(Self { path, leaf })
|
||||
}
|
||||
|
||||
/// Returns a new instance of [`SmtProof`] instantiated from the specified path and leaf.
|
||||
///
|
||||
/// The length of the path is not checked. Reserved for internal use.
|
||||
pub(super) fn new_unchecked(path: MerklePath, leaf: SmtLeaf) -> Self {
|
||||
Self { path, leaf }
|
||||
}
|
||||
|
||||
// PROOF VERIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if a [`super::Smt`] with the specified root contains the provided
|
||||
/// key-value pair.
|
||||
///
|
||||
/// Note: this method cannot be used to assert non-membership. That is, if false is returned,
|
||||
/// it does not mean that the provided key-value pair is not in the tree.
|
||||
pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool {
|
||||
let maybe_value_in_leaf = self.leaf.get_value(key);
|
||||
|
||||
match maybe_value_in_leaf {
|
||||
Some(value_in_leaf) => {
|
||||
// The value must match for the proof to be valid
|
||||
if value_in_leaf != *value {
|
||||
return false;
|
||||
}
|
||||
|
||||
// make sure the Merkle path resolves to the correct root
|
||||
self.compute_root() == *root
|
||||
}
|
||||
// If the key maps to a different leaf, the proof cannot verify membership of `value`
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the value associated with the specific key according to this proof, or None if
|
||||
/// this proof does not contain a value for the specified key.
|
||||
///
|
||||
/// A key-value pair generated by using this method should pass the `verify_membership()` check.
|
||||
pub fn get(&self, key: &RpoDigest) -> Option<Word> {
|
||||
self.leaf.get_value(key)
|
||||
}
|
||||
|
||||
/// Computes the root of a [`super::Smt`] to which this proof resolves.
|
||||
pub fn compute_root(&self) -> RpoDigest {
|
||||
self.path
|
||||
.compute_root(self.leaf.index().value(), self.leaf.hash())
|
||||
.expect("failed to compute Merkle path root")
|
||||
}
|
||||
|
||||
/// Returns the proof's Merkle path.
|
||||
pub fn path(&self) -> &MerklePath {
|
||||
&self.path
|
||||
}
|
||||
|
||||
/// Returns the leaf associated with the proof.
|
||||
pub fn leaf(&self) -> &SmtLeaf {
|
||||
&self.leaf
|
||||
}
|
||||
|
||||
/// Consume the proof and returns its parts.
|
||||
pub fn into_parts(self) -> (MerklePath, SmtLeaf) {
|
||||
(self.path, self.leaf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for SmtProof {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
self.path.write_into(target);
|
||||
self.leaf.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SmtProof {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let path = MerklePath::read_from(source)?;
|
||||
let leaf = SmtLeaf::read_from(source)?;
|
||||
|
||||
Self::new(path, leaf).map_err(|err| DeserializationError::InvalidValue(err.to_string()))
|
||||
}
|
||||
}
|
||||
@@ -1,408 +0,0 @@
|
||||
use winter_utils::{Deserializable, Serializable};
|
||||
|
||||
use super::*;
|
||||
use crate::{
|
||||
merkle::{EmptySubtreeRoots, MerkleStore},
|
||||
utils::collections::Vec,
|
||||
ONE, WORD_SIZE,
|
||||
};
|
||||
|
||||
// SMT
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// This test checks that inserting twice at the same key functions as expected. The test covers
|
||||
/// only the case where the key is alone in its leaf
|
||||
#[test]
|
||||
fn test_smt_insert_at_same_key() {
|
||||
let mut smt = Smt::default();
|
||||
let mut store: MerkleStore = MerkleStore::default();
|
||||
|
||||
assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
|
||||
|
||||
let key_1: RpoDigest = {
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
|
||||
};
|
||||
let key_1_index: NodeIndex = LeafIndex::<SMT_DEPTH>::from(key_1).into();
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [ONE + ONE; WORD_SIZE];
|
||||
|
||||
// Insert value 1 and ensure root is as expected
|
||||
{
|
||||
let leaf_node = build_empty_or_single_leaf_node(key_1, value_1);
|
||||
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
|
||||
|
||||
let old_value_1 = smt.insert(key_1, value_1);
|
||||
assert_eq!(old_value_1, EMPTY_WORD);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
}
|
||||
|
||||
// Insert value 2 and ensure root is as expected
|
||||
{
|
||||
let leaf_node = build_empty_or_single_leaf_node(key_1, value_2);
|
||||
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
|
||||
|
||||
let old_value_2 = smt.insert(key_1, value_2);
|
||||
assert_eq!(old_value_2, value_1);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
}
|
||||
}
|
||||
|
||||
/// This test checks that inserting twice at the same key functions as expected. The test covers
|
||||
/// only the case where the leaf type is `SmtLeaf::Multiple`
|
||||
#[test]
|
||||
fn test_smt_insert_at_same_key_2() {
|
||||
// The most significant u64 used for both keys (to ensure they map to the same leaf)
|
||||
let key_msb: u64 = 42;
|
||||
|
||||
let key_already_present: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(key_msb)]);
|
||||
let key_already_present_index: NodeIndex =
|
||||
LeafIndex::<SMT_DEPTH>::from(key_already_present).into();
|
||||
let value_already_present = [ONE + ONE + ONE; WORD_SIZE];
|
||||
|
||||
let mut smt =
|
||||
Smt::with_entries(core::iter::once((key_already_present, value_already_present))).unwrap();
|
||||
let mut store: MerkleStore = {
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
let leaf_node = build_empty_or_single_leaf_node(key_already_present, value_already_present);
|
||||
store
|
||||
.set_node(*EmptySubtreeRoots::entry(SMT_DEPTH, 0), key_already_present_index, leaf_node)
|
||||
.unwrap();
|
||||
store
|
||||
};
|
||||
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(key_msb)]);
|
||||
let key_1_index: NodeIndex = LeafIndex::<SMT_DEPTH>::from(key_1).into();
|
||||
|
||||
assert_eq!(key_1_index, key_already_present_index);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [ONE + ONE; WORD_SIZE];
|
||||
|
||||
// Insert value 1 and ensure root is as expected
|
||||
{
|
||||
// Note: key_1 comes first because it is smaller
|
||||
let leaf_node = build_multiple_leaf_node(&[
|
||||
(key_1, value_1),
|
||||
(key_already_present, value_already_present),
|
||||
]);
|
||||
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
|
||||
|
||||
let old_value_1 = smt.insert(key_1, value_1);
|
||||
assert_eq!(old_value_1, EMPTY_WORD);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
}
|
||||
|
||||
// Insert value 2 and ensure root is as expected
|
||||
{
|
||||
let leaf_node = build_multiple_leaf_node(&[
|
||||
(key_1, value_2),
|
||||
(key_already_present, value_already_present),
|
||||
]);
|
||||
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
|
||||
|
||||
let old_value_2 = smt.insert(key_1, value_2);
|
||||
assert_eq!(old_value_2, value_1);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
}
|
||||
}
|
||||
|
||||
/// This test ensures that the root of the tree is as expected when we add/remove 3 items at 3
|
||||
/// different keys. This also tests that the merkle paths produced are as expected.
|
||||
#[test]
|
||||
fn test_smt_insert_and_remove_multiple_values() {
|
||||
fn insert_values_and_assert_path(
|
||||
smt: &mut Smt,
|
||||
store: &mut MerkleStore,
|
||||
key_values: &[(RpoDigest, Word)],
|
||||
) {
|
||||
for &(key, value) in key_values {
|
||||
let key_index: NodeIndex = LeafIndex::<SMT_DEPTH>::from(key).into();
|
||||
|
||||
let leaf_node = build_empty_or_single_leaf_node(key, value);
|
||||
let tree_root = store.set_node(smt.root(), key_index, leaf_node).unwrap().root;
|
||||
|
||||
let _ = smt.insert(key, value);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
let expected_path = store.get_path(tree_root, key_index).unwrap();
|
||||
assert_eq!(smt.open(&key).into_parts().0, expected_path.path);
|
||||
}
|
||||
}
|
||||
let mut smt = Smt::default();
|
||||
let mut store: MerkleStore = MerkleStore::default();
|
||||
|
||||
assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
|
||||
|
||||
let key_1: RpoDigest = {
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
|
||||
};
|
||||
|
||||
let key_2: RpoDigest = {
|
||||
let raw = 0b_11111111_11111111_11111111_11111111_11111111_11111111_11111111_11111111_u64;
|
||||
|
||||
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
|
||||
};
|
||||
|
||||
let key_3: RpoDigest = {
|
||||
let raw = 0b_00000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000_u64;
|
||||
|
||||
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
|
||||
};
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [ONE + ONE; WORD_SIZE];
|
||||
let value_3 = [ONE + ONE + ONE; WORD_SIZE];
|
||||
|
||||
// Insert values in the tree
|
||||
let key_values = [(key_1, value_1), (key_2, value_2), (key_3, value_3)];
|
||||
insert_values_and_assert_path(&mut smt, &mut store, &key_values);
|
||||
|
||||
// Remove values from the tree
|
||||
let key_empty_values = [(key_1, EMPTY_WORD), (key_2, EMPTY_WORD), (key_3, EMPTY_WORD)];
|
||||
insert_values_and_assert_path(&mut smt, &mut store, &key_empty_values);
|
||||
|
||||
let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
|
||||
assert_eq!(smt.root(), empty_root);
|
||||
|
||||
// an empty tree should have no leaves or inner nodes
|
||||
assert!(smt.leaves.is_empty());
|
||||
assert!(smt.inner_nodes.is_empty());
|
||||
}
|
||||
|
||||
/// This tests that inserting the empty value does indeed remove the key-value contained at the
|
||||
/// leaf. We insert & remove 3 values at the same leaf to ensure that all cases are covered (empty,
|
||||
/// single, multiple).
|
||||
#[test]
|
||||
fn test_smt_removal() {
|
||||
let mut smt = Smt::default();
|
||||
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
|
||||
let key_3: RpoDigest =
|
||||
RpoDigest::from([3_u32.into(), 3_u32.into(), 3_u32.into(), Felt::new(raw)]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
|
||||
|
||||
// insert key-value 1
|
||||
{
|
||||
let old_value_1 = smt.insert(key_1, value_1);
|
||||
assert_eq!(old_value_1, EMPTY_WORD);
|
||||
|
||||
assert_eq!(smt.get_leaf(&key_1), SmtLeaf::Single((key_1, value_1)));
|
||||
}
|
||||
|
||||
// insert key-value 2
|
||||
{
|
||||
let old_value_2 = smt.insert(key_2, value_2);
|
||||
assert_eq!(old_value_2, EMPTY_WORD);
|
||||
|
||||
assert_eq!(
|
||||
smt.get_leaf(&key_2),
|
||||
SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)])
|
||||
);
|
||||
}
|
||||
|
||||
// insert key-value 3
|
||||
{
|
||||
let old_value_3 = smt.insert(key_3, value_3);
|
||||
assert_eq!(old_value_3, EMPTY_WORD);
|
||||
|
||||
assert_eq!(
|
||||
smt.get_leaf(&key_3),
|
||||
SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2), (key_3, value_3)])
|
||||
);
|
||||
}
|
||||
|
||||
// remove key 3
|
||||
{
|
||||
let old_value_3 = smt.insert(key_3, EMPTY_WORD);
|
||||
assert_eq!(old_value_3, value_3);
|
||||
|
||||
assert_eq!(
|
||||
smt.get_leaf(&key_3),
|
||||
SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)])
|
||||
);
|
||||
}
|
||||
|
||||
// remove key 2
|
||||
{
|
||||
let old_value_2 = smt.insert(key_2, EMPTY_WORD);
|
||||
assert_eq!(old_value_2, value_2);
|
||||
|
||||
assert_eq!(smt.get_leaf(&key_2), SmtLeaf::Single((key_1, value_1)));
|
||||
}
|
||||
|
||||
// remove key 1
|
||||
{
|
||||
let old_value_1 = smt.insert(key_1, EMPTY_WORD);
|
||||
assert_eq!(old_value_1, value_1);
|
||||
|
||||
assert_eq!(smt.get_leaf(&key_1), SmtLeaf::new_empty(key_1.into()));
|
||||
}
|
||||
}
|
||||
|
||||
/// Tests that 2 key-value pairs stored in the same leaf have the same path
|
||||
#[test]
|
||||
fn test_smt_path_to_keys_in_same_leaf_are_equal() {
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
|
||||
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
|
||||
|
||||
assert_eq!(smt.open(&key_1), smt.open(&key_2));
|
||||
}
|
||||
|
||||
/// Tests that an empty leaf hashes to the empty word
|
||||
#[test]
|
||||
fn test_empty_leaf_hash() {
|
||||
let smt = Smt::default();
|
||||
|
||||
let leaf = smt.get_leaf(&RpoDigest::default());
|
||||
assert_eq!(leaf.hash(), EMPTY_WORD.into());
|
||||
}
|
||||
|
||||
/// Tests that `get_value()` works as expected
|
||||
#[test]
|
||||
fn test_smt_get_value() {
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), 2_u32.into()]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
|
||||
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
|
||||
|
||||
let returned_value_1 = smt.get_value(&key_1);
|
||||
let returned_value_2 = smt.get_value(&key_2);
|
||||
|
||||
assert_eq!(value_1, returned_value_1);
|
||||
assert_eq!(value_2, returned_value_2);
|
||||
|
||||
// Check that a key with no inserted value returns the empty word
|
||||
let key_no_value =
|
||||
RpoDigest::from([42_u32.into(), 42_u32.into(), 42_u32.into(), 42_u32.into()]);
|
||||
|
||||
assert_eq!(EMPTY_WORD, smt.get_value(&key_no_value));
|
||||
}
|
||||
|
||||
/// Tests that `entries()` works as expected
|
||||
#[test]
|
||||
fn test_smt_entries() {
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), 2_u32.into()]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
|
||||
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
|
||||
|
||||
let mut entries = smt.entries();
|
||||
|
||||
// Note: for simplicity, we assume the order `(k1,v1), (k2,v2)`. If a new implementation
|
||||
// switches the order, it is OK to modify the order here as well.
|
||||
assert_eq!(&(key_1, value_1), entries.next().unwrap());
|
||||
assert_eq!(&(key_2, value_2), entries.next().unwrap());
|
||||
assert!(entries.next().is_none());
|
||||
}
|
||||
|
||||
// SMT LEAF
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_empty_smt_leaf_serialization() {
|
||||
let empty_leaf = SmtLeaf::new_empty(LeafIndex::new_max_depth(42));
|
||||
|
||||
let mut serialized = empty_leaf.to_bytes();
|
||||
// extend buffer with random bytes
|
||||
serialized.extend([1, 2, 3, 4, 5]);
|
||||
let deserialized = SmtLeaf::read_from_bytes(&serialized).unwrap();
|
||||
|
||||
assert_eq!(empty_leaf, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_smt_leaf_serialization() {
|
||||
let single_leaf = SmtLeaf::new_single(
|
||||
RpoDigest::from([10_u32.into(), 11_u32.into(), 12_u32.into(), 13_u32.into()]),
|
||||
[1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()],
|
||||
);
|
||||
|
||||
let mut serialized = single_leaf.to_bytes();
|
||||
// extend buffer with random bytes
|
||||
serialized.extend([1, 2, 3, 4, 5]);
|
||||
let deserialized = SmtLeaf::read_from_bytes(&serialized).unwrap();
|
||||
|
||||
assert_eq!(single_leaf, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_smt_leaf_serialization_success() {
|
||||
let multiple_leaf = SmtLeaf::new_multiple(vec![
|
||||
(
|
||||
RpoDigest::from([10_u32.into(), 11_u32.into(), 12_u32.into(), 13_u32.into()]),
|
||||
[1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()],
|
||||
),
|
||||
(
|
||||
RpoDigest::from([100_u32.into(), 101_u32.into(), 102_u32.into(), 13_u32.into()]),
|
||||
[11_u32.into(), 12_u32.into(), 13_u32.into(), 14_u32.into()],
|
||||
),
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let mut serialized = multiple_leaf.to_bytes();
|
||||
// extend buffer with random bytes
|
||||
serialized.extend([1, 2, 3, 4, 5]);
|
||||
let deserialized = SmtLeaf::read_from_bytes(&serialized).unwrap();
|
||||
|
||||
assert_eq!(multiple_leaf, deserialized);
|
||||
}
|
||||
|
||||
// HELPERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn build_empty_or_single_leaf_node(key: RpoDigest, value: Word) -> RpoDigest {
|
||||
if value == EMPTY_WORD {
|
||||
SmtLeaf::new_empty(key.into()).hash()
|
||||
} else {
|
||||
SmtLeaf::Single((key, value)).hash()
|
||||
}
|
||||
}
|
||||
|
||||
fn build_multiple_leaf_node(kv_pairs: &[(RpoDigest, Word)]) -> RpoDigest {
|
||||
let elements: Vec<Felt> = kv_pairs
|
||||
.iter()
|
||||
.flat_map(|(key, value)| {
|
||||
let key_elements = key.into_iter();
|
||||
let value_elements = (*value).into_iter();
|
||||
|
||||
key_elements.chain(value_elements)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Rpo256::hash_elements(&elements)
|
||||
}
|
||||
@@ -1,245 +0,0 @@
|
||||
use crate::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
Word,
|
||||
};
|
||||
|
||||
use super::{EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex, Vec};
|
||||
|
||||
mod full;
|
||||
pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
|
||||
|
||||
mod simple;
|
||||
pub use simple::SimpleSmt;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Minimum supported depth.
|
||||
pub const SMT_MIN_DEPTH: u8 = 1;
|
||||
|
||||
/// Maximum supported depth.
|
||||
pub const SMT_MAX_DEPTH: u8 = 64;
|
||||
|
||||
// SPARSE MERKLE TREE
|
||||
// ================================================================================================
|
||||
|
||||
/// An abstract description of a sparse Merkle tree.
|
||||
///
|
||||
/// A sparse Merkle tree is a key-value map which also supports proving that a given value is indeed
|
||||
/// stored at a given key in the tree. It is viewed as always being fully populated. If a leaf's
|
||||
/// value was not explicitly set, then its value is the default value. Typically, the vast majority
|
||||
/// of leaves will store the default value (hence it is "sparse"), and therefore the internal
|
||||
/// representation of the tree will only keep track of the leaves that have a different value from
|
||||
/// the default.
|
||||
///
|
||||
/// All leaves sit at the same depth. The deeper the tree, the more leaves it has; but also the
|
||||
/// longer its proofs are - of exactly `log(depth)` size. A tree cannot have depth 0, since such a
|
||||
/// tree is just a single value, and is probably a programming mistake.
|
||||
///
|
||||
/// Every key maps to one leaf. If there are as many keys as there are leaves, then
|
||||
/// [Self::Leaf] should be the same type as [Self::Value], as is the case with
|
||||
/// [crate::merkle::SimpleSmt]. However, if there are more keys than leaves, then [`Self::Leaf`]
|
||||
/// must accomodate all keys that map to the same leaf.
|
||||
///
|
||||
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
|
||||
pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
|
||||
/// The type for a key
|
||||
type Key: Clone;
|
||||
/// The type for a value
|
||||
type Value: Clone + PartialEq;
|
||||
/// The type for a leaf
|
||||
type Leaf;
|
||||
/// The type for an opening (i.e. a "proof") of a leaf
|
||||
type Opening;
|
||||
|
||||
/// The default value used to compute the hash of empty leaves
|
||||
const EMPTY_VALUE: Self::Value;
|
||||
|
||||
// PROVIDED METHODS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
|
||||
/// path to the leaf, as well as the leaf itself.
|
||||
fn open(&self, key: &Self::Key) -> Self::Opening {
|
||||
let leaf = self.get_leaf(key);
|
||||
|
||||
let mut index: NodeIndex = {
|
||||
let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(key);
|
||||
leaf_index.into()
|
||||
};
|
||||
|
||||
let merkle_path = {
|
||||
let mut path = Vec::with_capacity(index.depth() as usize);
|
||||
for _ in 0..index.depth() {
|
||||
let is_right = index.is_value_odd();
|
||||
index.move_up();
|
||||
let InnerNode { left, right } = self.get_inner_node(index);
|
||||
let value = if is_right { left } else { right };
|
||||
path.push(value);
|
||||
}
|
||||
|
||||
MerklePath::new(path)
|
||||
};
|
||||
|
||||
Self::path_and_leaf_to_opening(merkle_path, leaf)
|
||||
}
|
||||
|
||||
/// Inserts a value at the specified key, returning the previous value associated with that key.
|
||||
/// Recall that by definition, any key that hasn't been updated is associated with
|
||||
/// [`Self::EMPTY_VALUE`].
|
||||
///
|
||||
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
|
||||
/// updating the root itself.
|
||||
fn insert(&mut self, key: Self::Key, value: Self::Value) -> Self::Value {
|
||||
let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE);
|
||||
|
||||
// if the old value and new value are the same, there is nothing to update
|
||||
if value == old_value {
|
||||
return value;
|
||||
}
|
||||
|
||||
let leaf = self.get_leaf(&key);
|
||||
let node_index = {
|
||||
let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(&key);
|
||||
leaf_index.into()
|
||||
};
|
||||
|
||||
self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf));
|
||||
|
||||
old_value
|
||||
}
|
||||
|
||||
/// Recomputes the branch nodes (including the root) from `index` all the way to the root.
|
||||
/// `node_hash_at_index` is the hash of the node stored at index.
|
||||
fn recompute_nodes_from_index_to_root(
|
||||
&mut self,
|
||||
mut index: NodeIndex,
|
||||
node_hash_at_index: RpoDigest,
|
||||
) {
|
||||
let mut node_hash = node_hash_at_index;
|
||||
for node_depth in (0..index.depth()).rev() {
|
||||
let is_right = index.is_value_odd();
|
||||
index.move_up();
|
||||
let InnerNode { left, right } = self.get_inner_node(index);
|
||||
let (left, right) = if is_right {
|
||||
(left, node_hash)
|
||||
} else {
|
||||
(node_hash, right)
|
||||
};
|
||||
node_hash = Rpo256::merge(&[left, right]);
|
||||
|
||||
if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
|
||||
// If a subtree is empty, when can remove the inner node, since it's equal to the
|
||||
// default value
|
||||
self.remove_inner_node(index)
|
||||
} else {
|
||||
self.insert_inner_node(index, InnerNode { left, right });
|
||||
}
|
||||
}
|
||||
self.set_root(node_hash);
|
||||
}
|
||||
|
||||
// REQUIRED METHODS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// The root of the tree
|
||||
fn root(&self) -> RpoDigest;
|
||||
|
||||
/// Sets the root of the tree
|
||||
fn set_root(&mut self, root: RpoDigest);
|
||||
|
||||
/// Retrieves an inner node at the given index
|
||||
fn get_inner_node(&self, index: NodeIndex) -> InnerNode;
|
||||
|
||||
/// Inserts an inner node at the given index
|
||||
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode);
|
||||
|
||||
/// Removes an inner node at the given index
|
||||
fn remove_inner_node(&mut self, index: NodeIndex);
|
||||
|
||||
/// Inserts a leaf node, and returns the value at the key if already exists
|
||||
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>;
|
||||
|
||||
/// Returns the leaf at the specified index.
|
||||
fn get_leaf(&self, key: &Self::Key) -> Self::Leaf;
|
||||
|
||||
/// Returns the hash of a leaf
|
||||
fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest;
|
||||
|
||||
/// Maps a key to a leaf index
|
||||
fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
|
||||
|
||||
/// Maps a (MerklePath, Self::Leaf) to an opening.
|
||||
///
|
||||
/// The length `path` is guaranteed to be equal to `DEPTH`
|
||||
fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening;
|
||||
}
|
||||
|
||||
// INNER NODE
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub(crate) struct InnerNode {
|
||||
pub left: RpoDigest,
|
||||
pub right: RpoDigest,
|
||||
}
|
||||
|
||||
impl InnerNode {
|
||||
pub fn hash(&self) -> RpoDigest {
|
||||
Rpo256::merge(&[self.left, self.right])
|
||||
}
|
||||
}
|
||||
|
||||
// LEAF INDEX
|
||||
// ================================================================================================
|
||||
|
||||
/// The index of a leaf, at a depth known at compile-time.
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct LeafIndex<const DEPTH: u8> {
|
||||
index: NodeIndex,
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> LeafIndex<DEPTH> {
|
||||
pub fn new(value: u64) -> Result<Self, MerkleError> {
|
||||
if DEPTH < SMT_MIN_DEPTH {
|
||||
return Err(MerkleError::DepthTooSmall(DEPTH));
|
||||
}
|
||||
|
||||
Ok(LeafIndex { index: NodeIndex::new(DEPTH, value)? })
|
||||
}
|
||||
|
||||
pub fn value(&self) -> u64 {
|
||||
self.index.value()
|
||||
}
|
||||
}
|
||||
|
||||
impl LeafIndex<SMT_MAX_DEPTH> {
|
||||
pub const fn new_max_depth(value: u64) -> Self {
|
||||
LeafIndex {
|
||||
index: NodeIndex::new_unchecked(SMT_MAX_DEPTH, value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> From<LeafIndex<DEPTH>> for NodeIndex {
|
||||
fn from(value: LeafIndex<DEPTH>) -> Self {
|
||||
value.index
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
|
||||
type Error = MerkleError;
|
||||
|
||||
fn try_from(node_index: NodeIndex) -> Result<Self, Self::Error> {
|
||||
if node_index.depth() != DEPTH {
|
||||
return Err(MerkleError::InvalidDepth {
|
||||
expected: DEPTH,
|
||||
provided: node_index.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
Self::new(node_index.value())
|
||||
}
|
||||
}
|
||||
@@ -1,309 +0,0 @@
|
||||
use crate::{
|
||||
merkle::{EmptySubtreeRoots, InnerNodeInfo, MerklePath, ValuePath},
|
||||
EMPTY_WORD,
|
||||
};
|
||||
|
||||
use super::{
|
||||
InnerNode, LeafIndex, MerkleError, NodeIndex, RpoDigest, SparseMerkleTree, Word, SMT_MAX_DEPTH,
|
||||
SMT_MIN_DEPTH,
|
||||
};
|
||||
use crate::utils::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// SPARSE MERKLE TREE
|
||||
// ================================================================================================
|
||||
|
||||
/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
|
||||
///
|
||||
/// The root of the tree is recomputed on each new leaf update.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct SimpleSmt<const DEPTH: u8> {
|
||||
root: RpoDigest,
|
||||
leaves: BTreeMap<u64, Word>,
|
||||
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> SimpleSmt<DEPTH> {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// The default value used to compute the hash of empty leaves
|
||||
pub const EMPTY_VALUE: Word = <Self as SparseMerkleTree<DEPTH>>::EMPTY_VALUE;
|
||||
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [SimpleSmt].
|
||||
///
|
||||
/// All leaves in the returned tree are set to [ZERO; 4].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if DEPTH is 0 or is greater than 64.
|
||||
pub fn new() -> Result<Self, MerkleError> {
|
||||
// validate the range of the depth.
|
||||
if DEPTH < SMT_MIN_DEPTH {
|
||||
return Err(MerkleError::DepthTooSmall(DEPTH));
|
||||
} else if SMT_MAX_DEPTH < DEPTH {
|
||||
return Err(MerkleError::DepthTooBig(DEPTH as u64));
|
||||
}
|
||||
|
||||
let root = *EmptySubtreeRoots::entry(DEPTH, 0);
|
||||
|
||||
Ok(Self {
|
||||
root,
|
||||
leaves: BTreeMap::new(),
|
||||
inner_nodes: BTreeMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a new [SimpleSmt] instantiated with leaves set as specified by the provided entries.
|
||||
///
|
||||
/// All leaves omitted from the entries list are set to [ZERO; 4].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - If the depth is 0 or is greater than 64.
|
||||
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
|
||||
/// - The provided entries contain multiple values for the same key.
|
||||
pub fn with_leaves(
|
||||
entries: impl IntoIterator<Item = (u64, Word)>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
// create an empty tree
|
||||
let mut tree = Self::new()?;
|
||||
|
||||
// compute the max number of entries. We use an upper bound of depth 63 because we consider
|
||||
// passing in a vector of size 2^64 infeasible.
|
||||
let max_num_entries = 2_usize.pow(DEPTH.min(63).into());
|
||||
|
||||
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
|
||||
// entries with the empty value need additional tracking.
|
||||
let mut key_set_to_zero = BTreeSet::new();
|
||||
|
||||
for (idx, (key, value)) in entries.into_iter().enumerate() {
|
||||
if idx >= max_num_entries {
|
||||
return Err(MerkleError::InvalidNumEntries(max_num_entries));
|
||||
}
|
||||
|
||||
let old_value = tree.insert(LeafIndex::<DEPTH>::new(key)?, value);
|
||||
|
||||
if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) {
|
||||
return Err(MerkleError::DuplicateValuesForIndex(key));
|
||||
}
|
||||
|
||||
if value == Self::EMPTY_VALUE {
|
||||
key_set_to_zero.insert(key);
|
||||
};
|
||||
}
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
|
||||
/// starting at index 0.
|
||||
pub fn with_contiguous_leaves(
|
||||
entries: impl IntoIterator<Item = Word>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
Self::with_leaves(
|
||||
entries
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, word)| (idx.try_into().expect("tree max depth is 2^8"), word)),
|
||||
)
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the depth of the tree
|
||||
pub const fn depth(&self) -> u8 {
|
||||
DEPTH
|
||||
}
|
||||
|
||||
/// Returns the root of the tree
|
||||
pub fn root(&self) -> RpoDigest {
|
||||
<Self as SparseMerkleTree<DEPTH>>::root(self)
|
||||
}
|
||||
|
||||
/// Returns the leaf at the specified index.
|
||||
pub fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
|
||||
<Self as SparseMerkleTree<DEPTH>>::get_leaf(self, key)
|
||||
}
|
||||
|
||||
/// Returns a node at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified index has depth set to 0 or the depth is greater than
|
||||
/// the depth of this Merkle tree.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
if index.is_root() {
|
||||
Err(MerkleError::DepthTooSmall(index.depth()))
|
||||
} else if index.depth() > DEPTH {
|
||||
Err(MerkleError::DepthTooBig(index.depth() as u64))
|
||||
} else if index.depth() == DEPTH {
|
||||
let leaf = self.get_leaf(&LeafIndex::<DEPTH>::try_from(index)?);
|
||||
|
||||
Ok(leaf.into())
|
||||
} else {
|
||||
Ok(self.get_inner_node(index).hash())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
|
||||
/// path to the leaf, as well as the leaf itself.
|
||||
pub fn open(&self, key: &LeafIndex<DEPTH>) -> ValuePath {
|
||||
<Self as SparseMerkleTree<DEPTH>>::open(self, key)
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over the leaves of this [SimpleSmt].
|
||||
pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
|
||||
self.leaves.iter().map(|(i, w)| (*i, w))
|
||||
}
|
||||
|
||||
/// Returns an iterator over the inner nodes of this [SimpleSmt].
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.inner_nodes.values().map(|e| InnerNodeInfo {
|
||||
value: e.hash(),
|
||||
left: e.left,
|
||||
right: e.right,
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts a value at the specified key, returning the previous value associated with that key.
|
||||
/// Recall that by definition, any key that hasn't been updated is associated with
|
||||
/// [`EMPTY_WORD`].
|
||||
///
|
||||
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
|
||||
/// updating the root itself.
|
||||
pub fn insert(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Word {
|
||||
<Self as SparseMerkleTree<DEPTH>>::insert(self, key, value)
|
||||
}
|
||||
|
||||
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
|
||||
/// computed as `DEPTH - SUBTREE_DEPTH`.
|
||||
///
|
||||
/// Returns the new root.
|
||||
pub fn set_subtree<const SUBTREE_DEPTH: u8>(
|
||||
&mut self,
|
||||
subtree_insertion_index: u64,
|
||||
subtree: SimpleSmt<SUBTREE_DEPTH>,
|
||||
) -> Result<RpoDigest, MerkleError> {
|
||||
if SUBTREE_DEPTH > DEPTH {
|
||||
return Err(MerkleError::InvalidSubtreeDepth {
|
||||
subtree_depth: SUBTREE_DEPTH,
|
||||
tree_depth: DEPTH,
|
||||
});
|
||||
}
|
||||
|
||||
// Verify that `subtree_insertion_index` is valid.
|
||||
let subtree_root_insertion_depth = DEPTH - SUBTREE_DEPTH;
|
||||
let subtree_root_index =
|
||||
NodeIndex::new(subtree_root_insertion_depth, subtree_insertion_index)?;
|
||||
|
||||
// add leaves
|
||||
// --------------
|
||||
|
||||
// The subtree's leaf indices live in their own context - i.e. a subtree of depth `d`. If we
|
||||
// insert the subtree at `subtree_insertion_index = 0`, then the subtree leaf indices are
|
||||
// valid as they are. However, consider what happens when we insert at
|
||||
// `subtree_insertion_index = 1`. The first leaf of our subtree now will have index `2^d`;
|
||||
// you can see it as there's a full subtree sitting on its left. In general, for
|
||||
// `subtree_insertion_index = i`, there are `i` subtrees sitting before the subtree we want
|
||||
// to insert, so we need to adjust all its leaves by `i * 2^d`.
|
||||
let leaf_index_shift: u64 = subtree_insertion_index * 2_u64.pow(SUBTREE_DEPTH.into());
|
||||
for (subtree_leaf_idx, leaf_value) in subtree.leaves() {
|
||||
let new_leaf_idx = leaf_index_shift + subtree_leaf_idx;
|
||||
debug_assert!(new_leaf_idx < 2_u64.pow(DEPTH.into()));
|
||||
|
||||
self.leaves.insert(new_leaf_idx, *leaf_value);
|
||||
}
|
||||
|
||||
// add subtree's branch nodes (which includes the root)
|
||||
// --------------
|
||||
for (branch_idx, branch_node) in subtree.inner_nodes {
|
||||
let new_branch_idx = {
|
||||
let new_depth = subtree_root_insertion_depth + branch_idx.depth();
|
||||
let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
|
||||
+ branch_idx.value();
|
||||
|
||||
NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
|
||||
};
|
||||
|
||||
self.inner_nodes.insert(new_branch_idx, branch_node);
|
||||
}
|
||||
|
||||
// recompute nodes starting from subtree root
|
||||
// --------------
|
||||
self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
|
||||
|
||||
Ok(self.root)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
|
||||
type Key = LeafIndex<DEPTH>;
|
||||
type Value = Word;
|
||||
type Leaf = Word;
|
||||
type Opening = ValuePath;
|
||||
|
||||
const EMPTY_VALUE: Self::Value = EMPTY_WORD;
|
||||
|
||||
fn root(&self) -> RpoDigest {
|
||||
self.root
|
||||
}
|
||||
|
||||
fn set_root(&mut self, root: RpoDigest) {
|
||||
self.root = root;
|
||||
}
|
||||
|
||||
fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
|
||||
self.inner_nodes.get(&index).cloned().unwrap_or_else(|| {
|
||||
let node = EmptySubtreeRoots::entry(DEPTH, index.depth() + 1);
|
||||
|
||||
InnerNode { left: *node, right: *node }
|
||||
})
|
||||
}
|
||||
|
||||
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
|
||||
self.inner_nodes.insert(index, inner_node);
|
||||
}
|
||||
|
||||
fn remove_inner_node(&mut self, index: NodeIndex) {
|
||||
let _ = self.inner_nodes.remove(&index);
|
||||
}
|
||||
|
||||
fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> {
|
||||
self.leaves.insert(key.value(), value)
|
||||
}
|
||||
|
||||
fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
|
||||
// the lookup in empty_hashes could fail only if empty_hashes were not built correctly
|
||||
// by the constructor as we check the depth of the lookup above.
|
||||
let leaf_pos = key.value();
|
||||
|
||||
match self.leaves.get(&leaf_pos) {
|
||||
Some(word) => *word,
|
||||
None => Word::from(*EmptySubtreeRoots::entry(DEPTH, DEPTH)),
|
||||
}
|
||||
}
|
||||
|
||||
fn hash_leaf(leaf: &Word) -> RpoDigest {
|
||||
// `SimpleSmt` takes the leaf value itself as the hash
|
||||
leaf.into()
|
||||
}
|
||||
|
||||
fn key_to_leaf_index(key: &LeafIndex<DEPTH>) -> LeafIndex<DEPTH> {
|
||||
*key
|
||||
}
|
||||
|
||||
fn path_and_leaf_to_opening(path: MerklePath, leaf: Word) -> ValuePath {
|
||||
(path, leaf).into()
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::{
|
||||
mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, KvMap, MerkleError, MerklePath,
|
||||
MerkleTree, NodeIndex, PartialMerkleTree, RecordingMap, RootPath, Rpo256, RpoDigest, SimpleSmt,
|
||||
Smt, ValuePath, Vec,
|
||||
MerkleStoreDelta, MerkleTree, NodeIndex, PartialMerkleTree, RecordingMap, RootPath, Rpo256,
|
||||
RpoDigest, SimpleSmt, TieredSmt, TryApplyDiff, ValuePath, Vec, EMPTY_WORD,
|
||||
};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
use core::borrow::Borrow;
|
||||
@@ -173,7 +173,7 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
|
||||
// the path is computed from root to leaf, so it must be reversed
|
||||
path.reverse();
|
||||
|
||||
Ok(ValuePath::new(hash, MerklePath::new(path)))
|
||||
Ok(ValuePath::new(hash, path))
|
||||
}
|
||||
|
||||
// LEAF TRAVERSAL
|
||||
@@ -361,6 +361,9 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
|
||||
|
||||
// if the node is not in the store assume it is a leaf
|
||||
} else {
|
||||
// assert that if we have a leaf that is not at the max depth then it must be
|
||||
// at the depth of one of the tiers of an TSMT.
|
||||
debug_assert!(TieredSmt::TIER_DEPTHS[..3].contains(&index.depth()));
|
||||
return Some((index, node_hash));
|
||||
}
|
||||
}
|
||||
@@ -487,15 +490,8 @@ impl<T: KvMap<RpoDigest, StoreNode>> From<&MerkleTree> for MerkleStore<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>, const DEPTH: u8> From<&SimpleSmt<DEPTH>> for MerkleStore<T> {
|
||||
fn from(value: &SimpleSmt<DEPTH>) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&Smt> for MerkleStore<T> {
|
||||
fn from(value: &Smt) -> Self {
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&SimpleSmt> for MerkleStore<T> {
|
||||
fn from(value: &SimpleSmt) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
@@ -508,6 +504,13 @@ impl<T: KvMap<RpoDigest, StoreNode>> From<&Mmr> for MerkleStore<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&TieredSmt> for MerkleStore<T> {
|
||||
fn from(value: &TieredSmt) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&PartialMerkleTree> for MerkleStore<T> {
|
||||
fn from(value: &PartialMerkleTree) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
@@ -547,6 +550,39 @@ impl<T: KvMap<RpoDigest, StoreNode>> Extend<InnerNodeInfo> for MerkleStore<T> {
|
||||
}
|
||||
}
|
||||
|
||||
// DiffT & ApplyDiffT TRAIT IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> TryApplyDiff<RpoDigest, StoreNode> for MerkleStore<T> {
|
||||
type Error = MerkleError;
|
||||
type DiffType = MerkleStoreDelta;
|
||||
|
||||
fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), MerkleError> {
|
||||
for (root, delta) in diff.0 {
|
||||
let mut root = root;
|
||||
for cleared_slot in delta.cleared_slots() {
|
||||
root = self
|
||||
.set_node(
|
||||
root,
|
||||
NodeIndex::new(delta.depth(), *cleared_slot)?,
|
||||
EMPTY_WORD.into(),
|
||||
)?
|
||||
.root;
|
||||
}
|
||||
for (updated_slot, updated_value) in delta.updated_slots() {
|
||||
root = self
|
||||
.set_node(
|
||||
root,
|
||||
NodeIndex::new(delta.depth(), *updated_slot)?,
|
||||
(*updated_value).into(),
|
||||
)?
|
||||
.root;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
|
||||
@@ -3,9 +3,7 @@ use super::{
|
||||
PartialMerkleTree, RecordingMerkleStore, Rpo256, RpoDigest,
|
||||
};
|
||||
use crate::{
|
||||
merkle::{
|
||||
digests_to_words, int_to_leaf, int_to_node, LeafIndex, MerkleTree, SimpleSmt, SMT_MAX_DEPTH,
|
||||
},
|
||||
merkle::{digests_to_words, int_to_leaf, int_to_node, MerkleTree, SimpleSmt},
|
||||
Felt, Word, ONE, WORD_SIZE, ZERO,
|
||||
};
|
||||
|
||||
@@ -15,8 +13,6 @@ use super::{Deserializable, Serializable};
|
||||
#[cfg(feature = "std")]
|
||||
use std::error::Error;
|
||||
|
||||
use seq_macro::seq;
|
||||
|
||||
// TEST DATA
|
||||
// ================================================================================================
|
||||
|
||||
@@ -107,7 +103,7 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
|
||||
"node 3 must be the same for both MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
// STORE MERKLE PATH MATCHES ==============================================================
|
||||
// STORE MERKLE PATH MATCHS ==============================================================
|
||||
// assert the merkle path returned by the store is the same as the one in the tree
|
||||
let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 0)).unwrap();
|
||||
assert_eq!(
|
||||
@@ -177,12 +173,12 @@ fn test_leaf_paths_for_empty_trees() -> Result<(), MerkleError> {
|
||||
// Starts at 1 because leafs are not included in the store.
|
||||
// Ends at 64 because it is not possible to represent an index of a depth greater than 64,
|
||||
// because a u64 is used to index the leaf.
|
||||
seq!(DEPTH in 1_u8..64_u8 {
|
||||
let smt = SimpleSmt::<DEPTH>::new()?;
|
||||
for depth in 1..64 {
|
||||
let smt = SimpleSmt::new(depth)?;
|
||||
|
||||
let index = NodeIndex::make(DEPTH, 0);
|
||||
let index = NodeIndex::make(depth, 0);
|
||||
let store_path = store.get_path(smt.root(), index)?;
|
||||
let smt_path = smt.open(&LeafIndex::<DEPTH>::new(0)?).path;
|
||||
let smt_path = smt.get_path(index)?;
|
||||
assert_eq!(
|
||||
store_path.value,
|
||||
RpoDigest::default(),
|
||||
@@ -193,12 +189,11 @@ fn test_leaf_paths_for_empty_trees() -> Result<(), MerkleError> {
|
||||
"the returned merkle path does not match the computed values"
|
||||
);
|
||||
assert_eq!(
|
||||
store_path.path.compute_root(DEPTH.into(), RpoDigest::default()).unwrap(),
|
||||
store_path.path.compute_root(depth.into(), RpoDigest::default()).unwrap(),
|
||||
smt.root(),
|
||||
"computed root from the path must match the empty tree root"
|
||||
);
|
||||
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -215,7 +210,7 @@ fn test_get_invalid_node() {
|
||||
fn test_add_sparse_merkle_tree_one_level() -> Result<(), MerkleError> {
|
||||
let keys2: [u64; 2] = [0, 1];
|
||||
let leaves2: [Word; 2] = [int_to_leaf(1), int_to_leaf(2)];
|
||||
let smt = SimpleSmt::<1>::with_leaves(keys2.into_iter().zip(leaves2)).unwrap();
|
||||
let smt = SimpleSmt::with_leaves(1, keys2.into_iter().zip(leaves2.into_iter())).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
|
||||
let idx = NodeIndex::make(1, 0);
|
||||
@@ -231,36 +226,38 @@ fn test_add_sparse_merkle_tree_one_level() -> Result<(), MerkleError> {
|
||||
|
||||
#[test]
|
||||
fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
|
||||
let smt =
|
||||
SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4)))
|
||||
.unwrap();
|
||||
let smt = SimpleSmt::with_leaves(
|
||||
SimpleSmt::MAX_DEPTH,
|
||||
KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let store = MerkleStore::from(&smt);
|
||||
|
||||
// STORE LEAVES ARE CORRECT ==============================================================
|
||||
// checks the leaves in the store corresponds to the expected values
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)),
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 0)),
|
||||
Ok(VALUES4[0]),
|
||||
"node 0 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)),
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 1)),
|
||||
Ok(VALUES4[1]),
|
||||
"node 1 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)),
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 2)),
|
||||
Ok(VALUES4[2]),
|
||||
"node 2 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)),
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 3)),
|
||||
Ok(VALUES4[3]),
|
||||
"node 3 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)),
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 4)),
|
||||
Ok(RpoDigest::default()),
|
||||
"unmodified node 4 must be ZERO"
|
||||
);
|
||||
@@ -268,86 +265,86 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
|
||||
// STORE LEAVES MATCH TREE ===============================================================
|
||||
// sanity check the values returned by the store and the tree
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 0)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)),
|
||||
smt.get_node(NodeIndex::make(smt.depth(), 0)),
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 0)),
|
||||
"node 0 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 1)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)),
|
||||
smt.get_node(NodeIndex::make(smt.depth(), 1)),
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 1)),
|
||||
"node 1 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 2)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)),
|
||||
smt.get_node(NodeIndex::make(smt.depth(), 2)),
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 2)),
|
||||
"node 2 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 3)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)),
|
||||
smt.get_node(NodeIndex::make(smt.depth(), 3)),
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 3)),
|
||||
"node 3 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 4)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)),
|
||||
smt.get_node(NodeIndex::make(smt.depth(), 4)),
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 4)),
|
||||
"node 4 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
// STORE MERKLE PATH MATCHES ==============================================================
|
||||
// STORE MERKLE PATH MATCHS ==============================================================
|
||||
// assert the merkle path returned by the store is the same as the one in the tree
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap();
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 0)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[0], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(0).unwrap()).path,
|
||||
result.path,
|
||||
smt.get_path(NodeIndex::make(smt.depth(), 0)),
|
||||
Ok(result.path),
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap();
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 1)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[1], result.value,
|
||||
"Value for merkle path at index 1 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(1).unwrap()).path,
|
||||
result.path,
|
||||
smt.get_path(NodeIndex::make(smt.depth(), 1)),
|
||||
Ok(result.path),
|
||||
"merkle path for index 1 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap();
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 2)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[2], result.value,
|
||||
"Value for merkle path at index 2 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(2).unwrap()).path,
|
||||
result.path,
|
||||
smt.get_path(NodeIndex::make(smt.depth(), 2)),
|
||||
Ok(result.path),
|
||||
"merkle path for index 2 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap();
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 3)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[3], result.value,
|
||||
"Value for merkle path at index 3 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(3).unwrap()).path,
|
||||
result.path,
|
||||
smt.get_path(NodeIndex::make(smt.depth(), 3)),
|
||||
Ok(result.path),
|
||||
"merkle path for index 3 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap();
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 4)).unwrap();
|
||||
assert_eq!(
|
||||
RpoDigest::default(),
|
||||
result.value,
|
||||
"Value for merkle path at index 4 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(4).unwrap()).path,
|
||||
result.path,
|
||||
smt.get_path(NodeIndex::make(smt.depth(), 4)),
|
||||
Ok(result.path),
|
||||
"merkle path for index 4 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
@@ -428,7 +425,7 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> {
|
||||
"node 3 must be the same for both PartialMerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
// STORE MERKLE PATH MATCHES ==============================================================
|
||||
// STORE MERKLE PATH MATCHS ==============================================================
|
||||
// assert the merkle path returned by the store is the same as the one in the pmt
|
||||
let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)).unwrap();
|
||||
assert_eq!(
|
||||
@@ -555,15 +552,19 @@ fn test_constructors() -> Result<(), MerkleError> {
|
||||
assert_eq!(mtree.get_path(index)?, value_path.path);
|
||||
}
|
||||
|
||||
const DEPTH: u8 = 32;
|
||||
let smt =
|
||||
SimpleSmt::<DEPTH>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4))).unwrap();
|
||||
let depth = 32;
|
||||
let smt = SimpleSmt::with_leaves(
|
||||
depth,
|
||||
KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()),
|
||||
)
|
||||
.unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let depth = smt.depth();
|
||||
|
||||
for key in KEYS4 {
|
||||
let index = NodeIndex::make(DEPTH, key);
|
||||
let index = NodeIndex::make(depth, key);
|
||||
let value_path = store.get_path(smt.root(), index)?;
|
||||
assert_eq!(smt.open(&LeafIndex::<DEPTH>::new(key).unwrap()).path, value_path.path);
|
||||
assert_eq!(smt.get_path(index)?, value_path.path);
|
||||
}
|
||||
|
||||
let d = 2;
|
||||
@@ -651,7 +652,7 @@ fn get_leaf_depth_works_depth_64() {
|
||||
let index = NodeIndex::new(64, k).unwrap();
|
||||
|
||||
// assert the leaf doesn't exist before the insert. the returned depth should always
|
||||
// increment with the paths count of the set, as they are intersecting one another up to
|
||||
// increment with the paths count of the set, as they are insersecting one another up to
|
||||
// the first bits of the used key.
|
||||
assert_eq!(d, store.get_leaf_depth(root, 64, k).unwrap());
|
||||
|
||||
@@ -882,9 +883,8 @@ fn test_serialization() -> Result<(), Box<dyn Error>> {
|
||||
fn test_recorder() {
|
||||
// instantiate recorder from MerkleTree and SimpleSmt
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4)).unwrap();
|
||||
|
||||
const TREE_DEPTH: u8 = 64;
|
||||
let smtree = SimpleSmt::<TREE_DEPTH>::with_leaves(
|
||||
let smtree = SimpleSmt::with_leaves(
|
||||
64,
|
||||
KEYS8.into_iter().zip(VALUES8.into_iter().map(|x| x.into()).rev()),
|
||||
)
|
||||
.unwrap();
|
||||
@@ -897,13 +897,13 @@ fn test_recorder() {
|
||||
let node = recorder.get_node(mtree.root(), index_0).unwrap();
|
||||
assert_eq!(node, mtree.get_node(index_0).unwrap());
|
||||
|
||||
let index_1 = NodeIndex::new(TREE_DEPTH, 1).unwrap();
|
||||
let index_1 = NodeIndex::new(smtree.depth(), 1).unwrap();
|
||||
let node = recorder.get_node(smtree.root(), index_1).unwrap();
|
||||
assert_eq!(node, smtree.get_node(index_1).unwrap());
|
||||
|
||||
// insert a value and assert that when we request it next time it is accurate
|
||||
let new_value = [ZERO, ZERO, ONE, ONE].into();
|
||||
let index_2 = NodeIndex::new(TREE_DEPTH, 2).unwrap();
|
||||
let index_2 = NodeIndex::new(smtree.depth(), 2).unwrap();
|
||||
let root = recorder.set_node(smtree.root(), index_2, new_value).unwrap().root;
|
||||
assert_eq!(recorder.get_node(root, index_2).unwrap(), new_value);
|
||||
|
||||
@@ -920,13 +920,10 @@ fn test_recorder() {
|
||||
assert_eq!(node, smtree.get_node(index_1).unwrap());
|
||||
|
||||
let node = merkle_store.get_node(smtree.root(), index_2).unwrap();
|
||||
assert_eq!(
|
||||
node,
|
||||
smtree.get_leaf(&LeafIndex::<TREE_DEPTH>::try_from(index_2).unwrap()).into()
|
||||
);
|
||||
assert_eq!(node, smtree.get_leaf(index_2.value()).unwrap().into());
|
||||
|
||||
// assert that is doesnt contain nodes that were not recorded
|
||||
let not_recorded_index = NodeIndex::new(TREE_DEPTH, 4).unwrap();
|
||||
let not_recorded_index = NodeIndex::new(smtree.depth(), 4).unwrap();
|
||||
assert!(merkle_store.get_node(smtree.root(), not_recorded_index).is_err());
|
||||
assert!(smtree.get_node(not_recorded_index).is_ok());
|
||||
}
|
||||
|
||||
48
src/merkle/tiered_smt/error.rs
Normal file
48
src/merkle/tiered_smt/error.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use core::fmt::Display;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum TieredSmtProofError {
|
||||
EntriesEmpty,
|
||||
EmptyValueNotAllowed,
|
||||
MismatchedPrefixes(u64, u64),
|
||||
MultipleEntriesOutsideLastTier,
|
||||
NotATierPath(u8),
|
||||
PathTooLong,
|
||||
}
|
||||
|
||||
impl Display for TieredSmtProofError {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
match self {
|
||||
TieredSmtProofError::EntriesEmpty => {
|
||||
write!(f, "Missing entries for tiered sparse merkle tree proof")
|
||||
}
|
||||
TieredSmtProofError::EmptyValueNotAllowed => {
|
||||
write!(
|
||||
f,
|
||||
"The empty value [0, 0, 0, 0] is not allowed inside a tiered sparse merkle tree"
|
||||
)
|
||||
}
|
||||
TieredSmtProofError::MismatchedPrefixes(first, second) => {
|
||||
write!(f, "Not all leaves have the same prefix. First {first} second {second}")
|
||||
}
|
||||
TieredSmtProofError::MultipleEntriesOutsideLastTier => {
|
||||
write!(f, "Multiple entries are only allowed for the last tier (depth 64)")
|
||||
}
|
||||
TieredSmtProofError::NotATierPath(got) => {
|
||||
write!(
|
||||
f,
|
||||
"Path length does not correspond to a tier. Got {got} Expected one of 16, 32, 48, 64"
|
||||
)
|
||||
}
|
||||
TieredSmtProofError::PathTooLong => {
|
||||
write!(
|
||||
f,
|
||||
"Path longer than maximum depth of 64 for tiered sparse merkle tree proof"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for TieredSmtProofError {}
|
||||
509
src/merkle/tiered_smt/mod.rs
Normal file
509
src/merkle/tiered_smt/mod.rs
Normal file
@@ -0,0 +1,509 @@
|
||||
use super::{
|
||||
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex,
|
||||
Rpo256, RpoDigest, StarkField, Vec, Word,
|
||||
};
|
||||
use crate::utils::vec;
|
||||
use core::{cmp, ops::Deref};
|
||||
|
||||
mod nodes;
|
||||
use nodes::NodeStore;
|
||||
|
||||
mod values;
|
||||
use values::ValueStore;
|
||||
|
||||
mod proof;
|
||||
pub use proof::TieredSmtProof;
|
||||
|
||||
mod error;
|
||||
pub use error::TieredSmtProofError;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// TIERED SPARSE MERKLE TREE
|
||||
// ================================================================================================
|
||||
|
||||
/// Tiered (compacted) Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and
|
||||
/// values are represented by 4 field elements.
|
||||
///
|
||||
/// Leaves in the tree can exist only on specific depths called "tiers". These depths are: 16, 32,
|
||||
/// 48, and 64. Initially, when a tree is empty, it is equivalent to an empty Sparse Merkle tree
|
||||
/// of depth 64 (i.e., leaves at depth 64 are set to [ZERO; 4]). As non-empty values are inserted
|
||||
/// into the tree they are added to the first available tier.
|
||||
///
|
||||
/// For example, when the first key-value pair is inserted, it will be stored in a node at depth
|
||||
/// 16 such that the 16 most significant bits of the key determine the position of the node at
|
||||
/// depth 16. If another value with a key sharing the same 16-bit prefix is inserted, both values
|
||||
/// move into the next tier (depth 32). This process is repeated until values end up at the bottom
|
||||
/// tier (depth 64). If multiple values have keys with a common 64-bit prefix, such key-value pairs
|
||||
/// are stored in a sorted list at the bottom tier.
|
||||
///
|
||||
/// To differentiate between internal and leaf nodes, node values are computed as follows:
|
||||
/// - Internal nodes: hash(left_child, right_child).
|
||||
/// - Leaf node at depths 16, 32, or 64: hash(key, value, domain=depth).
|
||||
/// - Leaf node at depth 64: hash([key_0, value_0, ..., key_n, value_n], domain=64).
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct TieredSmt {
|
||||
root: RpoDigest,
|
||||
nodes: NodeStore,
|
||||
values: ValueStore,
|
||||
}
|
||||
|
||||
impl TieredSmt {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// The number of levels between tiers.
|
||||
pub const TIER_SIZE: u8 = 16;
|
||||
|
||||
/// Depths at which leaves can exist in a tiered SMT.
|
||||
pub const TIER_DEPTHS: [u8; 4] = [16, 32, 48, 64];
|
||||
|
||||
/// Maximum node depth. This is also the bottom tier of the tree.
|
||||
pub const MAX_DEPTH: u8 = 64;
|
||||
|
||||
/// Value of an empty leaf.
|
||||
pub const EMPTY_VALUE: Word = super::EMPTY_WORD;
|
||||
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [TieredSmt] instantiated with the specified key-value pairs.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the provided entries contain multiple values for the same key.
|
||||
pub fn with_entries<R, I>(entries: R) -> Result<Self, MerkleError>
|
||||
where
|
||||
R: IntoIterator<IntoIter = I>,
|
||||
I: Iterator<Item = (RpoDigest, Word)> + ExactSizeIterator,
|
||||
{
|
||||
// create an empty tree
|
||||
let mut tree = Self::default();
|
||||
|
||||
// append leaves to the tree returning an error if a duplicate entry for the same key
|
||||
// is found
|
||||
let mut empty_entries = BTreeSet::new();
|
||||
for (key, value) in entries {
|
||||
let old_value = tree.insert(key, value);
|
||||
if old_value != Self::EMPTY_VALUE || empty_entries.contains(&key) {
|
||||
return Err(MerkleError::DuplicateValuesForKey(key));
|
||||
}
|
||||
// if we've processed an empty entry, add the key to the set of empty entry keys, and
|
||||
// if this key was already in the set, return an error
|
||||
if value == Self::EMPTY_VALUE && !empty_entries.insert(key) {
|
||||
return Err(MerkleError::DuplicateValuesForKey(key));
|
||||
}
|
||||
}
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the root of this Merkle tree.
|
||||
pub const fn root(&self) -> RpoDigest {
|
||||
self.root
|
||||
}
|
||||
|
||||
/// Returns a node at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index depth is 0 or greater than 64.
|
||||
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
||||
/// when a leaf node with the same index prefix exists at a tier higher than the requested
|
||||
/// node.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
self.nodes.get_node(index)
|
||||
}
|
||||
|
||||
/// Returns a Merkle path from the node at the specified index to the root.
|
||||
///
|
||||
/// The node itself is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index depth is 0 or greater than 64.
|
||||
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
||||
/// when a leaf node with the same index prefix exists at a tier higher than the node to
|
||||
/// which the path is requested.
|
||||
pub fn get_path(&self, index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
self.nodes.get_path(index)
|
||||
}
|
||||
|
||||
/// Returns the value associated with the specified key.
|
||||
///
|
||||
/// If nothing was inserted into this tree for the specified key, [ZERO; 4] is returned.
|
||||
pub fn get_value(&self, key: RpoDigest) -> Word {
|
||||
match self.values.get(&key) {
|
||||
Some(value) => *value,
|
||||
None => Self::EMPTY_VALUE,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a proof for a key-value pair defined by the specified key.
|
||||
///
|
||||
/// The proof can be used to attest membership of this key-value pair in a Tiered Sparse Merkle
|
||||
/// Tree defined by the same root as this tree.
|
||||
pub fn prove(&self, key: RpoDigest) -> TieredSmtProof {
|
||||
let (path, index, leaf_exists) = self.nodes.get_proof(&key);
|
||||
|
||||
let entries = if index.depth() == Self::MAX_DEPTH {
|
||||
match self.values.get_all(index.value()) {
|
||||
Some(entries) => entries,
|
||||
None => vec![(key, Self::EMPTY_VALUE)],
|
||||
}
|
||||
} else if leaf_exists {
|
||||
let entry =
|
||||
self.values.get_first(index_to_prefix(&index)).expect("leaf entry not found");
|
||||
debug_assert_eq!(entry.0, key);
|
||||
vec![*entry]
|
||||
} else {
|
||||
vec![(key, Self::EMPTY_VALUE)]
|
||||
};
|
||||
|
||||
TieredSmtProof::new(path, entries).expect("Bug detected, TSMT produced invalid proof")
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts the provided value into the tree under the specified key and returns the value
|
||||
/// previously stored under this key.
|
||||
///
|
||||
/// If the value for the specified key was not previously set, [ZERO; 4] is returned.
|
||||
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Word {
|
||||
// if an empty value is being inserted, remove the leaf node to make it look as if the
|
||||
// value was never inserted
|
||||
if value == Self::EMPTY_VALUE {
|
||||
return self.remove_leaf_node(key);
|
||||
}
|
||||
|
||||
// insert the value into the value store, and if the key was already in the store, update
|
||||
// it with the new value
|
||||
if let Some(old_value) = self.values.insert(key, value) {
|
||||
if old_value != value {
|
||||
// if the new value is different from the old value, determine the location of
|
||||
// the leaf node for this key, build the node, and update the root
|
||||
let (index, leaf_exists) = self.nodes.get_leaf_index(&key);
|
||||
debug_assert!(leaf_exists);
|
||||
let node = self.build_leaf_node(index, key, value);
|
||||
self.root = self.nodes.update_leaf_node(index, node);
|
||||
}
|
||||
return old_value;
|
||||
};
|
||||
|
||||
// determine the location for the leaf node; this index could have 3 different meanings:
|
||||
// - it points to a root of an empty subtree or an empty node at depth 64; in this case,
|
||||
// we can replace the node with the value node immediately.
|
||||
// - it points to an existing leaf at the bottom tier (i.e., depth = 64); in this case,
|
||||
// we need to process update the bottom leaf.
|
||||
// - it points to an existing leaf node for a different key with the same prefix (same
|
||||
// key case was handled above); in this case, we need to move the leaf to a lower tier
|
||||
let (index, leaf_exists) = self.nodes.get_leaf_index(&key);
|
||||
|
||||
self.root = if leaf_exists && index.depth() == Self::MAX_DEPTH {
|
||||
// returned index points to a leaf at the bottom tier
|
||||
let node = self.build_leaf_node(index, key, value);
|
||||
self.nodes.update_leaf_node(index, node)
|
||||
} else if leaf_exists {
|
||||
// returned index points to a leaf for a different key with the same prefix
|
||||
|
||||
// get the key-value pair for the key with the same prefix; since the key-value
|
||||
// pair has already been inserted into the value store, we need to filter it out
|
||||
// when looking for the other key-value pair
|
||||
let (other_key, other_value) = self
|
||||
.values
|
||||
.get_first_filtered(index_to_prefix(&index), &key)
|
||||
.expect("other key-value pair not found");
|
||||
|
||||
// determine how far down the tree should we move the leaves
|
||||
let common_prefix_len = get_common_prefix_tier_depth(&key, other_key);
|
||||
let depth = cmp::min(common_prefix_len + Self::TIER_SIZE, Self::MAX_DEPTH);
|
||||
|
||||
// compute node locations for new and existing key-value paris
|
||||
let new_index = LeafNodeIndex::from_key(&key, depth);
|
||||
let other_index = LeafNodeIndex::from_key(other_key, depth);
|
||||
|
||||
// compute node values for the new and existing key-value pairs
|
||||
let new_node = self.build_leaf_node(new_index, key, value);
|
||||
let other_node = self.build_leaf_node(other_index, *other_key, *other_value);
|
||||
|
||||
// replace the leaf located at index with a subtree containing nodes for new and
|
||||
// existing key-value paris
|
||||
self.nodes.replace_leaf_with_subtree(
|
||||
index,
|
||||
[(new_index, new_node), (other_index, other_node)],
|
||||
)
|
||||
} else {
|
||||
// returned index points to an empty subtree or an empty leaf at the bottom tier
|
||||
let node = self.build_leaf_node(index, key, value);
|
||||
self.nodes.insert_leaf_node(index, node)
|
||||
};
|
||||
|
||||
Self::EMPTY_VALUE
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over all key-value pairs in this [TieredSmt].
|
||||
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
self.values.iter()
|
||||
}
|
||||
|
||||
/// Returns an iterator over all inner nodes of this [TieredSmt] (i.e., nodes not at depths 16
|
||||
/// 32, 48, or 64).
|
||||
///
|
||||
/// The iterator order is unspecified.
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.nodes.inner_nodes()
|
||||
}
|
||||
|
||||
/// Returns an iterator over upper leaves (i.e., depth = 16, 32, or 48) for this [TieredSmt]
|
||||
/// where each yielded item is a (node, key, value) tuple.
|
||||
///
|
||||
/// The iterator order is unspecified.
|
||||
pub fn upper_leaves(&self) -> impl Iterator<Item = (RpoDigest, RpoDigest, Word)> + '_ {
|
||||
self.nodes.upper_leaves().map(|(index, node)| {
|
||||
let key_prefix = index_to_prefix(index);
|
||||
let (key, value) = self.values.get_first(key_prefix).expect("upper leaf not found");
|
||||
debug_assert_eq!(*index, LeafNodeIndex::from_key(key, index.depth()).into());
|
||||
(*node, *key, *value)
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns an iterator over upper leaves (i.e., depth = 16, 32, or 48) for this [TieredSmt]
|
||||
/// where each yielded item is a (node_index, value) tuple.
|
||||
pub fn upper_leaf_nodes(&self) -> impl Iterator<Item = (&NodeIndex, &RpoDigest)> {
|
||||
self.nodes.upper_leaves()
|
||||
}
|
||||
|
||||
/// Returns an iterator over bottom leaves (i.e., depth = 64) of this [TieredSmt].
|
||||
///
|
||||
/// Each yielded item consists of the hash of the leaf and its contents, where contents is
|
||||
/// a vector containing key-value pairs of entries storied in this leaf.
|
||||
///
|
||||
/// The iterator order is unspecified.
|
||||
pub fn bottom_leaves(&self) -> impl Iterator<Item = (RpoDigest, Vec<(RpoDigest, Word)>)> + '_ {
|
||||
self.nodes.bottom_leaves().map(|(&prefix, node)| {
|
||||
let values = self.values.get_all(prefix).expect("bottom leaf not found");
|
||||
(*node, values)
|
||||
})
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Removes the node holding the key-value pair for the specified key from this tree, and
|
||||
/// returns the value associated with the specified key.
|
||||
///
|
||||
/// If no value was associated with the specified key, [ZERO; 4] is returned.
|
||||
fn remove_leaf_node(&mut self, key: RpoDigest) -> Word {
|
||||
// remove the key-value pair from the value store; if no value was associated with the
|
||||
// specified key, return.
|
||||
let old_value = match self.values.remove(&key) {
|
||||
Some(old_value) => old_value,
|
||||
None => return Self::EMPTY_VALUE,
|
||||
};
|
||||
|
||||
// determine the location of the leaf holding the key-value pair to be removed
|
||||
let (index, leaf_exists) = self.nodes.get_leaf_index(&key);
|
||||
debug_assert!(leaf_exists);
|
||||
|
||||
// if the leaf is at the bottom tier and after removing the key-value pair from it, the
|
||||
// leaf is still not empty, we either just update it, or move it up to a higher tier (if
|
||||
// the leaf doesn't have siblings at lower tiers)
|
||||
if index.depth() == Self::MAX_DEPTH {
|
||||
if let Some(entries) = self.values.get_all(index.value()) {
|
||||
// if there is only one key-value pair left at the bottom leaf, and it can be
|
||||
// moved up to a higher tier, truncate the branch and return
|
||||
if entries.len() == 1 {
|
||||
let new_depth = self.nodes.get_last_single_child_parent_depth(index.value());
|
||||
if new_depth != Self::MAX_DEPTH {
|
||||
let node = hash_upper_leaf(entries[0].0, entries[0].1, new_depth);
|
||||
self.root = self.nodes.truncate_branch(index.value(), new_depth, node);
|
||||
return old_value;
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise just recompute the leaf hash and update the leaf node
|
||||
let node = hash_bottom_leaf(&entries);
|
||||
self.root = self.nodes.update_leaf_node(index, node);
|
||||
return old_value;
|
||||
};
|
||||
}
|
||||
|
||||
// if the removed key-value pair has a lone sibling at the current tier with a root at
|
||||
// higher tier, we need to move the sibling to a higher tier
|
||||
if let Some((sib_key, sib_val, new_sib_index)) = self.values.get_lone_sibling(index) {
|
||||
// determine the current index of the sibling node
|
||||
let sib_index = LeafNodeIndex::from_key(sib_key, index.depth());
|
||||
debug_assert!(sib_index.depth() > new_sib_index.depth());
|
||||
|
||||
// compute node value for the new location of the sibling leaf and replace the subtree
|
||||
// with this leaf node
|
||||
let node = self.build_leaf_node(new_sib_index, *sib_key, *sib_val);
|
||||
let new_sib_depth = new_sib_index.depth();
|
||||
self.root = self.nodes.replace_subtree_with_leaf(index, sib_index, new_sib_depth, node);
|
||||
} else {
|
||||
// if the removed key-value pair did not have a sibling at the current tier with a
|
||||
// root at higher tiers, just clear the leaf node
|
||||
self.root = self.nodes.clear_leaf_node(index);
|
||||
}
|
||||
|
||||
old_value
|
||||
}
|
||||
|
||||
/// Builds and returns a leaf node value for the node located as the specified index.
|
||||
///
|
||||
/// This method assumes that the key-value pair for the node has already been inserted into
|
||||
/// the value store, however, for depths 16, 32, and 48, the node is computed directly from
|
||||
/// the passed-in values (for depth 64, the value store is queried to get all the key-value
|
||||
/// pairs located at the specified index).
|
||||
fn build_leaf_node(&self, index: LeafNodeIndex, key: RpoDigest, value: Word) -> RpoDigest {
|
||||
let depth = index.depth();
|
||||
|
||||
// insert the key into index-key map and compute the new value of the node
|
||||
if index.depth() == Self::MAX_DEPTH {
|
||||
// for the bottom tier, we add the key-value pair to the existing leaf, or create a
|
||||
// new leaf with this key-value pair
|
||||
let values = self.values.get_all(index.value()).unwrap();
|
||||
hash_bottom_leaf(&values)
|
||||
} else {
|
||||
debug_assert_eq!(self.values.get_first(index_to_prefix(&index)), Some(&(key, value)));
|
||||
hash_upper_leaf(key, value, depth)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TieredSmt {
|
||||
fn default() -> Self {
|
||||
let root = EmptySubtreeRoots::empty_hashes(Self::MAX_DEPTH)[0];
|
||||
Self {
|
||||
root,
|
||||
nodes: NodeStore::new(root),
|
||||
values: ValueStore::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LEAF NODE INDEX
|
||||
// ================================================================================================
|
||||
/// A wrapper around [NodeIndex] to provide type-safe references to nodes at depths 16, 32, 48, and
|
||||
/// 64.
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
|
||||
pub struct LeafNodeIndex(NodeIndex);
|
||||
|
||||
impl LeafNodeIndex {
|
||||
/// Returns a new [LeafNodeIndex] instantiated from the provided [NodeIndex].
|
||||
///
|
||||
/// In debug mode, panics if index depth is not 16, 32, 48, or 64.
|
||||
pub fn new(index: NodeIndex) -> Self {
|
||||
// check if the depth is 16, 32, 48, or 64; this works because for a valid depth,
|
||||
// depth - 16, can be 0, 16, 32, or 48 - i.e., the value is either 0 or any of the 4th
|
||||
// or 5th bits are set. We can test for this by computing a bitwise AND with a value
|
||||
// which has all but the 4th and 5th bits set (which is !48).
|
||||
debug_assert_eq!(((index.depth() - 16) & !48), 0, "invalid tier depth {}", index.depth());
|
||||
Self(index)
|
||||
}
|
||||
|
||||
/// Returns a new [LeafNodeIndex] instantiated from the specified key inserted at the specified
|
||||
/// depth.
|
||||
///
|
||||
/// The value for the key is computed by taking n most significant bits from the most significant
|
||||
/// element of the key, where n is the specified depth.
|
||||
pub fn from_key(key: &RpoDigest, depth: u8) -> Self {
|
||||
let mse = get_key_prefix(key);
|
||||
Self::new(NodeIndex::new_unchecked(depth, mse >> (TieredSmt::MAX_DEPTH - depth)))
|
||||
}
|
||||
|
||||
/// Returns a new [LeafNodeIndex] instantiated for testing purposes.
|
||||
#[cfg(test)]
|
||||
pub fn make(depth: u8, value: u64) -> Self {
|
||||
Self::new(NodeIndex::make(depth, value))
|
||||
}
|
||||
|
||||
/// Traverses towards the root until the specified depth is reached.
|
||||
///
|
||||
/// The new depth must be a valid tier depth - i.e., 16, 32, 48, or 64.
|
||||
pub fn move_up_to(&mut self, depth: u8) {
|
||||
debug_assert_eq!(((depth - 16) & !48), 0, "invalid tier depth: {depth}");
|
||||
self.0.move_up_to(depth);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for LeafNodeIndex {
|
||||
type Target = NodeIndex;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NodeIndex> for LeafNodeIndex {
|
||||
fn from(value: NodeIndex) -> Self {
|
||||
Self::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LeafNodeIndex> for NodeIndex {
|
||||
fn from(value: LeafNodeIndex) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Returns the value representing the 64 most significant bits of the specified key.
|
||||
fn get_key_prefix(key: &RpoDigest) -> u64 {
|
||||
Word::from(key)[3].as_int()
|
||||
}
|
||||
|
||||
/// Returns the index value shifted to be in the most significant bit positions of the returned
|
||||
/// u64 value.
|
||||
fn index_to_prefix(index: &NodeIndex) -> u64 {
|
||||
index.value() << (TieredSmt::MAX_DEPTH - index.depth())
|
||||
}
|
||||
|
||||
/// Returns tiered common prefix length between the most significant elements of the provided keys.
|
||||
///
|
||||
/// Specifically:
|
||||
/// - returns 64 if the most significant elements are equal.
|
||||
/// - returns 48 if the common prefix is between 48 and 63 bits.
|
||||
/// - returns 32 if the common prefix is between 32 and 47 bits.
|
||||
/// - returns 16 if the common prefix is between 16 and 31 bits.
|
||||
/// - returns 0 if the common prefix is fewer than 16 bits.
|
||||
fn get_common_prefix_tier_depth(key1: &RpoDigest, key2: &RpoDigest) -> u8 {
|
||||
let e1 = get_key_prefix(key1);
|
||||
let e2 = get_key_prefix(key2);
|
||||
let ex = (e1 ^ e2).leading_zeros() as u8;
|
||||
(ex / 16) * 16
|
||||
}
|
||||
|
||||
/// Computes node value for leaves at tiers 16, 32, or 48.
|
||||
///
|
||||
/// Node value is computed as: hash(key || value, domain = depth).
|
||||
pub fn hash_upper_leaf(key: RpoDigest, value: Word, depth: u8) -> RpoDigest {
|
||||
const NUM_UPPER_TIERS: usize = TieredSmt::TIER_DEPTHS.len() - 1;
|
||||
debug_assert!(TieredSmt::TIER_DEPTHS[..NUM_UPPER_TIERS].contains(&depth));
|
||||
Rpo256::merge_in_domain(&[key, value.into()], depth.into())
|
||||
}
|
||||
|
||||
/// Computes node value for leaves at the bottom tier (depth 64).
|
||||
///
|
||||
/// Node value is computed as: hash([key_0, value_0, ..., key_n, value_n], domain=64).
|
||||
///
|
||||
/// TODO: when hashing in domain is implemented for `hash_elements()`, combine this function with
|
||||
/// `hash_upper_leaf()` function.
|
||||
pub fn hash_bottom_leaf(values: &[(RpoDigest, Word)]) -> RpoDigest {
|
||||
let mut elements = Vec::with_capacity(values.len() * 8);
|
||||
for (key, val) in values.iter() {
|
||||
elements.extend_from_slice(key.as_elements());
|
||||
elements.extend_from_slice(val.as_slice());
|
||||
}
|
||||
// TODO: hash in domain
|
||||
Rpo256::hash_elements(&elements)
|
||||
}
|
||||
419
src/merkle/tiered_smt/nodes.rs
Normal file
419
src/merkle/tiered_smt/nodes.rs
Normal file
@@ -0,0 +1,419 @@
|
||||
use super::{
|
||||
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, LeafNodeIndex, MerkleError, MerklePath,
|
||||
NodeIndex, Rpo256, RpoDigest, Vec,
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// The number of levels between tiers.
|
||||
const TIER_SIZE: u8 = super::TieredSmt::TIER_SIZE;
|
||||
|
||||
/// Depths at which leaves can exist in a tiered SMT.
|
||||
const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
|
||||
|
||||
/// Maximum node depth. This is also the bottom tier of the tree.
|
||||
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
|
||||
|
||||
// NODE STORE
|
||||
// ================================================================================================
|
||||
|
||||
/// A store of nodes for a Tiered Sparse Merkle tree.
|
||||
///
|
||||
/// The store contains information about all nodes as well as information about which of the nodes
|
||||
/// represent leaf nodes in a Tiered Sparse Merkle tree. In the current implementation, [BTreeSet]s
|
||||
/// are used to determine the position of the leaves in the tree.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct NodeStore {
|
||||
nodes: BTreeMap<NodeIndex, RpoDigest>,
|
||||
upper_leaves: BTreeSet<NodeIndex>,
|
||||
bottom_leaves: BTreeSet<u64>,
|
||||
}
|
||||
|
||||
impl NodeStore {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns a new instance of [NodeStore] instantiated with the specified root node.
|
||||
///
|
||||
/// Root node is assumed to be a root of an empty sparse Merkle tree.
|
||||
pub fn new(root_node: RpoDigest) -> Self {
|
||||
let mut nodes = BTreeMap::default();
|
||||
nodes.insert(NodeIndex::root(), root_node);
|
||||
|
||||
Self {
|
||||
nodes,
|
||||
upper_leaves: BTreeSet::default(),
|
||||
bottom_leaves: BTreeSet::default(),
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a node at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index depth is 0 or greater than 64.
|
||||
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
||||
/// when a leaf node with the same index prefix exists at a tier higher than the requested
|
||||
/// node.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
self.validate_node_access(index)?;
|
||||
Ok(self.get_node_unchecked(&index))
|
||||
}
|
||||
|
||||
/// Returns a Merkle path from the node at the specified index to the root.
|
||||
///
|
||||
/// The node itself is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index depth is 0 or greater than 64.
|
||||
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
||||
/// when a leaf node with the same index prefix exists at a tier higher than the node to
|
||||
/// which the path is requested.
|
||||
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
self.validate_node_access(index)?;
|
||||
|
||||
let mut path = Vec::with_capacity(index.depth() as usize);
|
||||
for _ in 0..index.depth() {
|
||||
let node = self.get_node_unchecked(&index.sibling());
|
||||
path.push(node);
|
||||
index.move_up();
|
||||
}
|
||||
|
||||
Ok(path.into())
|
||||
}
|
||||
|
||||
/// Returns a Merkle path to the node specified by the key together with a flag indicating,
|
||||
/// whether this node is a leaf at depths 16, 32, or 48.
|
||||
pub fn get_proof(&self, key: &RpoDigest) -> (MerklePath, NodeIndex, bool) {
|
||||
let (index, leaf_exists) = self.get_leaf_index(key);
|
||||
let index: NodeIndex = index.into();
|
||||
let path = self.get_path(index).expect("failed to retrieve Merkle path for a node index");
|
||||
(path, index, leaf_exists)
|
||||
}
|
||||
|
||||
/// Returns an index at which a leaf node for the specified key should be inserted.
|
||||
///
|
||||
/// The second value in the returned tuple is set to true if the node at the returned index
|
||||
/// is already a leaf node.
|
||||
pub fn get_leaf_index(&self, key: &RpoDigest) -> (LeafNodeIndex, bool) {
|
||||
// traverse the tree from the root down checking nodes at tiers 16, 32, and 48. Return if
|
||||
// a node at any of the tiers is either a leaf or a root of an empty subtree.
|
||||
const NUM_UPPER_TIERS: usize = TIER_DEPTHS.len() - 1;
|
||||
for &tier_depth in TIER_DEPTHS[..NUM_UPPER_TIERS].iter() {
|
||||
let index = LeafNodeIndex::from_key(key, tier_depth);
|
||||
if self.upper_leaves.contains(&index) {
|
||||
return (index, true);
|
||||
} else if !self.nodes.contains_key(&index) {
|
||||
return (index, false);
|
||||
}
|
||||
}
|
||||
|
||||
// if we got here, that means all of the nodes checked so far are internal nodes, and
|
||||
// the new node would need to be inserted in the bottom tier.
|
||||
let index = LeafNodeIndex::from_key(key, MAX_DEPTH);
|
||||
(index, self.bottom_leaves.contains(&index.value()))
|
||||
}
|
||||
|
||||
/// Traverses the tree up from the bottom tier starting at the specified leaf index and
|
||||
/// returns the depth of the first node which hash more than one child. The returned depth
|
||||
/// is rounded up to the next tier.
|
||||
pub fn get_last_single_child_parent_depth(&self, leaf_index: u64) -> u8 {
|
||||
let mut index = NodeIndex::new_unchecked(MAX_DEPTH, leaf_index);
|
||||
|
||||
for _ in (TIER_DEPTHS[0]..MAX_DEPTH).rev() {
|
||||
let sibling_index = index.sibling();
|
||||
if self.nodes.contains_key(&sibling_index) {
|
||||
break;
|
||||
}
|
||||
index.move_up();
|
||||
}
|
||||
|
||||
let tier = (index.depth() - 1) / TIER_SIZE;
|
||||
TIER_DEPTHS[tier as usize]
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over all inner nodes of the Tiered Sparse Merkle tree (i.e., nodes not
|
||||
/// at depths 16 32, 48, or 64).
|
||||
///
|
||||
/// The iterator order is unspecified.
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.nodes.iter().filter_map(|(index, node)| {
|
||||
if self.is_internal_node(index) {
|
||||
Some(InnerNodeInfo {
|
||||
value: *node,
|
||||
left: self.get_node_unchecked(&index.left_child()),
|
||||
right: self.get_node_unchecked(&index.right_child()),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns an iterator over the upper leaves (i.e., leaves with depths 16, 32, 48) of the
|
||||
/// Tiered Sparse Merkle tree.
|
||||
pub fn upper_leaves(&self) -> impl Iterator<Item = (&NodeIndex, &RpoDigest)> {
|
||||
self.upper_leaves.iter().map(|index| (index, &self.nodes[index]))
|
||||
}
|
||||
|
||||
/// Returns an iterator over the bottom leaves (i.e., leaves with depth 64) of the Tiered
|
||||
/// Sparse Merkle tree.
|
||||
pub fn bottom_leaves(&self) -> impl Iterator<Item = (&u64, &RpoDigest)> {
|
||||
self.bottom_leaves.iter().map(|value| {
|
||||
let index = NodeIndex::new_unchecked(MAX_DEPTH, *value);
|
||||
(value, &self.nodes[&index])
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Replaces the leaf node at the specified index with a tree consisting of two leaves located
|
||||
/// at the specified indexes. Recomputes and returns the new root.
|
||||
pub fn replace_leaf_with_subtree(
|
||||
&mut self,
|
||||
leaf_index: LeafNodeIndex,
|
||||
subtree_leaves: [(LeafNodeIndex, RpoDigest); 2],
|
||||
) -> RpoDigest {
|
||||
debug_assert!(self.is_non_empty_leaf(&leaf_index));
|
||||
debug_assert!(!is_empty_root(&subtree_leaves[0].1));
|
||||
debug_assert!(!is_empty_root(&subtree_leaves[1].1));
|
||||
debug_assert_eq!(subtree_leaves[0].0.depth(), subtree_leaves[1].0.depth());
|
||||
debug_assert!(leaf_index.depth() < subtree_leaves[0].0.depth());
|
||||
|
||||
self.upper_leaves.remove(&leaf_index);
|
||||
|
||||
if subtree_leaves[0].0 == subtree_leaves[1].0 {
|
||||
// if the subtree is for a single node at depth 64, we only need to insert one node
|
||||
debug_assert_eq!(subtree_leaves[0].0.depth(), MAX_DEPTH);
|
||||
debug_assert_eq!(subtree_leaves[0].1, subtree_leaves[1].1);
|
||||
self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1)
|
||||
} else {
|
||||
self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1);
|
||||
self.insert_leaf_node(subtree_leaves[1].0, subtree_leaves[1].1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Replaces a subtree containing the retained and the removed leaf nodes, with a leaf node
|
||||
/// containing the retained leaf.
|
||||
///
|
||||
/// This has the effect of deleting the the node at the `removed_leaf` index from the tree,
|
||||
/// moving the node at the `retained_leaf` index up to the tier specified by `new_depth`.
|
||||
pub fn replace_subtree_with_leaf(
|
||||
&mut self,
|
||||
removed_leaf: LeafNodeIndex,
|
||||
retained_leaf: LeafNodeIndex,
|
||||
new_depth: u8,
|
||||
node: RpoDigest,
|
||||
) -> RpoDigest {
|
||||
debug_assert!(!is_empty_root(&node));
|
||||
debug_assert!(self.is_non_empty_leaf(&removed_leaf));
|
||||
debug_assert!(self.is_non_empty_leaf(&retained_leaf));
|
||||
debug_assert_eq!(removed_leaf.depth(), retained_leaf.depth());
|
||||
debug_assert!(removed_leaf.depth() > new_depth);
|
||||
|
||||
// remove the branches leading up to the tier to which the retained leaf is to be moved
|
||||
self.remove_branch(removed_leaf, new_depth);
|
||||
self.remove_branch(retained_leaf, new_depth);
|
||||
|
||||
// compute the index of the common root for retained and removed leaves
|
||||
let mut new_index = retained_leaf;
|
||||
new_index.move_up_to(new_depth);
|
||||
|
||||
// insert the node at the root index
|
||||
self.insert_leaf_node(new_index, node)
|
||||
}
|
||||
|
||||
/// Inserts the specified node at the specified index; recomputes and returns the new root
|
||||
/// of the Tiered Sparse Merkle tree.
|
||||
///
|
||||
/// This method assumes that the provided node is a non-empty value, and that there is no node
|
||||
/// at the specified index.
|
||||
pub fn insert_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest {
|
||||
debug_assert!(!is_empty_root(&node));
|
||||
debug_assert_eq!(self.nodes.get(&index), None);
|
||||
|
||||
// mark the node as the leaf
|
||||
if index.depth() == MAX_DEPTH {
|
||||
self.bottom_leaves.insert(index.value());
|
||||
} else {
|
||||
self.upper_leaves.insert(index.into());
|
||||
};
|
||||
|
||||
// insert the node and update the path from the node to the root
|
||||
let mut index: NodeIndex = index.into();
|
||||
for _ in 0..index.depth() {
|
||||
self.nodes.insert(index, node);
|
||||
let sibling = self.get_node_unchecked(&index.sibling());
|
||||
node = Rpo256::merge(&index.build_node(node, sibling));
|
||||
index.move_up();
|
||||
}
|
||||
|
||||
// update the root
|
||||
self.nodes.insert(NodeIndex::root(), node);
|
||||
node
|
||||
}
|
||||
|
||||
/// Updates the node at the specified index with the specified node value; recomputes and
|
||||
/// returns the new root of the Tiered Sparse Merkle tree.
|
||||
///
|
||||
/// This method can accept `node` as either an empty or a non-empty value.
|
||||
pub fn update_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest {
|
||||
debug_assert!(self.is_non_empty_leaf(&index));
|
||||
|
||||
// if the value we are updating the node to is a root of an empty tree, clear the leaf
|
||||
// flag for this node
|
||||
if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] {
|
||||
if index.depth() == MAX_DEPTH {
|
||||
self.bottom_leaves.remove(&index.value());
|
||||
} else {
|
||||
self.upper_leaves.remove(&index);
|
||||
}
|
||||
} else {
|
||||
debug_assert!(!is_empty_root(&node));
|
||||
}
|
||||
|
||||
// update the path from the node to the root
|
||||
let mut index: NodeIndex = index.into();
|
||||
for _ in 0..index.depth() {
|
||||
if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] {
|
||||
self.nodes.remove(&index);
|
||||
} else {
|
||||
self.nodes.insert(index, node);
|
||||
}
|
||||
|
||||
let sibling = self.get_node_unchecked(&index.sibling());
|
||||
node = Rpo256::merge(&index.build_node(node, sibling));
|
||||
index.move_up();
|
||||
}
|
||||
|
||||
// update the root
|
||||
self.nodes.insert(NodeIndex::root(), node);
|
||||
node
|
||||
}
|
||||
|
||||
/// Replaces the leaf node at the specified index with a root of an empty subtree; recomputes
|
||||
/// and returns the new root of the Tiered Sparse Merkle tree.
|
||||
pub fn clear_leaf_node(&mut self, index: LeafNodeIndex) -> RpoDigest {
|
||||
debug_assert!(self.is_non_empty_leaf(&index));
|
||||
let node = EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize];
|
||||
self.update_leaf_node(index, node)
|
||||
}
|
||||
|
||||
/// Truncates a branch starting with specified leaf at the bottom tier to new depth.
|
||||
///
|
||||
/// This involves removing the part of the branch below the new depth, and then inserting a new
|
||||
/// // node at the new depth.
|
||||
pub fn truncate_branch(
|
||||
&mut self,
|
||||
leaf_index: u64,
|
||||
new_depth: u8,
|
||||
node: RpoDigest,
|
||||
) -> RpoDigest {
|
||||
debug_assert!(self.bottom_leaves.contains(&leaf_index));
|
||||
|
||||
let mut leaf_index = LeafNodeIndex::new(NodeIndex::new_unchecked(MAX_DEPTH, leaf_index));
|
||||
self.remove_branch(leaf_index, new_depth);
|
||||
|
||||
leaf_index.move_up_to(new_depth);
|
||||
self.insert_leaf_node(leaf_index, node)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if the node at the specified index is a leaf node.
|
||||
fn is_non_empty_leaf(&self, index: &LeafNodeIndex) -> bool {
|
||||
if index.depth() == MAX_DEPTH {
|
||||
self.bottom_leaves.contains(&index.value())
|
||||
} else {
|
||||
self.upper_leaves.contains(index)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the node at the specified index is an internal node - i.e., there is
|
||||
/// no leaf at that node and the node does not belong to the bottom tier.
|
||||
fn is_internal_node(&self, index: &NodeIndex) -> bool {
|
||||
if index.depth() == MAX_DEPTH {
|
||||
false
|
||||
} else {
|
||||
!self.upper_leaves.contains(index)
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if the specified index is valid in the context of this Merkle tree.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index depth is 0 or greater than 64.
|
||||
/// - The node for the specified index does not exists in the Merkle tree. This is possible
|
||||
/// when an ancestors of the specified index is a leaf node.
|
||||
fn validate_node_access(&self, index: NodeIndex) -> Result<(), MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
} else if index.depth() > MAX_DEPTH {
|
||||
return Err(MerkleError::DepthTooBig(index.depth() as u64));
|
||||
} else {
|
||||
// make sure that there are no leaf nodes in the ancestors of the index; since leaf
|
||||
// nodes can live at specific depth, we just need to check these depths.
|
||||
let tier = ((index.depth() - 1) / TIER_SIZE) as usize;
|
||||
let mut tier_index = index;
|
||||
for &depth in TIER_DEPTHS[..tier].iter().rev() {
|
||||
tier_index.move_up_to(depth);
|
||||
if self.upper_leaves.contains(&tier_index) {
|
||||
return Err(MerkleError::NodeNotInSet(index));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns a node at the specified index. If the node does not exist at this index, a root
|
||||
/// for an empty subtree at the index's depth is returned.
|
||||
///
|
||||
/// Unlike [NodeStore::get_node()] this does not perform any checks to verify that the
|
||||
/// returned node is valid in the context of this tree.
|
||||
fn get_node_unchecked(&self, index: &NodeIndex) -> RpoDigest {
|
||||
match self.nodes.get(index) {
|
||||
Some(node) => *node,
|
||||
None => EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize],
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a sequence of nodes starting at the specified index and traversing the tree up to
|
||||
/// the specified depth. The node at the `end_depth` is also removed, and the appropriate leaf
|
||||
/// flag is cleared.
|
||||
///
|
||||
/// This method does not update any other nodes and does not recompute the tree root.
|
||||
fn remove_branch(&mut self, index: LeafNodeIndex, end_depth: u8) {
|
||||
if index.depth() == MAX_DEPTH {
|
||||
self.bottom_leaves.remove(&index.value());
|
||||
} else {
|
||||
self.upper_leaves.remove(&index);
|
||||
}
|
||||
|
||||
let mut index: NodeIndex = index.into();
|
||||
assert!(index.depth() > end_depth);
|
||||
for _ in 0..(index.depth() - end_depth + 1) {
|
||||
self.nodes.remove(&index);
|
||||
index.move_up()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Returns true if the specified node is a root of an empty tree or an empty value ([ZERO; 4]).
|
||||
fn is_empty_root(node: &RpoDigest) -> bool {
|
||||
EmptySubtreeRoots::empty_hashes(MAX_DEPTH).contains(node)
|
||||
}
|
||||
170
src/merkle/tiered_smt/proof.rs
Normal file
170
src/merkle/tiered_smt/proof.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
use super::{
|
||||
get_common_prefix_tier_depth, get_key_prefix, hash_bottom_leaf, hash_upper_leaf,
|
||||
EmptySubtreeRoots, LeafNodeIndex, MerklePath, RpoDigest, TieredSmtProofError, Vec, Word,
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Maximum node depth. This is also the bottom tier of the tree.
|
||||
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
|
||||
|
||||
/// Value of an empty leaf.
|
||||
pub const EMPTY_VALUE: Word = super::TieredSmt::EMPTY_VALUE;
|
||||
|
||||
/// Depths at which leaves can exist in a tiered SMT.
|
||||
pub const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
|
||||
|
||||
// TIERED SPARSE MERKLE TREE PROOF
|
||||
// ================================================================================================
|
||||
|
||||
/// A proof which can be used to assert membership (or non-membership) of key-value pairs in a
|
||||
/// Tiered Sparse Merkle tree.
|
||||
///
|
||||
/// The proof consists of a Merkle path and one or more key-value entries which describe the node
|
||||
/// located at the base of the path. If the node at the base of the path resolves to [ZERO; 4],
|
||||
/// the entries will contain a single item with value set to [ZERO; 4].
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
pub struct TieredSmtProof {
|
||||
path: MerklePath,
|
||||
entries: Vec<(RpoDigest, Word)>,
|
||||
}
|
||||
|
||||
impl TieredSmtProof {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns a new instance of [TieredSmtProof] instantiated from the specified path and entries.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if:
|
||||
/// - The length of the path is greater than 64.
|
||||
/// - Entries is an empty vector.
|
||||
/// - Entries contains more than 1 item, but the length of the path is not 64.
|
||||
/// - Entries contains more than 1 item, and one of the items has value set to [ZERO; 4].
|
||||
/// - Entries contains multiple items with keys which don't share the same 64-bit prefix.
|
||||
pub fn new<I>(path: MerklePath, entries: I) -> Result<Self, TieredSmtProofError>
|
||||
where
|
||||
I: IntoIterator<Item = (RpoDigest, Word)>,
|
||||
{
|
||||
let entries: Vec<(RpoDigest, Word)> = entries.into_iter().collect();
|
||||
|
||||
if !TIER_DEPTHS.into_iter().any(|e| e == path.depth()) {
|
||||
return Err(TieredSmtProofError::NotATierPath(path.depth()));
|
||||
}
|
||||
|
||||
if entries.is_empty() {
|
||||
return Err(TieredSmtProofError::EntriesEmpty);
|
||||
}
|
||||
|
||||
if entries.len() > 1 {
|
||||
if path.depth() != MAX_DEPTH {
|
||||
return Err(TieredSmtProofError::MultipleEntriesOutsideLastTier);
|
||||
}
|
||||
|
||||
let prefix = get_key_prefix(&entries[0].0);
|
||||
for entry in entries.iter().skip(1) {
|
||||
if entry.1 == EMPTY_VALUE {
|
||||
return Err(TieredSmtProofError::EmptyValueNotAllowed);
|
||||
}
|
||||
let current = get_key_prefix(&entry.0);
|
||||
if prefix != current {
|
||||
return Err(TieredSmtProofError::MismatchedPrefixes(prefix, current));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self { path, entries })
|
||||
}
|
||||
|
||||
// PROOF VERIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if a Tiered Sparse Merkle tree with the specified root contains the provided
|
||||
/// key-value pair.
|
||||
///
|
||||
/// Note: this method cannot be used to assert non-membership. That is, if false is returned,
|
||||
/// it does not mean that the provided key-value pair is not in the tree.
|
||||
pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool {
|
||||
// Handles the following scenarios:
|
||||
// - the value is set
|
||||
// - empty leaf, there is an explicit entry for the key with the empty value
|
||||
// - shared 64-bit prefix, the target key is not included in the entries list, the value is implicitly the empty word
|
||||
let v = match self.entries.iter().find(|(k, _)| k == key) {
|
||||
Some((_, v)) => v,
|
||||
None => &EMPTY_VALUE,
|
||||
};
|
||||
|
||||
// The value must match for the proof to be valid
|
||||
if v != value {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If the proof is for an empty value, we can verify it against any key which has a common
|
||||
// prefix with the key storied in entries, but the prefix must be greater than the path
|
||||
// length
|
||||
if self.is_value_empty()
|
||||
&& get_common_prefix_tier_depth(key, &self.entries[0].0) < self.path.depth()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// make sure the Merkle path resolves to the correct root
|
||||
root == &self.compute_root()
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the value associated with the specific key according to this proof, or None if
|
||||
/// this proof does not contain a value for the specified key.
|
||||
///
|
||||
/// A key-value pair generated by using this method should pass the `verify_membership()` check.
|
||||
pub fn get(&self, key: &RpoDigest) -> Option<Word> {
|
||||
if self.is_value_empty() {
|
||||
let common_prefix_tier = get_common_prefix_tier_depth(key, &self.entries[0].0);
|
||||
if common_prefix_tier < self.path.depth() {
|
||||
None
|
||||
} else {
|
||||
Some(EMPTY_VALUE)
|
||||
}
|
||||
} else {
|
||||
self.entries.iter().find(|(k, _)| k == key).map(|(_, value)| *value)
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the root of a Tiered Sparse Merkle tree to which this proof resolve.
|
||||
pub fn compute_root(&self) -> RpoDigest {
|
||||
let node = self.build_node();
|
||||
let index = LeafNodeIndex::from_key(&self.entries[0].0, self.path.depth());
|
||||
self.path
|
||||
.compute_root(index.value(), node)
|
||||
.expect("failed to compute Merkle path root")
|
||||
}
|
||||
|
||||
/// Consume the proof and returns its parts.
|
||||
pub fn into_parts(self) -> (MerklePath, Vec<(RpoDigest, Word)>) {
|
||||
(self.path, self.entries)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if the proof is for an empty value.
|
||||
fn is_value_empty(&self) -> bool {
|
||||
self.entries[0].1 == EMPTY_VALUE
|
||||
}
|
||||
|
||||
/// Converts the entries contained in this proof into a node value for node at the base of the
|
||||
/// path contained in this proof.
|
||||
fn build_node(&self) -> RpoDigest {
|
||||
let depth = self.path.depth();
|
||||
if self.is_value_empty() {
|
||||
EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[depth as usize]
|
||||
} else if depth == MAX_DEPTH {
|
||||
hash_bottom_leaf(&self.entries)
|
||||
} else {
|
||||
let (key, value) = self.entries[0];
|
||||
hash_upper_leaf(key, value, depth)
|
||||
}
|
||||
}
|
||||
}
|
||||
968
src/merkle/tiered_smt/tests.rs
Normal file
968
src/merkle/tiered_smt/tests.rs
Normal file
@@ -0,0 +1,968 @@
|
||||
use super::{
|
||||
super::{super::ONE, super::WORD_SIZE, Felt, MerkleStore, EMPTY_WORD, ZERO},
|
||||
EmptySubtreeRoots, InnerNodeInfo, NodeIndex, Rpo256, RpoDigest, TieredSmt, Vec, Word,
|
||||
};
|
||||
|
||||
// INSERTION TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_insert_one() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let value = [ONE; WORD_SIZE];
|
||||
|
||||
// since the tree is empty, the first node will be inserted at depth 16 and the index will be
|
||||
// 16 most significant bits of the key
|
||||
let index = NodeIndex::make(16, raw >> 48);
|
||||
let leaf_node = build_leaf_node(key, value, 16);
|
||||
let tree_root = store.set_node(smt.root(), index, leaf_node).unwrap().root;
|
||||
|
||||
smt.insert(key, value);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
// make sure the value was inserted, and the node is at the expected index
|
||||
assert_eq!(smt.get_value(key), value);
|
||||
assert_eq!(smt.get_node(index).unwrap(), leaf_node);
|
||||
|
||||
// make sure the paths we get from the store and the tree match
|
||||
let expected_path = store.get_path(tree_root, index).unwrap();
|
||||
assert_eq!(smt.get_path(index).unwrap(), expected_path.path);
|
||||
|
||||
// make sure inner nodes match
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
assert_eq!(actual_nodes.len(), expected_nodes.len());
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
|
||||
// make sure leaves are returned correctly
|
||||
let mut leaves = smt.upper_leaves();
|
||||
assert_eq!(leaves.next(), Some((leaf_node, key, value)));
|
||||
assert_eq!(leaves.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_insert_two_16() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// --- insert the first value ---------------------------------------------
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let val_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key_a, val_a);
|
||||
|
||||
// --- insert the second value --------------------------------------------
|
||||
// the key for this value has the same 16-bit prefix as the key for the first value,
|
||||
// thus, on insertions, both values should be pushed to depth 32 tier
|
||||
let raw_b = 0b_10101010_10101010_10011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let val_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key_b, val_b);
|
||||
|
||||
// --- build Merkle store with equivalent data ----------------------------
|
||||
let mut tree_root = get_init_root();
|
||||
let index_a = NodeIndex::make(32, raw_a >> 32);
|
||||
let leaf_node_a = build_leaf_node(key_a, val_a, 32);
|
||||
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
|
||||
|
||||
let index_b = NodeIndex::make(32, raw_b >> 32);
|
||||
let leaf_node_b = build_leaf_node(key_b, val_b, 32);
|
||||
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
|
||||
|
||||
// --- verify that data is consistent between store and tree --------------
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key_a), val_a);
|
||||
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
|
||||
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
|
||||
|
||||
assert_eq!(smt.get_value(key_b), val_b);
|
||||
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
|
||||
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
|
||||
// make sure leaves are returned correctly
|
||||
let mut leaves = smt.upper_leaves();
|
||||
assert_eq!(leaves.next(), Some((leaf_node_a, key_a, val_a)));
|
||||
assert_eq!(leaves.next(), Some((leaf_node_b, key_b, val_b)));
|
||||
assert_eq!(leaves.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_insert_two_32() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// --- insert the first value ---------------------------------------------
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let val_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key_a, val_a);
|
||||
|
||||
// --- insert the second value --------------------------------------------
|
||||
// the key for this value has the same 32-bit prefix as the key for the first value,
|
||||
// thus, on insertions, both values should be pushed to depth 48 tier
|
||||
let raw_b = 0b_10101010_10101010_00011111_11111111_00010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let val_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key_b, val_b);
|
||||
|
||||
// --- build Merkle store with equivalent data ----------------------------
|
||||
let mut tree_root = get_init_root();
|
||||
let index_a = NodeIndex::make(48, raw_a >> 16);
|
||||
let leaf_node_a = build_leaf_node(key_a, val_a, 48);
|
||||
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
|
||||
|
||||
let index_b = NodeIndex::make(48, raw_b >> 16);
|
||||
let leaf_node_b = build_leaf_node(key_b, val_b, 48);
|
||||
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
|
||||
|
||||
// --- verify that data is consistent between store and tree --------------
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key_a), val_a);
|
||||
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
|
||||
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
|
||||
|
||||
assert_eq!(smt.get_value(key_b), val_b);
|
||||
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
|
||||
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_insert_three() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// --- insert the first value ---------------------------------------------
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let val_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key_a, val_a);
|
||||
|
||||
// --- insert the second value --------------------------------------------
|
||||
// the key for this value has the same 16-bit prefix as the key for the first value,
|
||||
// thus, on insertions, both values should be pushed to depth 32 tier
|
||||
let raw_b = 0b_10101010_10101010_10011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let val_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key_b, val_b);
|
||||
|
||||
// --- insert the third value ---------------------------------------------
|
||||
// the key for this value has the same 16-bit prefix as the keys for the first two,
|
||||
// values; thus, on insertions, it will be inserted into depth 32 tier, but will not
|
||||
// affect locations of the other two values
|
||||
let raw_c = 0b_10101010_10101010_11011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let val_c = [Felt::new(3); WORD_SIZE];
|
||||
smt.insert(key_c, val_c);
|
||||
|
||||
// --- build Merkle store with equivalent data ----------------------------
|
||||
let mut tree_root = get_init_root();
|
||||
let index_a = NodeIndex::make(32, raw_a >> 32);
|
||||
let leaf_node_a = build_leaf_node(key_a, val_a, 32);
|
||||
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
|
||||
|
||||
let index_b = NodeIndex::make(32, raw_b >> 32);
|
||||
let leaf_node_b = build_leaf_node(key_b, val_b, 32);
|
||||
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
|
||||
|
||||
let index_c = NodeIndex::make(32, raw_c >> 32);
|
||||
let leaf_node_c = build_leaf_node(key_c, val_c, 32);
|
||||
tree_root = store.set_node(tree_root, index_c, leaf_node_c).unwrap().root;
|
||||
|
||||
// --- verify that data is consistent between store and tree --------------
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key_a), val_a);
|
||||
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
|
||||
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
|
||||
|
||||
assert_eq!(smt.get_value(key_b), val_b);
|
||||
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
|
||||
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
|
||||
|
||||
assert_eq!(smt.get_value(key_c), val_c);
|
||||
assert_eq!(smt.get_node(index_c).unwrap(), leaf_node_c);
|
||||
let expected_path = store.get_path(tree_root, index_c).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_c).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
}
|
||||
|
||||
// UPDATE TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_update() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let value_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key, value_a);
|
||||
|
||||
// --- update the value ---------------------------------------------------
|
||||
let value_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key, value_b);
|
||||
|
||||
// --- verify consistency -------------------------------------------------
|
||||
let mut tree_root = get_init_root();
|
||||
let index = NodeIndex::make(16, raw >> 48);
|
||||
let leaf_node = build_leaf_node(key, value_b, 16);
|
||||
tree_root = store.set_node(tree_root, index, leaf_node).unwrap().root;
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key), value_b);
|
||||
assert_eq!(smt.get_node(index).unwrap(), leaf_node);
|
||||
let expected_path = store.get_path(tree_root, index).unwrap().path;
|
||||
assert_eq!(smt.get_path(index).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
}
|
||||
|
||||
// DELETION TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_16() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let smt0 = smt.clone();
|
||||
let raw_a = 0b_01010101_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert another value into the tree ---------------------------------
|
||||
let smt1 = smt.clone();
|
||||
let raw_b = 0b_01011111_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
|
||||
assert_eq!(smt, smt1);
|
||||
|
||||
// --- delete the first inserted value ------------------------------------
|
||||
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
|
||||
assert_eq!(smt, smt0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_32() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let smt0 = smt.clone();
|
||||
let raw_a = 0b_01010101_01101100_01111111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert another with the same 16-bit prefix into the tree -----------
|
||||
let smt1 = smt.clone();
|
||||
let raw_b = 0b_01010101_01101100_00111111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// --- insert the 3rd value with the same 16-bit prefix into the tree -----
|
||||
let smt2 = smt.clone();
|
||||
let raw_c = 0b_01010101_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
|
||||
assert_eq!(smt, smt2);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
|
||||
assert_eq!(smt, smt1);
|
||||
|
||||
// --- delete the first inserted value ------------------------------------
|
||||
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
|
||||
assert_eq!(smt, smt0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_48_same_32_bit_prefix() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// test the case when all values share the same 32-bit prefix
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let smt0 = smt.clone();
|
||||
let raw_a = 0b_01010101_01010101_11111111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert another with the same 32-bit prefix into the tree -----------
|
||||
let smt1 = smt.clone();
|
||||
let raw_b = 0b_01010101_01010101_11111111_11111111_11010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// --- insert the 3rd value with the same 32-bit prefix into the tree -----
|
||||
let smt2 = smt.clone();
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_11110110_10010011_11100000_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
|
||||
assert_eq!(smt, smt2);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
|
||||
assert_eq!(smt, smt1);
|
||||
|
||||
// --- delete the first inserted value ------------------------------------
|
||||
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
|
||||
assert_eq!(smt, smt0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_48_mixed_prefix() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// test the case when some values share a 32-bit prefix and others share a 16-bit prefix
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let smt0 = smt.clone();
|
||||
let raw_a = 0b_01010101_01010101_11111111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert another with the same 16-bit prefix into the tree -----------
|
||||
let smt1 = smt.clone();
|
||||
let raw_b = 0b_01010101_01010101_01111111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// --- insert a value with the same 32-bit prefix as the first value -----
|
||||
let smt2 = smt.clone();
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_11010110_10010011_11100000_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// --- insert another value with the same 32-bit prefix as the first value
|
||||
let smt3 = smt.clone();
|
||||
let raw_d = 0b_01010101_01010101_11111111_11111111_11110110_10010011_11100000_00000000_u64;
|
||||
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_d)]);
|
||||
let value_d = [ONE, ZERO, ZERO, ZERO];
|
||||
smt.insert(key_d, value_d);
|
||||
|
||||
// --- delete the inserted values one-by-one ------------------------------
|
||||
assert_eq!(smt.insert(key_d, EMPTY_WORD), value_d);
|
||||
assert_eq!(smt, smt3);
|
||||
|
||||
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
|
||||
assert_eq!(smt, smt2);
|
||||
|
||||
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
|
||||
assert_eq!(smt, smt1);
|
||||
|
||||
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
|
||||
assert_eq!(smt, smt0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_64() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// test the case when all values share the same 48-bit prefix
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let smt0 = smt.clone();
|
||||
let raw_a = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert a value with the same 48-bit prefix into the tree -----------
|
||||
let smt1 = smt.clone();
|
||||
let raw_b = 0b_01010101_01010101_11111111_11111111_10110101_10101010_10111100_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// --- insert a value with the same 32-bit prefix into the tree -----------
|
||||
let smt2 = smt.clone();
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_11111101_10101010_10111100_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
let smt3 = smt.clone();
|
||||
let raw_d = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
|
||||
let key_d = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_d)]);
|
||||
let value_d = [ONE, ZERO, ZERO, ZERO];
|
||||
smt.insert(key_d, value_d);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_d, EMPTY_WORD), value_d);
|
||||
assert_eq!(smt, smt3);
|
||||
|
||||
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
|
||||
assert_eq!(smt, smt2);
|
||||
|
||||
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
|
||||
assert_eq!(smt, smt1);
|
||||
|
||||
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
|
||||
assert_eq!(smt, smt0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_64_leaf_promotion() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// --- delete from bottom tier (no promotion to upper tiers) --------------
|
||||
|
||||
// insert a value into the tree
|
||||
let raw_a = 0b_01010101_01010101_11111111_11111111_10101010_10101010_11111111_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// insert another value with a key having the same 64-bit prefix
|
||||
let key_b = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// insert a value with a key which shared the same 48-bit prefix
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_10101010_10101010_00111111_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// delete entry A and compare to the tree which was built from B and C
|
||||
smt.insert(key_a, EMPTY_WORD);
|
||||
|
||||
let mut expected_smt = TieredSmt::default();
|
||||
expected_smt.insert(key_b, value_b);
|
||||
expected_smt.insert(key_c, value_c);
|
||||
assert_eq!(smt, expected_smt);
|
||||
|
||||
// entries B and C should stay at depth 64
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 64);
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 64);
|
||||
|
||||
// --- delete from bottom tier (promotion to depth 48) --------------------
|
||||
|
||||
let mut smt = TieredSmt::default();
|
||||
smt.insert(key_a, value_a);
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// insert a value with a key which shared the same 32-bit prefix
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_11101010_10101010_11111111_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// delete entry A and compare to the tree which was built from B and C
|
||||
smt.insert(key_a, EMPTY_WORD);
|
||||
|
||||
let mut expected_smt = TieredSmt::default();
|
||||
expected_smt.insert(key_b, value_b);
|
||||
expected_smt.insert(key_c, value_c);
|
||||
assert_eq!(smt, expected_smt);
|
||||
|
||||
// entry B moves to depth 48, entry C stays at depth 48
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 48);
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 48);
|
||||
|
||||
// --- delete from bottom tier (promotion to depth 32) --------------------
|
||||
|
||||
let mut smt = TieredSmt::default();
|
||||
smt.insert(key_a, value_a);
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// insert a value with a key which shared the same 16-bit prefix
|
||||
let raw_c = 0b_01010101_01010101_01111111_11111111_10101010_10101010_11111111_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// delete entry A and compare to the tree which was built from B and C
|
||||
smt.insert(key_a, EMPTY_WORD);
|
||||
|
||||
let mut expected_smt = TieredSmt::default();
|
||||
expected_smt.insert(key_b, value_b);
|
||||
expected_smt.insert(key_c, value_c);
|
||||
assert_eq!(smt, expected_smt);
|
||||
|
||||
// entry B moves to depth 32, entry C stays at depth 32
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 32);
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 32);
|
||||
|
||||
// --- delete from bottom tier (promotion to depth 16) --------------------
|
||||
|
||||
let mut smt = TieredSmt::default();
|
||||
smt.insert(key_a, value_a);
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// insert a value with a key which shared prefix < 16 bits
|
||||
let raw_c = 0b_01010101_01010100_11111111_11111111_10101010_10101010_11111111_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// delete entry A and compare to the tree which was built from B and C
|
||||
smt.insert(key_a, EMPTY_WORD);
|
||||
|
||||
let mut expected_smt = TieredSmt::default();
|
||||
expected_smt.insert(key_b, value_b);
|
||||
expected_smt.insert(key_c, value_c);
|
||||
assert_eq!(smt, expected_smt);
|
||||
|
||||
// entry B moves to depth 16, entry C stays at depth 16
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 16);
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_sensitivity() {
|
||||
let raw = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000001_u64;
|
||||
let value = [ONE; WORD_SIZE];
|
||||
|
||||
let key_1 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let key_2 = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw)]);
|
||||
|
||||
let mut smt_1 = TieredSmt::default();
|
||||
|
||||
smt_1.insert(key_1, value);
|
||||
smt_1.insert(key_2, value);
|
||||
smt_1.insert(key_2, EMPTY_WORD);
|
||||
|
||||
let mut smt_2 = TieredSmt::default();
|
||||
smt_2.insert(key_1, value);
|
||||
|
||||
assert_eq!(smt_1.root(), smt_2.root());
|
||||
}
|
||||
|
||||
// BOTTOM TIER TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_bottom_tier() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// common prefix for the keys
|
||||
let prefix = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
// --- insert the first value ---------------------------------------------
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(prefix)]);
|
||||
let val_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key_a, val_a);
|
||||
|
||||
// --- insert the second value --------------------------------------------
|
||||
// this key has the same 64-bit prefix and thus both values should end up in the same
|
||||
// node at depth 64
|
||||
let key_b = RpoDigest::from([ZERO, ONE, ONE, Felt::new(prefix)]);
|
||||
let val_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key_b, val_b);
|
||||
|
||||
// --- build Merkle store with equivalent data ----------------------------
|
||||
let index = NodeIndex::make(64, prefix);
|
||||
// to build bottom leaf we sort by key starting with the least significant element, thus
|
||||
// key_b is smaller than key_a.
|
||||
let leaf_node = build_bottom_leaf_node(&[key_b, key_a], &[val_b, val_a]);
|
||||
let mut tree_root = get_init_root();
|
||||
tree_root = store.set_node(tree_root, index, leaf_node).unwrap().root;
|
||||
|
||||
// --- verify that data is consistent between store and tree --------------
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key_a), val_a);
|
||||
assert_eq!(smt.get_value(key_b), val_b);
|
||||
|
||||
assert_eq!(smt.get_node(index).unwrap(), leaf_node);
|
||||
let expected_path = store.get_path(tree_root, index).unwrap().path;
|
||||
assert_eq!(smt.get_path(index).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
|
||||
// make sure leaves are returned correctly
|
||||
let smt_clone = smt.clone();
|
||||
let mut leaves = smt_clone.bottom_leaves();
|
||||
assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a)])));
|
||||
assert_eq!(leaves.next(), None);
|
||||
|
||||
// --- update a leaf at the bottom tier -------------------------------------------------------
|
||||
|
||||
let val_a2 = [Felt::new(3); WORD_SIZE];
|
||||
assert_eq!(smt.insert(key_a, val_a2), val_a);
|
||||
|
||||
let leaf_node = build_bottom_leaf_node(&[key_b, key_a], &[val_b, val_a2]);
|
||||
store.set_node(tree_root, index, leaf_node).unwrap();
|
||||
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
|
||||
let mut leaves = smt.bottom_leaves();
|
||||
assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a2)])));
|
||||
assert_eq!(leaves.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_bottom_tier_two() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// --- insert the first value ---------------------------------------------
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let val_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key_a, val_a);
|
||||
|
||||
// --- insert the second value --------------------------------------------
|
||||
// the key for this value has the same 48-bit prefix as the key for the first value,
|
||||
// thus, on insertions, both should end up in different nodes at depth 64
|
||||
let raw_b = 0b_10101010_10101010_00011111_11111111_10010110_10010011_01100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let val_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key_b, val_b);
|
||||
|
||||
// --- build Merkle store with equivalent data ----------------------------
|
||||
let mut tree_root = get_init_root();
|
||||
let index_a = NodeIndex::make(64, raw_a);
|
||||
let leaf_node_a = build_bottom_leaf_node(&[key_a], &[val_a]);
|
||||
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
|
||||
|
||||
let index_b = NodeIndex::make(64, raw_b);
|
||||
let leaf_node_b = build_bottom_leaf_node(&[key_b], &[val_b]);
|
||||
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
|
||||
|
||||
// --- verify that data is consistent between store and tree --------------
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key_a), val_a);
|
||||
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
|
||||
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
|
||||
|
||||
assert_eq!(smt.get_value(key_b), val_b);
|
||||
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
|
||||
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
|
||||
// make sure leaves are returned correctly
|
||||
let mut leaves = smt.bottom_leaves();
|
||||
assert_eq!(leaves.next(), Some((leaf_node_b, vec![(key_b, val_b)])));
|
||||
assert_eq!(leaves.next(), Some((leaf_node_a, vec![(key_a, val_a)])));
|
||||
assert_eq!(leaves.next(), None);
|
||||
}
|
||||
|
||||
// GET PROOF TESTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Tests the membership and non-membership proof for a single at depth 64
|
||||
#[test]
|
||||
fn tsmt_get_proof_single_element_64() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
let raw_a = 0b_00000000_00000001_00000000_00000001_00000000_00000001_00000000_00000001_u64;
|
||||
let key_a = [ONE, ONE, ONE, raw_a.into()].into();
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// push element `a` to depth 64, by inserting another value that shares the 48-bit prefix
|
||||
let raw_b = 0b_00000000_00000001_00000000_00000001_00000000_00000001_00000000_00000000_u64;
|
||||
let key_b = [ONE, ONE, ONE, raw_b.into()].into();
|
||||
smt.insert(key_b, [ONE, ONE, ONE, ONE]);
|
||||
|
||||
// verify the proof for element `a`
|
||||
let proof = smt.prove(key_a);
|
||||
assert!(proof.verify_membership(&key_a, &value_a, &smt.root()));
|
||||
|
||||
// check that a value that is not inserted in the tree produces a valid membership proof for the
|
||||
// empty word
|
||||
let key = [ZERO, ZERO, ZERO, ZERO].into();
|
||||
let proof = smt.prove(key);
|
||||
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
|
||||
|
||||
// check that a key that shared the 64-bit prefix with `a`, but is not inserted, also has a
|
||||
// valid membership proof for the empty word
|
||||
let key = [ONE, ONE, ZERO, raw_a.into()].into();
|
||||
let proof = smt.prove(key);
|
||||
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_get_proof() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let raw_a = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert a value with the same 48-bit prefix into the tree -----------
|
||||
let raw_b = 0b_01010101_01010101_11111111_11111111_10110101_10101010_10111100_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
let smt_alt = smt.clone();
|
||||
|
||||
// --- insert a value with the same 32-bit prefix into the tree -----------
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_11111101_10101010_10111100_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// --- insert a value with the same 64-bit prefix as A into the tree ------
|
||||
let raw_d = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
|
||||
let key_d = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_d)]);
|
||||
let value_d = [ONE, ZERO, ZERO, ZERO];
|
||||
smt.insert(key_d, value_d);
|
||||
|
||||
// at this point the tree looks as follows:
|
||||
// - A and D are located in the same node at depth 64.
|
||||
// - B is located at depth 64 and shares the same 48-bit prefix with A and D.
|
||||
// - C is located at depth 48 and shares the same 32-bit prefix with A, B, and D.
|
||||
|
||||
// --- generate proof for key A and test that it verifies correctly -------
|
||||
let proof = smt.prove(key_a);
|
||||
assert!(proof.verify_membership(&key_a, &value_a, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key_a, &value_b, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_a, &EMPTY_WORD, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_b, &value_a, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_a, &value_a, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key_a), Some(value_a));
|
||||
assert_eq!(proof.get(&key_b), None);
|
||||
|
||||
// since A and D are stored in the same node, we should be able to use the proof to verify
|
||||
// membership of D
|
||||
assert!(proof.verify_membership(&key_d, &value_d, &smt.root()));
|
||||
assert_eq!(proof.get(&key_d), Some(value_d));
|
||||
|
||||
// --- generate proof for key B and test that it verifies correctly -------
|
||||
let proof = smt.prove(key_b);
|
||||
assert!(proof.verify_membership(&key_b, &value_b, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key_b, &value_a, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_b, &EMPTY_WORD, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_a, &value_b, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_b, &value_b, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key_b), Some(value_b));
|
||||
assert_eq!(proof.get(&key_a), None);
|
||||
|
||||
// --- generate proof for key C and test that it verifies correctly -------
|
||||
let proof = smt.prove(key_c);
|
||||
assert!(proof.verify_membership(&key_c, &value_c, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key_c, &value_a, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_c, &EMPTY_WORD, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_a, &value_c, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_c, &value_c, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key_c), Some(value_c));
|
||||
assert_eq!(proof.get(&key_b), None);
|
||||
|
||||
// --- generate proof for key D and test that it verifies correctly -------
|
||||
let proof = smt.prove(key_d);
|
||||
assert!(proof.verify_membership(&key_d, &value_d, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key_d, &value_b, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_d, &EMPTY_WORD, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_b, &value_d, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_d, &value_d, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key_d), Some(value_d));
|
||||
assert_eq!(proof.get(&key_b), None);
|
||||
|
||||
// since A and D are stored in the same node, we should be able to use the proof to verify
|
||||
// membership of A
|
||||
assert!(proof.verify_membership(&key_a, &value_a, &smt.root()));
|
||||
assert_eq!(proof.get(&key_a), Some(value_a));
|
||||
|
||||
// --- generate proof for an empty key at depth 64 ------------------------
|
||||
// this key has the same 48-bit prefix as A but is different from B
|
||||
let raw = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000011_u64;
|
||||
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
|
||||
let proof = smt.prove(key);
|
||||
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key, &value_a, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key, &EMPTY_WORD, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key), Some(EMPTY_WORD));
|
||||
assert_eq!(proof.get(&key_b), None);
|
||||
|
||||
// the same proof should verify against any key with the same 64-bit prefix
|
||||
let key2 = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw)]);
|
||||
assert!(proof.verify_membership(&key2, &EMPTY_WORD, &smt.root()));
|
||||
assert_eq!(proof.get(&key2), Some(EMPTY_WORD));
|
||||
|
||||
// but verifying if against a key with the same 63-bit prefix (or smaller) should fail
|
||||
let raw3 = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000010_u64;
|
||||
let key3 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw3)]);
|
||||
assert!(!proof.verify_membership(&key3, &EMPTY_WORD, &smt.root()));
|
||||
assert_eq!(proof.get(&key3), None);
|
||||
|
||||
// --- generate proof for an empty key at depth 48 ------------------------
|
||||
// this key has the same 32-prefix as A, B, C, and D, but is different from C
|
||||
let raw = 0b_01010101_01010101_11111111_11111111_00110101_10101010_11111100_00000000_u64;
|
||||
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
|
||||
let proof = smt.prove(key);
|
||||
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key, &value_a, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key, &EMPTY_WORD, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key), Some(EMPTY_WORD));
|
||||
assert_eq!(proof.get(&key_b), None);
|
||||
|
||||
// the same proof should verify against any key with the same 48-bit prefix
|
||||
let raw2 = 0b_01010101_01010101_11111111_11111111_00110101_10101010_01111100_00000000_u64;
|
||||
let key2 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw2)]);
|
||||
assert!(proof.verify_membership(&key2, &EMPTY_WORD, &smt.root()));
|
||||
assert_eq!(proof.get(&key2), Some(EMPTY_WORD));
|
||||
|
||||
// but verifying against a key with the same 47-bit prefix (or smaller) should fail
|
||||
let raw3 = 0b_01010101_01010101_11111111_11111111_00110101_10101011_11111100_00000000_u64;
|
||||
let key3 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw3)]);
|
||||
assert!(!proof.verify_membership(&key3, &EMPTY_WORD, &smt.root()));
|
||||
assert_eq!(proof.get(&key3), None);
|
||||
}
|
||||
|
||||
// ERROR TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_node_not_available() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
let raw = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let value = [ONE; WORD_SIZE];
|
||||
|
||||
// build an index which is just below the inserted leaf node
|
||||
let index = NodeIndex::make(17, raw >> 47);
|
||||
|
||||
// since we haven't inserted the node yet, we should be able to get node and path to this index
|
||||
assert!(smt.get_node(index).is_ok());
|
||||
assert!(smt.get_path(index).is_ok());
|
||||
|
||||
smt.insert(key, value);
|
||||
|
||||
// but once the node is inserted, everything under it should be unavailable
|
||||
assert!(smt.get_node(index).is_err());
|
||||
assert!(smt.get_path(index).is_err());
|
||||
|
||||
let index = NodeIndex::make(32, raw >> 32);
|
||||
assert!(smt.get_node(index).is_err());
|
||||
assert!(smt.get_path(index).is_err());
|
||||
|
||||
let index = NodeIndex::make(34, raw >> 30);
|
||||
assert!(smt.get_node(index).is_err());
|
||||
assert!(smt.get_path(index).is_err());
|
||||
|
||||
let index = NodeIndex::make(50, raw >> 14);
|
||||
assert!(smt.get_node(index).is_err());
|
||||
assert!(smt.get_path(index).is_err());
|
||||
|
||||
let index = NodeIndex::make(64, raw);
|
||||
assert!(smt.get_node(index).is_err());
|
||||
assert!(smt.get_path(index).is_err());
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
fn get_init_root() -> RpoDigest {
|
||||
EmptySubtreeRoots::empty_hashes(64)[0]
|
||||
}
|
||||
|
||||
fn build_leaf_node(key: RpoDigest, value: Word, depth: u8) -> RpoDigest {
|
||||
Rpo256::merge_in_domain(&[key, value.into()], depth.into())
|
||||
}
|
||||
|
||||
fn build_bottom_leaf_node(keys: &[RpoDigest], values: &[Word]) -> RpoDigest {
|
||||
assert_eq!(keys.len(), values.len());
|
||||
|
||||
let mut elements = Vec::with_capacity(keys.len());
|
||||
for (key, val) in keys.iter().zip(values.iter()) {
|
||||
elements.extend_from_slice(key.as_elements());
|
||||
elements.extend_from_slice(val.as_slice());
|
||||
}
|
||||
|
||||
Rpo256::hash_elements(&elements)
|
||||
}
|
||||
|
||||
fn get_non_empty_nodes(store: &MerkleStore) -> Vec<InnerNodeInfo> {
|
||||
store
|
||||
.inner_nodes()
|
||||
.filter(|node| !is_empty_subtree(&node.value))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn is_empty_subtree(node: &RpoDigest) -> bool {
|
||||
EmptySubtreeRoots::empty_hashes(255).contains(node)
|
||||
}
|
||||
584
src/merkle/tiered_smt/values.rs
Normal file
584
src/merkle/tiered_smt/values.rs
Normal file
@@ -0,0 +1,584 @@
|
||||
use super::{get_key_prefix, BTreeMap, LeafNodeIndex, RpoDigest, StarkField, Vec, Word};
|
||||
use crate::utils::vec;
|
||||
use core::{
|
||||
cmp::{Ord, Ordering},
|
||||
ops::RangeBounds,
|
||||
};
|
||||
use winter_utils::collections::btree_map::Entry;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Depths at which leaves can exist in a tiered SMT.
|
||||
const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
|
||||
|
||||
/// Maximum node depth. This is also the bottom tier of the tree.
|
||||
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
|
||||
|
||||
// VALUE STORE
|
||||
// ================================================================================================
|
||||
/// A store for key-value pairs for a Tiered Sparse Merkle tree.
|
||||
///
|
||||
/// The store is organized in a [BTreeMap] where keys are 64 most significant bits of a key, and
|
||||
/// the values are the corresponding key-value pairs (or a list of key-value pairs if more that
|
||||
/// a single key-value pair shares the same 64-bit prefix).
|
||||
///
|
||||
/// The store supports lookup by the full key (i.e. [RpoDigest]) as well as by the 64-bit key
|
||||
/// prefix.
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct ValueStore {
|
||||
values: BTreeMap<u64, StoreEntry>,
|
||||
}
|
||||
|
||||
impl ValueStore {
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a reference to the value stored under the specified key, or None if there is no
|
||||
/// value associated with the specified key.
|
||||
pub fn get(&self, key: &RpoDigest) -> Option<&Word> {
|
||||
let prefix = get_key_prefix(key);
|
||||
self.values.get(&prefix).and_then(|entry| entry.get(key))
|
||||
}
|
||||
|
||||
/// Returns the first key-value pair such that the key prefix is greater than or equal to the
|
||||
/// specified prefix.
|
||||
pub fn get_first(&self, prefix: u64) -> Option<&(RpoDigest, Word)> {
|
||||
self.range(prefix..).next()
|
||||
}
|
||||
|
||||
/// Returns the first key-value pair such that the key prefix is greater than or equal to the
|
||||
/// specified prefix and the key value is not equal to the exclude_key value.
|
||||
pub fn get_first_filtered(
|
||||
&self,
|
||||
prefix: u64,
|
||||
exclude_key: &RpoDigest,
|
||||
) -> Option<&(RpoDigest, Word)> {
|
||||
self.range(prefix..).find(|(key, _)| key != exclude_key)
|
||||
}
|
||||
|
||||
/// Returns a vector with key-value pairs for all keys with the specified 64-bit prefix, or
|
||||
/// None if no keys with the specified prefix are present in this store.
|
||||
pub fn get_all(&self, prefix: u64) -> Option<Vec<(RpoDigest, Word)>> {
|
||||
self.values.get(&prefix).map(|entry| match entry {
|
||||
StoreEntry::Single(kv_pair) => vec![*kv_pair],
|
||||
StoreEntry::List(kv_pairs) => kv_pairs.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns information about a sibling of a leaf node with the specified index, but only if
|
||||
/// this is the only sibling the leaf has in some subtree starting at the first tier.
|
||||
///
|
||||
/// For example, if `index` is an index at depth 32, and there is a leaf node at depth 32 with
|
||||
/// the same root at depth 16 as `index`, we say that this leaf is a lone sibling.
|
||||
///
|
||||
/// The returned tuple contains: they key-value pair of the sibling as well as the index of
|
||||
/// the node for the root of the common subtree in which both nodes are leaves.
|
||||
///
|
||||
/// This method assumes that the key-value pair for the specified index has already been
|
||||
/// removed from the store.
|
||||
pub fn get_lone_sibling(
|
||||
&self,
|
||||
index: LeafNodeIndex,
|
||||
) -> Option<(&RpoDigest, &Word, LeafNodeIndex)> {
|
||||
// iterate over tiers from top to bottom, looking at the tiers which are strictly above
|
||||
// the depth of the index. This implies that only tiers at depth 32 and 48 will be
|
||||
// considered. For each tier, check if the parent of the index at the higher tier
|
||||
// contains a single node. The fist tier (depth 16) is excluded because we cannot move
|
||||
// nodes at depth 16 to a higher tier. This implies that nodes at the first tier will
|
||||
// never have "lone siblings".
|
||||
for &tier_depth in TIER_DEPTHS.iter().filter(|&t| index.depth() > *t) {
|
||||
// compute the index of the root at a higher tier
|
||||
let mut parent_index = index;
|
||||
parent_index.move_up_to(tier_depth);
|
||||
|
||||
// find the lone sibling, if any; we need to handle the "last node" at a given tier
|
||||
// separately specify the bounds for the search correctly.
|
||||
let start_prefix = parent_index.value() << (MAX_DEPTH - tier_depth);
|
||||
let sibling = if start_prefix.leading_ones() as u8 == tier_depth {
|
||||
let mut iter = self.range(start_prefix..);
|
||||
iter.next().filter(|_| iter.next().is_none())
|
||||
} else {
|
||||
let end_prefix = (parent_index.value() + 1) << (MAX_DEPTH - tier_depth);
|
||||
let mut iter = self.range(start_prefix..end_prefix);
|
||||
iter.next().filter(|_| iter.next().is_none())
|
||||
};
|
||||
|
||||
if let Some((key, value)) = sibling {
|
||||
return Some((key, value, parent_index));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns an iterator over all key-value pairs in this store.
|
||||
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
self.values.iter().flat_map(|(_, entry)| entry.iter())
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts the specified key-value pair into this store and returns the value previously
|
||||
/// associated with the specified key.
|
||||
///
|
||||
/// If no value was previously associated with the specified key, None is returned.
|
||||
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
||||
let prefix = get_key_prefix(&key);
|
||||
match self.values.entry(prefix) {
|
||||
Entry::Occupied(mut entry) => entry.get_mut().insert(key, value),
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(StoreEntry::new(key, value));
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes the key-value pair for the specified key from this store and returns the value
|
||||
/// associated with this key.
|
||||
///
|
||||
/// If no value was associated with the specified key, None is returned.
|
||||
pub fn remove(&mut self, key: &RpoDigest) -> Option<Word> {
|
||||
let prefix = get_key_prefix(key);
|
||||
match self.values.entry(prefix) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
let (value, remove_entry) = entry.get_mut().remove(key);
|
||||
if remove_entry {
|
||||
entry.remove_entry();
|
||||
}
|
||||
value
|
||||
}
|
||||
Entry::Vacant(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over all key-value pairs contained in this store such that the most
|
||||
/// significant 64 bits of the key lay within the specified bounds.
|
||||
///
|
||||
/// The order of iteration is from the smallest to the largest key.
|
||||
fn range<R: RangeBounds<u64>>(&self, bounds: R) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
self.values.range(bounds).flat_map(|(_, entry)| entry.iter())
|
||||
}
|
||||
}
|
||||
|
||||
// VALUE NODE
|
||||
// ================================================================================================
|
||||
|
||||
/// An entry in the [ValueStore].
|
||||
///
|
||||
/// An entry can contain either a single key-value pair or a vector of key-value pairs sorted by
|
||||
/// key.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub enum StoreEntry {
|
||||
Single((RpoDigest, Word)),
|
||||
List(Vec<(RpoDigest, Word)>),
|
||||
}
|
||||
|
||||
impl StoreEntry {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns a new [StoreEntry] instantiated with a single key-value pair.
|
||||
pub fn new(key: RpoDigest, value: Word) -> Self {
|
||||
Self::Single((key, value))
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the value associated with the specified key, or None if this entry does not contain
|
||||
/// a value associated with the specified key.
|
||||
pub fn get(&self, key: &RpoDigest) -> Option<&Word> {
|
||||
match self {
|
||||
StoreEntry::Single(kv_pair) => {
|
||||
if kv_pair.0 == *key {
|
||||
Some(&kv_pair.1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
StoreEntry::List(kv_pairs) => {
|
||||
match kv_pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, key)) {
|
||||
Ok(pos) => Some(&kv_pairs[pos].1),
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an iterator over all key-value pairs in this entry.
|
||||
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
EntryIterator { entry: self, pos: 0 }
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts the specified key-value pair into this entry and returns the value previously
|
||||
/// associated with the specified key, or None if no value was associated with the specified
|
||||
/// key.
|
||||
///
|
||||
/// If a new key is inserted, this will also transform a `SingleEntry` into a `ListEntry`.
|
||||
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
||||
match self {
|
||||
StoreEntry::Single(kv_pair) => {
|
||||
// if the key is already in this entry, update the value and return
|
||||
if kv_pair.0 == key {
|
||||
let old_value = kv_pair.1;
|
||||
kv_pair.1 = value;
|
||||
return Some(old_value);
|
||||
}
|
||||
|
||||
// transform the entry into a list entry, and make sure the key-value pairs
|
||||
// are sorted by key
|
||||
let mut pairs = vec![*kv_pair, (key, value)];
|
||||
pairs.sort_by(|a, b| cmp_digests(&a.0, &b.0));
|
||||
|
||||
*self = StoreEntry::List(pairs);
|
||||
None
|
||||
}
|
||||
StoreEntry::List(pairs) => {
|
||||
match pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, &key)) {
|
||||
Ok(pos) => {
|
||||
let old_value = pairs[pos].1;
|
||||
pairs[pos].1 = value;
|
||||
Some(old_value)
|
||||
}
|
||||
Err(pos) => {
|
||||
pairs.insert(pos, (key, value));
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes the key-value pair with the specified key from this entry, and returns the value
|
||||
/// of the removed pair. If the entry did not contain a key-value pair for the specified key,
|
||||
/// None is returned.
|
||||
///
|
||||
/// If the last last key-value pair was removed from the entry, the second tuple value will
|
||||
/// be set to true.
|
||||
pub fn remove(&mut self, key: &RpoDigest) -> (Option<Word>, bool) {
|
||||
match self {
|
||||
StoreEntry::Single(kv_pair) => {
|
||||
if kv_pair.0 == *key {
|
||||
(Some(kv_pair.1), true)
|
||||
} else {
|
||||
(None, false)
|
||||
}
|
||||
}
|
||||
StoreEntry::List(kv_pairs) => {
|
||||
match kv_pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, key)) {
|
||||
Ok(pos) => {
|
||||
let kv_pair = kv_pairs.remove(pos);
|
||||
if kv_pairs.len() == 1 {
|
||||
*self = StoreEntry::Single(kv_pairs[0]);
|
||||
}
|
||||
(Some(kv_pair.1), false)
|
||||
}
|
||||
Err(_) => (None, false),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A custom iterator over key-value pairs of a [StoreEntry].
|
||||
///
|
||||
/// For a `SingleEntry` this returns only one value, but for `ListEntry`, this iterates over the
|
||||
/// entire list of key-value pairs.
|
||||
pub struct EntryIterator<'a> {
|
||||
entry: &'a StoreEntry,
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for EntryIterator<'a> {
|
||||
type Item = &'a (RpoDigest, Word);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self.entry {
|
||||
StoreEntry::Single(kv_pair) => {
|
||||
if self.pos == 0 {
|
||||
self.pos = 1;
|
||||
Some(kv_pair)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
StoreEntry::List(kv_pairs) => {
|
||||
if self.pos >= kv_pairs.len() {
|
||||
None
|
||||
} else {
|
||||
let kv_pair = &kv_pairs[self.pos];
|
||||
self.pos += 1;
|
||||
Some(kv_pair)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Compares two digests element-by-element using their integer representations starting with the
|
||||
/// most significant element.
|
||||
fn cmp_digests(d1: &RpoDigest, d2: &RpoDigest) -> Ordering {
|
||||
let d1 = Word::from(d1);
|
||||
let d2 = Word::from(d2);
|
||||
|
||||
for (v1, v2) in d1.iter().zip(d2.iter()).rev() {
|
||||
let v1 = v1.as_int();
|
||||
let v2 = v2.as_int();
|
||||
if v1 != v2 {
|
||||
return v1.cmp(&v2);
|
||||
}
|
||||
}
|
||||
|
||||
Ordering::Equal
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{LeafNodeIndex, RpoDigest, StoreEntry, ValueStore};
|
||||
use crate::{Felt, ONE, WORD_SIZE, ZERO};
|
||||
|
||||
#[test]
|
||||
fn test_insert() {
|
||||
let mut store = ValueStore::default();
|
||||
|
||||
// insert the first key-value pair into the store
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE; WORD_SIZE];
|
||||
|
||||
assert!(store.insert(key_a, value_a).is_none());
|
||||
assert_eq!(store.values.len(), 1);
|
||||
|
||||
let entry = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry = StoreEntry::Single((key_a, value_a));
|
||||
assert_eq!(entry, &expected_entry);
|
||||
|
||||
// insert a key-value pair with a different key into the store; since the keys are
|
||||
// different, another entry is added to the values map
|
||||
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
|
||||
assert!(store.insert(key_b, value_b).is_none());
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 = StoreEntry::Single((key_a, value_a));
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
let expected_entry2 = StoreEntry::Single((key_b, value_b));
|
||||
assert_eq!(entry2, &expected_entry2);
|
||||
|
||||
// insert a key-value pair with the same 64-bit key prefix as the first key; this should
|
||||
// transform the first entry into a List entry
|
||||
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
|
||||
assert!(store.insert(key_c, value_c).is_none());
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 = StoreEntry::List(vec![(key_c, value_c), (key_a, value_a)]);
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
let expected_entry2 = StoreEntry::Single((key_b, value_b));
|
||||
assert_eq!(entry2, &expected_entry2);
|
||||
|
||||
// replace values for keys a and b
|
||||
let value_a2 = [ONE, ONE, ONE, ZERO];
|
||||
let value_b2 = [ZERO, ZERO, ZERO, ONE];
|
||||
|
||||
assert_eq!(store.insert(key_a, value_a2), Some(value_a));
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
assert_eq!(store.insert(key_b, value_b2), Some(value_b));
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 = StoreEntry::List(vec![(key_c, value_c), (key_a, value_a2)]);
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
let expected_entry2 = StoreEntry::Single((key_b, value_b2));
|
||||
assert_eq!(entry2, &expected_entry2);
|
||||
|
||||
// insert one more key-value pair with the same 64-bit key-prefix as the first key
|
||||
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_d = [ZERO, ONE, ZERO, ZERO];
|
||||
|
||||
assert!(store.insert(key_d, value_d).is_none());
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 =
|
||||
StoreEntry::List(vec![(key_c, value_c), (key_a, value_a2), (key_d, value_d)]);
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
let expected_entry2 = StoreEntry::Single((key_b, value_b2));
|
||||
assert_eq!(entry2, &expected_entry2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove() {
|
||||
// populate the value store
|
||||
let mut store = ValueStore::default();
|
||||
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE; WORD_SIZE];
|
||||
store.insert(key_a, value_a);
|
||||
|
||||
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
store.insert(key_b, value_b);
|
||||
|
||||
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
store.insert(key_c, value_c);
|
||||
|
||||
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_d = [ZERO, ONE, ZERO, ZERO];
|
||||
store.insert(key_d, value_d);
|
||||
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 =
|
||||
StoreEntry::List(vec![(key_c, value_c), (key_a, value_a), (key_d, value_d)]);
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
let expected_entry2 = StoreEntry::Single((key_b, value_b));
|
||||
assert_eq!(entry2, &expected_entry2);
|
||||
|
||||
// remove non-existent keys
|
||||
let key_e = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_a)]);
|
||||
assert!(store.remove(&key_e).is_none());
|
||||
|
||||
let raw_f = 0b_11111110_11111111_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_f = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_f)]);
|
||||
assert!(store.remove(&key_f).is_none());
|
||||
|
||||
// remove keys from the list entry
|
||||
assert_eq!(store.remove(&key_c).unwrap(), value_c);
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 = StoreEntry::List(vec![(key_a, value_a), (key_d, value_d)]);
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
assert_eq!(store.remove(&key_a).unwrap(), value_a);
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 = StoreEntry::Single((key_d, value_d));
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
assert_eq!(store.remove(&key_d).unwrap(), value_d);
|
||||
assert!(store.values.get(&raw_a).is_none());
|
||||
assert_eq!(store.values.len(), 1);
|
||||
|
||||
// remove a key from a single entry
|
||||
assert_eq!(store.remove(&key_b).unwrap(), value_b);
|
||||
assert!(store.values.get(&raw_b).is_none());
|
||||
assert_eq!(store.values.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range() {
|
||||
// populate the value store
|
||||
let mut store = ValueStore::default();
|
||||
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE; WORD_SIZE];
|
||||
store.insert(key_a, value_a);
|
||||
|
||||
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
store.insert(key_b, value_b);
|
||||
|
||||
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
store.insert(key_c, value_c);
|
||||
|
||||
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_d = [ZERO, ONE, ZERO, ZERO];
|
||||
store.insert(key_d, value_d);
|
||||
|
||||
let raw_e = 0b_10101000_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_e = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_e)]);
|
||||
let value_e = [ZERO, ZERO, ZERO, ONE];
|
||||
store.insert(key_e, value_e);
|
||||
|
||||
// check the entire range
|
||||
let mut iter = store.range(..u64::MAX);
|
||||
assert_eq!(iter.next(), Some(&(key_e, value_e)));
|
||||
assert_eq!(iter.next(), Some(&(key_c, value_c)));
|
||||
assert_eq!(iter.next(), Some(&(key_a, value_a)));
|
||||
assert_eq!(iter.next(), Some(&(key_d, value_d)));
|
||||
assert_eq!(iter.next(), Some(&(key_b, value_b)));
|
||||
assert_eq!(iter.next(), None);
|
||||
|
||||
// check all but e
|
||||
let mut iter = store.range(raw_a..u64::MAX);
|
||||
assert_eq!(iter.next(), Some(&(key_c, value_c)));
|
||||
assert_eq!(iter.next(), Some(&(key_a, value_a)));
|
||||
assert_eq!(iter.next(), Some(&(key_d, value_d)));
|
||||
assert_eq!(iter.next(), Some(&(key_b, value_b)));
|
||||
assert_eq!(iter.next(), None);
|
||||
|
||||
// check all but e and b
|
||||
let mut iter = store.range(raw_a..raw_b);
|
||||
assert_eq!(iter.next(), Some(&(key_c, value_c)));
|
||||
assert_eq!(iter.next(), Some(&(key_a, value_a)));
|
||||
assert_eq!(iter.next(), Some(&(key_d, value_d)));
|
||||
assert_eq!(iter.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_lone_sibling() {
|
||||
// populate the value store
|
||||
let mut store = ValueStore::default();
|
||||
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE; WORD_SIZE];
|
||||
store.insert(key_a, value_a);
|
||||
|
||||
let raw_b = 0b_11111111_11111111_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
store.insert(key_b, value_b);
|
||||
|
||||
// check sibling node for `a`
|
||||
let index = LeafNodeIndex::make(32, 0b_10101010_10101010_00011111_11111110);
|
||||
let parent_index = LeafNodeIndex::make(16, 0b_10101010_10101010);
|
||||
assert_eq!(store.get_lone_sibling(index), Some((&key_a, &value_a, parent_index)));
|
||||
|
||||
// check sibling node for `b`
|
||||
let index = LeafNodeIndex::make(32, 0b_11111111_11111111_00011111_11111111);
|
||||
let parent_index = LeafNodeIndex::make(16, 0b_11111111_11111111);
|
||||
assert_eq!(store.get_lone_sibling(index), Some((&key_b, &value_b, parent_index)));
|
||||
|
||||
// check some other sibling for some other index
|
||||
let index = LeafNodeIndex::make(32, 0b_11101010_10101010);
|
||||
assert_eq!(store.get_lone_sibling(index), None);
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
pub use winter_crypto::{DefaultRandomCoin as WinterRandomCoin, RandomCoin, RandomCoinError};
|
||||
|
||||
use crate::{Felt, FieldElement, Word, ZERO};
|
||||
use crate::{Felt, FieldElement, StarkField, Word, ZERO};
|
||||
|
||||
mod rpo;
|
||||
pub use rpo::RpoRandomCoin;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::{Felt, FeltRng, FieldElement, Word, ZERO};
|
||||
use super::{Felt, FeltRng, FieldElement, StarkField, Word, ZERO};
|
||||
use crate::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
utils::{
|
||||
@@ -19,7 +19,7 @@ const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.star
|
||||
// RPO RANDOM COIN
|
||||
// ================================================================================================
|
||||
/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
|
||||
/// described in <https://eprint.iacr.org/2011/499.pdf>.
|
||||
/// described in https://eprint.iacr.org/2011/499.pdf.
|
||||
///
|
||||
/// The simplification is related to the following facts:
|
||||
/// 1. A call to the reseed method implies one and only one call to the permutation function.
|
||||
|
||||
31
src/utils/diff.rs
Normal file
31
src/utils/diff.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
/// A trait for computing the difference between two objects.
|
||||
pub trait Diff<K: Ord + Clone, V: Clone> {
|
||||
/// The type that describes the difference between two objects.
|
||||
type DiffType;
|
||||
|
||||
/// Returns a [Self::DiffType] object that represents the difference between this object and
|
||||
/// other.
|
||||
fn diff(&self, other: &Self) -> Self::DiffType;
|
||||
}
|
||||
|
||||
/// A trait for applying the difference between two objects.
|
||||
pub trait ApplyDiff<K: Ord + Clone, V: Clone> {
|
||||
/// The type that describes the difference between two objects.
|
||||
type DiffType;
|
||||
|
||||
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
|
||||
fn apply(&mut self, diff: Self::DiffType);
|
||||
}
|
||||
|
||||
/// A trait for applying the difference between two objects with the possibility of failure.
|
||||
pub trait TryApplyDiff<K: Ord + Clone, V: Clone> {
|
||||
/// The type that describes the difference between two objects.
|
||||
type DiffType;
|
||||
|
||||
/// An error type that can be returned if the changes cannot be applied.
|
||||
type Error;
|
||||
|
||||
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
|
||||
/// Returns an error if the changes cannot be applied.
|
||||
fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), Self::Error>;
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
use super::{collections::ApplyDiff, diff::Diff};
|
||||
use core::cell::RefCell;
|
||||
use winter_utils::{
|
||||
collections::{btree_map::IntoIter, BTreeMap, BTreeSet},
|
||||
@@ -208,6 +209,74 @@ impl<K: Clone + Ord, V: Clone> IntoIterator for RecordingMap<K, V> {
|
||||
}
|
||||
}
|
||||
|
||||
// KV MAP DIFF
|
||||
// ================================================================================================
|
||||
/// [KvMapDiff] stores the difference between two key-value maps.
|
||||
///
|
||||
/// The [KvMapDiff] is composed of two parts:
|
||||
/// - `updates` - a map of key-value pairs that were updated in the second map compared to the
|
||||
/// first map. This includes new key-value pairs.
|
||||
/// - `removed` - a set of keys that were removed from the second map compared to the first map.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KvMapDiff<K, V> {
|
||||
pub updated: BTreeMap<K, V>,
|
||||
pub removed: BTreeSet<K>,
|
||||
}
|
||||
|
||||
impl<K, V> KvMapDiff<K, V> {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Creates a new [KvMapDiff] instance.
|
||||
pub fn new() -> Self {
|
||||
KvMapDiff {
|
||||
updated: BTreeMap::new(),
|
||||
removed: BTreeSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, V> Default for KvMapDiff<K, V> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Ord + Clone, V: Clone + PartialEq, T: KvMap<K, V>> Diff<K, V> for T {
|
||||
type DiffType = KvMapDiff<K, V>;
|
||||
|
||||
fn diff(&self, other: &T) -> Self::DiffType {
|
||||
let mut diff = KvMapDiff::default();
|
||||
for (k, v) in self.iter() {
|
||||
if let Some(other_value) = other.get(k) {
|
||||
if v != other_value {
|
||||
diff.updated.insert(k.clone(), other_value.clone());
|
||||
}
|
||||
} else {
|
||||
diff.removed.insert(k.clone());
|
||||
}
|
||||
}
|
||||
for (k, v) in other.iter() {
|
||||
if self.get(k).is_none() {
|
||||
diff.updated.insert(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
diff
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Ord + Clone, V: Clone, T: KvMap<K, V>> ApplyDiff<K, V> for T {
|
||||
type DiffType = KvMapDiff<K, V>;
|
||||
|
||||
fn apply(&mut self, diff: Self::DiffType) {
|
||||
for (k, v) in diff.updated {
|
||||
self.insert(k, v);
|
||||
}
|
||||
for k in diff.removed {
|
||||
self.remove(&k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
@@ -401,4 +470,35 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kv_map_diff() {
|
||||
let mut initial_state = ITEMS.into_iter().collect::<BTreeMap<_, _>>();
|
||||
let mut map = RecordingMap::new(initial_state.clone());
|
||||
|
||||
// remove an item that exists
|
||||
let key = 0;
|
||||
let _value = map.remove(&key).unwrap();
|
||||
|
||||
// add a new item
|
||||
let key = 100;
|
||||
let value = 100;
|
||||
map.insert(key, value);
|
||||
|
||||
// update an existing item
|
||||
let key = 1;
|
||||
let value = 100;
|
||||
map.insert(key, value);
|
||||
|
||||
// compute a diff
|
||||
let diff = initial_state.diff(map.inner());
|
||||
assert!(diff.updated.len() == 2);
|
||||
assert!(diff.updated.iter().all(|(k, v)| [(100, 100), (1, 100)].contains(&(*k, *v))));
|
||||
assert!(diff.removed.len() == 1);
|
||||
assert!(diff.removed.first() == Some(&0));
|
||||
|
||||
// apply the diff to the initial state and assert the contents are the same as the map
|
||||
initial_state.apply(diff);
|
||||
assert!(initial_state.iter().eq(map.iter()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ pub use alloc::{format, vec};
|
||||
#[cfg(feature = "std")]
|
||||
pub use std::{format, vec};
|
||||
|
||||
mod diff;
|
||||
mod kv_map;
|
||||
|
||||
// RE-EXPORTS
|
||||
@@ -19,6 +20,7 @@ pub use winter_utils::{
|
||||
};
|
||||
|
||||
pub mod collections {
|
||||
pub use super::diff::*;
|
||||
pub use super::kv_map::*;
|
||||
pub use winter_utils::collections::*;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user