40 Commits

Author SHA1 Message Date
Al-Kindi-0
0f06fa30a9 minor nits 2024-01-22 19:42:09 +01:00
Grzegorz Swirski
0b074a795d feat: use AVX2 instructions whenever available 2024-01-22 19:42:09 +01:00
Bobbin Threadbare
862ccf54dd Merge pull request #234 from reilabs/avx
feat: implement RPO hash using AVX2 instructions
2024-01-04 10:48:52 -08:00
Grzegorz Swirski
88bcdfd576 feat: use AVX2 instructions whenever available 2024-01-04 19:08:43 +01:00
Bobbin Threadbare
290894f497 Merge pull request #242 from 0xPolygonMiden/bobbin-partial-mmr-apply
Improvements to `PartialMmr::apply_delta()`
2023-12-24 13:58:03 -08:00
Bobbin Threadbare
4aac00884c fix: bugfix in PartialMmr apply delta 2023-12-23 20:38:08 -08:00
Bobbin Threadbare
2ef6f79656 Merge pull request #241 from 0xPolygonMiden/al-export-default-randcoin
Export default randomcoin
2023-12-21 11:21:56 -08:00
Al-Kindi-0
5142e2fd31 chore: export default Winterfell randomcoin 2023-12-21 14:26:23 +01:00
Bobbin Threadbare
9fb41337ec feat: add Clone derive to PartialMmr 2023-12-21 01:24:20 -08:00
Bobbin Threadbare
0296e05ccd refactor: return MmrPeaks from PartialMmr::peaks() 2023-12-21 01:00:52 -08:00
Bobbin Threadbare
499f97046d fix: typos 2023-12-21 00:17:41 -08:00
Bobbin Threadbare
600feafe53 feat: implement inner_nodes() iterator for PartialMmr 2023-12-21 00:16:36 -08:00
Bobbin Threadbare
9d854f1fcb feat: add serialization to RpoRandomCoin 2023-12-21 00:15:46 -08:00
Al-Kindi-0
af76cb10d0 feat: move RpoRandomCoin and define Rng trait
nits: minor

chore: update log and readme
2023-12-21 00:15:46 -08:00
Augusto F. Hack
4758e0672f serde: for MerklePath, ValuePath, and RootPath 2023-12-21 00:15:46 -08:00
Philippe Laferrière
8bb080a91d Implement SimpleSmt::set_subtree (#232)
* recompute_nodes_from_indeX_to_root

* MerkleError variant

* set_subtree

* test_simplesmt_set_subtree

* test_simplesmt_set_subtree_entire_tree

* test

* set_subtree: return root
2023-12-21 00:15:46 -08:00
Augusto F. Hack
e5f3b28645 bugfix: TSMT failed to verify empty word for depth 64.
When a prefix is pushed to the depth 64, the entry list includes only
the values different than ZERO. This is required, since each block
represents a 2^192 values.

The bug was in the proof membership code, that failed to handle the case
of a key that was not in the list, because the depth is 64 and the value
was not set.
2023-12-21 00:15:46 -08:00
Philippe Laferrière
29e0d07129 MmrPeaks::hash_peaks() returns Digest (#230) 2023-12-21 00:15:46 -08:00
Philippe Laferrière
81a94ecbe7 Remove ExactSizeIterator constraint from SimpleSmt::with_leaves() (#228)
* Change InvalidNumEntries error

* max computation

* remove length check

* remove ExactSizeIterator constraint

* fix InvalidNumEntries error condition

* 2_usize
2023-12-21 00:15:46 -08:00
Augusto F. Hack
223fbf887d simplesmt: simplify duplicate check 2023-12-21 00:15:46 -08:00
Philippe Laferrière
9e77a7c9b7 Introduce SimpleSmt::with_contiguous_leaves() (#227)
* with_contiguous_leaves

* test
2023-12-21 00:15:46 -08:00
Augusto F. Hack
894e20fe0c simplesmt: bugfix, index must be validated before modifying the tree 2023-12-21 00:15:46 -08:00
Austin Abell
7ec7b06574 feat: memoize Signature polynomial decoding 2023-12-21 00:15:46 -08:00
Philippe Laferriere
2499a8a2dd Consuming iterator for RpoDigest 2023-12-21 00:15:46 -08:00
Augusto F. Hack
800994c69b mmr: add into_parts for the peaks 2023-12-21 00:15:46 -08:00
Augusto F. Hack
26560605bf simple_smt: reduce serialized size, use static hashes of the empty word 2023-12-21 00:15:46 -08:00
Augusto F. Hack
672340d0c2 mmr: support accumulator of older forest versions 2023-12-21 00:15:46 -08:00
Bobbin Threadbare
8083b02aef chore: update changelog 2023-12-21 00:15:46 -08:00
Al-Kindi-0
ecb8719d45 chore: bump winterfell release to .7 2023-12-21 00:15:46 -08:00
Bobbin Threadbare
4144f98560 docs: update bench readme 2023-12-21 00:15:46 -08:00
Augusto F. Hack
c726050957 mmr: support proofs with older forest versions 2023-12-21 00:15:46 -08:00
Augusto F. Hack
9239340888 mmr: support arbitrary from/to delta updates 2023-12-21 00:15:46 -08:00
Augusto F. Hack
97ee9298a4 mmr: publicly export MmrDelta 2023-12-21 00:15:46 -08:00
Bobbin Threadbare
bfae06e128 docs: update changelog 2023-12-21 00:15:46 -08:00
Al-Kindi-0
b4e2d63c10 docs: added RPX benchmarks 2023-12-21 00:15:46 -08:00
Al-Kindi-0
9679329746 feat: RPX (xHash12) hash function implementation 2023-12-21 00:15:45 -08:00
Augusto F. Hack
2bbea37dbe rpo: added conversions for digest 2023-12-21 00:14:28 -08:00
Bobbin Threadbare
83000940da chore: update main readme 2023-12-21 00:14:28 -08:00
Augusto F. Hack
f44175e7a9 config: add .editorconfig 2023-12-21 00:14:28 -08:00
Bobbin Threadbare
4cf8eebff5 chore: update crate version to v0.8 2023-12-21 00:14:28 -08:00
67 changed files with 6074 additions and 2686 deletions

View File

@@ -4,74 +4,46 @@ on:
branches:
- main
pull_request:
types: [opened, reopened, synchronize]
types: [opened, repoened, synchronize]
jobs:
rustfmt:
name: rustfmt ${{matrix.toolchain}} on ${{matrix.os}}
runs-on: ${{matrix.os}}-latest
strategy:
fail-fast: false
matrix:
toolchain: [nightly]
os: [ubuntu]
steps:
- uses: actions/checkout@v4
- name: Install minimal Rust with rustfmt
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: ${{matrix.toolchain}}
components: rustfmt
override: true
- name: fmt
uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check
clippy:
name: clippy ${{matrix.toolchain}} on ${{matrix.os}}
runs-on: ${{matrix.os}}-latest
strategy:
fail-fast: false
matrix:
toolchain: [nightly]
os: [ubuntu]
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Install minimal Rust with clippy
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: ${{matrix.toolchain}}
components: clippy
override: true
- name: Clippy
uses: actions-rs/cargo@v1
with:
command: clippy
args: --all-targets -- -D clippy::all -D warnings
- name: Clippy all features
uses: actions-rs/cargo@v1
with:
command: clippy
args: --all-targets --all-features -- -D clippy::all -D warnings
test:
name: test ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.features}}
build:
name: Build ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.args}}
runs-on: ${{matrix.os}}-latest
strategy:
fail-fast: false
matrix:
toolchain: [stable, nightly]
os: [ubuntu]
features: ["--features default,serde", --no-default-features]
timeout-minutes: 30
target: [wasm32-unknown-unknown]
args: [--no-default-features --target wasm32-unknown-unknown]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@main
with:
submodules: recursive
- name: Install rust
uses: actions-rs/toolchain@v1
with:
toolchain: ${{matrix.toolchain}}
override: true
- run: rustup target add ${{matrix.target}}
- name: Test
uses: actions-rs/cargo@v1
with:
command: build
args: ${{matrix.args}}
test:
name: Test ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.features}}
runs-on: ${{matrix.os}}-latest
strategy:
fail-fast: false
matrix:
toolchain: [stable, nightly]
os: [ubuntu]
features: ["--features default,std,serde", --no-default-features]
steps:
- uses: actions/checkout@main
with:
submodules: recursive
- name: Install rust
@@ -85,49 +57,45 @@ jobs:
command: test
args: ${{matrix.features}}
no-std:
name: build ${{matrix.toolchain}} no-std for wasm32-unknown-unknown
clippy:
name: Clippy with ${{matrix.features}}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
toolchain: [stable, nightly]
features: ["--features default,std,serde", --no-default-features]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@main
with:
submodules: recursive
- name: Install rust
- name: Install minimal nightly with clippy
uses: actions-rs/toolchain@v1
with:
toolchain: ${{matrix.toolchain}}
profile: minimal
toolchain: nightly
components: clippy
override: true
- run: rustup target add wasm32-unknown-unknown
- name: Build
- name: Clippy
uses: actions-rs/cargo@v1
with:
command: build
args: --no-default-features --target wasm32-unknown-unknown
command: clippy
args: --all ${{matrix.features}} -- -D clippy::all -D warnings
docs:
name: Verify the docs on ${{matrix.toolchain}}
rustfmt:
name: rustfmt
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
toolchain: [stable]
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Install rust
- uses: actions/checkout@main
- name: Install minimal stable with rustfmt
uses: actions-rs/toolchain@v1
with:
toolchain: ${{matrix.toolchain}}
profile: minimal
toolchain: stable
components: rustfmt
override: true
- name: Check docs
- name: rustfmt
uses: actions-rs/cargo@v1
env:
RUSTDOCFLAGS: -D warnings
with:
command: doc
args: --verbose --all-features --keep-going
command: fmt
args: --all -- --check

3
.gitignore vendored
View File

@@ -11,6 +11,3 @@ Cargo.lock
# Generated by cmake
cmake-build-*
# VS Code
.vscode/

View File

@@ -1,14 +1,10 @@
## 0.8.0 (2024-02-14)
## 0.8.0 (TBD)
* Implemented the `PartialMmr` data structure (#195).
* Updated Winterfell dependency to v0.7 (#200)
* Implemented RPX hash function (#201).
* Added `FeltRng` and `RpoRandomCoin` (#237).
* Accelerated RPO/RPX hash functions using AVX512 instructions (#234).
* Added `inner_nodes()` method to `PartialMmr` (#238).
* Improved `PartialMmr::apply_delta()` (#242).
* Refactored `SimpleSmt` struct (#245).
* Replaced `TieredSmt` struct with `Smt` struct (#254, #277).
* Updated Winterfell dependency to v0.8 (#275).
## 0.7.1 (2023-10-10)

View File

@@ -72,7 +72,7 @@ For example, a new change to the AIR crate might have the following message: `fe
// ================================================================================
```
- [Rustfmt](https://github.com/rust-lang/rustfmt) and [Clippy](https://github.com/rust-lang/rust-clippy) linting is included in CI pipeline. Anyways it's preferable to run linting locally before push:
- [Rustfmt](https://github.com/rust-lang/rustfmt) and [Clippy](https://github.com/rust-lang/rust-clippy) linting is included in CI pipeline. Anyways it's prefferable to run linting locally before push:
```
cargo fix --allow-staged --allow-dirty --all-targets --all-features; cargo fmt; cargo clippy --workspace --all-targets --all-features -- -D warnings
```

View File

@@ -10,7 +10,7 @@ documentation = "https://docs.rs/miden-crypto/0.8.0"
categories = ["cryptography", "no-std"]
keywords = ["miden", "crypto", "hash", "merkle"]
edition = "2021"
rust-version = "1.75"
rust-version = "1.73"
[[bin]]
name = "miden-crypto"
@@ -35,30 +35,26 @@ harness = false
default = ["std"]
executable = ["dep:clap", "dep:rand_utils", "std"]
serde = ["dep:serde", "serde?/alloc", "winter_math/serde"]
std = [
"blake3/std",
"dep:cc",
"dep:libc",
"winter_crypto/std",
"winter_math/std",
"winter_utils/std",
]
std = ["blake3/std", "dep:cc", "dep:libc", "winter_crypto/std", "winter_math/std", "winter_utils/std"]
sve = ["std"]
[dependencies]
blake3 = { version = "1.5", default-features = false }
clap = { version = "4.5", features = ["derive"], optional = true }
libc = { version = "0.2", default-features = false, optional = true }
rand_utils = { version = "0.8", package = "winter-rand-utils", optional = true }
serde = { version = "1.0", features = ["derive"], default-features = false, optional = true }
winter_crypto = { version = "0.8", package = "winter-crypto", default-features = false }
winter_math = { version = "0.8", package = "winter-math", default-features = false }
winter_utils = { version = "0.8", package = "winter-utils", default-features = false }
clap = { version = "4.4", features = ["derive"], optional = true }
libc = { version = "0.2", default-features = false, optional = true }
rand_utils = { version = "0.7", package = "winter-rand-utils", optional = true }
serde = { version = "1.0", features = [ "derive" ], default-features = false, optional = true }
winter_crypto = { version = "0.7", package = "winter-crypto", default-features = false }
winter_math = { version = "0.7", package = "winter-math", default-features = false }
winter_utils = { version = "0.7", package = "winter-utils", default-features = false }
rayon = "1.8.0"
rand = "0.8.4"
rand_core = { version = "0.5", default-features = false }
[dev-dependencies]
seq-macro = { version = "0.3" }
criterion = { version = "0.5", features = ["html_reports"] }
proptest = "1.4"
rand_utils = { version = "0.8", package = "winter-rand-utils" }
proptest = "1.3"
rand_utils = { version = "0.7", package = "winter-rand-utils" }
[build-dependencies]
cc = { version = "1.0", features = ["parallel"], optional = true }

View File

@@ -19,7 +19,7 @@ For performance benchmarks of these hash functions and their comparison to other
* `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64.
* `PartialMmr`: a partial view of a Merkle mountain range structure.
* `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values.
* `Smt`: a Sparse Merkle tree (with compaction at depth 64), mapping 4-element keys to 4-element values.
* `TieredSmt`: a Sparse Merkle tree (with compaction), mapping 4-element keys to 4-element values.
The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state.
@@ -46,16 +46,10 @@ Both of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/
To compile with `no_std`, disable default features via `--no-default-features` flag.
### AVX2 acceleration
On platforms with [AVX2](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable AVX2 acceleration, the code needs to be compiled with the `avx2` target feature enabled. For example:
```shell
RUSTFLAGS="-C target-feature=+avx2" cargo build --release
```
### SVE acceleration
On platforms with [SVE](https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` target feature enabled. For example:
On platforms with [SVE](https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)) support, RPO hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` feature enabled. This feature has an effect only if the platform exposes `target-feature=sve` flag. On some platforms (e.g., Graviton 3), for this flag to be set, the compilation must be done in "native" mode. For example, to enable SVE acceleration on Graviton 3, we can execute the following:
```shell
RUSTFLAGS="-C target-feature=+sve" cargo build --release
RUSTFLAGS="-C target-cpu=native" cargo build --release --features sve
```
## Testing

View File

@@ -22,7 +22,6 @@ The second scenario is that of sequential hashing where we take a sequence of le
| Apple M2 Max | 71 ns | 233 ns | 1.3 µs | 7.9 µs | 4.6 µs | 2.4 µs |
| Amazon Graviton 3 | 108 ns | | | | 5.3 µs | 3.1 µs |
| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 µs | 9.1 µs | 5.5 µs | |
| AMD EPYC 9R14 | 83 ns | | | | 4.3 µs | 2.4 µs |
| Intel Core i5-8279U | 68 ns | 536 ns | 2.0 µs | 13.6 µs | 8.5 µs | 4.4 µs |
| Intel Xeon 8375C | 67 ns | | | | 8.2 µs | |
@@ -34,13 +33,11 @@ The second scenario is that of sequential hashing where we take a sequence of le
| Apple M2 Max | 0.9 µs | 1.5 µs | 17.4 µs | 103 µs | 60 µs | 31 µs |
| Amazon Graviton 3 | 1.4 µs | | | | 69 µs | 41 µs |
| AMD Ryzen 9 5950X | 0.8 µs | 1.7 µs | 15.7 µs | 120 µs | 72 µs | |
| AMD EPYC 9R14 | 0.9 µs | | | | 56 µs | 32 µs |
| Intel Core i5-8279U | 0.9 µs | | | | 107 µs | 56 µs |
| Intel Xeon 8375C | 0.8 µs | | | | 110 µs | |
Notes:
- On Graviton 3, RPO256 and RPX256 are run with SVE acceleration enabled.
- On AMD EPYC 9R14, RPO256 and RPX256 are run with AVX2 acceleration enabled.
### Instructions
Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks for RPO and BLAKE3, clone the current repository, and from the root directory of the repo run the following:

View File

@@ -32,6 +32,7 @@ fn rpo256_2to1(c: &mut Criterion) {
fn rpo256_sequential(c: &mut Criterion) {
let v: [Felt; 100] = (0..100)
.into_iter()
.map(Felt::new)
.collect::<Vec<Felt>>()
.try_into()
@@ -44,6 +45,7 @@ fn rpo256_sequential(c: &mut Criterion) {
bench.iter_batched(
|| {
let v: [Felt; 100] = (0..100)
.into_iter()
.map(|_| Felt::new(rand_value()))
.collect::<Vec<Felt>>()
.try_into()
@@ -78,6 +80,7 @@ fn rpx256_2to1(c: &mut Criterion) {
fn rpx256_sequential(c: &mut Criterion) {
let v: [Felt; 100] = (0..100)
.into_iter()
.map(Felt::new)
.collect::<Vec<Felt>>()
.try_into()
@@ -90,6 +93,7 @@ fn rpx256_sequential(c: &mut Criterion) {
bench.iter_batched(
|| {
let v: [Felt; 100] = (0..100)
.into_iter()
.map(|_| Felt::new(rand_value()))
.collect::<Vec<Felt>>()
.try_into()
@@ -125,6 +129,7 @@ fn blake3_2to1(c: &mut Criterion) {
fn blake3_sequential(c: &mut Criterion) {
let v: [Felt; 100] = (0..100)
.into_iter()
.map(Felt::new)
.collect::<Vec<Felt>>()
.try_into()
@@ -137,6 +142,7 @@ fn blake3_sequential(c: &mut Criterion) {
bench.iter_batched(
|| {
let v: [Felt; 100] = (0..100)
.into_iter()
.map(|_| Felt::new(rand_value()))
.collect::<Vec<Felt>>()
.try_into()

View File

@@ -1,20 +1,16 @@
use core::mem::swap;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use miden_crypto::{
merkle::{LeafIndex, SimpleSmt},
Felt, Word,
};
use miden_crypto::{merkle::SimpleSmt, Felt, Word};
use rand_utils::prng_array;
use seq_macro::seq;
fn smt_rpo(c: &mut Criterion) {
// setup trees
let mut seed = [0u8; 32];
let leaf = generate_word(&mut seed);
let mut trees = vec![];
seq!(DEPTH in 14..=20 {
let leaves = ((1 << DEPTH) - 1) as u64;
for depth in 14..=20 {
let leaves = ((1 << depth) - 1) as u64;
for count in [1, leaves / 2, leaves] {
let entries: Vec<_> = (0..count)
.map(|i| {
@@ -22,45 +18,50 @@ fn smt_rpo(c: &mut Criterion) {
(i, word)
})
.collect();
let mut tree = SimpleSmt::<DEPTH>::with_leaves(entries).unwrap();
// benchmark 1
let mut insert = c.benchmark_group("smt update_leaf".to_string());
{
let depth = DEPTH;
let key = count >> 2;
insert.bench_with_input(
format!("simple smt(depth:{depth},count:{count})"),
&(key, leaf),
|b, (key, leaf)| {
b.iter(|| {
tree.insert(black_box(LeafIndex::<DEPTH>::new(*key).unwrap()), black_box(*leaf));
});
},
);
}
insert.finish();
// benchmark 2
let mut path = c.benchmark_group("smt get_leaf_path".to_string());
{
let depth = DEPTH;
let key = count >> 2;
path.bench_with_input(
format!("simple smt(depth:{depth},count:{count})"),
&key,
|b, key| {
b.iter(|| {
tree.open(black_box(&LeafIndex::<DEPTH>::new(*key).unwrap()));
});
},
);
}
path.finish();
let tree = SimpleSmt::with_leaves(depth, entries).unwrap();
trees.push((tree, count));
}
});
}
let leaf = generate_word(&mut seed);
// benchmarks
let mut insert = c.benchmark_group(format!("smt update_leaf"));
for (tree, count) in trees.iter_mut() {
let depth = tree.depth();
let key = *count >> 2;
insert.bench_with_input(
format!("simple smt(depth:{depth},count:{count})"),
&(key, leaf),
|b, (key, leaf)| {
b.iter(|| {
tree.update_leaf(black_box(*key), black_box(*leaf)).unwrap();
});
},
);
}
insert.finish();
let mut path = c.benchmark_group(format!("smt get_leaf_path"));
for (tree, count) in trees.iter_mut() {
let depth = tree.depth();
let key = *count >> 2;
path.bench_with_input(
format!("simple smt(depth:{depth},count:{count})"),
&key,
|b, key| {
b.iter(|| {
tree.get_leaf_path(black_box(*key)).unwrap();
});
},
);
}
path.finish();
}
criterion_group!(smt_group, smt_rpo);

View File

@@ -1,7 +1,5 @@
use criterion::{black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use miden_crypto::merkle::{
DefaultMerkleStore as MerkleStore, LeafIndex, MerkleTree, NodeIndex, SimpleSmt, SMT_MAX_DEPTH,
};
use miden_crypto::merkle::{DefaultMerkleStore as MerkleStore, MerkleTree, NodeIndex, SimpleSmt};
use miden_crypto::Word;
use miden_crypto::{hash::rpo::RpoDigest, Felt};
use rand_utils::{rand_array, rand_value};
@@ -17,7 +15,7 @@ fn random_rpo_digest() -> RpoDigest {
/// Generates a random `Word`.
fn random_word() -> Word {
rand_array::<Felt, 4>()
rand_array::<Felt, 4>().into()
}
/// Generates an index at the specified depth in `0..range`.
@@ -30,26 +28,26 @@ fn random_index(range: u64, depth: u8) -> NodeIndex {
fn get_empty_leaf_simplesmt(c: &mut Criterion) {
let mut group = c.benchmark_group("get_empty_leaf_simplesmt");
const DEPTH: u8 = SMT_MAX_DEPTH;
let depth = SimpleSmt::MAX_DEPTH;
let size = u64::MAX;
// both SMT and the store are pre-populated with empty hashes, accessing these values is what is
// being benchmarked here, so no values are inserted into the backends
let smt = SimpleSmt::<DEPTH>::new().unwrap();
let smt = SimpleSmt::new(depth).unwrap();
let store = MerkleStore::from(&smt);
let root = smt.root();
group.bench_function(BenchmarkId::new("SimpleSmt", DEPTH), |b| {
group.bench_function(BenchmarkId::new("SimpleSmt", depth), |b| {
b.iter_batched(
|| random_index(size, DEPTH),
|| random_index(size, depth),
|index| black_box(smt.get_node(index)),
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("MerkleStore", DEPTH), |b| {
group.bench_function(BenchmarkId::new("MerkleStore", depth), |b| {
b.iter_batched(
|| random_index(size, DEPTH),
|| random_index(size, depth),
|index| black_box(store.get_node(root, index)),
BatchSize::SmallInput,
)
@@ -106,14 +104,15 @@ fn get_leaf_simplesmt(c: &mut Criterion) {
.enumerate()
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>();
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
let store = MerkleStore::from(&smt);
let depth = smt.depth();
let root = smt.root();
let size_u64 = size as u64;
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
b.iter_batched(
|| random_index(size_u64, SMT_MAX_DEPTH),
|| random_index(size_u64, depth),
|index| black_box(smt.get_node(index)),
BatchSize::SmallInput,
)
@@ -121,7 +120,7 @@ fn get_leaf_simplesmt(c: &mut Criterion) {
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
b.iter_batched(
|| random_index(size_u64, SMT_MAX_DEPTH),
|| random_index(size_u64, depth),
|index| black_box(store.get_node(root, index)),
BatchSize::SmallInput,
)
@@ -133,18 +132,18 @@ fn get_leaf_simplesmt(c: &mut Criterion) {
fn get_node_of_empty_simplesmt(c: &mut Criterion) {
let mut group = c.benchmark_group("get_node_of_empty_simplesmt");
const DEPTH: u8 = SMT_MAX_DEPTH;
let depth = SimpleSmt::MAX_DEPTH;
// both SMT and the store are pre-populated with the empty hashes, accessing the internal nodes
// of these values is what is being benchmarked here, so no values are inserted into the
// backends.
let smt = SimpleSmt::<DEPTH>::new().unwrap();
let smt = SimpleSmt::new(depth).unwrap();
let store = MerkleStore::from(&smt);
let root = smt.root();
let half_depth = DEPTH / 2;
let half_depth = depth / 2;
let half_size = 2_u64.pow(half_depth as u32);
group.bench_function(BenchmarkId::new("SimpleSmt", DEPTH), |b| {
group.bench_function(BenchmarkId::new("SimpleSmt", depth), |b| {
b.iter_batched(
|| random_index(half_size, half_depth),
|index| black_box(smt.get_node(index)),
@@ -152,7 +151,7 @@ fn get_node_of_empty_simplesmt(c: &mut Criterion) {
)
});
group.bench_function(BenchmarkId::new("MerkleStore", DEPTH), |b| {
group.bench_function(BenchmarkId::new("MerkleStore", depth), |b| {
b.iter_batched(
|| random_index(half_size, half_depth),
|index| black_box(store.get_node(root, index)),
@@ -213,10 +212,10 @@ fn get_node_simplesmt(c: &mut Criterion) {
.enumerate()
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>();
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
let store = MerkleStore::from(&smt);
let root = smt.root();
let half_depth = SMT_MAX_DEPTH / 2;
let half_depth = smt.depth() / 2;
let half_size = 2_u64.pow(half_depth as u32);
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
@@ -287,24 +286,23 @@ fn get_leaf_path_simplesmt(c: &mut Criterion) {
.enumerate()
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>();
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
let store = MerkleStore::from(&smt);
let depth = smt.depth();
let root = smt.root();
let size_u64 = size as u64;
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
b.iter_batched(
|| random_index(size_u64, SMT_MAX_DEPTH),
|index| {
black_box(smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(index.value()).unwrap()))
},
|| random_index(size_u64, depth),
|index| black_box(smt.get_path(index)),
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
b.iter_batched(
|| random_index(size_u64, SMT_MAX_DEPTH),
|| random_index(size_u64, depth),
|index| black_box(store.get_path(root, index)),
BatchSize::SmallInput,
)
@@ -354,7 +352,7 @@ fn new(c: &mut Criterion) {
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>()
},
|l| black_box(SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(l)),
|l| black_box(SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, l)),
BatchSize::SmallInput,
)
});
@@ -369,7 +367,7 @@ fn new(c: &mut Criterion) {
.collect::<Vec<(u64, Word)>>()
},
|l| {
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(l).unwrap();
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, l).unwrap();
black_box(MerkleStore::from(&smt));
},
BatchSize::SmallInput,
@@ -435,17 +433,16 @@ fn update_leaf_simplesmt(c: &mut Criterion) {
.enumerate()
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>();
let mut smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
let mut smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
let mut store = MerkleStore::from(&smt);
let depth = smt.depth();
let root = smt.root();
let size_u64 = size as u64;
group.bench_function(BenchmarkId::new("SimpleSMT", size), |b| {
b.iter_batched(
|| (rand_value::<u64>() % size_u64, random_word()),
|(index, value)| {
black_box(smt.insert(LeafIndex::<SMT_MAX_DEPTH>::new(index).unwrap(), value))
},
|(index, value)| black_box(smt.update_leaf(index, value)),
BatchSize::SmallInput,
)
});
@@ -453,7 +450,7 @@ fn update_leaf_simplesmt(c: &mut Criterion) {
let mut store_root = root;
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
b.iter_batched(
|| (random_index(size_u64, SMT_MAX_DEPTH), random_word()),
|| (random_index(size_u64, depth), random_word()),
|(index, value)| {
// The MerkleTree automatically updates its internal root, the Store maintains
// the old root and adds the new one. Here we update the root to have a fair

View File

@@ -2,7 +2,7 @@ fn main() {
#[cfg(feature = "std")]
compile_rpo_falcon();
#[cfg(target_feature = "sve")]
#[cfg(all(target_feature = "sve", feature = "sve"))]
compile_arch_arm64_sve();
}
@@ -34,7 +34,7 @@ fn compile_rpo_falcon() {
.compile("rpo_falcon512");
}
#[cfg(target_feature = "sve")]
#[cfg(all(target_feature = "sve", feature = "sve"))]
fn compile_arch_arm64_sve() {
const RPO_SVE_PATH: &str = "arch/arm64-sve/rpo";

View File

@@ -4,7 +4,7 @@ use super::{
};
#[cfg(feature = "std")]
use super::{ffi, NonceBytes, NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
use super::{ffi, NonceBytes, StarkField, NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
// PUBLIC KEY
// ================================================================================================

View File

@@ -4,7 +4,7 @@ use crate::{
collections::Vec, ByteReader, ByteWriter, Deserializable, DeserializationError,
Serializable,
},
Felt, Word, ZERO,
Felt, StarkField, Word, ZERO,
};
#[cfg(feature = "std")]
@@ -39,10 +39,10 @@ const NONCE_LEN: usize = 40;
const NONCE_ELEMENTS: usize = 8;
/// Public key length as a u8 vector.
pub const PK_LEN: usize = 897;
const PK_LEN: usize = 897;
/// Secret key length as a u8 vector.
pub const SK_LEN: usize = 1281;
const SK_LEN: usize = 1281;
/// Signature length as a u8 vector.
const SIG_LEN: usize = 626;

View File

@@ -4,7 +4,7 @@ use core::ops::{Add, Mul, Sub};
// FALCON POLYNOMIAL
// ================================================================================================
/// A polynomial over Z_p\[x\]/(phi) where phi := x^512 + 1
/// A polynomial over Z_p[x]/(phi) where phi := x^512 + 1
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct Polynomial([u16; N]);
@@ -24,7 +24,7 @@ impl Polynomial {
Self(data)
}
/// Decodes raw bytes representing a public key into a polynomial in Z_p\[x\]/(phi).
/// Decodes raw bytes representing a public key into a polynomial in Z_p[x]/(phi).
///
/// # Errors
/// Returns an error if:
@@ -69,14 +69,14 @@ impl Polynomial {
}
}
/// Decodes the signature into the coefficients of a polynomial in Z_p\[x\]/(phi). It assumes
/// Decodes the signature into the coefficients of a polynomial in Z_p[x]/(phi). It assumes
/// that the signature has been encoded using the uncompressed format.
///
/// # Errors
/// Returns an error if:
/// - The signature has been encoded using a different algorithm than the reference compressed
/// encoding algorithm.
/// - The encoded signature polynomial is in Z_p\[x\]/(phi') where phi' = x^N' + 1 and N' != 512.
/// - The encoded signature polynomial is in Z_p[x]/(phi') where phi' = x^N' + 1 and N' != 512.
/// - While decoding the high bits of a coefficient, the current accumulated value of its
/// high bits is larger than 2048.
/// - The decoded coefficient is -0.
@@ -149,12 +149,12 @@ impl Polynomial {
// POLYNOMIAL OPERATIONS
// --------------------------------------------------------------------------------------------
/// Multiplies two polynomials over Z_p\[x\] without reducing modulo p. Given that the degrees
/// Multiplies two polynomials over Z_p[x] without reducing modulo p. Given that the degrees
/// of the input polynomials are less than 512 and their coefficients are less than the modulus
/// q equal to 12289, the resulting product polynomial is guaranteed to have coefficients less
/// than the Miden prime.
///
/// Note that this multiplication is not over Z_p\[x\]/(phi).
/// Note that this multiplication is not over Z_p[x]/(phi).
pub fn mul_modulo_p(a: &Self, b: &Self) -> [u64; 1024] {
let mut c = [0; 2 * N];
for i in 0..N {
@@ -166,8 +166,8 @@ impl Polynomial {
c
}
/// Reduces a polynomial, that is the product of two polynomials over Z_p\[x\], modulo
/// the irreducible polynomial phi. This results in an element in Z_p\[x\]/(phi).
/// Reduces a polynomial, that is the product of two polynomials over Z_p[x], modulo
/// the irreducible polynomial phi. This results in an element in Z_p[x]/(phi).
pub fn reduce_negacyclic(a: &[u64; 1024]) -> Self {
let mut c = [0; N];
for i in 0..N {
@@ -181,7 +181,7 @@ impl Polynomial {
Self(c)
}
/// Computes the norm squared of a polynomial in Z_p\[x\]/(phi) after normalizing its
/// Computes the norm squared of a polynomial in Z_p[x]/(phi) after normalizing its
/// coefficients to be in the interval (-p/2, p/2].
pub fn sq_norm(&self) -> u64 {
let mut res = 0;
@@ -203,7 +203,7 @@ impl Default for Polynomial {
}
}
/// Multiplication over Z_p\[x\]/(phi)
/// Multiplication over Z_p[x]/(phi)
impl Mul for Polynomial {
type Output = Self;
@@ -227,7 +227,7 @@ impl Mul for Polynomial {
}
}
/// Addition over Z_p\[x\]/(phi)
/// Addition over Z_p[x]/(phi)
impl Add for Polynomial {
type Output = Self;
@@ -239,7 +239,7 @@ impl Add for Polynomial {
}
}
/// Subtraction over Z_p\[x\]/(phi)
/// Subtraction over Z_p[x]/(phi)
impl Sub for Polynomial {
type Output = Self;

View File

@@ -1,6 +1,6 @@
use super::{
ByteReader, ByteWriter, Deserializable, DeserializationError, Felt, NonceBytes, NonceElements,
Polynomial, PublicKeyBytes, Rpo256, Serializable, SignatureBytes, Word, MODULUS, N,
ByteReader, ByteWriter, Deserializable, DeserializationError, NonceBytes, NonceElements,
Polynomial, PublicKeyBytes, Rpo256, Serializable, SignatureBytes, StarkField, Word, MODULUS, N,
SIG_L2_BOUND, ZERO,
};
use crate::utils::string::ToString;
@@ -11,7 +11,7 @@ use core::cell::OnceCell;
/// An RPO Falcon512 signature over a message.
///
/// The signature is a pair of polynomials (s1, s2) in (Z_p\[x\]/(phi))^2, where:
/// The signature is a pair of polynomials (s1, s2) in (Z_p[x]/(phi))^2, where:
/// - p := 12289
/// - phi := x^512 + 1
/// - s1 = c - s2 * h
@@ -41,7 +41,6 @@ use core::cell::OnceCell;
/// 3. 625 bytes encoding the `s2` polynomial above.
///
/// The total size of the signature (including the extended public key) is 1563 bytes.
#[derive(Debug, Clone)]
pub struct Signature {
pub(super) pk: PublicKeyBytes,
pub(super) sig: SignatureBytes,
@@ -87,7 +86,7 @@ impl Signature {
// HASH-TO-POINT
// --------------------------------------------------------------------------------------------
/// Returns a polynomial in Z_p\[x\]/(phi) representing the hash of the provided message.
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message.
pub fn hash_to_point(&self, message: Word) -> Polynomial {
hash_to_point(message, &self.nonce())
}
@@ -134,7 +133,7 @@ impl Deserializable for Signature {
let pk_polynomial = Polynomial::from_pub_key(&pk)
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?
.into();
let sig_polynomial = Polynomial::from_signature(&sig)
let sig_polynomial = Polynomial::from_signature(&sig[41..])
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?
.into();
@@ -182,9 +181,7 @@ fn decode_nonce(nonce: &NonceBytes) -> NonceElements {
let mut result = [ZERO; 8];
for (i, bytes) in nonce.chunks(5).enumerate() {
buffer[..5].copy_from_slice(bytes);
// we can safely (without overflow) create a new Felt from u64 value here since this value
// contains at most 5 bytes
result[i] = Felt::new(u64::from_le_bytes(buffer));
result[i] = u64::from_le_bytes(buffer).into();
}
result
@@ -196,7 +193,7 @@ fn decode_nonce(nonce: &NonceBytes) -> NonceElements {
#[cfg(all(test, feature = "std"))]
mod tests {
use super::{
super::{ffi::*, Felt, KeyPair},
super::{ffi::*, Felt},
*,
};
use libc::c_void;
@@ -271,14 +268,4 @@ mod tests {
let nonce = decode_nonce(&nonce);
assert_eq!(res, hash_to_point(msg_felts, &nonce).inner());
}
#[test]
fn test_serialization_round_trip() {
let key = KeyPair::new().unwrap();
let signature = key.sign(Word::default()).unwrap();
let serialized = signature.to_bytes();
let deserialized = Signature::read_from_bytes(&serialized).unwrap();
assert_eq!(signature.sig_poly(), deserialized.sig_poly());
assert_eq!(signature.pub_key_poly(), deserialized.pub_key_poly());
}
}

977
src/gkr/circuit/mod.rs Normal file
View File

@@ -0,0 +1,977 @@
use alloc::sync::Arc;
use winter_crypto::{ElementHasher, RandomCoin};
use winter_math::fields::f64::BaseElement;
use winter_math::FieldElement;
use crate::gkr::multivariate::{
ComposedMultiLinearsOracle, EqPolynomial, GkrCompositionVanilla, MultiLinearOracle,
};
use crate::gkr::sumcheck::{sum_check_verify, Claim};
use super::multivariate::{
gen_plain_gkr_oracle, gkr_composition_from_composition_polys, ComposedMultiLinears,
CompositionPolynomial, MultiLinear,
};
use super::sumcheck::{
sum_check_prove, sum_check_verify_and_reduce, FinalEvaluationClaim,
PartialProof as SumcheckInstanceProof, RoundProof as SumCheckRoundProof, Witness,
};
/// Layered circuit for computing a sum of fractions.
///
/// The circuit computes a sum of fractions based on the formula a / c + b / d = (a * d + b * c) / (c * d)
/// which defines a "gate" ((a, b), (c, d)) --> (a * d + b * c, c * d) upon which the `FractionalSumCircuit`
/// is built.
/// TODO: Swap 1 and 0
#[derive(Debug)]
pub struct FractionalSumCircuit<E: FieldElement> {
p_1_vec: Vec<MultiLinear<E>>,
p_0_vec: Vec<MultiLinear<E>>,
q_1_vec: Vec<MultiLinear<E>>,
q_0_vec: Vec<MultiLinear<E>>,
}
impl<E: FieldElement> FractionalSumCircuit<E> {
/// Computes The values of the gates outputs for each of the layers of the fractional sum circuit.
pub fn new_(num_den: &Vec<MultiLinear<E>>) -> Self {
let mut p_1_vec: Vec<MultiLinear<E>> = Vec::new();
let mut p_0_vec: Vec<MultiLinear<E>> = Vec::new();
let mut q_1_vec: Vec<MultiLinear<E>> = Vec::new();
let mut q_0_vec: Vec<MultiLinear<E>> = Vec::new();
let num_layers = num_den[0].len().ilog2() as usize;
p_1_vec.push(num_den[0].to_owned());
p_0_vec.push(num_den[1].to_owned());
q_1_vec.push(num_den[2].to_owned());
q_0_vec.push(num_den[3].to_owned());
for i in 0..num_layers {
let (output_p_1, output_p_0, output_q_1, output_q_0) =
FractionalSumCircuit::compute_layer(
&p_1_vec[i],
&p_0_vec[i],
&q_1_vec[i],
&q_0_vec[i],
);
p_1_vec.push(output_p_1);
p_0_vec.push(output_p_0);
q_1_vec.push(output_q_1);
q_0_vec.push(output_q_0);
}
FractionalSumCircuit { p_1_vec, p_0_vec, q_1_vec, q_0_vec }
}
/// Compute the output values of the layer given a set of input values
fn compute_layer(
inp_p_1: &MultiLinear<E>,
inp_p_0: &MultiLinear<E>,
inp_q_1: &MultiLinear<E>,
inp_q_0: &MultiLinear<E>,
) -> (MultiLinear<E>, MultiLinear<E>, MultiLinear<E>, MultiLinear<E>) {
let len = inp_q_1.len();
let outp_p_1 = (0..len / 2)
.map(|i| inp_p_1[i] * inp_q_0[i] + inp_p_0[i] * inp_q_1[i])
.collect::<Vec<E>>();
let outp_p_0 = (len / 2..len)
.map(|i| inp_p_1[i] * inp_q_0[i] + inp_p_0[i] * inp_q_1[i])
.collect::<Vec<E>>();
let outp_q_1 = (0..len / 2).map(|i| inp_q_1[i] * inp_q_0[i]).collect::<Vec<E>>();
let outp_q_0 = (len / 2..len).map(|i| inp_q_1[i] * inp_q_0[i]).collect::<Vec<E>>();
(
MultiLinear::new(outp_p_1),
MultiLinear::new(outp_p_0),
MultiLinear::new(outp_q_1),
MultiLinear::new(outp_q_0),
)
}
/// Computes The values of the gates outputs for each of the layers of the fractional sum circuit.
pub fn new(poly: &MultiLinear<E>) -> Self {
let mut p_1_vec: Vec<MultiLinear<E>> = Vec::new();
let mut p_0_vec: Vec<MultiLinear<E>> = Vec::new();
let mut q_1_vec: Vec<MultiLinear<E>> = Vec::new();
let mut q_0_vec: Vec<MultiLinear<E>> = Vec::new();
let num_layers = poly.len().ilog2() as usize - 1;
let (output_p, output_q) = poly.split(poly.len() / 2);
let (output_p_1, output_p_0) = output_p.split(output_p.len() / 2);
let (output_q_1, output_q_0) = output_q.split(output_q.len() / 2);
p_1_vec.push(output_p_1);
p_0_vec.push(output_p_0);
q_1_vec.push(output_q_1);
q_0_vec.push(output_q_0);
for i in 0..num_layers - 1 {
let (output_p_1, output_p_0, output_q_1, output_q_0) =
FractionalSumCircuit::compute_layer(
&p_1_vec[i],
&p_0_vec[i],
&q_1_vec[i],
&q_0_vec[i],
);
p_1_vec.push(output_p_1);
p_0_vec.push(output_p_0);
q_1_vec.push(output_q_1);
q_0_vec.push(output_q_0);
}
FractionalSumCircuit { p_1_vec, p_0_vec, q_1_vec, q_0_vec }
}
/// Given a value r, computes the evaluation of the last layer at r when interpreted as (two)
/// multilinear polynomials.
pub fn evaluate(&self, r: E) -> (E, E) {
let len = self.p_1_vec.len();
assert_eq!(self.p_1_vec[len - 1].num_variables(), 0);
assert_eq!(self.p_0_vec[len - 1].num_variables(), 0);
assert_eq!(self.q_1_vec[len - 1].num_variables(), 0);
assert_eq!(self.q_0_vec[len - 1].num_variables(), 0);
let mut p_1 = self.p_1_vec[len - 1].clone();
p_1.extend(&self.p_0_vec[len - 1]);
let mut q_1 = self.q_1_vec[len - 1].clone();
q_1.extend(&self.q_0_vec[len - 1]);
(p_1.evaluate(&[r]), q_1.evaluate(&[r]))
}
}
/// A proof for reducing a claim on the correctness of the output of a layer to that of:
///
/// 1. Correctness of a sumcheck proof on the claimed output.
/// 2. Correctness of the evaluation of the input (to the said layer) at a random point when
/// interpreted as multilinear polynomial.
///
/// The verifier will then have to work backward and:
///
/// 1. Verify that the sumcheck proof is valid.
/// 2. Recurse on the (claimed evaluations) using the same approach as above.
///
/// Note that the following struct batches proofs for many circuits of the same type that
/// are independent i.e., parallel.
#[derive(Debug)]
pub struct LayerProof<E: FieldElement> {
pub proof: SumcheckInstanceProof<E>,
pub claims_sum_p1: E,
pub claims_sum_p0: E,
pub claims_sum_q1: E,
pub claims_sum_q0: E,
}
#[allow(dead_code)]
impl<E: FieldElement<BaseField = BaseElement> + 'static> LayerProof<E> {
/// Checks the validity of a `LayerProof`.
///
/// It first reduces the 2 claims to 1 claim using randomness and then checks that the sumcheck
/// protocol was correctly executed.
///
/// The method outputs:
///
/// 1. A vector containing the randomness sent by the verifier throughout the course of the
/// sum-check protocol.
/// 2. The (claimed) evaluation of the inner polynomial (i.e., the one being summed) at the this random vector.
/// 3. The random value used in the 2-to-1 reduction of the 2 sumchecks.
pub fn verify_sum_check_before_last<
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
H: ElementHasher<BaseField = BaseElement>,
>(
&self,
claim: (E, E),
num_rounds: usize,
transcript: &mut C,
) -> ((E, Vec<E>), E) {
// Absorb the claims
let data = vec![claim.0, claim.1];
transcript.reseed(H::hash_elements(&data));
// Squeeze challenge to reduce two sumchecks to one
let r_sum_check: E = transcript.draw().unwrap();
// Run the sumcheck protocol
// Given r_sum_check and claim, we create a Claim with the GKR composer and then call the generic sum-check verifier
let reduced_claim = claim.0 + claim.1 * r_sum_check;
// Create vanilla oracle
let oracle = gen_plain_gkr_oracle(num_rounds, r_sum_check);
// Create sum-check claim
let transformed_claim = Claim {
sum_value: reduced_claim,
polynomial: oracle,
};
let reduced_gkr_claim =
sum_check_verify_and_reduce(&transformed_claim, self.proof.clone(), transcript);
(reduced_gkr_claim, r_sum_check)
}
}
#[derive(Debug)]
pub struct GkrClaim<E: FieldElement + 'static> {
evaluation_point: Vec<E>,
claimed_evaluation: (E, E),
}
#[derive(Debug)]
pub struct CircuitProof<E: FieldElement + 'static> {
pub proof: Vec<LayerProof<E>>,
}
impl<E: FieldElement<BaseField = BaseElement> + 'static> CircuitProof<E> {
pub fn prove<
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
H: ElementHasher<BaseField = BaseElement>,
>(
circuit: &mut FractionalSumCircuit<E>,
transcript: &mut C,
) -> (Self, Vec<E>, Vec<Vec<E>>) {
let mut proof_layers: Vec<LayerProof<E>> = Vec::new();
let num_layers = circuit.p_0_vec.len();
let data = vec![
circuit.p_1_vec[num_layers - 1][0],
circuit.p_0_vec[num_layers - 1][0],
circuit.q_1_vec[num_layers - 1][0],
circuit.q_0_vec[num_layers - 1][0],
];
transcript.reseed(H::hash_elements(&data));
// Challenge to reduce p1, p0, q1, q0 to pr, qr
let r_cord = transcript.draw().unwrap();
// Compute the (2-to-1 folded) claim
let mut claim = circuit.evaluate(r_cord);
let mut all_rand = Vec::new();
let mut rand = Vec::new();
rand.push(r_cord);
for layer_id in (0..num_layers - 1).rev() {
let len = circuit.p_0_vec[layer_id].len();
// Construct the Lagrange kernel evaluated at previous GKR round randomness.
// TODO: Treat the direction of doing sum-check more robustly.
let mut rand_reversed = rand.clone();
rand_reversed.reverse();
let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations();
let mut poly_x = MultiLinear::from_values(&eq_evals);
assert_eq!(poly_x.len(), len);
let num_rounds = poly_x.len().ilog2() as usize;
// 1. A is a polynomial containing the evaluations `p_1`.
// 2. B is a polynomial containing the evaluations `p_0`.
// 3. C is a polynomial containing the evaluations `q_1`.
// 4. D is a polynomial containing the evaluations `q_0`.
let poly_a: &mut MultiLinear<E>;
let poly_b: &mut MultiLinear<E>;
let poly_c: &mut MultiLinear<E>;
let poly_d: &mut MultiLinear<E>;
poly_a = &mut circuit.p_1_vec[layer_id];
poly_b = &mut circuit.p_0_vec[layer_id];
poly_c = &mut circuit.q_1_vec[layer_id];
poly_d = &mut circuit.q_0_vec[layer_id];
let poly_vec_par = (poly_a, poly_b, poly_c, poly_d, &mut poly_x);
// The (non-linear) polynomial combining the multilinear polynomials
let comb_func = |a: &E, b: &E, c: &E, d: &E, x: &E, rho: &E| -> E {
(*a * *d + *b * *c + *rho * *c * *d) * *x
};
// Run the sumcheck protocol
let (proof, rand_sumcheck, claims_sum) = sum_check_prover_gkr_before_last::<E, _, _>(
claim,
num_rounds,
poly_vec_par,
comb_func,
transcript,
);
let (claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0, _claims_eq) =
claims_sum;
let data = vec![claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0];
transcript.reseed(H::hash_elements(&data));
// Produce a random challenge to condense claims into a single claim
let r_layer = transcript.draw().unwrap();
claim = (
claims_sum_p1 + r_layer * (claims_sum_p0 - claims_sum_p1),
claims_sum_q1 + r_layer * (claims_sum_q0 - claims_sum_q1),
);
// Collect the randomness used for the current layer in order to construct the random
// point where the input multilinear polynomials were evaluated.
let mut ext = rand_sumcheck;
ext.push(r_layer);
all_rand.push(rand);
rand = ext;
proof_layers.push(LayerProof {
proof,
claims_sum_p1,
claims_sum_p0,
claims_sum_q1,
claims_sum_q0,
});
}
(CircuitProof { proof: proof_layers }, rand, all_rand)
}
pub fn prove_virtual_bus<
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
H: ElementHasher<BaseField = BaseElement>,
>(
composition_polys: Vec<Vec<Arc<dyn CompositionPolynomial<E>>>>,
mls: &mut Vec<MultiLinear<E>>,
transcript: &mut C,
) -> (Vec<E>, Self, super::sumcheck::FullProof<E>) {
let num_evaluations = 1 << mls[0].num_variables();
// I) Evaluate the numerators and denominators over the boolean hyper-cube
let mut num_den: Vec<Vec<E>> = vec![vec![]; 4];
for i in 0..num_evaluations {
for j in 0..4 {
let query: Vec<E> = mls.iter().map(|ml| ml[i]).collect();
composition_polys[j].iter().for_each(|c| {
let evaluation = c.as_ref().evaluate(&query);
num_den[j].push(evaluation);
});
}
}
// II) Evaluate the GKR fractional sum circuit
let input: Vec<MultiLinear<E>> =
(0..4).map(|i| MultiLinear::from_values(&num_den[i])).collect();
let mut circuit = FractionalSumCircuit::new_(&input);
// III) Run the GKR prover for all layers except the last one
let (gkr_proofs, GkrClaim { evaluation_point, claimed_evaluation }) =
CircuitProof::prove_before_final(&mut circuit, transcript);
// IV) Run the sum-check prover for the last GKR layer counting backwards i.e., first layer
// in the circuit.
// 1) Build the EQ polynomial (Lagrange kernel) at the randomness sampled during the previous
// sum-check protocol run
let mut rand_reversed = evaluation_point.clone();
rand_reversed.reverse();
let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations();
let poly_x = MultiLinear::from_values(&eq_evals);
// 2) Add the Lagrange kernel to the list of MLs
mls.push(poly_x);
// 3) Absorb the final sum-check claims and generate randomness for 2-to-1 sum-check reduction
let data = vec![claimed_evaluation.0, claimed_evaluation.1];
transcript.reseed(H::hash_elements(&data));
// Squeeze challenge to reduce two sumchecks to one
let r_sum_check = transcript.draw().unwrap();
let reduced_claim = claimed_evaluation.0 + claimed_evaluation.1 * r_sum_check;
// 4) Create the composed ML representing the numerators and denominators of the topmost GKR layer
let gkr_final_composed_ml = gkr_composition_from_composition_polys(
&composition_polys,
r_sum_check,
1 << mls[0].num_variables,
);
let composed_ml =
ComposedMultiLinears::new(Arc::new(gkr_final_composed_ml.clone()), mls.to_vec());
// 5) Create the composed ML oracle. This will be used for verifying the FinalEvaluationClaim downstream
// TODO: This should be an input to the current function.
// TODO: Make MultiLinearOracle a variant in an enum so that it is possible to capture other types of oracles.
// For example, shifts of polynomials, Lagrange kernels at a random point or periodic (transparent) polynomials.
let left_num_oracle = MultiLinearOracle { id: 0 };
let right_num_oracle = MultiLinearOracle { id: 1 };
let left_denom_oracle = MultiLinearOracle { id: 2 };
let right_denom_oracle = MultiLinearOracle { id: 3 };
let eq_oracle = MultiLinearOracle { id: 4 };
let composed_ml_oracle = ComposedMultiLinearsOracle {
composer: (Arc::new(gkr_final_composed_ml.clone())),
multi_linears: vec![
eq_oracle,
left_num_oracle,
right_num_oracle,
left_denom_oracle,
right_denom_oracle,
],
};
// 6) Create the claim for the final sum-check protocol.
let claim = Claim {
sum_value: reduced_claim,
polynomial: composed_ml_oracle.clone(),
};
// 7) Create the witness for the sum-check claim.
let witness = Witness { polynomial: composed_ml };
let output = sum_check_prove(&claim, composed_ml_oracle, witness, transcript);
// 8) Create the claimed output of the circuit.
let circuit_outputs = vec![
circuit.p_1_vec.last().unwrap()[0],
circuit.p_0_vec.last().unwrap()[0],
circuit.q_1_vec.last().unwrap()[0],
circuit.q_0_vec.last().unwrap()[0],
];
// 9) Return:
// 1. The claimed circuit outputs.
// 2. GKR proofs of all circuit layers except the initial layer.
// 3. Output of the final sum-check protocol.
(circuit_outputs, gkr_proofs, output)
}
pub fn prove_before_final<
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
H: ElementHasher<BaseField = BaseElement>,
>(
sum_circuits: &mut FractionalSumCircuit<E>,
transcript: &mut C,
) -> (Self, GkrClaim<E>) {
let mut proof_layers: Vec<LayerProof<E>> = Vec::new();
let num_layers = sum_circuits.p_0_vec.len();
let data = vec![
sum_circuits.p_1_vec[num_layers - 1][0],
sum_circuits.p_0_vec[num_layers - 1][0],
sum_circuits.q_1_vec[num_layers - 1][0],
sum_circuits.q_0_vec[num_layers - 1][0],
];
transcript.reseed(H::hash_elements(&data));
// Challenge to reduce p1, p0, q1, q0 to pr, qr
let r_cord = transcript.draw().unwrap();
// Compute the (2-to-1 folded) claim
let mut claims_to_verify = sum_circuits.evaluate(r_cord);
let mut all_rand = Vec::new();
let mut rand = Vec::new();
rand.push(r_cord);
for layer_id in (1..num_layers - 1).rev() {
let len = sum_circuits.p_0_vec[layer_id].len();
// Construct the Lagrange kernel evaluated at previous GKR round randomness.
// TODO: Treat the direction of doing sum-check more robustly.
let mut rand_reversed = rand.clone();
rand_reversed.reverse();
let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations();
let mut poly_x = MultiLinear::from_values(&eq_evals);
assert_eq!(poly_x.len(), len);
let num_rounds = poly_x.len().ilog2() as usize;
// 1. A is a polynomial containing the evaluations `p_1`.
// 2. B is a polynomial containing the evaluations `p_0`.
// 3. C is a polynomial containing the evaluations `q_1`.
// 4. D is a polynomial containing the evaluations `q_0`.
let poly_a: &mut MultiLinear<E>;
let poly_b: &mut MultiLinear<E>;
let poly_c: &mut MultiLinear<E>;
let poly_d: &mut MultiLinear<E>;
poly_a = &mut sum_circuits.p_1_vec[layer_id];
poly_b = &mut sum_circuits.p_0_vec[layer_id];
poly_c = &mut sum_circuits.q_1_vec[layer_id];
poly_d = &mut sum_circuits.q_0_vec[layer_id];
let poly_vec = (poly_a, poly_b, poly_c, poly_d, &mut poly_x);
let claim = claims_to_verify;
// The (non-linear) polynomial combining the multilinear polynomials
let comb_func = |a: &E, b: &E, c: &E, d: &E, x: &E, rho: &E| -> E {
(*a * *d + *b * *c + *rho * *c * *d) * *x
};
// Run the sumcheck protocol
let (proof, rand_sumcheck, claims_sum) = sum_check_prover_gkr_before_last::<E, _, _>(
claim, num_rounds, poly_vec, comb_func, transcript,
);
let (claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0, _claims_eq) =
claims_sum;
let data = vec![claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0];
transcript.reseed(H::hash_elements(&data));
// Produce a random challenge to condense claims into a single claim
let r_layer = transcript.draw().unwrap();
claims_to_verify = (
claims_sum_p1 + r_layer * (claims_sum_p0 - claims_sum_p1),
claims_sum_q1 + r_layer * (claims_sum_q0 - claims_sum_q1),
);
// Collect the randomness used for the current layer in order to construct the random
// point where the input multilinear polynomials were evaluated.
let mut ext = rand_sumcheck;
ext.push(r_layer);
all_rand.push(rand);
rand = ext;
proof_layers.push(LayerProof {
proof,
claims_sum_p1,
claims_sum_p0,
claims_sum_q1,
claims_sum_q0,
});
}
let gkr_claim = GkrClaim {
evaluation_point: rand.clone(),
claimed_evaluation: claims_to_verify,
};
(CircuitProof { proof: proof_layers }, gkr_claim)
}
pub fn verify<
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
H: ElementHasher<BaseField = BaseElement>,
>(
&self,
claims_sum_vec: &[E],
transcript: &mut C,
) -> ((E, E), Vec<E>) {
let num_layers = self.proof.len() as usize - 1;
let mut rand: Vec<E> = Vec::new();
let data = claims_sum_vec;
transcript.reseed(H::hash_elements(&data));
let r_cord = transcript.draw().unwrap();
let p_poly_coef = vec![claims_sum_vec[0], claims_sum_vec[1]];
let q_poly_coef = vec![claims_sum_vec[2], claims_sum_vec[3]];
let p_poly = MultiLinear::new(p_poly_coef);
let q_poly = MultiLinear::new(q_poly_coef);
let p_eval = p_poly.evaluate(&[r_cord]);
let q_eval = q_poly.evaluate(&[r_cord]);
let mut reduced_claim = (p_eval, q_eval);
rand.push(r_cord);
for (num_rounds, i) in (0..num_layers).enumerate() {
let ((claim_last, rand_sumcheck), r_two_sumchecks) = self.proof[i]
.verify_sum_check_before_last::<_, _>(reduced_claim, num_rounds + 1, transcript);
let claims_sum_p1 = &self.proof[i].claims_sum_p1;
let claims_sum_p0 = &self.proof[i].claims_sum_p0;
let claims_sum_q1 = &self.proof[i].claims_sum_q1;
let claims_sum_q0 = &self.proof[i].claims_sum_q0;
let data = vec![
claims_sum_p1.clone(),
claims_sum_p0.clone(),
claims_sum_q1.clone(),
claims_sum_q0.clone(),
];
transcript.reseed(H::hash_elements(&data));
assert_eq!(rand.len(), rand_sumcheck.len());
let eq: E = (0..rand.len())
.map(|i| {
rand[i] * rand_sumcheck[i] + (E::ONE - rand[i]) * (E::ONE - rand_sumcheck[i])
})
.fold(E::ONE, |acc, term| acc * term);
let claim_expected: E = (*claims_sum_p1 * *claims_sum_q0
+ *claims_sum_p0 * *claims_sum_q1
+ r_two_sumchecks * *claims_sum_q1 * *claims_sum_q0)
* eq;
assert_eq!(claim_expected, claim_last);
// Produce a random challenge to condense claims into a single claim
let r_layer = transcript.draw().unwrap();
reduced_claim = (
*claims_sum_p1 + r_layer * (*claims_sum_p0 - *claims_sum_p1),
*claims_sum_q1 + r_layer * (*claims_sum_q0 - *claims_sum_q1),
);
// Collect the randomness' used for the current layer in order to construct the random
// point where the input multilinear polynomials were evaluated.
let mut ext = rand_sumcheck;
ext.push(r_layer);
rand = ext;
}
(reduced_claim, rand)
}
pub fn verify_virtual_bus<
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
H: ElementHasher<BaseField = BaseElement>,
>(
&self,
composition_polys: Vec<Vec<Arc<dyn CompositionPolynomial<E>>>>,
final_layer_proof: super::sumcheck::FullProof<E>,
claims_sum_vec: &[E],
transcript: &mut C,
) -> (FinalEvaluationClaim<E>, Vec<E>) {
let num_layers = self.proof.len() as usize;
let mut rand: Vec<E> = Vec::new();
// Check that a/b + d/e is equal to 0
assert_ne!(claims_sum_vec[2], E::ZERO);
assert_ne!(claims_sum_vec[3], E::ZERO);
assert_eq!(
claims_sum_vec[0] * claims_sum_vec[3] + claims_sum_vec[1] * claims_sum_vec[2],
E::ZERO
);
let data = claims_sum_vec;
transcript.reseed(H::hash_elements(&data));
let r_cord = transcript.draw().unwrap();
let p_poly_coef = vec![claims_sum_vec[0], claims_sum_vec[1]];
let q_poly_coef = vec![claims_sum_vec[2], claims_sum_vec[3]];
let p_poly = MultiLinear::new(p_poly_coef);
let q_poly = MultiLinear::new(q_poly_coef);
let p_eval = p_poly.evaluate(&[r_cord]);
let q_eval = q_poly.evaluate(&[r_cord]);
let mut reduced_claim = (p_eval, q_eval);
// I) Verify all GKR layers but for the last one counting backwards.
rand.push(r_cord);
for (num_rounds, i) in (0..num_layers).enumerate() {
let ((claim_last, rand_sumcheck), r_two_sumchecks) = self.proof[i]
.verify_sum_check_before_last::<_, _>(reduced_claim, num_rounds + 1, transcript);
let claims_sum_p1 = &self.proof[i].claims_sum_p1;
let claims_sum_p0 = &self.proof[i].claims_sum_p0;
let claims_sum_q1 = &self.proof[i].claims_sum_q1;
let claims_sum_q0 = &self.proof[i].claims_sum_q0;
let data = vec![
claims_sum_p1.clone(),
claims_sum_p0.clone(),
claims_sum_q1.clone(),
claims_sum_q0.clone(),
];
transcript.reseed(H::hash_elements(&data));
assert_eq!(rand.len(), rand_sumcheck.len());
let eq: E = (0..rand.len())
.map(|i| {
rand[i] * rand_sumcheck[i] + (E::ONE - rand[i]) * (E::ONE - rand_sumcheck[i])
})
.fold(E::ONE, |acc, term| acc * term);
let claim_expected: E = (*claims_sum_p1 * *claims_sum_q0
+ *claims_sum_p0 * *claims_sum_q1
+ r_two_sumchecks * *claims_sum_q1 * *claims_sum_q0)
* eq;
assert_eq!(claim_expected, claim_last);
// Produce a random challenge to condense claims into a single claim
let r_layer = transcript.draw().unwrap();
reduced_claim = (
*claims_sum_p1 + r_layer * (*claims_sum_p0 - *claims_sum_p1),
*claims_sum_q1 + r_layer * (*claims_sum_q0 - *claims_sum_q1),
);
let mut ext = rand_sumcheck;
ext.push(r_layer);
rand = ext;
}
// II) Verify the final GKR layer counting backwards.
// Absorb the claims
let data = vec![reduced_claim.0, reduced_claim.1];
transcript.reseed(H::hash_elements(&data));
// Squeeze challenge to reduce two sumchecks to one
let r_sum_check = transcript.draw().unwrap();
let reduced_claim = reduced_claim.0 + reduced_claim.1 * r_sum_check;
let gkr_final_composed_ml = gkr_composition_from_composition_polys(
&composition_polys,
r_sum_check,
1 << (num_layers + 1),
);
// TODO: refactor
let composed_ml_oracle = {
let left_num_oracle = MultiLinearOracle { id: 0 };
let right_num_oracle = MultiLinearOracle { id: 1 };
let left_denom_oracle = MultiLinearOracle { id: 2 };
let right_denom_oracle = MultiLinearOracle { id: 3 };
let eq_oracle = MultiLinearOracle { id: 4 };
ComposedMultiLinearsOracle {
composer: (Arc::new(gkr_final_composed_ml.clone())),
multi_linears: vec![
eq_oracle,
left_num_oracle,
right_num_oracle,
left_denom_oracle,
right_denom_oracle,
],
}
};
let claim = Claim {
sum_value: reduced_claim,
polynomial: composed_ml_oracle.clone(),
};
let final_eval_claim = sum_check_verify(&claim, final_layer_proof, transcript);
(final_eval_claim, rand)
}
}
fn sum_check_prover_gkr_before_last<
E: FieldElement<BaseField = BaseElement>,
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
H: ElementHasher<BaseField = BaseElement>,
>(
claim: (E, E),
num_rounds: usize,
ml_polys: (
&mut MultiLinear<E>,
&mut MultiLinear<E>,
&mut MultiLinear<E>,
&mut MultiLinear<E>,
&mut MultiLinear<E>,
),
comb_func: impl Fn(&E, &E, &E, &E, &E, &E) -> E,
transcript: &mut C,
) -> (SumcheckInstanceProof<E>, Vec<E>, (E, E, E, E, E)) {
// Absorb the claims
let data = vec![claim.0, claim.1];
transcript.reseed(H::hash_elements(&data));
// Squeeze challenge to reduce two sumchecks to one
let r_sum_check = transcript.draw().unwrap();
let (poly_a, poly_b, poly_c, poly_d, poly_x) = ml_polys;
let mut e = claim.0 + claim.1 * r_sum_check;
let mut r: Vec<E> = Vec::new();
let mut round_proofs: Vec<SumCheckRoundProof<E>> = Vec::new();
for _j in 0..num_rounds {
let evals: (E, E, E) = {
let mut eval_point_0 = E::ZERO;
let mut eval_point_2 = E::ZERO;
let mut eval_point_3 = E::ZERO;
let len = poly_a.len() / 2;
for i in 0..len {
// The interpolation formula for a linear function is:
// z * A(x) + (1 - z) * A (y)
// z * A(1) + (1 - z) * A(0)
// eval at z = 0: A(1)
eval_point_0 += comb_func(
&poly_a[i << 1],
&poly_b[i << 1],
&poly_c[i << 1],
&poly_d[i << 1],
&poly_x[i << 1],
&r_sum_check,
);
let poly_a_u = poly_a[(i << 1) + 1];
let poly_a_v = poly_a[i << 1];
let poly_b_u = poly_b[(i << 1) + 1];
let poly_b_v = poly_b[i << 1];
let poly_c_u = poly_c[(i << 1) + 1];
let poly_c_v = poly_c[i << 1];
let poly_d_u = poly_d[(i << 1) + 1];
let poly_d_v = poly_d[i << 1];
let poly_x_u = poly_x[(i << 1) + 1];
let poly_x_v = poly_x[i << 1];
// eval at z = 2: 2 * A(1) - A(0)
let poly_a_extrapolated_point = poly_a_u + poly_a_u - poly_a_v;
let poly_b_extrapolated_point = poly_b_u + poly_b_u - poly_b_v;
let poly_c_extrapolated_point = poly_c_u + poly_c_u - poly_c_v;
let poly_d_extrapolated_point = poly_d_u + poly_d_u - poly_d_v;
let poly_x_extrapolated_point = poly_x_u + poly_x_u - poly_x_v;
eval_point_2 += comb_func(
&poly_a_extrapolated_point,
&poly_b_extrapolated_point,
&poly_c_extrapolated_point,
&poly_d_extrapolated_point,
&poly_x_extrapolated_point,
&r_sum_check,
);
// eval at z = 3: 3 * A(1) - 2 * A(0) = 2 * A(1) - A(0) + A(1) - A(0)
// hence we can compute the evaluation at z + 1 from that of z for z > 1
let poly_a_extrapolated_point = poly_a_extrapolated_point + poly_a_u - poly_a_v;
let poly_b_extrapolated_point = poly_b_extrapolated_point + poly_b_u - poly_b_v;
let poly_c_extrapolated_point = poly_c_extrapolated_point + poly_c_u - poly_c_v;
let poly_d_extrapolated_point = poly_d_extrapolated_point + poly_d_u - poly_d_v;
let poly_x_extrapolated_point = poly_x_extrapolated_point + poly_x_u - poly_x_v;
eval_point_3 += comb_func(
&poly_a_extrapolated_point,
&poly_b_extrapolated_point,
&poly_c_extrapolated_point,
&poly_d_extrapolated_point,
&poly_x_extrapolated_point,
&r_sum_check,
);
}
(eval_point_0, eval_point_2, eval_point_3)
};
let eval_0 = evals.0;
let eval_2 = evals.1;
let eval_3 = evals.2;
let evals = vec![e - eval_0, eval_2, eval_3];
let compressed_poly = SumCheckRoundProof { poly_evals: evals };
// append the prover's message to the transcript
transcript.reseed(H::hash_elements(&compressed_poly.poly_evals));
// derive the verifier's challenge for the next round
let r_j = transcript.draw().unwrap();
r.push(r_j);
poly_a.bind_assign(r_j);
poly_b.bind_assign(r_j);
poly_c.bind_assign(r_j);
poly_d.bind_assign(r_j);
poly_x.bind_assign(r_j);
e = compressed_poly.evaluate(e, r_j);
round_proofs.push(compressed_poly);
}
let claims_sum = (poly_a[0], poly_b[0], poly_c[0], poly_d[0], poly_x[0]);
(SumcheckInstanceProof { round_proofs }, r, claims_sum)
}
#[cfg(test)]
mod sum_circuit_tests {
use crate::rand::RpoRandomCoin;
use super::*;
use rand::Rng;
use rand_utils::rand_value;
use BaseElement as Felt;
/// The following tests the fractional sum circuit to check that \sum_{i = 0}^{log(m)-1} m / 2^{i} = 2 * (m - 1)
#[test]
fn sum_circuit_example() {
let n = 4; // n := log(m)
let mut inp: Vec<Felt> = (0..n).map(|_| Felt::from(1_u64 << n)).collect();
let inp_: Vec<Felt> = (0..n).map(|i| Felt::from(1_u64 << i)).collect();
inp.extend(inp_.iter());
let summation = MultiLinear::new(inp);
let expected_output = Felt::from(2 * ((1_u64 << n) - 1));
let mut circuit = FractionalSumCircuit::new(&summation);
let seed = [BaseElement::ZERO; 4];
let mut transcript = RpoRandomCoin::new(seed.into());
let (proof, _evals, _) = CircuitProof::prove(&mut circuit, &mut transcript);
let (p1, q1) = circuit.evaluate(Felt::from(1_u8));
let (p0, q0) = circuit.evaluate(Felt::from(0_u8));
assert_eq!(expected_output, (p1 * q0 + q1 * p0) / (q1 * q0));
let seed = [BaseElement::ZERO; 4];
let mut transcript = RpoRandomCoin::new(seed.into());
let claims = vec![p0, p1, q0, q1];
proof.verify(&claims, &mut transcript);
}
// Test the fractional sum GKR in the context of LogUp.
#[test]
fn log_up() {
use rand::distributions::Slice;
let n: usize = 16;
let num_w: usize = 31; // This should be of the form 2^k - 1
let rng = rand::thread_rng();
let t_table: Vec<u32> = (0..(1 << n)).collect();
let mut m_table: Vec<u32> = (0..(1 << n)).map(|_| 0).collect();
let t_table_slice = Slice::new(&t_table).unwrap();
// Construct the witness columns. Uses sampling with replacement in order to have multiplicities
// different from 1.
let mut w_tables = Vec::new();
for _ in 0..num_w {
let wi_table: Vec<u32> =
rng.clone().sample_iter(&t_table_slice).cloned().take(1 << n).collect();
// Construct the multiplicities
wi_table.iter().for_each(|w| {
m_table[*w as usize] += 1;
});
w_tables.push(wi_table)
}
// The numerators
let mut p: Vec<Felt> = m_table.iter().map(|m| Felt::from(*m as u32)).collect();
p.extend((0..(num_w * (1 << n))).map(|_| Felt::from(1_u32)).collect::<Vec<Felt>>());
// Sample the challenge alpha to construct the denominators.
let alpha = rand_value();
// Construct the denominators
let mut q: Vec<Felt> = t_table.iter().map(|t| Felt::from(*t) - alpha).collect();
for w_table in w_tables {
q.extend(w_table.iter().map(|w| alpha - Felt::from(*w)).collect::<Vec<Felt>>());
}
// Build the input to the fractional sum GKR circuit
p.extend(q);
let input = p;
let summation = MultiLinear::new(input);
let expected_output = Felt::from(0_u8);
let mut circuit = FractionalSumCircuit::new(&summation);
let seed = [BaseElement::ZERO; 4];
let mut transcript = RpoRandomCoin::new(seed.into());
let (proof, _evals, _) = CircuitProof::prove(&mut circuit, &mut transcript);
let (p1, q1) = circuit.evaluate(Felt::from(1_u8));
let (p0, q0) = circuit.evaluate(Felt::from(0_u8));
assert_eq!(expected_output, (p1 * q0 + q1 * p0) / (q1 * q0)); // This check should be part of verification
let seed = [BaseElement::ZERO; 4];
let mut transcript = RpoRandomCoin::new(seed.into());
let claims = vec![p0, p1, q0, q1];
proof.verify(&claims, &mut transcript);
}
}

7
src/gkr/mod.rs Normal file
View File

@@ -0,0 +1,7 @@
#![allow(unused_imports)]
#![allow(dead_code)]
mod sumcheck;
mod multivariate;
mod utils;
mod circuit;

View File

@@ -0,0 +1,34 @@
use super::FieldElement;
pub struct EqPolynomial<E> {
r: Vec<E>,
}
impl<E: FieldElement> EqPolynomial<E> {
pub fn new(r: Vec<E>) -> Self {
EqPolynomial { r }
}
pub fn evaluate(&self, rho: &[E]) -> E {
assert_eq!(self.r.len(), rho.len());
(0..rho.len())
.map(|i| self.r[i] * rho[i] + (E::ONE - self.r[i]) * (E::ONE - rho[i]))
.fold(E::ONE, |acc, term| acc * term)
}
pub fn evaluations(&self) -> Vec<E> {
let nu = self.r.len();
let mut evals: Vec<E> = vec![E::ONE; 1 << nu];
let mut size = 1;
for j in 0..nu {
size *= 2;
for i in (0..size).rev().step_by(2) {
let scalar = evals[i / 2];
evals[i] = scalar * self.r[j];
evals[i - 1] = scalar - evals[i];
}
}
evals
}
}

543
src/gkr/multivariate/mod.rs Normal file
View File

@@ -0,0 +1,543 @@
use core::ops::Index;
use alloc::sync::Arc;
use winter_math::{fields::f64::BaseElement, log2, FieldElement, StarkField};
mod eq_poly;
pub use eq_poly::EqPolynomial;
#[derive(Clone, Debug)]
pub struct MultiLinear<E: FieldElement> {
pub num_variables: usize,
pub evaluations: Vec<E>,
}
impl<E: FieldElement> MultiLinear<E> {
pub fn new(values: Vec<E>) -> Self {
Self {
num_variables: log2(values.len()) as usize,
evaluations: values,
}
}
pub fn from_values(values: &[E]) -> Self {
Self {
num_variables: log2(values.len()) as usize,
evaluations: values.to_owned(),
}
}
pub fn num_variables(&self) -> usize {
self.num_variables
}
pub fn evaluations(&self) -> &[E] {
&self.evaluations
}
pub fn len(&self) -> usize {
self.evaluations.len()
}
pub fn evaluate(&self, query: &[E]) -> E {
let tensored_query = tensorize(query);
inner_product(&self.evaluations, &tensored_query)
}
pub fn bind(&self, round_challenge: E) -> Self {
let mut result = vec![E::ZERO; 1 << (self.num_variables() - 1)];
for i in 0..(1 << (self.num_variables() - 1)) {
result[i] = self.evaluations[i << 1]
+ round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]);
}
Self::from_values(&result)
}
pub fn bind_assign(&mut self, round_challenge: E) {
let mut result = vec![E::ZERO; 1 << (self.num_variables() - 1)];
for i in 0..(1 << (self.num_variables() - 1)) {
result[i] = self.evaluations[i << 1]
+ round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]);
}
*self = Self::from_values(&result);
}
pub fn split(&self, at: usize) -> (Self, Self) {
assert!(at < self.len());
(
Self::new(self.evaluations[..at].to_vec()),
Self::new(self.evaluations[at..2 * at].to_vec()),
)
}
pub fn extend(&mut self, other: &MultiLinear<E>) {
let other_vec = other.evaluations.to_vec();
assert_eq!(other_vec.len(), self.len());
self.evaluations.extend(other_vec);
self.num_variables += 1;
}
}
impl<E: FieldElement> Index<usize> for MultiLinear<E> {
type Output = E;
fn index(&self, index: usize) -> &E {
&(self.evaluations[index])
}
}
/// A multi-variate polynomial for composing individual multi-linear polynomials
pub trait CompositionPolynomial<E: FieldElement>: Sync + Send {
/// The number of variables when interpreted as a multi-variate polynomial.
fn num_variables(&self) -> usize;
/// Maximum degree in all variables.
fn max_degree(&self) -> usize;
/// Given a query, of length equal the number of variables, evaluate [Self] at this query.
fn evaluate(&self, query: &[E]) -> E;
}
pub struct ComposedMultiLinears<E: FieldElement> {
pub composer: Arc<dyn CompositionPolynomial<E>>,
pub multi_linears: Vec<MultiLinear<E>>,
}
impl<E: FieldElement> ComposedMultiLinears<E> {
pub fn new(
composer: Arc<dyn CompositionPolynomial<E>>,
multi_linears: Vec<MultiLinear<E>>,
) -> Self {
Self { composer, multi_linears }
}
pub fn num_ml(&self) -> usize {
self.multi_linears.len()
}
pub fn num_variables(&self) -> usize {
self.composer.num_variables()
}
pub fn num_variables_ml(&self) -> usize {
self.multi_linears[0].num_variables
}
pub fn degree(&self) -> usize {
self.composer.max_degree()
}
pub fn bind(&self, round_challenge: E) -> ComposedMultiLinears<E> {
let result: Vec<MultiLinear<E>> =
self.multi_linears.iter().map(|f| f.bind(round_challenge)).collect();
Self {
composer: self.composer.clone(),
multi_linears: result,
}
}
}
#[derive(Clone)]
pub struct ComposedMultiLinearsOracle<E: FieldElement> {
pub composer: Arc<dyn CompositionPolynomial<E>>,
pub multi_linears: Vec<MultiLinearOracle>,
}
#[derive(Debug, Clone)]
pub struct MultiLinearOracle {
pub id: usize,
}
// Composition polynomials
pub struct IdentityComposition {
num_variables: usize,
}
impl IdentityComposition {
pub fn new() -> Self {
Self { num_variables: 1 }
}
}
impl<E> CompositionPolynomial<E> for IdentityComposition
where
E: FieldElement,
{
fn num_variables(&self) -> usize {
self.num_variables
}
fn max_degree(&self) -> usize {
self.num_variables
}
fn evaluate(&self, query: &[E]) -> E {
assert_eq!(query.len(), 1);
query[0]
}
}
pub struct ProjectionComposition {
coordinate: usize,
}
impl ProjectionComposition {
pub fn new(coordinate: usize) -> Self {
Self { coordinate }
}
}
impl<E> CompositionPolynomial<E> for ProjectionComposition
where
E: FieldElement,
{
fn num_variables(&self) -> usize {
1
}
fn max_degree(&self) -> usize {
1
}
fn evaluate(&self, query: &[E]) -> E {
query[self.coordinate]
}
}
pub struct LogUpDenominatorTableComposition<E>
where
E: FieldElement,
{
projection_coordinate: usize,
alpha: E,
}
impl<E> LogUpDenominatorTableComposition<E>
where
E: FieldElement,
{
pub fn new(projection_coordinate: usize, alpha: E) -> Self {
Self { projection_coordinate, alpha }
}
}
impl<E> CompositionPolynomial<E> for LogUpDenominatorTableComposition<E>
where
E: FieldElement,
{
fn num_variables(&self) -> usize {
1
}
fn max_degree(&self) -> usize {
1
}
fn evaluate(&self, query: &[E]) -> E {
query[self.projection_coordinate] + self.alpha
}
}
pub struct LogUpDenominatorWitnessComposition<E>
where
E: FieldElement,
{
projection_coordinate: usize,
alpha: E,
}
impl<E> LogUpDenominatorWitnessComposition<E>
where
E: FieldElement,
{
pub fn new(projection_coordinate: usize, alpha: E) -> Self {
Self { projection_coordinate, alpha }
}
}
impl<E> CompositionPolynomial<E> for LogUpDenominatorWitnessComposition<E>
where
E: FieldElement,
{
fn num_variables(&self) -> usize {
1
}
fn max_degree(&self) -> usize {
1
}
fn evaluate(&self, query: &[E]) -> E {
-(query[self.projection_coordinate] + self.alpha)
}
}
pub struct ProductComposition {
num_variables: usize,
}
impl ProductComposition {
pub fn new(num_variables: usize) -> Self {
Self { num_variables }
}
}
impl<E> CompositionPolynomial<E> for ProductComposition
where
E: FieldElement,
{
fn num_variables(&self) -> usize {
self.num_variables
}
fn max_degree(&self) -> usize {
self.num_variables
}
fn evaluate(&self, query: &[E]) -> E {
query.iter().fold(E::ONE, |acc, x| acc * *x)
}
}
pub struct SumComposition {
num_variables: usize,
}
impl SumComposition {
pub fn new(num_variables: usize) -> Self {
Self { num_variables }
}
}
impl<E> CompositionPolynomial<E> for SumComposition
where
E: FieldElement,
{
fn num_variables(&self) -> usize {
self.num_variables
}
fn max_degree(&self) -> usize {
self.num_variables
}
fn evaluate(&self, query: &[E]) -> E {
query.iter().fold(E::ZERO, |acc, x| acc + *x)
}
}
pub struct GkrCompositionVanilla<E: 'static>
where
E: FieldElement,
{
num_variables_ml: usize,
num_variables_merge: usize,
combining_randomness: E,
gkr_randomness: Vec<E>,
}
impl<E> GkrCompositionVanilla<E>
where
E: FieldElement,
{
pub fn new(
num_variables_ml: usize,
num_variables_merge: usize,
combining_randomness: E,
gkr_randomness: Vec<E>,
) -> Self {
Self {
num_variables_ml,
num_variables_merge,
combining_randomness,
gkr_randomness,
}
}
}
impl<E> CompositionPolynomial<E> for GkrCompositionVanilla<E>
where
E: FieldElement,
{
fn num_variables(&self) -> usize {
self.num_variables_ml // + TODO
}
fn max_degree(&self) -> usize {
self.num_variables_ml //TODO
}
fn evaluate(&self, query: &[E]) -> E {
let eval_left_numerator = query[0];
let eval_right_numerator = query[1];
let eval_left_denominator = query[2];
let eval_right_denominator = query[3];
let eq_eval = query[4];
eq_eval
* ((eval_left_numerator * eval_right_denominator
+ eval_right_numerator * eval_left_denominator)
+ eval_left_denominator * eval_right_denominator * self.combining_randomness)
}
}
#[derive(Clone)]
pub struct GkrComposition<E>
where
E: FieldElement<BaseField = BaseElement>,
{
pub num_variables_ml: usize,
pub combining_randomness: E,
eq_composer: Arc<dyn CompositionPolynomial<E>>,
right_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
left_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
right_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
left_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
}
impl<E> GkrComposition<E>
where
E: FieldElement<BaseField = BaseElement>,
{
pub fn new(
num_variables_ml: usize,
combining_randomness: E,
eq_composer: Arc<dyn CompositionPolynomial<E>>,
right_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
left_numerator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
right_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
left_denominator_composer: Vec<Arc<dyn CompositionPolynomial<E>>>,
) -> Self {
Self {
num_variables_ml,
combining_randomness,
eq_composer,
right_numerator_composer,
left_numerator_composer,
right_denominator_composer,
left_denominator_composer,
}
}
}
impl<E> CompositionPolynomial<E> for GkrComposition<E>
where
E: FieldElement<BaseField = BaseElement>,
{
fn num_variables(&self) -> usize {
self.num_variables_ml // + TODO
}
fn max_degree(&self) -> usize {
3 // TODO
}
fn evaluate(&self, query: &[E]) -> E {
let eval_right_numerator = self.right_numerator_composer[0].evaluate(query);
let eval_left_numerator = self.left_numerator_composer[0].evaluate(query);
let eval_right_denominator = self.right_denominator_composer[0].evaluate(query);
let eval_left_denominator = self.left_denominator_composer[0].evaluate(query);
let eq_eval = self.eq_composer.evaluate(query);
let res = eq_eval
* ((eval_left_numerator * eval_right_denominator
+ eval_right_numerator * eval_left_denominator)
+ eval_left_denominator * eval_right_denominator * self.combining_randomness);
res
}
}
/// Generates a composed ML polynomial for the initial GKR layer from a vector of composition
/// polynomials.
/// The composition polynomials are divided into LeftNumerator, RightNumerator, LeftDenominator
/// and RightDenominator.
/// TODO: Generalize this to the case where each numerator/denominator contains more than one
/// composition polynomial i.e., a merged composed ML polynomial.
pub fn gkr_composition_from_composition_polys<
E: FieldElement<BaseField = BaseElement> + 'static,
>(
composition_polys: &Vec<Vec<Arc<dyn CompositionPolynomial<E>>>>,
combining_randomness: E,
num_variables: usize,
) -> GkrComposition<E> {
let eq_composer = Arc::new(ProjectionComposition::new(4));
let left_numerator = composition_polys[0].to_owned();
let right_numerator = composition_polys[1].to_owned();
let left_denominator = composition_polys[2].to_owned();
let right_denominator = composition_polys[3].to_owned();
GkrComposition::new(
num_variables,
combining_randomness,
eq_composer,
right_numerator,
left_numerator,
right_denominator,
left_denominator,
)
}
/// Generates a plain oracle for the sum-check protocol except the final one.
pub fn gen_plain_gkr_oracle<E: FieldElement<BaseField = BaseElement> + 'static>(
num_rounds: usize,
r_sum_check: E,
) -> ComposedMultiLinearsOracle<E> {
let gkr_composer = Arc::new(GkrCompositionVanilla::new(num_rounds, 0, r_sum_check, vec![]));
let ml_oracles = vec![
MultiLinearOracle { id: 0 },
MultiLinearOracle { id: 1 },
MultiLinearOracle { id: 2 },
MultiLinearOracle { id: 3 },
MultiLinearOracle { id: 4 },
];
let oracle = ComposedMultiLinearsOracle {
composer: gkr_composer,
multi_linears: ml_oracles,
};
oracle
}
fn to_index<E: FieldElement<BaseField = BaseElement>>(index: &[E]) -> usize {
let res = index.iter().fold(E::ZERO, |acc, term| acc * E::ONE.double() + (*term));
let res = res.base_element(0);
res.as_int() as usize
}
fn inner_product<E: FieldElement>(evaluations: &[E], tensored_query: &[E]) -> E {
assert_eq!(evaluations.len(), tensored_query.len());
evaluations
.iter()
.zip(tensored_query.iter())
.fold(E::ZERO, |acc, (x_i, y_i)| acc + *x_i * *y_i)
}
pub fn tensorize<E: FieldElement>(query: &[E]) -> Vec<E> {
let nu = query.len();
let n = 1 << nu;
(0..n).map(|i| lagrange_basis_eval(query, i)).collect()
}
fn lagrange_basis_eval<E: FieldElement>(query: &[E], i: usize) -> E {
query
.iter()
.enumerate()
.map(|(j, x_j)| if i & (1 << j) == 0 { E::ONE - *x_j } else { *x_j })
.fold(E::ONE, |acc, v| acc * v)
}
pub fn compute_claim<E: FieldElement>(poly: &ComposedMultiLinears<E>) -> E {
let cube_size = 1 << poly.num_variables_ml();
let mut res = E::ZERO;
for i in 0..cube_size {
let eval_point: Vec<E> =
poly.multi_linears.iter().map(|poly| poly.evaluations[i]).collect();
res += poly.composer.evaluate(&eval_point);
}
res
}

108
src/gkr/sumcheck/mod.rs Normal file
View File

@@ -0,0 +1,108 @@
use super::{
multivariate::{ComposedMultiLinears, ComposedMultiLinearsOracle},
utils::{barycentric_weights, evaluate_barycentric},
};
use winter_math::FieldElement;
mod prover;
pub use prover::sum_check_prove;
mod verifier;
pub use verifier::{sum_check_verify, sum_check_verify_and_reduce};
mod tests;
#[derive(Debug, Clone)]
pub struct RoundProof<E> {
pub poly_evals: Vec<E>,
}
impl<E: FieldElement> RoundProof<E> {
pub fn to_evals(&self, claim: E) -> Vec<E> {
let mut result = vec![];
// s(0) + s(1) = claim
let c0 = claim - self.poly_evals[0];
result.push(c0);
result.extend_from_slice(&self.poly_evals);
result
}
// TODO: refactor once we move to coefficient form
pub(crate) fn evaluate(&self, claim: E, r: E) -> E {
let poly_evals = self.to_evals(claim);
let points: Vec<E> = (0..poly_evals.len()).map(|i| E::from(i as u8)).collect();
let evalss: Vec<(E, E)> =
points.iter().zip(poly_evals.iter()).map(|(x, y)| (*x, *y)).collect();
let weights = barycentric_weights(&evalss);
let new_claim = evaluate_barycentric(&evalss, r, &weights);
new_claim
}
}
#[derive(Debug, Clone)]
pub struct PartialProof<E> {
pub round_proofs: Vec<RoundProof<E>>,
}
#[derive(Clone)]
pub struct FinalEvaluationClaim<E: FieldElement> {
pub evaluation_point: Vec<E>,
pub claimed_evaluation: E,
pub polynomial: ComposedMultiLinearsOracle<E>,
}
#[derive(Clone)]
pub struct FullProof<E: FieldElement> {
pub sum_check_proof: PartialProof<E>,
pub final_evaluation_claim: FinalEvaluationClaim<E>,
}
pub struct Claim<E: FieldElement> {
pub sum_value: E,
pub polynomial: ComposedMultiLinearsOracle<E>,
}
#[derive(Debug)]
pub struct RoundClaim<E: FieldElement> {
pub partial_eval_point: Vec<E>,
pub current_claim: E,
}
pub struct RoundOutput<E: FieldElement> {
proof: PartialProof<E>,
witness: Witness<E>,
}
impl<E: FieldElement> From<Claim<E>> for RoundClaim<E> {
fn from(value: Claim<E>) -> Self {
Self {
partial_eval_point: vec![],
current_claim: value.sum_value,
}
}
}
pub struct Witness<E: FieldElement> {
pub(crate) polynomial: ComposedMultiLinears<E>,
}
pub fn reduce_claim<E: FieldElement>(
current_poly: RoundProof<E>,
current_round_claim: RoundClaim<E>,
round_challenge: E,
) -> RoundClaim<E> {
let poly_evals = current_poly.to_evals(current_round_claim.current_claim);
let points: Vec<E> = (0..poly_evals.len()).map(|i| E::from(i as u8)).collect();
let evalss: Vec<(E, E)> = points.iter().zip(poly_evals.iter()).map(|(x, y)| (*x, *y)).collect();
let weights = barycentric_weights(&evalss);
let new_claim = evaluate_barycentric(&evalss, round_challenge, &weights);
let mut new_partial_eval_point = current_round_claim.partial_eval_point;
new_partial_eval_point.push(round_challenge);
RoundClaim {
partial_eval_point: new_partial_eval_point,
current_claim: new_claim,
}
}

109
src/gkr/sumcheck/prover.rs Normal file
View File

@@ -0,0 +1,109 @@
use super::{Claim, FullProof, RoundProof, Witness};
use crate::gkr::{
multivariate::{ComposedMultiLinears, ComposedMultiLinearsOracle},
sumcheck::{reduce_claim, FinalEvaluationClaim, PartialProof, RoundClaim, RoundOutput},
};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use winter_crypto::{ElementHasher, RandomCoin};
use winter_math::{fields::f64::BaseElement, FieldElement};
pub fn sum_check_prove<
E: FieldElement<BaseField = BaseElement>,
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
H: ElementHasher<BaseField = BaseElement>,
>(
claim: &Claim<E>,
oracle: ComposedMultiLinearsOracle<E>,
witness: Witness<E>,
coin: &mut C,
) -> FullProof<E> {
// Setup first round
let mut prev_claim = RoundClaim {
partial_eval_point: vec![],
current_claim: claim.sum_value.clone(),
};
let prev_proof = PartialProof { round_proofs: vec![] };
let num_vars = witness.polynomial.num_variables_ml();
let prev_output = RoundOutput { proof: prev_proof, witness };
let mut output = sumcheck_round(prev_output);
let poly_evals = &output.proof.round_proofs[0].poly_evals;
coin.reseed(H::hash_elements(&poly_evals));
for i in 1..num_vars {
let round_challenge = coin.draw().unwrap();
let new_claim = reduce_claim(
output.proof.round_proofs.last().unwrap().clone(),
prev_claim,
round_challenge,
);
output.witness.polynomial = output.witness.polynomial.bind(round_challenge);
output = sumcheck_round(output);
prev_claim = new_claim;
let poly_evals = &output.proof.round_proofs[i].poly_evals;
coin.reseed(H::hash_elements(&poly_evals));
}
let round_challenge = coin.draw().unwrap();
let RoundClaim { partial_eval_point, current_claim } = reduce_claim(
output.proof.round_proofs.last().unwrap().clone(),
prev_claim,
round_challenge,
);
let final_eval_claim = FinalEvaluationClaim {
evaluation_point: partial_eval_point,
claimed_evaluation: current_claim,
polynomial: oracle,
};
FullProof {
sum_check_proof: output.proof,
final_evaluation_claim: final_eval_claim,
}
}
fn sumcheck_round<E: FieldElement>(prev_proof: RoundOutput<E>) -> RoundOutput<E> {
let RoundOutput { mut proof, witness } = prev_proof;
let polynomial = witness.polynomial;
let num_ml = polynomial.num_ml();
let num_vars = polynomial.num_variables_ml();
let num_rounds = num_vars - 1;
let mut evals_zero = vec![E::ZERO; num_ml];
let mut evals_one = vec![E::ZERO; num_ml];
let mut deltas = vec![E::ZERO; num_ml];
let mut evals_x = vec![E::ZERO; num_ml];
let total_evals = (0..1 << num_rounds).into_iter().map(|i| {
for (j, ml) in polynomial.multi_linears.iter().enumerate() {
evals_zero[j] = ml.evaluations[(i << 1) as usize];
evals_one[j] = ml.evaluations[(i << 1) + 1];
}
let mut total_evals = vec![E::ZERO; polynomial.degree()];
total_evals[0] = polynomial.composer.evaluate(&evals_one);
evals_zero
.iter()
.zip(evals_one.iter().zip(deltas.iter_mut().zip(evals_x.iter_mut())))
.for_each(|(a0, (a1, (delta, evx)))| {
*delta = *a1 - *a0;
*evx = *a1;
});
total_evals.iter_mut().skip(1).for_each(|e| {
evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| {
*evx += *delta;
});
*e = polynomial.composer.evaluate(&evals_x);
});
total_evals
});
let evaluations = total_evals.fold(vec![E::ZERO; polynomial.degree()], |mut acc, evals| {
acc.iter_mut().zip(evals.iter()).for_each(|(a, ev)| *a += *ev);
acc
});
let proof_update = RoundProof { poly_evals: evaluations };
proof.round_proofs.push(proof_update);
RoundOutput { proof, witness: Witness { polynomial } }
}

199
src/gkr/sumcheck/tests.rs Normal file
View File

@@ -0,0 +1,199 @@
use alloc::sync::Arc;
use rand::{distributions::Uniform, SeedableRng};
use winter_crypto::RandomCoin;
use winter_math::{fields::f64::BaseElement, FieldElement};
use crate::{
gkr::{
circuit::{CircuitProof, FractionalSumCircuit},
multivariate::{
compute_claim, gkr_composition_from_composition_polys, ComposedMultiLinears,
ComposedMultiLinearsOracle, CompositionPolynomial, EqPolynomial, GkrComposition,
GkrCompositionVanilla, LogUpDenominatorTableComposition,
LogUpDenominatorWitnessComposition, MultiLinear, MultiLinearOracle,
ProjectionComposition, SumComposition,
},
sumcheck::{
prover::sum_check_prove, verifier::sum_check_verify, Claim, FinalEvaluationClaim,
FullProof, Witness,
},
},
hash::rpo::Rpo256,
rand::RpoRandomCoin,
};
#[test]
fn gkr_workflow() {
// generate the data witness for the LogUp argument
let mut mls = generate_logup_witness::<BaseElement>(3);
// the is sampled after receiving the main trace commitment
let alpha = rand_utils::rand_value();
// the composition polynomials defining the numerators/denominators
let composition_polys: Vec<Vec<Arc<dyn CompositionPolynomial<BaseElement>>>> = vec![
// left num
vec![Arc::new(ProjectionComposition::new(0))],
// right num
vec![Arc::new(ProjectionComposition::new(1))],
// left den
vec![Arc::new(LogUpDenominatorTableComposition::new(2, alpha))],
// right den
vec![Arc::new(LogUpDenominatorWitnessComposition::new(3, alpha))],
];
// run the GKR prover to obtain:
// 1. The fractional sum circuit output.
// 2. GKR proofs up to the last circuit layer counting backwards.
// 3. GKR proof (i.e., a sum-check proof) for the last circuit layer counting backwards.
let seed = [BaseElement::ZERO; 4];
let mut transcript = RpoRandomCoin::new(seed.into());
let (circuit_outputs, gkr_before_last_proof, final_layer_proof) =
CircuitProof::prove_virtual_bus(composition_polys.clone(), &mut mls, &mut transcript);
let seed = [BaseElement::ZERO; 4];
let mut transcript = RpoRandomCoin::new(seed.into());
// run the GKR verifier to obtain:
// 1. A final evaluation claim.
// 2. Randomness defining the Lagrange kernel in the final sum-check protocol. Note that this
// Lagrange kernel is different from the one used by the STARK (outer) prover to open the MLs
// at the evaluation point.
let (final_eval_claim, gkr_lagrange_kernel_rand) = gkr_before_last_proof.verify_virtual_bus(
composition_polys.clone(),
final_layer_proof,
&circuit_outputs,
&mut transcript,
);
// the final verification step is composed of:
// 1. Querying the oracles for the openings at the evaluation point. This will be done by the
// (outer) STARK prover using:
// a. The Lagrange kernel (auxiliary) column at the evaluation point.
// b. An extra (auxiliary) column to compute an inner product between two vectors. The first
// being the Lagrange kernel and the second being (\sum_{j=0}^3 mls[j][i] * \lambda_i)_{i\in\{0,..,n\}}
// 2. Evaluating the composition polynomial at the previous openings and checking equality with
// the claimed evaluation.
// 1. Querying the oracles
let FinalEvaluationClaim {
evaluation_point,
claimed_evaluation,
polynomial,
} = final_eval_claim;
// The evaluation of the EQ polynomial can be done by the verifier directly
let eq = (0..gkr_lagrange_kernel_rand.len())
.map(|i| {
gkr_lagrange_kernel_rand[i] * evaluation_point[i]
+ (BaseElement::ONE - gkr_lagrange_kernel_rand[i])
* (BaseElement::ONE - evaluation_point[i])
})
.fold(BaseElement::ONE, |acc, term| acc * term);
// These are the queries to the oracles.
// They should be provided by the prover non-deterministically
let left_num_eval = mls[0].evaluate(&evaluation_point);
let right_num_eval = mls[1].evaluate(&evaluation_point);
let left_den_eval = mls[2].evaluate(&evaluation_point);
let right_den_eval = mls[3].evaluate(&evaluation_point);
// The verifier absorbs the claimed openings and generates batching randomness lambda
let mut query = vec![left_num_eval, right_num_eval, left_den_eval, right_den_eval];
transcript.reseed(Rpo256::hash_elements(&query));
let lambdas: Vec<BaseElement> = vec![
transcript.draw().unwrap(),
transcript.draw().unwrap(),
transcript.draw().unwrap(),
];
let batched_query =
query[0] + query[1] * lambdas[0] + query[2] * lambdas[1] + query[3] * lambdas[2];
// The prover generates the Lagrange kernel as an auxiliary column
let mut rev_evaluation_point = evaluation_point;
rev_evaluation_point.reverse();
let lagrange_kernel = EqPolynomial::new(rev_evaluation_point).evaluations();
// The prover generates the additional auxiliary column for the inner product
let tmp_col: Vec<BaseElement> = (0..mls[0].len())
.map(|i| {
mls[0][i] + mls[1][i] * lambdas[0] + mls[2][i] * lambdas[1] + mls[3][i] * lambdas[2]
})
.collect();
let mut running_sum_col = vec![BaseElement::ZERO; tmp_col.len() + 1];
running_sum_col[0] = BaseElement::ZERO;
for i in 1..(tmp_col.len() + 1) {
running_sum_col[i] = running_sum_col[i - 1] + tmp_col[i - 1] * lagrange_kernel[i - 1];
}
// Boundary constraint to check correctness of openings
assert_eq!(batched_query, *running_sum_col.last().unwrap());
// 2) Final evaluation and check
query.push(eq);
let verifier_computed = polynomial.composer.evaluate(&query);
assert_eq!(verifier_computed, claimed_evaluation);
}
pub fn generate_logup_witness<E: FieldElement>(trace_len: usize) -> Vec<MultiLinear<E>> {
let num_variables_ml = trace_len;
let num_evaluations = 1 << num_variables_ml;
let num_witnesses = 1;
let (p, q) = generate_logup_data::<E>(num_variables_ml, num_witnesses);
let numerators: Vec<Vec<E>> = p.chunks(num_evaluations).map(|x| x.into()).collect();
let denominators: Vec<Vec<E>> = q.chunks(num_evaluations).map(|x| x.into()).collect();
let mut mls = vec![];
for i in 0..2 {
let ml = MultiLinear::from_values(&numerators[i]);
mls.push(ml);
}
for i in 0..2 {
let ml = MultiLinear::from_values(&denominators[i]);
mls.push(ml);
}
mls
}
pub fn generate_logup_data<E: FieldElement>(
trace_len: usize,
num_witnesses: usize,
) -> (Vec<E>, Vec<E>) {
use rand::distributions::Slice;
use rand::Rng;
let n: usize = trace_len;
let num_w: usize = num_witnesses; // This should be of the form 2^k - 1
let rng = rand::rngs::StdRng::seed_from_u64(0);
let t_table: Vec<u32> = (0..(1 << n)).collect();
let mut m_table: Vec<u32> = (0..(1 << n)).map(|_| 0).collect();
let t_table_slice = Slice::new(&t_table).unwrap();
// Construct the witness columns. Uses sampling with replacement in order to have multiplicities
// different from 1.
let mut w_tables = Vec::new();
for _ in 0..num_w {
let wi_table: Vec<u32> =
rng.clone().sample_iter(&t_table_slice).cloned().take(1 << n).collect();
// Construct the multiplicities
wi_table.iter().for_each(|w| {
m_table[*w as usize] += 1;
});
w_tables.push(wi_table)
}
// The numerators
let mut p: Vec<E> = m_table.iter().map(|m| E::from(*m as u32)).collect();
p.extend((0..(num_w * (1 << n))).map(|_| E::from(1_u32)).collect::<Vec<E>>());
// Construct the denominators
let mut q: Vec<E> = t_table.iter().map(|t| E::from(*t)).collect();
for w_table in w_tables {
q.extend(w_table.iter().map(|w| E::from(*w)).collect::<Vec<E>>());
}
(p, q)
}

View File

@@ -0,0 +1,71 @@
use winter_crypto::{ElementHasher, RandomCoin};
use winter_math::{fields::f64::BaseElement, FieldElement};
use crate::gkr::utils::{barycentric_weights, evaluate_barycentric};
use super::{Claim, FinalEvaluationClaim, FullProof, PartialProof};
pub fn sum_check_verify_and_reduce<
E: FieldElement<BaseField = BaseElement>,
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
H: ElementHasher<BaseField = BaseElement>,
>(
claim: &Claim<E>,
proofs: PartialProof<E>,
coin: &mut C,
) -> (E, Vec<E>) {
let degree = 3;
let points: Vec<E> = (0..degree + 1).map(|x| E::from(x as u8)).collect();
let mut sum_value = claim.sum_value.clone();
let mut randomness = vec![];
for proof in proofs.round_proofs {
let partial_evals = proof.poly_evals.clone();
coin.reseed(H::hash_elements(&partial_evals));
// get r
let r: E = coin.draw().unwrap();
randomness.push(r);
let evals = proof.to_evals(sum_value);
let point_evals: Vec<_> = points.iter().zip(evals.iter()).map(|(x, y)| (*x, *y)).collect();
let weights = barycentric_weights(&point_evals);
sum_value = evaluate_barycentric(&point_evals, r, &weights);
}
(sum_value, randomness)
}
pub fn sum_check_verify<
E: FieldElement<BaseField = BaseElement>,
C: RandomCoin<Hasher = H, BaseField = BaseElement>,
H: ElementHasher<BaseField = BaseElement>,
>(
claim: &Claim<E>,
proofs: FullProof<E>,
coin: &mut C,
) -> FinalEvaluationClaim<E> {
let FullProof {
sum_check_proof: proofs,
final_evaluation_claim,
} = proofs;
let Claim { mut sum_value, polynomial } = claim;
let degree = polynomial.composer.max_degree();
let points: Vec<E> = (0..degree + 1).map(|x| E::from(x as u8)).collect();
for proof in proofs.round_proofs {
let partial_evals = proof.poly_evals.clone();
coin.reseed(H::hash_elements(&partial_evals));
// get r
let r: E = coin.draw().unwrap();
let evals = proof.to_evals(sum_value);
let point_evals: Vec<_> = points.iter().zip(evals.iter()).map(|(x, y)| (*x, *y)).collect();
let weights = barycentric_weights(&point_evals);
sum_value = evaluate_barycentric(&point_evals, r, &weights);
}
assert_eq!(final_evaluation_claim.claimed_evaluation, sum_value);
final_evaluation_claim
}

33
src/gkr/utils/mod.rs Normal file
View File

@@ -0,0 +1,33 @@
use winter_math::{FieldElement, batch_inversion};
pub fn barycentric_weights<E: FieldElement>(points: &[(E, E)]) -> Vec<E> {
let n = points.len();
let tmp = (0..n)
.map(|i| (0..n).filter(|&j| j != i).fold(E::ONE, |acc, j| acc * (points[i].0 - points[j].0)))
.collect::<Vec<_>>();
batch_inversion(&tmp)
}
pub fn evaluate_barycentric<E: FieldElement>(
points: &[(E, E)],
x: E,
barycentric_weights: &[E],
) -> E {
for &(x_i, y_i) in points {
if x_i == x {
return y_i;
}
}
let l_x: E = points.iter().fold(E::ONE, |acc, &(x_i, _y_i)| acc * (x - x_i));
let sum = (0..points.len()).fold(E::ZERO, |acc, i| {
let x_i = points[i].0;
let y_i = points[i].1;
let w_i = barycentric_weights[i];
acc + (w_i / (x - x_i) * y_i)
});
l_x * sum
}

View File

@@ -1,4 +1,4 @@
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher};
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField};
use crate::utils::{
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
DeserializationError, HexParseError, Serializable,

View File

@@ -1,4 +1,4 @@
#[cfg(target_feature = "sve")]
#[cfg(all(target_feature = "sve", feature = "sve"))]
pub mod optimized {
use crate::hash::rescue::STATE_WIDTH;
use crate::Felt;
@@ -78,7 +78,7 @@ pub mod optimized {
}
}
#[cfg(not(any(target_feature = "avx2", target_feature = "sve")))]
#[cfg(not(any(target_feature = "avx2", all(target_feature = "sve", feature = "sve"))))]
pub mod optimized {
use crate::hash::rescue::STATE_WIDTH;
use crate::Felt;

View File

@@ -33,11 +33,6 @@ impl RpoDigest {
{
digests.flat_map(|d| d.0.iter())
}
/// Returns hexadecimal representation of this digest prefixed with `0x`.
pub fn to_hex(&self) -> String {
bytes_to_hex_string(self.as_bytes())
}
}
impl Digest for RpoDigest {
@@ -163,7 +158,7 @@ impl From<RpoDigest> for [u8; DIGEST_BYTES] {
impl From<RpoDigest> for String {
/// The returned string starts with `0x`.
fn from(value: RpoDigest) -> Self {
value.to_hex()
bytes_to_hex_string(value.as_bytes())
}
}
@@ -234,12 +229,15 @@ impl TryFrom<[u64; DIGEST_SIZE]> for RpoDigest {
type Error = RpoDigestError;
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
Ok(Self([
value[0].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
value[1].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
value[2].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
value[3].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
]))
if value[0] >= Felt::MODULUS
|| value[1] >= Felt::MODULUS
|| value[2] >= Felt::MODULUS
|| value[3] >= Felt::MODULUS
{
return Err(RpoDigestError::InvalidInteger);
}
Ok(Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()]))
}
}

View File

@@ -27,7 +27,7 @@ mod tests;
/// * Number of founds: 7.
/// * S-Box degree: 7.
///
/// The above parameters target a 128-bit security level. The digest consists of four field elements
/// The above parameters target 128-bit security level. The digest consists of four field elements
/// and it can be serialized into 32 bytes (256 bits).
///
/// ## Hash output consistency
@@ -55,7 +55,13 @@ mod tests;
pub struct Rpo256();
impl Hasher for Rpo256 {
/// Rpo256 collision resistance is 128-bits.
/// Rpo256 collision resistance is the same as the security level, that is 128-bits.
///
/// #### Collision resistance
///
/// However, our setup of the capacity registers might drop it to 126.
///
/// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69)
const COLLISION_RESISTANCE: u32 = 128;
type Digest = RpoDigest;

View File

@@ -33,11 +33,6 @@ impl RpxDigest {
{
digests.flat_map(|d| d.0.iter())
}
/// Returns hexadecimal representation of this digest prefixed with `0x`.
pub fn to_hex(&self) -> String {
bytes_to_hex_string(self.as_bytes())
}
}
impl Digest for RpxDigest {
@@ -163,7 +158,7 @@ impl From<RpxDigest> for [u8; DIGEST_BYTES] {
impl From<RpxDigest> for String {
/// The returned string starts with `0x`.
fn from(value: RpxDigest) -> Self {
value.to_hex()
bytes_to_hex_string(value.as_bytes())
}
}
@@ -234,12 +229,15 @@ impl TryFrom<[u64; DIGEST_SIZE]> for RpxDigest {
type Error = RpxDigestError;
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
Ok(Self([
value[0].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
value[1].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
value[2].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
value[3].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
]))
if value[0] >= Felt::MODULUS
|| value[1] >= Felt::MODULUS
|| value[2] >= Felt::MODULUS
|| value[3] >= Felt::MODULUS
{
return Err(RpxDigestError::InvalidInteger);
}
Ok(Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()]))
}
}

View File

@@ -2,8 +2,8 @@ use super::{
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
apply_mds, apply_sbox, CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher,
StarkField, ARK1, ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE,
DIGEST_SIZE, INPUT1_RANGE, INPUT2_RANGE, MDS, NUM_ROUNDS, RATE_RANGE, RATE_WIDTH, STATE_WIDTH,
ZERO,
DIGEST_SIZE, INPUT1_RANGE, INPUT2_RANGE, MDS, NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH,
STATE_WIDTH, ZERO,
};
use core::{convert::TryInto, ops::Range};
@@ -30,7 +30,7 @@ pub type CubicExtElement = CubeExtension<Felt>;
/// - (M): `apply_mds` → `add_constants`.
/// * Permutation: (FB) (E) (FB) (E) (FB) (E) (M).
///
/// The above parameters target a 128-bit security level. The digest consists of four field elements
/// The above parameters target 128-bit security level. The digest consists of four field elements
/// and it can be serialized into 32 bytes (256 bits).
///
/// ## Hash output consistency
@@ -58,7 +58,13 @@ pub type CubicExtElement = CubeExtension<Felt>;
pub struct Rpx256();
impl Hasher for Rpx256 {
/// Rpx256 collision resistance is 128-bits.
/// Rpx256 collision resistance is the same as the security level, that is 128-bits.
///
/// #### Collision resistance
///
/// However, our setup of the capacity registers might drop it to 126.
///
/// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69)
const COLLISION_RESISTANCE: u32 = 128;
type Digest = RpxDigest;
@@ -67,16 +73,14 @@ impl Hasher for Rpx256 {
// initialize the state with zeroes
let mut state = [ZERO; STATE_WIDTH];
// determine the number of field elements needed to encode `bytes` when each field element
// represents at most 7 bytes.
let num_field_elem = bytes.len().div_ceil(BINARY_CHUNK_SIZE);
// set the first capacity element to `RATE_WIDTH + (num_field_elem % RATE_WIDTH)`. We do
// this to achieve:
// 1. Domain separating hashing of `[u8]` from hashing of `[Felt]`.
// 2. Avoiding collisions at the `[Felt]` representation of the encoded bytes.
state[CAPACITY_RANGE.start] =
Felt::from((RATE_WIDTH + (num_field_elem % RATE_WIDTH)) as u8);
// set the capacity (first element) to a flag on whether or not the input length is evenly
// divided by the rate. this will prevent collisions between padded and non-padded inputs,
// and will rule out the need to perform an extra permutation in case of evenly divided
// inputs.
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0;
if !is_rate_multiple {
state[CAPACITY_RANGE.start] = ONE;
}
// initialize a buffer to receive the little-endian elements.
let mut buf = [0_u8; 8];
@@ -88,12 +92,12 @@ impl Hasher for Rpx256 {
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
// additional permutation must be performed.
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
// copy the chunk into the buffer
if i != num_field_elem - 1 {
// the last element of the iteration may or may not be a full chunk. if it's not, then
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`.
// this will avoid collisions.
if chunk.len() == BINARY_CHUNK_SIZE {
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
} else {
// on the last iteration, we pad `buf` with a 1 followed by as many 0's as are
// needed to fill it
buf.fill(0);
buf[..chunk.len()].copy_from_slice(chunk);
buf[chunk.len()] = 1;
@@ -116,10 +120,10 @@ impl Hasher for Rpx256 {
// if we absorbed some elements but didn't apply a permutation to them (would happen when
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation. we
// don't need to apply any extra padding because the first capacity element contains a
// flag indicating the number of field elements constituting the last block when the latter
// is not divisible by `RATE_WIDTH`.
// flag indicating whether the input is evenly divisible by the rate.
if i != 0 {
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
state[RATE_RANGE.start + i] = ONE;
Self::apply_permutation(&mut state);
}
@@ -144,20 +148,25 @@ impl Hasher for Rpx256 {
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
// initialize the state as follows:
// - seed is copied into the first 4 elements of the rate portion of the state.
// - if the value fits into a single field element, copy it into the fifth rate element and
// set the first capacity element to 5.
// - if the value fits into a single field element, copy it into the fifth rate element
// and set the sixth rate element to 1.
// - if the value doesn't fit into a single field element, split it into two field
// elements, copy them into rate elements 5 and 6 and set the first capacity element to 6.
// elements, copy them into rate elements 5 and 6, and set the seventh rate element
// to 1.
// - set the first capacity element to 1
let mut state = [ZERO; STATE_WIDTH];
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
state[INPUT2_RANGE.start] = Felt::new(value);
if value < Felt::MODULUS {
state[CAPACITY_RANGE.start] = Felt::from(5_u8);
state[INPUT2_RANGE.start + 1] = ONE;
} else {
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
state[CAPACITY_RANGE.start] = Felt::from(6_u8);
state[INPUT2_RANGE.start + 2] = ONE;
}
// common padding for both cases
state[CAPACITY_RANGE.start] = ONE;
// apply the RPX permutation and return the first four elements of the state
Self::apply_permutation(&mut state);
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
@@ -172,9 +181,11 @@ impl ElementHasher for Rpx256 {
let elements = E::slice_as_base_elements(elements);
// initialize state to all zeros, except for the first element of the capacity part, which
// is set to `elements.len() % RATE_WIDTH`.
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH.
let mut state = [ZERO; STATE_WIDTH];
state[CAPACITY_RANGE.start] = Self::BaseField::from((elements.len() % RATE_WIDTH) as u8);
if elements.len() % RATE_WIDTH != 0 {
state[CAPACITY_RANGE.start] = ONE;
}
// absorb elements into the state one by one until the rate portion of the state is filled
// up; then apply the Rescue permutation and start absorbing again; repeat until all
@@ -191,8 +202,11 @@ impl ElementHasher for Rpx256 {
// if we absorbed some elements but didn't apply a permutation to them (would happen when
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation after
// padding by as many 0 as necessary to make the input length a multiple of the RATE_WIDTH.
// padding by appending a 1 followed by as many 0 as necessary to make the input length a
// multiple of the RATE_WIDTH.
if i > 0 {
state[RATE_RANGE.start + i] = ONE;
i += 1;
while i != RATE_WIDTH {
state[RATE_RANGE.start + i] = ZERO;
i += 1;
@@ -340,7 +354,7 @@ impl Rpx256 {
add_constants(state, &ARK1[round]);
}
/// Computes an exponentiation to the power 7 in cubic extension field.
/// Computes an exponentiation to the power 7 in cubic extension field
#[inline(always)]
pub fn exp7(x: CubeExtension<Felt>) -> CubeExtension<Felt> {
let x2 = x.square();

View File

@@ -1,7 +1,7 @@
#![cfg_attr(not(feature = "std"), no_std)]
#[cfg(not(feature = "std"))]
#[cfg_attr(test, macro_use)]
//#[cfg(not(feature = "std"))]
//#[cfg_attr(test, macro_use)]
extern crate alloc;
pub mod dsa;
@@ -9,6 +9,7 @@ pub mod hash;
pub mod merkle;
pub mod rand;
pub mod utils;
pub mod gkr;
// RE-EXPORTS
// ================================================================================================

View File

@@ -1,14 +1,19 @@
use clap::Parser;
use miden_crypto::{
hash::rpo::{Rpo256, RpoDigest},
merkle::{MerkleError, Smt},
merkle::{MerkleError, TieredSmt},
Felt, Word, ONE,
};
use rand_utils::rand_value;
use std::time::Instant;
#[derive(Parser, Debug)]
#[clap(name = "Benchmark", about = "SMT benchmark", version, rename_all = "kebab-case")]
#[clap(
name = "Benchmark",
about = "Tiered SMT benchmark",
version,
rename_all = "kebab-case"
)]
pub struct BenchmarkCmd {
/// Size of the tree
#[clap(short = 's', long = "size")]
@@ -16,11 +21,11 @@ pub struct BenchmarkCmd {
}
fn main() {
benchmark_smt();
benchmark_tsmt();
}
/// Run a benchmark for [`Smt`].
pub fn benchmark_smt() {
/// Run a benchmark for the Tiered SMT.
pub fn benchmark_tsmt() {
let args = BenchmarkCmd::parse();
let tree_size = args.size;
@@ -37,25 +42,38 @@ pub fn benchmark_smt() {
proof_generation(&mut tree, tree_size).unwrap();
}
/// Runs the construction benchmark for [`Smt`], returning the constructed tree.
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result<Smt, MerkleError> {
/// Runs the construction benchmark for the Tiered SMT, returning the constructed tree.
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result<TieredSmt, MerkleError> {
println!("Running a construction benchmark:");
let now = Instant::now();
let tree = Smt::with_entries(entries)?;
let tree = TieredSmt::with_entries(entries)?;
let elapsed = now.elapsed();
println!(
"Constructed a SMT with {} key-value pairs in {:.3} seconds",
"Constructed a TSMT with {} key-value pairs in {:.3} seconds",
size,
elapsed.as_secs_f32(),
);
println!("Number of leaf nodes: {}\n", tree.leaves().count());
// Count how many nodes end up at each tier
let mut nodes_num_16_32_48 = (0, 0, 0);
tree.upper_leaf_nodes().for_each(|(index, _)| match index.depth() {
16 => nodes_num_16_32_48.0 += 1,
32 => nodes_num_16_32_48.1 += 1,
48 => nodes_num_16_32_48.2 += 1,
_ => unreachable!(),
});
println!("Number of nodes on depth 16: {}", nodes_num_16_32_48.0);
println!("Number of nodes on depth 32: {}", nodes_num_16_32_48.1);
println!("Number of nodes on depth 48: {}", nodes_num_16_32_48.2);
println!("Number of nodes on depth 64: {}\n", tree.bottom_leaves().count());
Ok(tree)
}
/// Runs the insertion benchmark for the [`Smt`].
pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
/// Runs the insertion benchmark for the Tiered SMT.
pub fn insertion(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
println!("Running an insertion benchmark:");
let mut insertion_times = Vec::new();
@@ -71,7 +89,7 @@ pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
}
println!(
"An average insertion time measured by 20 inserts into a SMT with {} key-value pairs is {:.3} milliseconds\n",
"An average insertion time measured by 20 inserts into a TSMT with {} key-value pairs is {:.3} milliseconds\n",
size,
// calculate the average by dividing by 20 and convert to milliseconds by multiplying by
// 1000. As a result, we can only multiply by 50
@@ -81,8 +99,8 @@ pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
Ok(())
}
/// Runs the proof generation benchmark for the [`Smt`].
pub fn proof_generation(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
/// Runs the proof generation benchmark for the Tiered SMT.
pub fn proof_generation(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
println!("Running a proof generation benchmark:");
let mut insertion_times = Vec::new();
@@ -93,13 +111,13 @@ pub fn proof_generation(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
tree.insert(test_key, test_value);
let now = Instant::now();
let _proof = tree.open(&test_key);
let _proof = tree.prove(test_key);
let elapsed = now.elapsed();
insertion_times.push(elapsed.as_secs_f32());
}
println!(
"An average proving time measured by 20 value proofs in a SMT with {} key-value pairs in {:.3} microseconds",
"An average proving time measured by 20 value proofs in a TSMT with {} key-value pairs in {:.3} microseconds",
size,
// calculate the average by dividing by 20 and convert to microseconds by multiplying by
// 1000000. As a result, we can only multiply by 50000

156
src/merkle/delta.rs Normal file
View File

@@ -0,0 +1,156 @@
use super::{
BTreeMap, KvMap, MerkleError, MerkleStore, NodeIndex, RpoDigest, StoreNode, Vec, Word,
};
use crate::utils::collections::Diff;
#[cfg(test)]
use super::{super::ONE, Felt, SimpleSmt, EMPTY_WORD, ZERO};
// MERKLE STORE DELTA
// ================================================================================================
/// [MerkleStoreDelta] stores a vector of ([RpoDigest], [MerkleTreeDelta]) tuples where the
/// [RpoDigest] represents the root of the Merkle tree and [MerkleTreeDelta] represents the
/// differences between the initial and final Merkle tree states.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>);
// MERKLE TREE DELTA
// ================================================================================================
/// [MerkleDelta] stores the differences between the initial and final Merkle tree states.
///
/// The differences are represented as follows:
/// - depth: the depth of the merkle tree.
/// - cleared_slots: indexes of slots where values were set to [ZERO; 4].
/// - updated_slots: index-value pairs of slots where values were set to non [ZERO; 4] values.
#[cfg(not(test))]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MerkleTreeDelta {
depth: u8,
cleared_slots: Vec<u64>,
updated_slots: Vec<(u64, Word)>,
}
impl MerkleTreeDelta {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
pub fn new(depth: u8) -> Self {
Self {
depth,
cleared_slots: Vec::new(),
updated_slots: Vec::new(),
}
}
// ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the depth of the Merkle tree the [MerkleDelta] is associated with.
pub fn depth(&self) -> u8 {
self.depth
}
/// Returns the indexes of slots where values were set to [ZERO; 4].
pub fn cleared_slots(&self) -> &[u64] {
&self.cleared_slots
}
/// Returns the index-value pairs of slots where values were set to non [ZERO; 4] values.
pub fn updated_slots(&self) -> &[(u64, Word)] {
&self.updated_slots
}
// MODIFIERS
// --------------------------------------------------------------------------------------------
/// Adds a slot index to the list of cleared slots.
pub fn add_cleared_slot(&mut self, index: u64) {
self.cleared_slots.push(index);
}
/// Adds a slot index and a value to the list of updated slots.
pub fn add_updated_slot(&mut self, index: u64, value: Word) {
self.updated_slots.push((index, value));
}
}
/// Extracts a [MerkleDelta] object by comparing the leaves of two Merkle trees specifies by
/// their roots and depth.
pub fn merkle_tree_delta<T: KvMap<RpoDigest, StoreNode>>(
tree_root_1: RpoDigest,
tree_root_2: RpoDigest,
depth: u8,
merkle_store: &MerkleStore<T>,
) -> Result<MerkleTreeDelta, MerkleError> {
if tree_root_1 == tree_root_2 {
return Ok(MerkleTreeDelta::new(depth));
}
let tree_1_leaves: BTreeMap<NodeIndex, RpoDigest> =
merkle_store.non_empty_leaves(tree_root_1, depth).collect();
let tree_2_leaves: BTreeMap<NodeIndex, RpoDigest> =
merkle_store.non_empty_leaves(tree_root_2, depth).collect();
let diff = tree_1_leaves.diff(&tree_2_leaves);
// TODO: Refactor this diff implementation to prevent allocation of both BTree and Vec.
Ok(MerkleTreeDelta {
depth,
cleared_slots: diff.removed.into_iter().map(|index| index.value()).collect(),
updated_slots: diff
.updated
.into_iter()
.map(|(index, leaf)| (index.value(), *leaf))
.collect(),
})
}
// INTERNALS
// --------------------------------------------------------------------------------------------
#[cfg(test)]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MerkleTreeDelta {
pub depth: u8,
pub cleared_slots: Vec<u64>,
pub updated_slots: Vec<(u64, Word)>,
}
// MERKLE DELTA
// ================================================================================================
#[test]
fn test_compute_merkle_delta() {
let entries = vec![
(10, [ZERO, ONE, Felt::new(2), Felt::new(3)]),
(15, [Felt::new(4), Felt::new(5), Felt::new(6), Felt::new(7)]),
(20, [Felt::new(8), Felt::new(9), Felt::new(10), Felt::new(11)]),
(31, [Felt::new(12), Felt::new(13), Felt::new(14), Felt::new(15)]),
];
let simple_smt = SimpleSmt::with_leaves(30, entries.clone()).unwrap();
let mut store: MerkleStore = (&simple_smt).into();
let root = simple_smt.root();
// add a new node
let new_value = [Felt::new(16), Felt::new(17), Felt::new(18), Felt::new(19)];
let new_index = NodeIndex::new(simple_smt.depth(), 32).unwrap();
let root = store.set_node(root, new_index, new_value.into()).unwrap().root;
// update an existing node
let update_value = [Felt::new(20), Felt::new(21), Felt::new(22), Felt::new(23)];
let update_idx = NodeIndex::new(simple_smt.depth(), entries[0].0).unwrap();
let root = store.set_node(root, update_idx, update_value.into()).unwrap().root;
// remove a node
let remove_idx = NodeIndex::new(simple_smt.depth(), entries[1].0).unwrap();
let root = store.set_node(root, remove_idx, EMPTY_WORD.into()).unwrap().root;
let merkle_delta =
merkle_tree_delta(simple_smt.root(), root, simple_smt.depth(), &store).unwrap();
let expected_merkle_delta = MerkleTreeDelta {
depth: simple_smt.depth(),
cleared_slots: vec![remove_idx.value()],
updated_slots: vec![(update_idx.value(), update_value), (new_index.value(), new_value)],
};
assert_eq!(merkle_delta, expected_merkle_delta);
}

View File

@@ -4,8 +4,6 @@ use crate::{
};
use core::fmt;
use super::smt::SmtLeafError;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum MerkleError {
ConflictingRoots(Vec<RpoDigest>),
@@ -22,7 +20,6 @@ pub enum MerkleError {
NodeNotInStore(RpoDigest, NodeIndex),
NumLeavesNotPowerOfTwo(usize),
RootNotInStore(RpoDigest),
SmtLeaf(SmtLeafError),
}
impl fmt::Display for MerkleError {
@@ -53,16 +50,9 @@ impl fmt::Display for MerkleError {
write!(f, "the leaves count {leaves} is not a power of 2")
}
RootNotInStore(root) => write!(f, "the root {:?} is not in the store", root),
SmtLeaf(smt_leaf_error) => write!(f, "smt leaf error: {smt_leaf_error}"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for MerkleError {}
impl From<SmtLeafError> for MerkleError {
fn from(value: SmtLeafError) -> Self {
Self::SmtLeaf(value)
}
}

View File

@@ -1,4 +1,4 @@
use super::{Felt, MerkleError, RpoDigest};
use super::{Felt, MerkleError, RpoDigest, StarkField};
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use core::fmt::Display;

View File

@@ -1,16 +1,16 @@
use super::super::{RpoDigest, Vec};
/// Container for the update data of a [super::PartialMmr]
/// Container for the update data of a [PartialMmr]
#[derive(Debug)]
pub struct MmrDelta {
/// The new version of the [super::Mmr]
/// The new version of the [Mmr]
pub forest: usize,
/// Update data.
///
/// The data is packed as follows:
/// 1. All the elements needed to perform authentication path updates. These are the right
/// siblings required to perform tree merges on the [super::PartialMmr].
/// siblings required to perform tree merges on the [PartialMmr].
/// 2. The new peaks.
pub data: Vec<RpoDigest>,
}

View File

@@ -280,7 +280,7 @@ impl Mmr {
// Update the depth of the tree to correspond to a subtree
forest_target >>= 1;
// compute the indices of the right and left subtrees based on the post-order
// compute the indeces of the right and left subtrees based on the post-order
let right_offset = index - 1;
let left_offset = right_offset - nodes_in_forest(forest_target);

View File

@@ -10,11 +10,6 @@ use crate::{
},
};
// TYPE ALIASES
// ================================================================================================
type NodeMap = BTreeMap<InOrderIndex, RpoDigest>;
// PARTIAL MERKLE MOUNTAIN RANGE
// ================================================================================================
/// Partially materialized Merkle Mountain Range (MMR), used to efficiently store and update the
@@ -51,16 +46,16 @@ pub struct PartialMmr {
/// Authentication nodes used to construct merkle paths for a subset of the MMR's leaves.
///
/// This does not include the MMR's peaks nor the tracked nodes, only the elements required to
/// construct their authentication paths. This property is used to detect when elements can be
/// safely removed, because they are no longer required to authenticate any element in the
/// [PartialMmr].
/// This does not include the MMR's peaks nor the tracked nodes, only the elements required
/// to construct their authentication paths. This property is used to detect when elements can
/// be safely removed from, because they are no longer required to authenticate any element in
/// the [PartialMmr].
///
/// The elements in the MMR are referenced using a in-order tree index. This indexing scheme
/// permits for easy computation of the relative nodes (left/right children, sibling, parent),
/// which is useful for traversal. The indexing is also stable, meaning that merges to the
/// trees in the MMR can be represented without rewrites of the indexes.
pub(crate) nodes: NodeMap,
pub(crate) nodes: BTreeMap<InOrderIndex, RpoDigest>,
/// Flag indicating if the odd element should be tracked.
///
@@ -73,27 +68,16 @@ impl PartialMmr {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Returns a new [PartialMmr] instantiated from the specified peaks.
/// Constructs a [PartialMmr] from the given [MmrPeaks].
pub fn from_peaks(peaks: MmrPeaks) -> Self {
let forest = peaks.num_leaves();
let peaks = peaks.into();
let peaks = peaks.peaks().to_vec();
let nodes = BTreeMap::new();
let track_latest = false;
Self { forest, peaks, nodes, track_latest }
}
/// Returns a new [PartialMmr] instantiated from the specified components.
///
/// This constructor does not check the consistency between peaks and nodes. If the specified
/// peaks are nodes are inconsistent, the returned partial MMR may exhibit undefined behavior.
pub fn from_parts(peaks: MmrPeaks, nodes: NodeMap, track_latest: bool) -> Self {
let forest = peaks.num_leaves();
let peaks = peaks.into();
Self { forest, peaks, nodes, track_latest }
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
@@ -117,31 +101,14 @@ impl PartialMmr {
MmrPeaks::new(self.forest, self.peaks.clone()).expect("invalid MMR peaks")
}
/// Returns true if this partial MMR tracks an authentication path for the leaf at the
/// specified position.
pub fn is_tracked(&self, pos: usize) -> bool {
if pos >= self.forest {
return false;
} else if pos == self.forest - 1 && self.forest & 1 != 0 {
// if the number of leaves in the MMR is odd and the position is for the last leaf
// whether the leaf is tracked is defined by the `track_latest` flag
return self.track_latest;
}
let leaf_index = InOrderIndex::from_leaf_pos(pos);
self.is_tracked_node(&leaf_index)
}
/// Given a leaf position, returns the Merkle path to its corresponding peak, or None if this
/// partial MMR does not track an authentication paths for the specified leaf.
/// Given a leaf position, returns the Merkle path to its corresponding peak.
///
/// If the position is greater-or-equal than the tree size an error is returned. If the
/// requested value is not tracked returns `None`.
///
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
/// has position 0, the second position 1, and so on.
///
/// # Errors
/// Returns an error if the specified position is greater-or-equal than the number of leaves
/// in the underlying MMR.
pub fn open(&self, pos: usize) -> Result<Option<MmrProof>, MmrError> {
let tree_bit =
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?;
@@ -182,13 +149,13 @@ impl PartialMmr {
///
/// The order of iteration is not defined. If a leaf is not presented in this partial MMR it
/// is silently ignored.
pub fn inner_nodes<'a, I: Iterator<Item = (usize, RpoDigest)> + 'a>(
pub fn inner_nodes<'a, I: Iterator<Item = &'a (usize, RpoDigest)> + 'a>(
&'a self,
mut leaves: I,
) -> impl Iterator<Item = InnerNodeInfo> + '_ {
let stack = if let Some((pos, leaf)) = leaves.next() {
let idx = InOrderIndex::from_leaf_pos(pos);
vec![(idx, leaf)]
let idx = InOrderIndex::from_leaf_pos(*pos);
vec![(idx, *leaf)]
} else {
Vec::new()
};
@@ -204,93 +171,20 @@ impl PartialMmr {
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Adds a new peak and optionally track it. Returns a vector of the authentication nodes
/// inserted into this [PartialMmr] as a result of this operation.
/// Add the authentication path represented by [MerklePath] if it is valid.
///
/// When `track` is `true` the new leaf is tracked.
pub fn add(&mut self, leaf: RpoDigest, track: bool) -> Vec<(InOrderIndex, RpoDigest)> {
self.forest += 1;
let merges = self.forest.trailing_zeros() as usize;
let mut new_nodes = Vec::with_capacity(merges);
let peak = if merges == 0 {
self.track_latest = track;
leaf
} else {
let mut track_right = track;
let mut track_left = self.track_latest;
let mut right = leaf;
let mut right_idx = forest_to_rightmost_index(self.forest);
for _ in 0..merges {
let left = self.peaks.pop().expect("Missing peak");
let left_idx = right_idx.sibling();
if track_right {
let old = self.nodes.insert(left_idx, left);
new_nodes.push((left_idx, left));
debug_assert!(
old.is_none(),
"Idx {:?} already contained an element {:?}",
left_idx,
old
);
};
if track_left {
let old = self.nodes.insert(right_idx, right);
new_nodes.push((right_idx, right));
debug_assert!(
old.is_none(),
"Idx {:?} already contained an element {:?}",
right_idx,
old
);
};
// Update state for the next iteration.
// --------------------------------------------------------------------------------
// This layer is merged, go up one layer.
right_idx = right_idx.parent();
// Merge the current layer. The result is either the right element of the next
// merge, or a new peak.
right = Rpo256::merge(&[left, right]);
// This iteration merged the left and right nodes, the new value is always used as
// the next iteration's right node. Therefore the tracking flags of this iteration
// have to be merged into the right side only.
track_right = track_right || track_left;
// On the next iteration, a peak will be merged. If any of its children are tracked,
// then we have to track the left side
track_left = self.is_tracked_node(&right_idx.sibling());
}
right
};
self.peaks.push(peak);
new_nodes
}
/// Adds the authentication path represented by [MerklePath] if it is valid.
///
/// The `leaf_pos` refers to the global position of the leaf in the MMR, these are 0-indexed
/// The `index` refers to the global position of the leaf in the MMR, these are 0-indexed
/// values assigned in a strictly monotonic fashion as elements are inserted into the MMR,
/// this value corresponds to the values used in the MMR structure.
///
/// The `leaf` corresponds to the value at `leaf_pos`, and `path` is the authentication path for
/// that element up to its corresponding Mmr peak. The `leaf` is only used to compute the root
/// The `node` corresponds to the value at `index`, and `path` is the authentication path for
/// that element up to its corresponding Mmr peak. The `node` is only used to compute the root
/// from the authentication path to valid the data, only the authentication data is saved in
/// the structure. If the value is required it should be stored out-of-band.
pub fn track(
pub fn add(
&mut self,
leaf_pos: usize,
leaf: RpoDigest,
index: usize,
node: RpoDigest,
path: &MerklePath,
) -> Result<(), MmrError> {
// Checks there is a tree with same depth as the authentication path, if not the path is
@@ -300,42 +194,42 @@ impl PartialMmr {
return Err(MmrError::UnknownPeak);
};
if leaf_pos + 1 == self.forest
if index + 1 == self.forest
&& path.depth() == 0
&& self.peaks.last().map_or(false, |v| *v == leaf)
&& self.peaks.last().map_or(false, |v| *v == node)
{
self.track_latest = true;
return Ok(());
}
// ignore the trees smaller than the target (these elements are position after the current
// target and don't affect the target leaf_pos)
// target and don't affect the target index)
let target_forest = self.forest ^ (self.forest & (tree - 1));
let peak_pos = (target_forest.count_ones() - 1) as usize;
// translate from mmr leaf_pos to merkle path
let path_idx = leaf_pos - (target_forest ^ tree);
// translate from mmr index to merkle path
let path_idx = index - (target_forest ^ tree);
// Compute the root of the authentication path, and check it matches the current version of
// the PartialMmr.
let computed = path.compute_root(path_idx as u64, leaf).map_err(MmrError::MerkleError)?;
let computed = path.compute_root(path_idx as u64, node).map_err(MmrError::MerkleError)?;
if self.peaks[peak_pos] != computed {
return Err(MmrError::InvalidPeak);
}
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
for leaf in path.nodes() {
self.nodes.insert(idx.sibling(), *leaf);
let mut idx = InOrderIndex::from_leaf_pos(index);
for node in path.nodes() {
self.nodes.insert(idx.sibling(), *node);
idx = idx.parent();
}
Ok(())
}
/// Removes a leaf of the [PartialMmr] and the unused nodes from the authentication path.
/// Remove a leaf of the [PartialMmr] and the unused nodes from the authentication path.
///
/// Note: `leaf_pos` corresponds to the position in the MMR and not on an individual tree.
pub fn untrack(&mut self, leaf_pos: usize) {
pub fn remove(&mut self, leaf_pos: usize) {
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
self.nodes.remove(&idx.sibling());
@@ -531,14 +425,14 @@ impl From<&PartialMmr> for MmrPeaks {
// ================================================================================================
/// An iterator over every inner node of the [PartialMmr].
pub struct InnerNodeIterator<'a, I: Iterator<Item = (usize, RpoDigest)>> {
nodes: &'a NodeMap,
pub struct InnerNodeIterator<'a, I: Iterator<Item = &'a (usize, RpoDigest)>> {
nodes: &'a BTreeMap<InOrderIndex, RpoDigest>,
leaves: I,
stack: Vec<(InOrderIndex, RpoDigest)>,
seen_nodes: BTreeSet<InOrderIndex>,
}
impl<'a, I: Iterator<Item = (usize, RpoDigest)>> Iterator for InnerNodeIterator<'a, I> {
impl<'a, I: Iterator<Item = &'a (usize, RpoDigest)>> Iterator for InnerNodeIterator<'a, I> {
type Item = InnerNodeInfo;
fn next(&mut self) -> Option<Self::Item> {
@@ -565,8 +459,8 @@ impl<'a, I: Iterator<Item = (usize, RpoDigest)>> Iterator for InnerNodeIterator<
// the previous leaf has been processed, try to process the next leaf
if let Some((pos, leaf)) = self.leaves.next() {
let idx = InOrderIndex::from_leaf_pos(pos);
self.stack.push((idx, leaf));
let idx = InOrderIndex::from_leaf_pos(*pos);
self.stack.push((idx, *leaf));
}
}
@@ -596,29 +490,12 @@ fn forest_to_root_index(forest: usize) -> InOrderIndex {
InOrderIndex::new(idx.try_into().unwrap())
}
/// Given the description of a `forest`, returns the index of the right most element.
fn forest_to_rightmost_index(forest: usize) -> InOrderIndex {
// Count total size of all trees in the forest.
let nodes = nodes_in_forest(forest);
// Add the count for the parent nodes that separate each tree. These are allocated but
// currently empty, and correspond to the nodes that will be used once the trees are merged.
let open_trees = (forest.count_ones() - 1) as usize;
let idx = nodes + open_trees;
InOrderIndex::new(idx.try_into().unwrap())
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::{
forest_to_rightmost_index, forest_to_root_index, BTreeSet, InOrderIndex, MmrPeaks,
PartialMmr, RpoDigest, Vec,
};
use super::{forest_to_root_index, BTreeSet, InOrderIndex, PartialMmr, RpoDigest, Vec};
use crate::merkle::{int_to_node, MerkleStore, Mmr, NodeIndex};
const LEAVES: [RpoDigest; 7] = [
@@ -657,33 +534,6 @@ mod tests {
assert_eq!(forest_to_root_index(0b1110), idx(26));
}
#[test]
fn test_forest_to_rightmost_index() {
fn idx(pos: usize) -> InOrderIndex {
InOrderIndex::new(pos.try_into().unwrap())
}
for forest in 1..256 {
assert!(forest_to_rightmost_index(forest).inner() % 2 == 1, "Leaves are always odd");
}
assert_eq!(forest_to_rightmost_index(0b0001), idx(1));
assert_eq!(forest_to_rightmost_index(0b0010), idx(3));
assert_eq!(forest_to_rightmost_index(0b0011), idx(5));
assert_eq!(forest_to_rightmost_index(0b0100), idx(7));
assert_eq!(forest_to_rightmost_index(0b0101), idx(9));
assert_eq!(forest_to_rightmost_index(0b0110), idx(11));
assert_eq!(forest_to_rightmost_index(0b0111), idx(13));
assert_eq!(forest_to_rightmost_index(0b1000), idx(15));
assert_eq!(forest_to_rightmost_index(0b1001), idx(17));
assert_eq!(forest_to_rightmost_index(0b1010), idx(19));
assert_eq!(forest_to_rightmost_index(0b1011), idx(21));
assert_eq!(forest_to_rightmost_index(0b1100), idx(23));
assert_eq!(forest_to_rightmost_index(0b1101), idx(25));
assert_eq!(forest_to_rightmost_index(0b1110), idx(27));
assert_eq!(forest_to_rightmost_index(0b1111), idx(29));
}
#[test]
fn test_partial_mmr_apply_delta() {
// build an MMR with 10 nodes (2 peaks) and a partial MMR based on it
@@ -695,13 +545,13 @@ mod tests {
{
let node = mmr.get(1).unwrap();
let proof = mmr.open(1, mmr.forest()).unwrap();
partial_mmr.track(1, node, &proof.merkle_path).unwrap();
partial_mmr.add(1, node, &proof.merkle_path).unwrap();
}
{
let node = mmr.get(8).unwrap();
let proof = mmr.open(8, mmr.forest()).unwrap();
partial_mmr.track(8, node, &proof.merkle_path).unwrap();
partial_mmr.add(8, node, &proof.merkle_path).unwrap();
}
// add 2 more nodes into the MMR and validate apply_delta()
@@ -714,7 +564,7 @@ mod tests {
{
let node = mmr.get(12).unwrap();
let proof = mmr.open(12, mmr.forest()).unwrap();
partial_mmr.track(12, node, &proof.merkle_path).unwrap();
partial_mmr.add(12, node, &proof.merkle_path).unwrap();
assert!(partial_mmr.track_latest);
}
@@ -773,14 +623,14 @@ mod tests {
// create partial MMR and add authentication path to node at position 1
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
partial_mmr.add(1, node1, &proof1.merkle_path).unwrap();
// empty iterator should have no nodes
assert_eq!(partial_mmr.inner_nodes([].iter().cloned()).next(), None);
assert_eq!(partial_mmr.inner_nodes([].iter()).next(), None);
// build Merkle store from authentication paths in partial MMR
let mut store: MerkleStore = MerkleStore::new();
store.extend(partial_mmr.inner_nodes([(1, node1)].iter().cloned()));
store.extend(partial_mmr.inner_nodes([(1, node1)].iter()));
let index1 = NodeIndex::new(2, 1).unwrap();
let path1 = store.get_path(first_peak, index1).unwrap().path;
@@ -798,19 +648,19 @@ mod tests {
let node2 = mmr.get(2).unwrap();
let proof2 = mmr.open(2, mmr.forest()).unwrap();
partial_mmr.track(0, node0, &proof0.merkle_path).unwrap();
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
partial_mmr.track(2, node2, &proof2.merkle_path).unwrap();
partial_mmr.add(0, node0, &proof0.merkle_path).unwrap();
partial_mmr.add(1, node1, &proof1.merkle_path).unwrap();
partial_mmr.add(2, node2, &proof2.merkle_path).unwrap();
// make sure there are no duplicates
let leaves = [(0, node0), (1, node1), (2, node2)];
let mut nodes = BTreeSet::new();
for node in partial_mmr.inner_nodes(leaves.iter().cloned()) {
for node in partial_mmr.inner_nodes(leaves.iter()) {
assert!(nodes.insert(node.value));
}
// and also that the store is still be built correctly
store.extend(partial_mmr.inner_nodes(leaves.iter().cloned()));
store.extend(partial_mmr.inner_nodes(leaves.iter()));
let index0 = NodeIndex::new(2, 0).unwrap();
let index1 = NodeIndex::new(2, 1).unwrap();
@@ -832,12 +682,12 @@ mod tests {
let node5 = mmr.get(5).unwrap();
let proof5 = mmr.open(5, mmr.forest()).unwrap();
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
partial_mmr.track(5, node5, &proof5.merkle_path).unwrap();
partial_mmr.add(1, node1, &proof1.merkle_path).unwrap();
partial_mmr.add(5, node5, &proof5.merkle_path).unwrap();
// build Merkle store from authentication paths in partial MMR
let mut store: MerkleStore = MerkleStore::new();
store.extend(partial_mmr.inner_nodes([(1, node1), (5, node5)].iter().cloned()));
store.extend(partial_mmr.inner_nodes([(1, node1), (5, node5)].iter()));
let index1 = NodeIndex::new(2, 1).unwrap();
let index5 = NodeIndex::new(1, 1).unwrap();
@@ -850,62 +700,4 @@ mod tests {
assert_eq!(path1, proof1.merkle_path);
assert_eq!(path5, proof5.merkle_path);
}
#[test]
fn test_partial_mmr_add_without_track() {
let mut mmr = Mmr::default();
let empty_peaks = MmrPeaks::new(0, vec![]).unwrap();
let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
for el in (0..256).map(int_to_node) {
mmr.add(el);
partial_mmr.add(el, false);
let mmr_peaks = mmr.peaks(mmr.forest()).unwrap();
assert_eq!(mmr_peaks, partial_mmr.peaks());
assert_eq!(mmr.forest(), partial_mmr.forest());
}
}
#[test]
fn test_partial_mmr_add_with_track() {
let mut mmr = Mmr::default();
let empty_peaks = MmrPeaks::new(0, vec![]).unwrap();
let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
for i in 0..256 {
let el = int_to_node(i);
mmr.add(el);
partial_mmr.add(el, true);
let mmr_peaks = mmr.peaks(mmr.forest()).unwrap();
assert_eq!(mmr_peaks, partial_mmr.peaks());
assert_eq!(mmr.forest(), partial_mmr.forest());
for pos in 0..i {
let mmr_proof = mmr.open(pos as usize, mmr.forest()).unwrap();
let partialmmr_proof = partial_mmr.open(pos as usize).unwrap().unwrap();
assert_eq!(mmr_proof, partialmmr_proof);
}
}
}
#[test]
fn test_partial_mmr_add_existing_track() {
let mut mmr = Mmr::from((0..7).map(int_to_node));
// derive a partial Mmr from it which tracks authentication path to leaf 5
let mut partial_mmr = PartialMmr::from_peaks(mmr.peaks(mmr.forest()).unwrap());
let path_to_5 = mmr.open(5, mmr.forest()).unwrap().merkle_path;
let leaf_at_5 = mmr.get(5).unwrap();
partial_mmr.track(5, leaf_at_5, &path_to_5).unwrap();
// add a new leaf to both Mmr and partial Mmr
let leaf_at_7 = int_to_node(7);
mmr.add(leaf_at_7);
partial_mmr.add(leaf_at_7, false);
// the openings should be the same
assert_eq!(mmr.open(5, mmr.forest()).unwrap(), partial_mmr.open(5).unwrap().unwrap());
}
}

View File

@@ -132,9 +132,3 @@ impl MmrPeaks {
elements
}
}
impl From<MmrPeaks> for Vec<RpoDigest> {
fn from(peaks: MmrPeaks) -> Self {
peaks.peaks
}
}

View File

@@ -2,9 +2,6 @@
use super::super::MerklePath;
use super::{full::high_bitmask, leaf_to_corresponding_tree};
// MMR PROOF
// ================================================================================================
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MmrProof {
@@ -29,78 +26,9 @@ impl MmrProof {
self.position - forest_before
}
/// Returns index of the MMR peak against which the Merkle path in this proof can be verified.
pub fn peak_index(&self) -> usize {
let root = leaf_to_corresponding_tree(self.position, self.forest)
.expect("position must be part of the forest");
let smaller_peak_mask = 2_usize.pow(root) as usize - 1;
let num_smaller_peaks = (self.forest & smaller_peak_mask).count_ones();
(self.forest.count_ones() - num_smaller_peaks - 1) as usize
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::{MerklePath, MmrProof};
#[test]
fn test_peak_index() {
// --- single peak forest ---------------------------------------------
let forest = 11;
// the first 4 leaves belong to peak 0
for position in 0..8 {
let proof = make_dummy_proof(forest, position);
assert_eq!(proof.peak_index(), 0);
}
// --- forest with non-consecutive peaks ------------------------------
let forest = 11;
// the first 8 leaves belong to peak 0
for position in 0..8 {
let proof = make_dummy_proof(forest, position);
assert_eq!(proof.peak_index(), 0);
}
// the next 2 leaves belong to peak 1
for position in 8..10 {
let proof = make_dummy_proof(forest, position);
assert_eq!(proof.peak_index(), 1);
}
// the last leaf is the peak 2
let proof = make_dummy_proof(forest, 10);
assert_eq!(proof.peak_index(), 2);
// --- forest with consecutive peaks ----------------------------------
let forest = 7;
// the first 4 leaves belong to peak 0
for position in 0..4 {
let proof = make_dummy_proof(forest, position);
assert_eq!(proof.peak_index(), 0);
}
// the next 2 leaves belong to peak 1
for position in 4..6 {
let proof = make_dummy_proof(forest, position);
assert_eq!(proof.peak_index(), 1);
}
// the last leaf is the peak 2
let proof = make_dummy_proof(forest, 6);
assert_eq!(proof.peak_index(), 2);
}
fn make_dummy_proof(forest: usize, position: usize) -> MmrProof {
MmrProof {
forest,
position,
merkle_path: MerklePath::default(),
}
(self.forest.count_ones() - root - 1) as usize
}
}

View File

@@ -114,14 +114,13 @@ const LEAVES: [RpoDigest; 7] = [
#[test]
fn test_mmr_simple() {
let mut postorder = vec![
LEAVES[0],
LEAVES[1],
merge(LEAVES[0], LEAVES[1]),
LEAVES[2],
LEAVES[3],
merge(LEAVES[2], LEAVES[3]),
];
let mut postorder = Vec::new();
postorder.push(LEAVES[0]);
postorder.push(LEAVES[1]);
postorder.push(merge(LEAVES[0], LEAVES[1]));
postorder.push(LEAVES[2]);
postorder.push(LEAVES[3]);
postorder.push(merge(LEAVES[2], LEAVES[3]));
postorder.push(merge(postorder[2], postorder[5]));
postorder.push(LEAVES[4]);
postorder.push(LEAVES[5]);
@@ -769,7 +768,7 @@ fn test_partial_mmr_simple() {
// check state after adding tracking one element
let proof1 = mmr.open(0, mmr.forest()).unwrap();
let el1 = mmr.get(proof1.position).unwrap();
partial.track(proof1.position, el1, &proof1.merkle_path).unwrap();
partial.add(proof1.position, el1, &proof1.merkle_path).unwrap();
// check the number of nodes increased by the number of nodes in the proof
assert_eq!(partial.nodes.len(), proof1.merkle_path.len());
@@ -781,7 +780,7 @@ fn test_partial_mmr_simple() {
let proof2 = mmr.open(1, mmr.forest()).unwrap();
let el2 = mmr.get(proof2.position).unwrap();
partial.track(proof2.position, el2, &proof2.merkle_path).unwrap();
partial.add(proof2.position, el2, &proof2.merkle_path).unwrap();
// check the number of nodes increased by a single element (the one that is not shared)
assert_eq!(partial.nodes.len(), 3);
@@ -800,7 +799,7 @@ fn test_partial_mmr_update_single() {
let mut partial: PartialMmr = full.peaks(full.forest()).unwrap().into();
let proof = full.open(0, full.forest()).unwrap();
partial.track(proof.position, zero, &proof.merkle_path).unwrap();
partial.add(proof.position, zero, &proof.merkle_path).unwrap();
for i in 1..100 {
let node = int_to_node(i);
@@ -812,7 +811,7 @@ fn test_partial_mmr_update_single() {
assert_eq!(partial.peaks(), full.peaks(full.forest()).unwrap());
let proof1 = full.open(i as usize, full.forest()).unwrap();
partial.track(proof1.position, node, &proof1.merkle_path).unwrap();
partial.add(proof1.position, node, &proof1.merkle_path).unwrap();
let proof2 = partial.open(proof1.position).unwrap().unwrap();
assert_eq!(proof1.merkle_path, proof2.merkle_path);
}
@@ -828,11 +827,11 @@ fn test_mmr_add_invalid_odd_leaf() {
// None of the other leaves should work
for node in LEAVES.iter().cloned().rev().skip(1) {
let result = partial.track(LEAVES.len() - 1, node, &empty);
let result = partial.add(LEAVES.len() - 1, node, &empty);
assert!(result.is_err());
}
let result = partial.track(LEAVES.len() - 1, LEAVES[6], &empty);
let result = partial.add(LEAVES.len() - 1, LEAVES[6], &empty);
assert!(result.is_ok());
}

View File

@@ -2,8 +2,8 @@
use super::{
hash::rpo::{Rpo256, RpoDigest},
utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, Vec},
Felt, Word, EMPTY_WORD, ZERO,
utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, TryApplyDiff, Vec},
Felt, StarkField, Word, EMPTY_WORD, ZERO,
};
// REEXPORTS
@@ -12,6 +12,9 @@ use super::{
mod empty_roots;
pub use empty_roots::EmptySubtreeRoots;
mod delta;
pub use delta::{merkle_tree_delta, MerkleStoreDelta, MerkleTreeDelta};
mod index;
pub use index::NodeIndex;
@@ -21,11 +24,11 @@ pub use merkle_tree::{path_to_text, tree_to_text, MerkleTree};
mod path;
pub use path::{MerklePath, RootPath, ValuePath};
mod smt;
pub use smt::{
LeafIndex, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH,
SMT_MAX_DEPTH, SMT_MIN_DEPTH,
};
mod simple_smt;
pub use simple_smt::SimpleSmt;
mod tiered_smt;
pub use tiered_smt::{TieredSmt, TieredSmtProof, TieredSmtProofError};
mod mmr;
pub use mmr::{InOrderIndex, Mmr, MmrDelta, MmrError, MmrPeaks, MmrProof, PartialMmr};

View File

@@ -179,7 +179,7 @@ impl PartialMerkleTree {
/// # Errors
/// Returns an error if the specified NodeIndex is not contained in the nodes map.
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
self.nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index)).copied()
self.nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index)).map(|hash| *hash)
}
/// Returns true if provided index contains in the leaves set, false otherwise.

View File

@@ -209,7 +209,7 @@ fn get_paths() {
// Which have leaf nodes 20, 22, 23, 32 and 33. Hence overall we will have 5 paths -- one path
// for each leaf.
let leaves = [NODE20, NODE22, NODE23, NODE32, NODE33];
let leaves = vec![NODE20, NODE22, NODE23, NODE32, NODE33];
let expected_paths: Vec<(NodeIndex, ValuePath)> = leaves
.iter()
.map(|&leaf| {
@@ -257,7 +257,7 @@ fn leaves() {
let value32 = mt.get_node(NODE32).unwrap();
let value33 = mt.get_node(NODE33).unwrap();
let leaves = [(NODE11, value11), (NODE20, value20), (NODE32, value32), (NODE33, value33)];
let leaves = vec![(NODE11, value11), (NODE20, value20), (NODE32, value32), (NODE33, value33)];
let expected_leaves = leaves.iter().copied();
assert!(expected_leaves.eq(pmt.leaves()));

View File

@@ -1,5 +1,3 @@
use crate::Word;
use super::{vec, InnerNodeInfo, MerkleError, NodeIndex, Rpo256, RpoDigest, Vec};
use core::ops::{Deref, DerefMut};
use winter_utils::{ByteReader, Deserializable, DeserializationError, Serializable};
@@ -165,7 +163,7 @@ impl<'a> Iterator for InnerNodeIterator<'a> {
// MERKLE PATH CONTAINERS
// ================================================================================================
/// A container for a [crate::Word] value and its [MerklePath] opening.
/// A container for a [Word] value and its [MerklePath] opening.
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct ValuePath {
/// The node value opening for `path`.
@@ -176,18 +174,12 @@ pub struct ValuePath {
impl ValuePath {
/// Returns a new [ValuePath] instantiated from the specified value and path.
pub fn new(value: RpoDigest, path: MerklePath) -> Self {
Self { value, path }
pub fn new(value: RpoDigest, path: Vec<RpoDigest>) -> Self {
Self { value, path: MerklePath::new(path) }
}
}
impl From<(MerklePath, Word)> for ValuePath {
fn from((path, value): (MerklePath, Word)) -> Self {
ValuePath::new(value.into(), path)
}
}
/// A container for a [MerklePath] and its [crate::Word] root.
/// A container for a [MerklePath] and its [Word] root.
///
/// This structure does not provide any guarantees regarding the correctness of the path to the
/// root. For more information, check [MerklePath::verify].
@@ -206,14 +198,14 @@ impl Serializable for MerklePath {
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
assert!(self.nodes.len() <= u8::MAX.into(), "Length enforced in the constructor");
target.write_u8(self.nodes.len() as u8);
target.write_many(&self.nodes);
self.nodes.write_into(target);
}
}
impl Deserializable for MerklePath {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let count = source.read_u8()?.into();
let nodes = source.read_many::<RpoDigest>(count)?;
let nodes = RpoDigest::read_batch_from(source, count)?;
Ok(Self { nodes })
}
}

View File

@@ -0,0 +1,389 @@
use super::{
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, MerkleTreeDelta,
NodeIndex, Rpo256, RpoDigest, StoreNode, TryApplyDiff, Vec, Word,
};
#[cfg(test)]
mod tests;
// SPARSE MERKLE TREE
// ================================================================================================
/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
///
/// The root of the tree is recomputed on each new leaf update.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct SimpleSmt {
depth: u8,
root: RpoDigest,
leaves: BTreeMap<u64, Word>,
branches: BTreeMap<NodeIndex, BranchNode>,
}
impl SimpleSmt {
// CONSTANTS
// --------------------------------------------------------------------------------------------
/// Minimum supported depth.
pub const MIN_DEPTH: u8 = 1;
/// Maximum supported depth.
pub const MAX_DEPTH: u8 = 64;
/// Value of an empty leaf.
pub const EMPTY_VALUE: Word = super::EMPTY_WORD;
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Returns a new [SimpleSmt] instantiated with the specified depth.
///
/// All leaves in the returned tree are set to [ZERO; 4].
///
/// # Errors
/// Returns an error if the depth is 0 or is greater than 64.
pub fn new(depth: u8) -> Result<Self, MerkleError> {
// validate the range of the depth.
if depth < Self::MIN_DEPTH {
return Err(MerkleError::DepthTooSmall(depth));
} else if Self::MAX_DEPTH < depth {
return Err(MerkleError::DepthTooBig(depth as u64));
}
let root = *EmptySubtreeRoots::entry(depth, 0);
Ok(Self {
root,
depth,
leaves: BTreeMap::new(),
branches: BTreeMap::new(),
})
}
/// Returns a new [SimpleSmt] instantiated with the specified depth and with leaves
/// set as specified by the provided entries.
///
/// All leaves omitted from the entries list are set to [ZERO; 4].
///
/// # Errors
/// Returns an error if:
/// - If the depth is 0 or is greater than 64.
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
/// - The provided entries contain multiple values for the same key.
pub fn with_leaves(
depth: u8,
entries: impl IntoIterator<Item = (u64, Word)>,
) -> Result<Self, MerkleError> {
// create an empty tree
let mut tree = Self::new(depth)?;
// compute the max number of entries. We use an upper bound of depth 63 because we consider
// passing in a vector of size 2^64 infeasible.
let max_num_entries = 2_usize.pow(tree.depth.min(63).into());
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
// entries with the empty value need additional tracking.
let mut key_set_to_zero = BTreeSet::new();
for (idx, (key, value)) in entries.into_iter().enumerate() {
if idx >= max_num_entries {
return Err(MerkleError::InvalidNumEntries(max_num_entries));
}
let old_value = tree.update_leaf(key, value)?;
if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) {
return Err(MerkleError::DuplicateValuesForIndex(key));
}
if value == Self::EMPTY_VALUE {
key_set_to_zero.insert(key);
};
}
Ok(tree)
}
/// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
/// starting at index 0.
pub fn with_contiguous_leaves(
depth: u8,
entries: impl IntoIterator<Item = Word>,
) -> Result<Self, MerkleError> {
Self::with_leaves(
depth,
entries
.into_iter()
.enumerate()
.map(|(idx, word)| (idx.try_into().expect("tree max depth is 2^8"), word)),
)
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the root of this Merkle tree.
pub const fn root(&self) -> RpoDigest {
self.root
}
/// Returns the depth of this Merkle tree.
pub const fn depth(&self) -> u8 {
self.depth
}
/// Returns a node at the specified index.
///
/// # Errors
/// Returns an error if the specified index has depth set to 0 or the depth is greater than
/// the depth of this Merkle tree.
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
if index.is_root() {
Err(MerkleError::DepthTooSmall(index.depth()))
} else if index.depth() > self.depth() {
Err(MerkleError::DepthTooBig(index.depth() as u64))
} else if index.depth() == self.depth() {
// the lookup in empty_hashes could fail only if empty_hashes were not built correctly
// by the constructor as we check the depth of the lookup above.
let leaf_pos = index.value();
let leaf = match self.get_leaf_node(leaf_pos) {
Some(word) => word.into(),
None => *EmptySubtreeRoots::entry(self.depth, index.depth()),
};
Ok(leaf)
} else {
Ok(self.get_branch_node(&index).parent())
}
}
/// Returns a value of the leaf at the specified index.
///
/// # Errors
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
pub fn get_leaf(&self, index: u64) -> Result<Word, MerkleError> {
let index = NodeIndex::new(self.depth, index)?;
Ok(self.get_node(index)?.into())
}
/// Returns a Merkle path from the node at the specified index to the root.
///
/// The node itself is not included in the path.
///
/// # Errors
/// Returns an error if the specified index has depth set to 0 or the depth is greater than
/// the depth of this Merkle tree.
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
if index.is_root() {
return Err(MerkleError::DepthTooSmall(index.depth()));
} else if index.depth() > self.depth() {
return Err(MerkleError::DepthTooBig(index.depth() as u64));
}
let mut path = Vec::with_capacity(index.depth() as usize);
for _ in 0..index.depth() {
let is_right = index.is_value_odd();
index.move_up();
let BranchNode { left, right } = self.get_branch_node(&index);
let value = if is_right { left } else { right };
path.push(value);
}
Ok(MerklePath::new(path))
}
/// Return a Merkle path from the leaf at the specified index to the root.
///
/// The leaf itself is not included in the path.
///
/// # Errors
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
pub fn get_leaf_path(&self, index: u64) -> Result<MerklePath, MerkleError> {
let index = NodeIndex::new(self.depth(), index)?;
self.get_path(index)
}
// ITERATORS
// --------------------------------------------------------------------------------------------
/// Returns an iterator over the leaves of this [SimpleSmt].
pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
self.leaves.iter().map(|(i, w)| (*i, w))
}
/// Returns an iterator over the inner nodes of this Merkle tree.
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
self.branches.values().map(|e| InnerNodeInfo {
value: e.parent(),
left: e.left,
right: e.right,
})
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Updates value of the leaf at the specified index returning the old leaf value.
///
/// This also recomputes all hashes between the leaf and the root, updating the root itself.
///
/// # Errors
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<Word, MerkleError> {
// validate the index before modifying the structure
let idx = NodeIndex::new(self.depth(), index)?;
let old_value = self.insert_leaf_node(index, value).unwrap_or(Self::EMPTY_VALUE);
// if the old value and new value are the same, there is nothing to update
if value == old_value {
return Ok(value);
}
self.recompute_nodes_from_index_to_root(idx, RpoDigest::from(value));
Ok(old_value)
}
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
/// computed as `self.depth() - subtree.depth()`.
///
/// Returns the new root.
pub fn set_subtree(
&mut self,
subtree_insertion_index: u64,
subtree: SimpleSmt,
) -> Result<RpoDigest, MerkleError> {
if subtree.depth() > self.depth() {
return Err(MerkleError::InvalidSubtreeDepth {
subtree_depth: subtree.depth(),
tree_depth: self.depth(),
});
}
// Verify that `subtree_insertion_index` is valid.
let subtree_root_insertion_depth = self.depth() - subtree.depth();
let subtree_root_index =
NodeIndex::new(subtree_root_insertion_depth, subtree_insertion_index)?;
// add leaves
// --------------
// The subtree's leaf indices live in their own context - i.e. a subtree of depth `d`. If we
// insert the subtree at `subtree_insertion_index = 0`, then the subtree leaf indices are
// valid as they are. However, consider what happens when we insert at
// `subtree_insertion_index = 1`. The first leaf of our subtree now will have index `2^d`;
// you can see it as there's a full subtree sitting on its left. In general, for
// `subtree_insertion_index = i`, there are `i` subtrees sitting before the subtree we want
// to insert, so we need to adjust all its leaves by `i * 2^d`.
let leaf_index_shift: u64 = subtree_insertion_index * 2_u64.pow(subtree.depth().into());
for (subtree_leaf_idx, leaf_value) in subtree.leaves() {
let new_leaf_idx = leaf_index_shift + subtree_leaf_idx;
debug_assert!(new_leaf_idx < 2_u64.pow(self.depth().into()));
self.insert_leaf_node(new_leaf_idx, *leaf_value);
}
// add subtree's branch nodes (which includes the root)
// --------------
for (branch_idx, branch_node) in subtree.branches {
let new_branch_idx = {
let new_depth = subtree_root_insertion_depth + branch_idx.depth();
let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
+ branch_idx.value();
NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
};
self.branches.insert(new_branch_idx, branch_node);
}
// recompute nodes starting from subtree root
// --------------
self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
Ok(self.root)
}
// HELPER METHODS
// --------------------------------------------------------------------------------------------
/// Recomputes the branch nodes (including the root) from `index` all the way to the root.
/// `node_hash_at_index` is the hash of the node stored at index.
fn recompute_nodes_from_index_to_root(
&mut self,
mut index: NodeIndex,
node_hash_at_index: RpoDigest,
) {
let mut value = node_hash_at_index;
for _ in 0..index.depth() {
let is_right = index.is_value_odd();
index.move_up();
let BranchNode { left, right } = self.get_branch_node(&index);
let (left, right) = if is_right { (left, value) } else { (value, right) };
self.insert_branch_node(index, left, right);
value = Rpo256::merge(&[left, right]);
}
self.root = value;
}
fn get_leaf_node(&self, key: u64) -> Option<Word> {
self.leaves.get(&key).copied()
}
fn insert_leaf_node(&mut self, key: u64, node: Word) -> Option<Word> {
self.leaves.insert(key, node)
}
fn get_branch_node(&self, index: &NodeIndex) -> BranchNode {
self.branches.get(index).cloned().unwrap_or_else(|| {
let node = EmptySubtreeRoots::entry(self.depth, index.depth() + 1);
BranchNode { left: *node, right: *node }
})
}
fn insert_branch_node(&mut self, index: NodeIndex, left: RpoDigest, right: RpoDigest) {
let branch = BranchNode { left, right };
self.branches.insert(index, branch);
}
}
// BRANCH NODE
// ================================================================================================
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
struct BranchNode {
left: RpoDigest,
right: RpoDigest,
}
impl BranchNode {
fn parent(&self) -> RpoDigest {
Rpo256::merge(&[self.left, self.right])
}
}
// TRY APPLY DIFF
// ================================================================================================
impl TryApplyDiff<RpoDigest, StoreNode> for SimpleSmt {
type Error = MerkleError;
type DiffType = MerkleTreeDelta;
fn try_apply(&mut self, diff: MerkleTreeDelta) -> Result<(), MerkleError> {
if diff.depth() != self.depth() {
return Err(MerkleError::InvalidDepth {
expected: self.depth(),
provided: diff.depth(),
});
}
for slot in diff.cleared_slots() {
self.update_leaf(*slot, Self::EMPTY_VALUE)?;
}
for (slot, value) in diff.updated_slots() {
self.update_leaf(*slot, *value)?;
}
Ok(())
}
}

View File

@@ -1,15 +1,10 @@
use super::{
super::{MerkleError, RpoDigest, SimpleSmt},
NodeIndex,
super::{InnerNodeInfo, MerkleError, MerkleTree, RpoDigest, SimpleSmt, EMPTY_WORD},
NodeIndex, Rpo256, Vec,
};
use crate::{
hash::rpo::Rpo256,
merkle::{
digests_to_words, int_to_leaf, int_to_node, smt::SparseMerkleTree, EmptySubtreeRoots,
InnerNodeInfo, LeafIndex, MerkleTree,
},
utils::collections::Vec,
Word, EMPTY_WORD,
merkle::{digests_to_words, int_to_leaf, int_to_node, EmptySubtreeRoots},
Word,
};
// TEST DATA
@@ -39,27 +34,26 @@ const ZERO_VALUES8: [Word; 8] = [int_to_leaf(0); 8];
#[test]
fn build_empty_tree() {
// tree of depth 3
let smt = SimpleSmt::<3>::new().unwrap();
let mt = MerkleTree::new(ZERO_VALUES8).unwrap();
let smt = SimpleSmt::new(3).unwrap();
let mt = MerkleTree::new(ZERO_VALUES8.to_vec()).unwrap();
assert_eq!(mt.root(), smt.root());
}
#[test]
fn build_sparse_tree() {
const DEPTH: u8 = 3;
let mut smt = SimpleSmt::<DEPTH>::new().unwrap();
let mut smt = SimpleSmt::new(3).unwrap();
let mut values = ZERO_VALUES8.to_vec();
// insert single value
let key = 6;
let new_node = int_to_leaf(7);
values[key as usize] = new_node;
let old_value = smt.insert(LeafIndex::<DEPTH>::new(key).unwrap(), new_node);
let old_value = smt.update_leaf(key, new_node).expect("Failed to update leaf");
let mt2 = MerkleTree::new(values.clone()).unwrap();
assert_eq!(mt2.root(), smt.root());
assert_eq!(
mt2.get_path(NodeIndex::make(3, 6)).unwrap(),
smt.open(&LeafIndex::<3>::new(6).unwrap()).path
smt.get_path(NodeIndex::make(3, 6)).unwrap()
);
assert_eq!(old_value, EMPTY_WORD);
@@ -67,12 +61,12 @@ fn build_sparse_tree() {
let key = 2;
let new_node = int_to_leaf(3);
values[key as usize] = new_node;
let old_value = smt.insert(LeafIndex::<DEPTH>::new(key).unwrap(), new_node);
let old_value = smt.update_leaf(key, new_node).expect("Failed to update leaf");
let mt3 = MerkleTree::new(values).unwrap();
assert_eq!(mt3.root(), smt.root());
assert_eq!(
mt3.get_path(NodeIndex::make(3, 2)).unwrap(),
smt.open(&LeafIndex::<3>::new(2).unwrap()).path
smt.get_path(NodeIndex::make(3, 2)).unwrap()
);
assert_eq!(old_value, EMPTY_WORD);
}
@@ -80,12 +74,14 @@ fn build_sparse_tree() {
/// Tests that [`SimpleSmt::with_contiguous_leaves`] works as expected
#[test]
fn build_contiguous_tree() {
let tree_with_leaves =
SimpleSmt::<2>::with_leaves([0, 1, 2, 3].into_iter().zip(digests_to_words(&VALUES4)))
.unwrap();
let tree_with_leaves = SimpleSmt::with_leaves(
2,
[0, 1, 2, 3].into_iter().zip(digests_to_words(&VALUES4).into_iter()),
)
.unwrap();
let tree_with_contiguous_leaves =
SimpleSmt::<2>::with_contiguous_leaves(digests_to_words(&VALUES4)).unwrap();
SimpleSmt::with_contiguous_leaves(2, digests_to_words(&VALUES4).into_iter()).unwrap();
assert_eq!(tree_with_leaves, tree_with_contiguous_leaves);
}
@@ -93,7 +89,8 @@ fn build_contiguous_tree() {
#[test]
fn test_depth2_tree() {
let tree =
SimpleSmt::<2>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4))).unwrap();
SimpleSmt::with_leaves(2, KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()))
.unwrap();
// check internal structure
let (root, node2, node3) = compute_internal_nodes();
@@ -108,16 +105,21 @@ fn test_depth2_tree() {
assert_eq!(VALUES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
// check get_path(): depth 2
assert_eq!(vec![VALUES4[1], node3], *tree.open(&LeafIndex::<2>::new(0).unwrap()).path);
assert_eq!(vec![VALUES4[0], node3], *tree.open(&LeafIndex::<2>::new(1).unwrap()).path);
assert_eq!(vec![VALUES4[3], node2], *tree.open(&LeafIndex::<2>::new(2).unwrap()).path);
assert_eq!(vec![VALUES4[2], node2], *tree.open(&LeafIndex::<2>::new(3).unwrap()).path);
assert_eq!(vec![VALUES4[1], node3], *tree.get_path(NodeIndex::make(2, 0)).unwrap());
assert_eq!(vec![VALUES4[0], node3], *tree.get_path(NodeIndex::make(2, 1)).unwrap());
assert_eq!(vec![VALUES4[3], node2], *tree.get_path(NodeIndex::make(2, 2)).unwrap());
assert_eq!(vec![VALUES4[2], node2], *tree.get_path(NodeIndex::make(2, 3)).unwrap());
// check get_path(): depth 1
assert_eq!(vec![node3], *tree.get_path(NodeIndex::make(1, 0)).unwrap());
assert_eq!(vec![node2], *tree.get_path(NodeIndex::make(1, 1)).unwrap());
}
#[test]
fn test_inner_node_iterator() -> Result<(), MerkleError> {
let tree =
SimpleSmt::<2>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4))).unwrap();
SimpleSmt::with_leaves(2, KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()))
.unwrap();
// check depth 2
assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
@@ -147,9 +149,9 @@ fn test_inner_node_iterator() -> Result<(), MerkleError> {
#[test]
fn update_leaf() {
const DEPTH: u8 = 3;
let mut tree =
SimpleSmt::<DEPTH>::with_leaves(KEYS8.into_iter().zip(digests_to_words(&VALUES8))).unwrap();
SimpleSmt::with_leaves(3, KEYS8.into_iter().zip(digests_to_words(&VALUES8).into_iter()))
.unwrap();
// update one value
let key = 3;
@@ -158,7 +160,7 @@ fn update_leaf() {
expected_values[key] = new_node;
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
let old_leaf = tree.insert(LeafIndex::<DEPTH>::new(key as u64).unwrap(), new_node);
let old_leaf = tree.update_leaf(key as u64, new_node).unwrap();
assert_eq!(expected_tree.root(), tree.root);
assert_eq!(old_leaf, *VALUES8[key]);
@@ -168,7 +170,7 @@ fn update_leaf() {
expected_values[key] = new_node;
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
let old_leaf = tree.insert(LeafIndex::<DEPTH>::new(key as u64).unwrap(), new_node);
let old_leaf = tree.update_leaf(key as u64, new_node).unwrap();
assert_eq!(expected_tree.root(), tree.root);
assert_eq!(old_leaf, *VALUES8[key]);
}
@@ -200,22 +202,29 @@ fn small_tree_opening_is_consistent() {
let k = Rpo256::merge(&[i, j]);
let depth = 3;
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
let tree = SimpleSmt::<3>::with_leaves(entries).unwrap();
let tree = SimpleSmt::with_leaves(depth, entries).unwrap();
assert_eq!(tree.root(), k);
let cases: Vec<(u64, Vec<RpoDigest>)> = vec![
(0, vec![b.into(), f, j]),
(1, vec![a.into(), f, j]),
(4, vec![z.into(), h, i]),
(7, vec![z.into(), g, i]),
let cases: Vec<(u8, u64, Vec<RpoDigest>)> = vec![
(3, 0, vec![b.into(), f, j]),
(3, 1, vec![a.into(), f, j]),
(3, 4, vec![z.into(), h, i]),
(3, 7, vec![z.into(), g, i]),
(2, 0, vec![f, j]),
(2, 1, vec![e, j]),
(2, 2, vec![h, i]),
(2, 3, vec![g, i]),
(1, 0, vec![j]),
(1, 1, vec![i]),
];
for (key, path) in cases {
let opening = tree.open(&LeafIndex::<3>::new(key).unwrap());
for (depth, key, path) in cases {
let opening = tree.get_path(NodeIndex::make(depth, key)).unwrap();
assert_eq!(path, *opening.path);
assert_eq!(path, *opening);
}
}
@@ -237,12 +246,12 @@ fn test_simplesmt_fail_on_duplicates() {
for (first, second) in values.iter() {
// consecutive
let entries = [(1, *first), (1, *second)];
let smt = SimpleSmt::<64>::with_leaves(entries);
let smt = SimpleSmt::with_leaves(64, entries);
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
// not consecutive
let entries = [(1, *first), (5, int_to_leaf(5)), (1, *second)];
let smt = SimpleSmt::<64>::with_leaves(entries);
let smt = SimpleSmt::with_leaves(64, entries);
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
}
}
@@ -250,10 +259,56 @@ fn test_simplesmt_fail_on_duplicates() {
#[test]
fn with_no_duplicates_empty_node() {
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2))];
let smt = SimpleSmt::<64>::with_leaves(entries);
let smt = SimpleSmt::with_leaves(64, entries);
assert!(smt.is_ok());
}
#[test]
fn test_simplesmt_update_nonexisting_leaf_with_zero() {
// TESTING WITH EMPTY WORD
// --------------------------------------------------------------------------------------------
// Depth 1 has 2 leaf. Position is 0-indexed, position 2 doesn't exist.
let mut smt = SimpleSmt::new(1).unwrap();
let result = smt.update_leaf(2, EMPTY_WORD);
assert!(!smt.leaves.contains_key(&2));
assert!(result.is_err());
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
let mut smt = SimpleSmt::new(2).unwrap();
let result = smt.update_leaf(4, EMPTY_WORD);
assert!(!smt.leaves.contains_key(&4));
assert!(result.is_err());
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
let mut smt = SimpleSmt::new(3).unwrap();
let result = smt.update_leaf(8, EMPTY_WORD);
assert!(!smt.leaves.contains_key(&8));
assert!(result.is_err());
// TESTING WITH A VALUE
// --------------------------------------------------------------------------------------------
let value = int_to_node(1);
// Depth 1 has 2 leaves. Position is 0-indexed, position 1 doesn't exist.
let mut smt = SimpleSmt::new(1).unwrap();
let result = smt.update_leaf(2, *value);
assert!(!smt.leaves.contains_key(&2));
assert!(result.is_err());
// Depth 2 has 4 leaves. Position is 0-indexed, position 2 doesn't exist.
let mut smt = SimpleSmt::new(2).unwrap();
let result = smt.update_leaf(4, *value);
assert!(!smt.leaves.contains_key(&4));
assert!(result.is_err());
// Depth 3 has 8 leaves. Position is 0-indexed, position 4 doesn't exist.
let mut smt = SimpleSmt::new(3).unwrap();
let result = smt.update_leaf(8, *value);
assert!(!smt.leaves.contains_key(&8));
assert!(result.is_err());
}
#[test]
fn test_simplesmt_with_leaves_nonexisting_leaf() {
// TESTING WITH EMPTY WORD
@@ -261,17 +316,17 @@ fn test_simplesmt_with_leaves_nonexisting_leaf() {
// Depth 1 has 2 leaf. Position is 0-indexed, position 2 doesn't exist.
let leaves = [(2, EMPTY_WORD)];
let result = SimpleSmt::<1>::with_leaves(leaves);
let result = SimpleSmt::with_leaves(1, leaves);
assert!(result.is_err());
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
let leaves = [(4, EMPTY_WORD)];
let result = SimpleSmt::<2>::with_leaves(leaves);
let result = SimpleSmt::with_leaves(2, leaves);
assert!(result.is_err());
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
let leaves = [(8, EMPTY_WORD)];
let result = SimpleSmt::<3>::with_leaves(leaves);
let result = SimpleSmt::with_leaves(3, leaves);
assert!(result.is_err());
// TESTING WITH A VALUE
@@ -280,17 +335,17 @@ fn test_simplesmt_with_leaves_nonexisting_leaf() {
// Depth 1 has 2 leaves. Position is 0-indexed, position 2 doesn't exist.
let leaves = [(2, *value)];
let result = SimpleSmt::<1>::with_leaves(leaves);
let result = SimpleSmt::with_leaves(1, leaves);
assert!(result.is_err());
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
let leaves = [(4, *value)];
let result = SimpleSmt::<2>::with_leaves(leaves);
let result = SimpleSmt::with_leaves(2, leaves);
assert!(result.is_err());
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
let leaves = [(8, *value)];
let result = SimpleSmt::<3>::with_leaves(leaves);
let result = SimpleSmt::with_leaves(3, leaves);
assert!(result.is_err());
}
@@ -328,15 +383,16 @@ fn test_simplesmt_set_subtree() {
// / \
// c 0
let subtree = {
let depth = 1;
let entries = vec![(0, c)];
SimpleSmt::<1>::with_leaves(entries).unwrap()
SimpleSmt::with_leaves(depth, entries).unwrap()
};
// insert subtree
const TREE_DEPTH: u8 = 3;
let tree = {
let depth = 3;
let entries = vec![(0, a), (1, b), (7, d)];
let mut tree = SimpleSmt::<TREE_DEPTH>::with_leaves(entries).unwrap();
let mut tree = SimpleSmt::with_leaves(depth, entries).unwrap();
tree.set_subtree(2, subtree).unwrap();
@@ -344,8 +400,8 @@ fn test_simplesmt_set_subtree() {
};
assert_eq!(tree.root(), k);
assert_eq!(tree.get_leaf(&LeafIndex::<TREE_DEPTH>::new(4).unwrap()), c);
assert_eq!(tree.get_inner_node(NodeIndex::new_unchecked(2, 2)).hash(), g);
assert_eq!(tree.get_leaf(4).unwrap(), c);
assert_eq!(tree.get_branch_node(&NodeIndex::new_unchecked(2, 2)).parent(), g);
}
/// Ensures that an invalid input node index into `set_subtree()` incurs no mutation of the tree
@@ -373,13 +429,15 @@ fn test_simplesmt_set_subtree_unchanged_for_wrong_index() {
// / \
// c 0
let subtree = {
let depth = 1;
let entries = vec![(0, c)];
SimpleSmt::<1>::with_leaves(entries).unwrap()
SimpleSmt::with_leaves(depth, entries).unwrap()
};
let mut tree = {
let depth = 3;
let entries = vec![(0, a), (1, b), (7, d)];
SimpleSmt::<3>::with_leaves(entries).unwrap()
SimpleSmt::with_leaves(depth, entries).unwrap()
};
let tree_root_before_insertion = tree.root();
@@ -409,20 +467,21 @@ fn test_simplesmt_set_subtree_entire_tree() {
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
let depth = 3;
// subtree: E3
const DEPTH: u8 = 3;
let subtree = { SimpleSmt::<DEPTH>::with_leaves(Vec::new()).unwrap() };
assert_eq!(subtree.root(), *EmptySubtreeRoots::entry(DEPTH, 0));
let subtree = { SimpleSmt::with_leaves(depth, Vec::new()).unwrap() };
assert_eq!(subtree.root(), *EmptySubtreeRoots::entry(depth, 0));
// insert subtree
let mut tree = {
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
SimpleSmt::<3>::with_leaves(entries).unwrap()
SimpleSmt::with_leaves(depth, entries).unwrap()
};
tree.set_subtree(0, subtree).unwrap();
assert_eq!(tree.root(), *EmptySubtreeRoots::entry(DEPTH, 0));
assert_eq!(tree.root(), *EmptySubtreeRoots::entry(depth, 0));
}
// HELPER FUNCTIONS

View File

@@ -1,86 +0,0 @@
use core::fmt;
use crate::{
hash::rpo::RpoDigest,
merkle::{LeafIndex, SMT_DEPTH},
utils::collections::Vec,
Word,
};
// SMT LEAF ERROR
// =================================================================================================
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum SmtLeafError {
InconsistentKeys {
entries: Vec<(RpoDigest, Word)>,
key_1: RpoDigest,
key_2: RpoDigest,
},
InvalidNumEntriesForMultiple(usize),
SingleKeyInconsistentWithLeafIndex {
key: RpoDigest,
leaf_index: LeafIndex<SMT_DEPTH>,
},
MultipleKeysInconsistentWithLeafIndex {
leaf_index_from_keys: LeafIndex<SMT_DEPTH>,
leaf_index_supplied: LeafIndex<SMT_DEPTH>,
},
}
#[cfg(feature = "std")]
impl std::error::Error for SmtLeafError {}
impl fmt::Display for SmtLeafError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use SmtLeafError::*;
match self {
InvalidNumEntriesForMultiple(num_entries) => {
write!(f, "Multiple leaf requires 2 or more entries. Got: {num_entries}")
}
InconsistentKeys { entries, key_1, key_2 } => {
write!(f, "Multiple leaf requires all keys to map to the same leaf index. Offending keys: {key_1} and {key_2}. Entries: {entries:?}.")
}
SingleKeyInconsistentWithLeafIndex { key, leaf_index } => {
write!(
f,
"Single key in leaf inconsistent with leaf index. Key: {key}, leaf index: {}",
leaf_index.value()
)
}
MultipleKeysInconsistentWithLeafIndex {
leaf_index_from_keys,
leaf_index_supplied,
} => {
write!(
f,
"Keys in entries map to leaf index {}, but leaf index {} was supplied",
leaf_index_from_keys.value(),
leaf_index_supplied.value()
)
}
}
}
}
// SMT PROOF ERROR
// =================================================================================================
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum SmtProofError {
InvalidPathLength(usize),
}
#[cfg(feature = "std")]
impl std::error::Error for SmtProofError {}
impl fmt::Display for SmtProofError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use SmtProofError::*;
match self {
InvalidPathLength(path_length) => {
write!(f, "Invalid Merkle path length. Expected {SMT_DEPTH}, got {path_length}")
}
}
}
}

View File

@@ -1,372 +0,0 @@
use core::cmp::Ordering;
use crate::utils::{collections::Vec, string::ToString, vec};
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use super::{Felt, LeafIndex, Rpo256, RpoDigest, SmtLeafError, Word, EMPTY_WORD, SMT_DEPTH};
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum SmtLeaf {
Empty(LeafIndex<SMT_DEPTH>),
Single((RpoDigest, Word)),
Multiple(Vec<(RpoDigest, Word)>),
}
impl SmtLeaf {
// CONSTRUCTORS
// ---------------------------------------------------------------------------------------------
/// Returns a new leaf with the specified entries
///
/// # Errors
/// - Returns an error if 2 keys in `entries` map to a different leaf index
/// - Returns an error if 1 or more keys in `entries` map to a leaf index
/// different from `leaf_index`
pub fn new(
entries: Vec<(RpoDigest, Word)>,
leaf_index: LeafIndex<SMT_DEPTH>,
) -> Result<Self, SmtLeafError> {
match entries.len() {
0 => Ok(Self::new_empty(leaf_index)),
1 => {
let (key, value) = entries[0];
if LeafIndex::<SMT_DEPTH>::from(key) != leaf_index {
return Err(SmtLeafError::SingleKeyInconsistentWithLeafIndex {
key,
leaf_index,
});
}
Ok(Self::new_single(key, value))
}
_ => {
let leaf = Self::new_multiple(entries)?;
// `new_multiple()` checked that all keys map to the same leaf index. We still need
// to ensure that that leaf index is `leaf_index`.
if leaf.index() != leaf_index {
Err(SmtLeafError::MultipleKeysInconsistentWithLeafIndex {
leaf_index_from_keys: leaf.index(),
leaf_index_supplied: leaf_index,
})
} else {
Ok(leaf)
}
}
}
}
/// Returns a new empty leaf with the specified leaf index
pub fn new_empty(leaf_index: LeafIndex<SMT_DEPTH>) -> Self {
Self::Empty(leaf_index)
}
/// Returns a new single leaf with the specified entry. The leaf index is derived from the
/// entry's key.
pub fn new_single(key: RpoDigest, value: Word) -> Self {
Self::Single((key, value))
}
/// Returns a new single leaf with the specified entry. The leaf index is derived from the
/// entries' keys.
///
/// # Errors
/// - Returns an error if 2 keys in `entries` map to a different leaf index
pub fn new_multiple(entries: Vec<(RpoDigest, Word)>) -> Result<Self, SmtLeafError> {
if entries.len() < 2 {
return Err(SmtLeafError::InvalidNumEntriesForMultiple(entries.len()));
}
// Check that all keys map to the same leaf index
{
let mut keys = entries.iter().map(|(key, _)| key);
let first_key = *keys.next().expect("ensured at least 2 entries");
let first_leaf_index: LeafIndex<SMT_DEPTH> = first_key.into();
for &next_key in keys {
let next_leaf_index: LeafIndex<SMT_DEPTH> = next_key.into();
if next_leaf_index != first_leaf_index {
return Err(SmtLeafError::InconsistentKeys {
entries,
key_1: first_key,
key_2: next_key,
});
}
}
}
Ok(Self::Multiple(entries))
}
// PUBLIC ACCESSORS
// ---------------------------------------------------------------------------------------------
/// Returns true if the leaf is empty
pub fn is_empty(&self) -> bool {
matches!(self, Self::Empty(_))
}
/// Returns the leaf's index in the [`super::Smt`]
pub fn index(&self) -> LeafIndex<SMT_DEPTH> {
match self {
SmtLeaf::Empty(leaf_index) => *leaf_index,
SmtLeaf::Single((key, _)) => key.into(),
SmtLeaf::Multiple(entries) => {
// Note: All keys are guaranteed to have the same leaf index
let (first_key, _) = entries[0];
first_key.into()
}
}
}
/// Returns the number of entries stored in the leaf
pub fn num_entries(&self) -> u64 {
match self {
SmtLeaf::Empty(_) => 0,
SmtLeaf::Single(_) => 1,
SmtLeaf::Multiple(entries) => {
entries.len().try_into().expect("shouldn't have more than 2^64 entries")
}
}
}
/// Computes the hash of the leaf
pub fn hash(&self) -> RpoDigest {
match self {
SmtLeaf::Empty(_) => EMPTY_WORD.into(),
SmtLeaf::Single((key, value)) => Rpo256::merge(&[*key, value.into()]),
SmtLeaf::Multiple(kvs) => {
let elements: Vec<Felt> = kvs.iter().copied().flat_map(kv_to_elements).collect();
Rpo256::hash_elements(&elements)
}
}
}
// ITERATORS
// ---------------------------------------------------------------------------------------------
/// Returns the key-value pairs in the leaf
pub fn entries(&self) -> Vec<&(RpoDigest, Word)> {
match self {
SmtLeaf::Empty(_) => Vec::new(),
SmtLeaf::Single(kv_pair) => vec![kv_pair],
SmtLeaf::Multiple(kv_pairs) => kv_pairs.iter().collect(),
}
}
// CONVERSIONS
// ---------------------------------------------------------------------------------------------
/// Converts a leaf to a list of field elements
pub fn to_elements(&self) -> Vec<Felt> {
self.clone().into_elements()
}
/// Converts a leaf to a list of field elements
pub fn into_elements(self) -> Vec<Felt> {
self.into_entries().into_iter().flat_map(kv_to_elements).collect()
}
/// Converts a leaf the key-value pairs in the leaf
pub fn into_entries(self) -> Vec<(RpoDigest, Word)> {
match self {
SmtLeaf::Empty(_) => Vec::new(),
SmtLeaf::Single(kv_pair) => vec![kv_pair],
SmtLeaf::Multiple(kv_pairs) => kv_pairs,
}
}
// HELPERS
// ---------------------------------------------------------------------------------------------
/// Returns the value associated with `key` in the leaf, or `None` if `key` maps to another leaf.
pub(super) fn get_value(&self, key: &RpoDigest) -> Option<Word> {
// Ensure that `key` maps to this leaf
if self.index() != key.into() {
return None;
}
match self {
SmtLeaf::Empty(_) => Some(EMPTY_WORD),
SmtLeaf::Single((key_in_leaf, value_in_leaf)) => {
if key == key_in_leaf {
Some(*value_in_leaf)
} else {
Some(EMPTY_WORD)
}
}
SmtLeaf::Multiple(kv_pairs) => {
for (key_in_leaf, value_in_leaf) in kv_pairs {
if key == key_in_leaf {
return Some(*value_in_leaf);
}
}
Some(EMPTY_WORD)
}
}
}
/// Inserts key-value pair into the leaf; returns the previous value associated with `key`, if
/// any.
///
/// The caller needs to ensure that `key` has the same leaf index as all other keys in the leaf
pub(super) fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
match self {
SmtLeaf::Empty(_) => {
*self = SmtLeaf::new_single(key, value);
None
}
SmtLeaf::Single(kv_pair) => {
if kv_pair.0 == key {
// the key is already in this leaf. Update the value and return the previous
// value
let old_value = kv_pair.1;
kv_pair.1 = value;
Some(old_value)
} else {
// Another entry is present in this leaf. Transform the entry into a list
// entry, and make sure the key-value pairs are sorted by key
let mut pairs = vec![*kv_pair, (key, value)];
pairs.sort_by(|(key_1, _), (key_2, _)| cmp_keys(*key_1, *key_2));
*self = SmtLeaf::Multiple(pairs);
None
}
}
SmtLeaf::Multiple(kv_pairs) => {
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
Ok(pos) => {
let old_value = kv_pairs[pos].1;
kv_pairs[pos].1 = value;
Some(old_value)
}
Err(pos) => {
kv_pairs.insert(pos, (key, value));
None
}
}
}
}
}
/// Removes key-value pair from the leaf stored at key; returns the previous value associated
/// with `key`, if any. Also returns an `is_empty` flag, indicating whether the leaf became
/// empty, and must be removed from the data structure it is contained in.
pub(super) fn remove(&mut self, key: RpoDigest) -> (Option<Word>, bool) {
match self {
SmtLeaf::Empty(_) => (None, false),
SmtLeaf::Single((key_at_leaf, value_at_leaf)) => {
if *key_at_leaf == key {
// our key was indeed stored in the leaf, so we return the value that was stored
// in it, and indicate that the leaf should be removed
let old_value = *value_at_leaf;
// Note: this is not strictly needed, since the caller is expected to drop this
// `SmtLeaf` object.
*self = SmtLeaf::new_empty(key.into());
(Some(old_value), true)
} else {
// another key is stored at leaf; nothing to update
(None, false)
}
}
SmtLeaf::Multiple(kv_pairs) => {
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
Ok(pos) => {
let old_value = kv_pairs[pos].1;
kv_pairs.remove(pos);
debug_assert!(!kv_pairs.is_empty());
if kv_pairs.len() == 1 {
// convert the leaf into `Single`
*self = SmtLeaf::Single(kv_pairs[0]);
}
(Some(old_value), false)
}
Err(_) => {
// other keys are stored at leaf; nothing to update
(None, false)
}
}
}
}
}
}
impl Serializable for SmtLeaf {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
// Write: num entries
self.num_entries().write_into(target);
// Write: leaf index
let leaf_index: u64 = self.index().value();
leaf_index.write_into(target);
// Write: entries
for (key, value) in self.entries() {
key.write_into(target);
value.write_into(target);
}
}
}
impl Deserializable for SmtLeaf {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
// Read: num entries
let num_entries = source.read_u64()?;
// Read: leaf index
let leaf_index: LeafIndex<SMT_DEPTH> = {
let value = source.read_u64()?;
LeafIndex::new_max_depth(value)
};
// Read: entries
let mut entries: Vec<(RpoDigest, Word)> = Vec::new();
for _ in 0..num_entries {
let key: RpoDigest = source.read()?;
let value: Word = source.read()?;
entries.push((key, value));
}
Self::new(entries, leaf_index)
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))
}
}
// HELPER FUNCTIONS
// ================================================================================================
/// Converts a key-value tuple to an iterator of `Felt`s
fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt> {
let key_elements = key.into_iter();
let value_elements = value.into_iter();
key_elements.chain(value_elements)
}
/// Compares two keys, compared element-by-element using their integer representations starting with
/// the most significant element.
fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering {
for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() {
let v1 = v1.as_int();
let v2 = v2.as_int();
if v1 != v2 {
return v1.cmp(&v2);
}
}
Ordering::Equal
}

View File

@@ -1,299 +0,0 @@
use crate::hash::rpo::Rpo256;
use crate::merkle::{EmptySubtreeRoots, InnerNodeInfo};
use crate::utils::collections::{BTreeMap, BTreeSet};
use crate::{Felt, EMPTY_WORD};
use super::{
InnerNode, LeafIndex, MerkleError, MerklePath, NodeIndex, RpoDigest, SparseMerkleTree, Word,
};
mod error;
pub use error::{SmtLeafError, SmtProofError};
mod leaf;
pub use leaf::SmtLeaf;
mod proof;
pub use proof::SmtProof;
#[cfg(test)]
mod tests;
// CONSTANTS
// ================================================================================================
pub const SMT_DEPTH: u8 = 64;
// SMT
// ================================================================================================
/// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented
/// by 4 field elements.
///
/// All leaves sit at depth 64. The most significant element of the key is used to identify the leaf to
/// which the key maps.
///
/// A leaf is either empty, or holds one or more key-value pairs. An empty leaf hashes to the empty
/// word. Otherwise, a leaf hashes to the hash of its key-value pairs, ordered by key first, value
/// second.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Smt {
root: RpoDigest,
leaves: BTreeMap<u64, SmtLeaf>,
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
}
impl Smt {
// CONSTANTS
// --------------------------------------------------------------------------------------------
/// The default value used to compute the hash of empty leaves
pub const EMPTY_VALUE: Word = <Self as SparseMerkleTree<SMT_DEPTH>>::EMPTY_VALUE;
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Returns a new [Smt].
///
/// All leaves in the returned tree are set to [Self::EMPTY_VALUE].
pub fn new() -> Self {
let root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
Self {
root,
leaves: BTreeMap::new(),
inner_nodes: BTreeMap::new(),
}
}
/// Returns a new [Smt] instantiated with leaves set as specified by the provided entries.
///
/// All leaves omitted from the entries list are set to [Self::EMPTY_VALUE].
///
/// # Errors
/// Returns an error if the provided entries contain multiple values for the same key.
pub fn with_entries(
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> Result<Self, MerkleError> {
// create an empty tree
let mut tree = Self::new();
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
// entries with the empty value need additional tracking.
let mut key_set_to_zero = BTreeSet::new();
for (key, value) in entries {
let old_value = tree.insert(key, value);
if old_value != EMPTY_WORD || key_set_to_zero.contains(&key) {
return Err(MerkleError::DuplicateValuesForIndex(
LeafIndex::<SMT_DEPTH>::from(key).value(),
));
}
if value == EMPTY_WORD {
key_set_to_zero.insert(key);
};
}
Ok(tree)
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the depth of the tree
pub const fn depth(&self) -> u8 {
SMT_DEPTH
}
/// Returns the root of the tree
pub fn root(&self) -> RpoDigest {
<Self as SparseMerkleTree<SMT_DEPTH>>::root(self)
}
/// Returns the leaf to which `key` maps
pub fn get_leaf(&self, key: &RpoDigest) -> SmtLeaf {
<Self as SparseMerkleTree<SMT_DEPTH>>::get_leaf(self, key)
}
/// Returns the value associated with `key`
pub fn get_value(&self, key: &RpoDigest) -> Word {
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
match self.leaves.get(&leaf_pos) {
Some(leaf) => leaf.get_value(key).unwrap_or_default(),
None => EMPTY_WORD,
}
}
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
/// path to the leaf, as well as the leaf itself.
pub fn open(&self, key: &RpoDigest) -> SmtProof {
<Self as SparseMerkleTree<SMT_DEPTH>>::open(self, key)
}
// ITERATORS
// --------------------------------------------------------------------------------------------
/// Returns an iterator over the leaves of this [Smt].
pub fn leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
self.leaves
.iter()
.map(|(leaf_index, leaf)| (LeafIndex::new_max_depth(*leaf_index), leaf))
}
/// Returns an iterator over the key-value pairs of this [Smt].
pub fn entries(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
self.leaves().flat_map(|(_, leaf)| leaf.entries())
}
/// Returns an iterator over the inner nodes of this [Smt].
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
self.inner_nodes.values().map(|e| InnerNodeInfo {
value: e.hash(),
left: e.left,
right: e.right,
})
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Inserts a value at the specified key, returning the previous value associated with that key.
/// Recall that by definition, any key that hasn't been updated is associated with
/// [`Self::EMPTY_VALUE`].
///
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
/// updating the root itself.
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Word {
<Self as SparseMerkleTree<SMT_DEPTH>>::insert(self, key, value)
}
// HELPERS
// --------------------------------------------------------------------------------------------
/// Inserts `value` at leaf index pointed to by `key`. `value` is guaranteed to not be the empty
/// value, such that this is indeed an insertion.
fn perform_insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
debug_assert_ne!(value, Self::EMPTY_VALUE);
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
match self.leaves.get_mut(&leaf_index.value()) {
Some(leaf) => leaf.insert(key, value),
None => {
self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
None
}
}
}
/// Removes key-value pair at leaf index pointed to by `key` if it exists.
fn perform_remove(&mut self, key: RpoDigest) -> Option<Word> {
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
if let Some(leaf) = self.leaves.get_mut(&leaf_index.value()) {
let (old_value, is_empty) = leaf.remove(key);
if is_empty {
self.leaves.remove(&leaf_index.value());
}
old_value
} else {
// there's nothing stored at the leaf; nothing to update
None
}
}
}
impl SparseMerkleTree<SMT_DEPTH> for Smt {
type Key = RpoDigest;
type Value = Word;
type Leaf = SmtLeaf;
type Opening = SmtProof;
const EMPTY_VALUE: Self::Value = EMPTY_WORD;
fn root(&self) -> RpoDigest {
self.root
}
fn set_root(&mut self, root: RpoDigest) {
self.root = root;
}
fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
self.inner_nodes.get(&index).cloned().unwrap_or_else(|| {
let node = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth() + 1);
InnerNode { left: *node, right: *node }
})
}
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
self.inner_nodes.insert(index, inner_node);
}
fn remove_inner_node(&mut self, index: NodeIndex) {
let _ = self.inner_nodes.remove(&index);
}
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> {
// inserting an `EMPTY_VALUE` is equivalent to removing any value associated with `key`
if value != Self::EMPTY_VALUE {
self.perform_insert(key, value)
} else {
self.perform_remove(key)
}
}
fn get_leaf(&self, key: &RpoDigest) -> Self::Leaf {
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
match self.leaves.get(&leaf_pos) {
Some(leaf) => leaf.clone(),
None => SmtLeaf::new_empty(key.into()),
}
}
fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest {
leaf.hash()
}
fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex<SMT_DEPTH> {
let most_significant_felt = key[3];
LeafIndex::new_max_depth(most_significant_felt.as_int())
}
fn path_and_leaf_to_opening(path: MerklePath, leaf: SmtLeaf) -> SmtProof {
SmtProof::new_unchecked(path, leaf)
}
}
impl Default for Smt {
fn default() -> Self {
Self::new()
}
}
// CONVERSIONS
// ================================================================================================
impl From<Word> for LeafIndex<SMT_DEPTH> {
fn from(value: Word) -> Self {
// We use the most significant `Felt` of a `Word` as the leaf index.
Self::new_max_depth(value[3].as_int())
}
}
impl From<RpoDigest> for LeafIndex<SMT_DEPTH> {
fn from(value: RpoDigest) -> Self {
Word::from(value).into()
}
}
impl From<&RpoDigest> for LeafIndex<SMT_DEPTH> {
fn from(value: &RpoDigest) -> Self {
Word::from(value).into()
}
}

View File

@@ -1,114 +0,0 @@
use crate::utils::string::ToString;
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use super::{MerklePath, RpoDigest, SmtLeaf, SmtProofError, Word, SMT_DEPTH};
/// A proof which can be used to assert membership (or non-membership) of key-value pairs in a
/// [`super::Smt`].
///
/// The proof consists of a Merkle path and leaf which describes the node located at the base of the
/// path.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SmtProof {
path: MerklePath,
leaf: SmtLeaf,
}
impl SmtProof {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
/// Returns a new instance of [`SmtProof`] instantiated from the specified path and leaf.
///
/// # Errors
/// Returns an error if the path length is not [`SMT_DEPTH`].
pub fn new(path: MerklePath, leaf: SmtLeaf) -> Result<Self, SmtProofError> {
if path.len() != SMT_DEPTH.into() {
return Err(SmtProofError::InvalidPathLength(path.len()));
}
Ok(Self { path, leaf })
}
/// Returns a new instance of [`SmtProof`] instantiated from the specified path and leaf.
///
/// The length of the path is not checked. Reserved for internal use.
pub(super) fn new_unchecked(path: MerklePath, leaf: SmtLeaf) -> Self {
Self { path, leaf }
}
// PROOF VERIFIER
// --------------------------------------------------------------------------------------------
/// Returns true if a [`super::Smt`] with the specified root contains the provided
/// key-value pair.
///
/// Note: this method cannot be used to assert non-membership. That is, if false is returned,
/// it does not mean that the provided key-value pair is not in the tree.
pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool {
let maybe_value_in_leaf = self.leaf.get_value(key);
match maybe_value_in_leaf {
Some(value_in_leaf) => {
// The value must match for the proof to be valid
if value_in_leaf != *value {
return false;
}
// make sure the Merkle path resolves to the correct root
self.compute_root() == *root
}
// If the key maps to a different leaf, the proof cannot verify membership of `value`
None => false,
}
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the value associated with the specific key according to this proof, or None if
/// this proof does not contain a value for the specified key.
///
/// A key-value pair generated by using this method should pass the `verify_membership()` check.
pub fn get(&self, key: &RpoDigest) -> Option<Word> {
self.leaf.get_value(key)
}
/// Computes the root of a [`super::Smt`] to which this proof resolves.
pub fn compute_root(&self) -> RpoDigest {
self.path
.compute_root(self.leaf.index().value(), self.leaf.hash())
.expect("failed to compute Merkle path root")
}
/// Returns the proof's Merkle path.
pub fn path(&self) -> &MerklePath {
&self.path
}
/// Returns the leaf associated with the proof.
pub fn leaf(&self) -> &SmtLeaf {
&self.leaf
}
/// Consume the proof and returns its parts.
pub fn into_parts(self) -> (MerklePath, SmtLeaf) {
(self.path, self.leaf)
}
}
impl Serializable for SmtProof {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.path.write_into(target);
self.leaf.write_into(target);
}
}
impl Deserializable for SmtProof {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let path = MerklePath::read_from(source)?;
let leaf = SmtLeaf::read_from(source)?;
Self::new(path, leaf).map_err(|err| DeserializationError::InvalidValue(err.to_string()))
}
}

View File

@@ -1,408 +0,0 @@
use winter_utils::{Deserializable, Serializable};
use super::*;
use crate::{
merkle::{EmptySubtreeRoots, MerkleStore},
utils::collections::Vec,
ONE, WORD_SIZE,
};
// SMT
// --------------------------------------------------------------------------------------------
/// This test checks that inserting twice at the same key functions as expected. The test covers
/// only the case where the key is alone in its leaf
#[test]
fn test_smt_insert_at_same_key() {
let mut smt = Smt::default();
let mut store: MerkleStore = MerkleStore::default();
assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
let key_1: RpoDigest = {
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
};
let key_1_index: NodeIndex = LeafIndex::<SMT_DEPTH>::from(key_1).into();
let value_1 = [ONE; WORD_SIZE];
let value_2 = [ONE + ONE; WORD_SIZE];
// Insert value 1 and ensure root is as expected
{
let leaf_node = build_empty_or_single_leaf_node(key_1, value_1);
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
let old_value_1 = smt.insert(key_1, value_1);
assert_eq!(old_value_1, EMPTY_WORD);
assert_eq!(smt.root(), tree_root);
}
// Insert value 2 and ensure root is as expected
{
let leaf_node = build_empty_or_single_leaf_node(key_1, value_2);
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
let old_value_2 = smt.insert(key_1, value_2);
assert_eq!(old_value_2, value_1);
assert_eq!(smt.root(), tree_root);
}
}
/// This test checks that inserting twice at the same key functions as expected. The test covers
/// only the case where the leaf type is `SmtLeaf::Multiple`
#[test]
fn test_smt_insert_at_same_key_2() {
// The most significant u64 used for both keys (to ensure they map to the same leaf)
let key_msb: u64 = 42;
let key_already_present: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(key_msb)]);
let key_already_present_index: NodeIndex =
LeafIndex::<SMT_DEPTH>::from(key_already_present).into();
let value_already_present = [ONE + ONE + ONE; WORD_SIZE];
let mut smt =
Smt::with_entries(core::iter::once((key_already_present, value_already_present))).unwrap();
let mut store: MerkleStore = {
let mut store = MerkleStore::default();
let leaf_node = build_empty_or_single_leaf_node(key_already_present, value_already_present);
store
.set_node(*EmptySubtreeRoots::entry(SMT_DEPTH, 0), key_already_present_index, leaf_node)
.unwrap();
store
};
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(key_msb)]);
let key_1_index: NodeIndex = LeafIndex::<SMT_DEPTH>::from(key_1).into();
assert_eq!(key_1_index, key_already_present_index);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [ONE + ONE; WORD_SIZE];
// Insert value 1 and ensure root is as expected
{
// Note: key_1 comes first because it is smaller
let leaf_node = build_multiple_leaf_node(&[
(key_1, value_1),
(key_already_present, value_already_present),
]);
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
let old_value_1 = smt.insert(key_1, value_1);
assert_eq!(old_value_1, EMPTY_WORD);
assert_eq!(smt.root(), tree_root);
}
// Insert value 2 and ensure root is as expected
{
let leaf_node = build_multiple_leaf_node(&[
(key_1, value_2),
(key_already_present, value_already_present),
]);
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
let old_value_2 = smt.insert(key_1, value_2);
assert_eq!(old_value_2, value_1);
assert_eq!(smt.root(), tree_root);
}
}
/// This test ensures that the root of the tree is as expected when we add/remove 3 items at 3
/// different keys. This also tests that the merkle paths produced are as expected.
#[test]
fn test_smt_insert_and_remove_multiple_values() {
fn insert_values_and_assert_path(
smt: &mut Smt,
store: &mut MerkleStore,
key_values: &[(RpoDigest, Word)],
) {
for &(key, value) in key_values {
let key_index: NodeIndex = LeafIndex::<SMT_DEPTH>::from(key).into();
let leaf_node = build_empty_or_single_leaf_node(key, value);
let tree_root = store.set_node(smt.root(), key_index, leaf_node).unwrap().root;
let _ = smt.insert(key, value);
assert_eq!(smt.root(), tree_root);
let expected_path = store.get_path(tree_root, key_index).unwrap();
assert_eq!(smt.open(&key).into_parts().0, expected_path.path);
}
}
let mut smt = Smt::default();
let mut store: MerkleStore = MerkleStore::default();
assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
let key_1: RpoDigest = {
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
};
let key_2: RpoDigest = {
let raw = 0b_11111111_11111111_11111111_11111111_11111111_11111111_11111111_11111111_u64;
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
};
let key_3: RpoDigest = {
let raw = 0b_00000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000_u64;
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
};
let value_1 = [ONE; WORD_SIZE];
let value_2 = [ONE + ONE; WORD_SIZE];
let value_3 = [ONE + ONE + ONE; WORD_SIZE];
// Insert values in the tree
let key_values = [(key_1, value_1), (key_2, value_2), (key_3, value_3)];
insert_values_and_assert_path(&mut smt, &mut store, &key_values);
// Remove values from the tree
let key_empty_values = [(key_1, EMPTY_WORD), (key_2, EMPTY_WORD), (key_3, EMPTY_WORD)];
insert_values_and_assert_path(&mut smt, &mut store, &key_empty_values);
let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
assert_eq!(smt.root(), empty_root);
// an empty tree should have no leaves or inner nodes
assert!(smt.leaves.is_empty());
assert!(smt.inner_nodes.is_empty());
}
/// This tests that inserting the empty value does indeed remove the key-value contained at the
/// leaf. We insert & remove 3 values at the same leaf to ensure that all cases are covered (empty,
/// single, multiple).
#[test]
fn test_smt_removal() {
let mut smt = Smt::default();
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
let key_3: RpoDigest =
RpoDigest::from([3_u32.into(), 3_u32.into(), 3_u32.into(), Felt::new(raw)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
// insert key-value 1
{
let old_value_1 = smt.insert(key_1, value_1);
assert_eq!(old_value_1, EMPTY_WORD);
assert_eq!(smt.get_leaf(&key_1), SmtLeaf::Single((key_1, value_1)));
}
// insert key-value 2
{
let old_value_2 = smt.insert(key_2, value_2);
assert_eq!(old_value_2, EMPTY_WORD);
assert_eq!(
smt.get_leaf(&key_2),
SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)])
);
}
// insert key-value 3
{
let old_value_3 = smt.insert(key_3, value_3);
assert_eq!(old_value_3, EMPTY_WORD);
assert_eq!(
smt.get_leaf(&key_3),
SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2), (key_3, value_3)])
);
}
// remove key 3
{
let old_value_3 = smt.insert(key_3, EMPTY_WORD);
assert_eq!(old_value_3, value_3);
assert_eq!(
smt.get_leaf(&key_3),
SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)])
);
}
// remove key 2
{
let old_value_2 = smt.insert(key_2, EMPTY_WORD);
assert_eq!(old_value_2, value_2);
assert_eq!(smt.get_leaf(&key_2), SmtLeaf::Single((key_1, value_1)));
}
// remove key 1
{
let old_value_1 = smt.insert(key_1, EMPTY_WORD);
assert_eq!(old_value_1, value_1);
assert_eq!(smt.get_leaf(&key_1), SmtLeaf::new_empty(key_1.into()));
}
}
/// Tests that 2 key-value pairs stored in the same leaf have the same path
#[test]
fn test_smt_path_to_keys_in_same_leaf_are_equal() {
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
assert_eq!(smt.open(&key_1), smt.open(&key_2));
}
/// Tests that an empty leaf hashes to the empty word
#[test]
fn test_empty_leaf_hash() {
let smt = Smt::default();
let leaf = smt.get_leaf(&RpoDigest::default());
assert_eq!(leaf.hash(), EMPTY_WORD.into());
}
/// Tests that `get_value()` works as expected
#[test]
fn test_smt_get_value() {
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), 2_u32.into()]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
let returned_value_1 = smt.get_value(&key_1);
let returned_value_2 = smt.get_value(&key_2);
assert_eq!(value_1, returned_value_1);
assert_eq!(value_2, returned_value_2);
// Check that a key with no inserted value returns the empty word
let key_no_value =
RpoDigest::from([42_u32.into(), 42_u32.into(), 42_u32.into(), 42_u32.into()]);
assert_eq!(EMPTY_WORD, smt.get_value(&key_no_value));
}
/// Tests that `entries()` works as expected
#[test]
fn test_smt_entries() {
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), 2_u32.into()]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
let mut entries = smt.entries();
// Note: for simplicity, we assume the order `(k1,v1), (k2,v2)`. If a new implementation
// switches the order, it is OK to modify the order here as well.
assert_eq!(&(key_1, value_1), entries.next().unwrap());
assert_eq!(&(key_2, value_2), entries.next().unwrap());
assert!(entries.next().is_none());
}
// SMT LEAF
// --------------------------------------------------------------------------------------------
#[test]
fn test_empty_smt_leaf_serialization() {
let empty_leaf = SmtLeaf::new_empty(LeafIndex::new_max_depth(42));
let mut serialized = empty_leaf.to_bytes();
// extend buffer with random bytes
serialized.extend([1, 2, 3, 4, 5]);
let deserialized = SmtLeaf::read_from_bytes(&serialized).unwrap();
assert_eq!(empty_leaf, deserialized);
}
#[test]
fn test_single_smt_leaf_serialization() {
let single_leaf = SmtLeaf::new_single(
RpoDigest::from([10_u32.into(), 11_u32.into(), 12_u32.into(), 13_u32.into()]),
[1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()],
);
let mut serialized = single_leaf.to_bytes();
// extend buffer with random bytes
serialized.extend([1, 2, 3, 4, 5]);
let deserialized = SmtLeaf::read_from_bytes(&serialized).unwrap();
assert_eq!(single_leaf, deserialized);
}
#[test]
fn test_multiple_smt_leaf_serialization_success() {
let multiple_leaf = SmtLeaf::new_multiple(vec![
(
RpoDigest::from([10_u32.into(), 11_u32.into(), 12_u32.into(), 13_u32.into()]),
[1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()],
),
(
RpoDigest::from([100_u32.into(), 101_u32.into(), 102_u32.into(), 13_u32.into()]),
[11_u32.into(), 12_u32.into(), 13_u32.into(), 14_u32.into()],
),
])
.unwrap();
let mut serialized = multiple_leaf.to_bytes();
// extend buffer with random bytes
serialized.extend([1, 2, 3, 4, 5]);
let deserialized = SmtLeaf::read_from_bytes(&serialized).unwrap();
assert_eq!(multiple_leaf, deserialized);
}
// HELPERS
// --------------------------------------------------------------------------------------------
fn build_empty_or_single_leaf_node(key: RpoDigest, value: Word) -> RpoDigest {
if value == EMPTY_WORD {
SmtLeaf::new_empty(key.into()).hash()
} else {
SmtLeaf::Single((key, value)).hash()
}
}
fn build_multiple_leaf_node(kv_pairs: &[(RpoDigest, Word)]) -> RpoDigest {
let elements: Vec<Felt> = kv_pairs
.iter()
.flat_map(|(key, value)| {
let key_elements = key.into_iter();
let value_elements = (*value).into_iter();
key_elements.chain(value_elements)
})
.collect();
Rpo256::hash_elements(&elements)
}

View File

@@ -1,245 +0,0 @@
use crate::{
hash::rpo::{Rpo256, RpoDigest},
Word,
};
use super::{EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex, Vec};
mod full;
pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
mod simple;
pub use simple::SimpleSmt;
// CONSTANTS
// ================================================================================================
/// Minimum supported depth.
pub const SMT_MIN_DEPTH: u8 = 1;
/// Maximum supported depth.
pub const SMT_MAX_DEPTH: u8 = 64;
// SPARSE MERKLE TREE
// ================================================================================================
/// An abstract description of a sparse Merkle tree.
///
/// A sparse Merkle tree is a key-value map which also supports proving that a given value is indeed
/// stored at a given key in the tree. It is viewed as always being fully populated. If a leaf's
/// value was not explicitly set, then its value is the default value. Typically, the vast majority
/// of leaves will store the default value (hence it is "sparse"), and therefore the internal
/// representation of the tree will only keep track of the leaves that have a different value from
/// the default.
///
/// All leaves sit at the same depth. The deeper the tree, the more leaves it has; but also the
/// longer its proofs are - of exactly `log(depth)` size. A tree cannot have depth 0, since such a
/// tree is just a single value, and is probably a programming mistake.
///
/// Every key maps to one leaf. If there are as many keys as there are leaves, then
/// [Self::Leaf] should be the same type as [Self::Value], as is the case with
/// [crate::merkle::SimpleSmt]. However, if there are more keys than leaves, then [`Self::Leaf`]
/// must accomodate all keys that map to the same leaf.
///
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// The type for a key
type Key: Clone;
/// The type for a value
type Value: Clone + PartialEq;
/// The type for a leaf
type Leaf;
/// The type for an opening (i.e. a "proof") of a leaf
type Opening;
/// The default value used to compute the hash of empty leaves
const EMPTY_VALUE: Self::Value;
// PROVIDED METHODS
// ---------------------------------------------------------------------------------------------
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
/// path to the leaf, as well as the leaf itself.
fn open(&self, key: &Self::Key) -> Self::Opening {
let leaf = self.get_leaf(key);
let mut index: NodeIndex = {
let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(key);
leaf_index.into()
};
let merkle_path = {
let mut path = Vec::with_capacity(index.depth() as usize);
for _ in 0..index.depth() {
let is_right = index.is_value_odd();
index.move_up();
let InnerNode { left, right } = self.get_inner_node(index);
let value = if is_right { left } else { right };
path.push(value);
}
MerklePath::new(path)
};
Self::path_and_leaf_to_opening(merkle_path, leaf)
}
/// Inserts a value at the specified key, returning the previous value associated with that key.
/// Recall that by definition, any key that hasn't been updated is associated with
/// [`Self::EMPTY_VALUE`].
///
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
/// updating the root itself.
fn insert(&mut self, key: Self::Key, value: Self::Value) -> Self::Value {
let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE);
// if the old value and new value are the same, there is nothing to update
if value == old_value {
return value;
}
let leaf = self.get_leaf(&key);
let node_index = {
let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(&key);
leaf_index.into()
};
self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf));
old_value
}
/// Recomputes the branch nodes (including the root) from `index` all the way to the root.
/// `node_hash_at_index` is the hash of the node stored at index.
fn recompute_nodes_from_index_to_root(
&mut self,
mut index: NodeIndex,
node_hash_at_index: RpoDigest,
) {
let mut node_hash = node_hash_at_index;
for node_depth in (0..index.depth()).rev() {
let is_right = index.is_value_odd();
index.move_up();
let InnerNode { left, right } = self.get_inner_node(index);
let (left, right) = if is_right {
(left, node_hash)
} else {
(node_hash, right)
};
node_hash = Rpo256::merge(&[left, right]);
if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
// If a subtree is empty, when can remove the inner node, since it's equal to the
// default value
self.remove_inner_node(index)
} else {
self.insert_inner_node(index, InnerNode { left, right });
}
}
self.set_root(node_hash);
}
// REQUIRED METHODS
// ---------------------------------------------------------------------------------------------
/// The root of the tree
fn root(&self) -> RpoDigest;
/// Sets the root of the tree
fn set_root(&mut self, root: RpoDigest);
/// Retrieves an inner node at the given index
fn get_inner_node(&self, index: NodeIndex) -> InnerNode;
/// Inserts an inner node at the given index
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode);
/// Removes an inner node at the given index
fn remove_inner_node(&mut self, index: NodeIndex);
/// Inserts a leaf node, and returns the value at the key if already exists
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>;
/// Returns the leaf at the specified index.
fn get_leaf(&self, key: &Self::Key) -> Self::Leaf;
/// Returns the hash of a leaf
fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest;
/// Maps a key to a leaf index
fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
/// Maps a (MerklePath, Self::Leaf) to an opening.
///
/// The length `path` is guaranteed to be equal to `DEPTH`
fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening;
}
// INNER NODE
// ================================================================================================
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub(crate) struct InnerNode {
pub left: RpoDigest,
pub right: RpoDigest,
}
impl InnerNode {
pub fn hash(&self) -> RpoDigest {
Rpo256::merge(&[self.left, self.right])
}
}
// LEAF INDEX
// ================================================================================================
/// The index of a leaf, at a depth known at compile-time.
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct LeafIndex<const DEPTH: u8> {
index: NodeIndex,
}
impl<const DEPTH: u8> LeafIndex<DEPTH> {
pub fn new(value: u64) -> Result<Self, MerkleError> {
if DEPTH < SMT_MIN_DEPTH {
return Err(MerkleError::DepthTooSmall(DEPTH));
}
Ok(LeafIndex { index: NodeIndex::new(DEPTH, value)? })
}
pub fn value(&self) -> u64 {
self.index.value()
}
}
impl LeafIndex<SMT_MAX_DEPTH> {
pub const fn new_max_depth(value: u64) -> Self {
LeafIndex {
index: NodeIndex::new_unchecked(SMT_MAX_DEPTH, value),
}
}
}
impl<const DEPTH: u8> From<LeafIndex<DEPTH>> for NodeIndex {
fn from(value: LeafIndex<DEPTH>) -> Self {
value.index
}
}
impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
type Error = MerkleError;
fn try_from(node_index: NodeIndex) -> Result<Self, Self::Error> {
if node_index.depth() != DEPTH {
return Err(MerkleError::InvalidDepth {
expected: DEPTH,
provided: node_index.depth(),
});
}
Self::new(node_index.value())
}
}

View File

@@ -1,309 +0,0 @@
use crate::{
merkle::{EmptySubtreeRoots, InnerNodeInfo, MerklePath, ValuePath},
EMPTY_WORD,
};
use super::{
InnerNode, LeafIndex, MerkleError, NodeIndex, RpoDigest, SparseMerkleTree, Word, SMT_MAX_DEPTH,
SMT_MIN_DEPTH,
};
use crate::utils::collections::{BTreeMap, BTreeSet};
#[cfg(test)]
mod tests;
// SPARSE MERKLE TREE
// ================================================================================================
/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
///
/// The root of the tree is recomputed on each new leaf update.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct SimpleSmt<const DEPTH: u8> {
root: RpoDigest,
leaves: BTreeMap<u64, Word>,
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
}
impl<const DEPTH: u8> SimpleSmt<DEPTH> {
// CONSTANTS
// --------------------------------------------------------------------------------------------
/// The default value used to compute the hash of empty leaves
pub const EMPTY_VALUE: Word = <Self as SparseMerkleTree<DEPTH>>::EMPTY_VALUE;
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Returns a new [SimpleSmt].
///
/// All leaves in the returned tree are set to [ZERO; 4].
///
/// # Errors
/// Returns an error if DEPTH is 0 or is greater than 64.
pub fn new() -> Result<Self, MerkleError> {
// validate the range of the depth.
if DEPTH < SMT_MIN_DEPTH {
return Err(MerkleError::DepthTooSmall(DEPTH));
} else if SMT_MAX_DEPTH < DEPTH {
return Err(MerkleError::DepthTooBig(DEPTH as u64));
}
let root = *EmptySubtreeRoots::entry(DEPTH, 0);
Ok(Self {
root,
leaves: BTreeMap::new(),
inner_nodes: BTreeMap::new(),
})
}
/// Returns a new [SimpleSmt] instantiated with leaves set as specified by the provided entries.
///
/// All leaves omitted from the entries list are set to [ZERO; 4].
///
/// # Errors
/// Returns an error if:
/// - If the depth is 0 or is greater than 64.
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
/// - The provided entries contain multiple values for the same key.
pub fn with_leaves(
entries: impl IntoIterator<Item = (u64, Word)>,
) -> Result<Self, MerkleError> {
// create an empty tree
let mut tree = Self::new()?;
// compute the max number of entries. We use an upper bound of depth 63 because we consider
// passing in a vector of size 2^64 infeasible.
let max_num_entries = 2_usize.pow(DEPTH.min(63).into());
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
// entries with the empty value need additional tracking.
let mut key_set_to_zero = BTreeSet::new();
for (idx, (key, value)) in entries.into_iter().enumerate() {
if idx >= max_num_entries {
return Err(MerkleError::InvalidNumEntries(max_num_entries));
}
let old_value = tree.insert(LeafIndex::<DEPTH>::new(key)?, value);
if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) {
return Err(MerkleError::DuplicateValuesForIndex(key));
}
if value == Self::EMPTY_VALUE {
key_set_to_zero.insert(key);
};
}
Ok(tree)
}
/// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
/// starting at index 0.
pub fn with_contiguous_leaves(
entries: impl IntoIterator<Item = Word>,
) -> Result<Self, MerkleError> {
Self::with_leaves(
entries
.into_iter()
.enumerate()
.map(|(idx, word)| (idx.try_into().expect("tree max depth is 2^8"), word)),
)
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the depth of the tree
pub const fn depth(&self) -> u8 {
DEPTH
}
/// Returns the root of the tree
pub fn root(&self) -> RpoDigest {
<Self as SparseMerkleTree<DEPTH>>::root(self)
}
/// Returns the leaf at the specified index.
pub fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
<Self as SparseMerkleTree<DEPTH>>::get_leaf(self, key)
}
/// Returns a node at the specified index.
///
/// # Errors
/// Returns an error if the specified index has depth set to 0 or the depth is greater than
/// the depth of this Merkle tree.
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
if index.is_root() {
Err(MerkleError::DepthTooSmall(index.depth()))
} else if index.depth() > DEPTH {
Err(MerkleError::DepthTooBig(index.depth() as u64))
} else if index.depth() == DEPTH {
let leaf = self.get_leaf(&LeafIndex::<DEPTH>::try_from(index)?);
Ok(leaf.into())
} else {
Ok(self.get_inner_node(index).hash())
}
}
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
/// path to the leaf, as well as the leaf itself.
pub fn open(&self, key: &LeafIndex<DEPTH>) -> ValuePath {
<Self as SparseMerkleTree<DEPTH>>::open(self, key)
}
// ITERATORS
// --------------------------------------------------------------------------------------------
/// Returns an iterator over the leaves of this [SimpleSmt].
pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
self.leaves.iter().map(|(i, w)| (*i, w))
}
/// Returns an iterator over the inner nodes of this [SimpleSmt].
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
self.inner_nodes.values().map(|e| InnerNodeInfo {
value: e.hash(),
left: e.left,
right: e.right,
})
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Inserts a value at the specified key, returning the previous value associated with that key.
/// Recall that by definition, any key that hasn't been updated is associated with
/// [`EMPTY_WORD`].
///
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
/// updating the root itself.
pub fn insert(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Word {
<Self as SparseMerkleTree<DEPTH>>::insert(self, key, value)
}
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
/// computed as `DEPTH - SUBTREE_DEPTH`.
///
/// Returns the new root.
pub fn set_subtree<const SUBTREE_DEPTH: u8>(
&mut self,
subtree_insertion_index: u64,
subtree: SimpleSmt<SUBTREE_DEPTH>,
) -> Result<RpoDigest, MerkleError> {
if SUBTREE_DEPTH > DEPTH {
return Err(MerkleError::InvalidSubtreeDepth {
subtree_depth: SUBTREE_DEPTH,
tree_depth: DEPTH,
});
}
// Verify that `subtree_insertion_index` is valid.
let subtree_root_insertion_depth = DEPTH - SUBTREE_DEPTH;
let subtree_root_index =
NodeIndex::new(subtree_root_insertion_depth, subtree_insertion_index)?;
// add leaves
// --------------
// The subtree's leaf indices live in their own context - i.e. a subtree of depth `d`. If we
// insert the subtree at `subtree_insertion_index = 0`, then the subtree leaf indices are
// valid as they are. However, consider what happens when we insert at
// `subtree_insertion_index = 1`. The first leaf of our subtree now will have index `2^d`;
// you can see it as there's a full subtree sitting on its left. In general, for
// `subtree_insertion_index = i`, there are `i` subtrees sitting before the subtree we want
// to insert, so we need to adjust all its leaves by `i * 2^d`.
let leaf_index_shift: u64 = subtree_insertion_index * 2_u64.pow(SUBTREE_DEPTH.into());
for (subtree_leaf_idx, leaf_value) in subtree.leaves() {
let new_leaf_idx = leaf_index_shift + subtree_leaf_idx;
debug_assert!(new_leaf_idx < 2_u64.pow(DEPTH.into()));
self.leaves.insert(new_leaf_idx, *leaf_value);
}
// add subtree's branch nodes (which includes the root)
// --------------
for (branch_idx, branch_node) in subtree.inner_nodes {
let new_branch_idx = {
let new_depth = subtree_root_insertion_depth + branch_idx.depth();
let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
+ branch_idx.value();
NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
};
self.inner_nodes.insert(new_branch_idx, branch_node);
}
// recompute nodes starting from subtree root
// --------------
self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
Ok(self.root)
}
}
impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
type Key = LeafIndex<DEPTH>;
type Value = Word;
type Leaf = Word;
type Opening = ValuePath;
const EMPTY_VALUE: Self::Value = EMPTY_WORD;
fn root(&self) -> RpoDigest {
self.root
}
fn set_root(&mut self, root: RpoDigest) {
self.root = root;
}
fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
self.inner_nodes.get(&index).cloned().unwrap_or_else(|| {
let node = EmptySubtreeRoots::entry(DEPTH, index.depth() + 1);
InnerNode { left: *node, right: *node }
})
}
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
self.inner_nodes.insert(index, inner_node);
}
fn remove_inner_node(&mut self, index: NodeIndex) {
let _ = self.inner_nodes.remove(&index);
}
fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> {
self.leaves.insert(key.value(), value)
}
fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
// the lookup in empty_hashes could fail only if empty_hashes were not built correctly
// by the constructor as we check the depth of the lookup above.
let leaf_pos = key.value();
match self.leaves.get(&leaf_pos) {
Some(word) => *word,
None => Word::from(*EmptySubtreeRoots::entry(DEPTH, DEPTH)),
}
}
fn hash_leaf(leaf: &Word) -> RpoDigest {
// `SimpleSmt` takes the leaf value itself as the hash
leaf.into()
}
fn key_to_leaf_index(key: &LeafIndex<DEPTH>) -> LeafIndex<DEPTH> {
*key
}
fn path_and_leaf_to_opening(path: MerklePath, leaf: Word) -> ValuePath {
(path, leaf).into()
}
}

View File

@@ -1,7 +1,7 @@
use super::{
mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, KvMap, MerkleError, MerklePath,
MerkleTree, NodeIndex, PartialMerkleTree, RecordingMap, RootPath, Rpo256, RpoDigest, SimpleSmt,
Smt, ValuePath, Vec,
MerkleStoreDelta, MerkleTree, NodeIndex, PartialMerkleTree, RecordingMap, RootPath, Rpo256,
RpoDigest, SimpleSmt, TieredSmt, TryApplyDiff, ValuePath, Vec, EMPTY_WORD,
};
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use core::borrow::Borrow;
@@ -173,7 +173,7 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
// the path is computed from root to leaf, so it must be reversed
path.reverse();
Ok(ValuePath::new(hash, MerklePath::new(path)))
Ok(ValuePath::new(hash, path))
}
// LEAF TRAVERSAL
@@ -361,6 +361,9 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
// if the node is not in the store assume it is a leaf
} else {
// assert that if we have a leaf that is not at the max depth then it must be
// at the depth of one of the tiers of an TSMT.
debug_assert!(TieredSmt::TIER_DEPTHS[..3].contains(&index.depth()));
return Some((index, node_hash));
}
}
@@ -487,15 +490,8 @@ impl<T: KvMap<RpoDigest, StoreNode>> From<&MerkleTree> for MerkleStore<T> {
}
}
impl<T: KvMap<RpoDigest, StoreNode>, const DEPTH: u8> From<&SimpleSmt<DEPTH>> for MerkleStore<T> {
fn from(value: &SimpleSmt<DEPTH>) -> Self {
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
Self { nodes }
}
}
impl<T: KvMap<RpoDigest, StoreNode>> From<&Smt> for MerkleStore<T> {
fn from(value: &Smt) -> Self {
impl<T: KvMap<RpoDigest, StoreNode>> From<&SimpleSmt> for MerkleStore<T> {
fn from(value: &SimpleSmt) -> Self {
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
Self { nodes }
}
@@ -508,6 +504,13 @@ impl<T: KvMap<RpoDigest, StoreNode>> From<&Mmr> for MerkleStore<T> {
}
}
impl<T: KvMap<RpoDigest, StoreNode>> From<&TieredSmt> for MerkleStore<T> {
fn from(value: &TieredSmt) -> Self {
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
Self { nodes }
}
}
impl<T: KvMap<RpoDigest, StoreNode>> From<&PartialMerkleTree> for MerkleStore<T> {
fn from(value: &PartialMerkleTree) -> Self {
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
@@ -547,6 +550,39 @@ impl<T: KvMap<RpoDigest, StoreNode>> Extend<InnerNodeInfo> for MerkleStore<T> {
}
}
// DiffT & ApplyDiffT TRAIT IMPLEMENTATION
// ================================================================================================
impl<T: KvMap<RpoDigest, StoreNode>> TryApplyDiff<RpoDigest, StoreNode> for MerkleStore<T> {
type Error = MerkleError;
type DiffType = MerkleStoreDelta;
fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), MerkleError> {
for (root, delta) in diff.0 {
let mut root = root;
for cleared_slot in delta.cleared_slots() {
root = self
.set_node(
root,
NodeIndex::new(delta.depth(), *cleared_slot)?,
EMPTY_WORD.into(),
)?
.root;
}
for (updated_slot, updated_value) in delta.updated_slots() {
root = self
.set_node(
root,
NodeIndex::new(delta.depth(), *updated_slot)?,
(*updated_value).into(),
)?
.root;
}
}
Ok(())
}
}
// SERIALIZATION
// ================================================================================================

View File

@@ -3,9 +3,7 @@ use super::{
PartialMerkleTree, RecordingMerkleStore, Rpo256, RpoDigest,
};
use crate::{
merkle::{
digests_to_words, int_to_leaf, int_to_node, LeafIndex, MerkleTree, SimpleSmt, SMT_MAX_DEPTH,
},
merkle::{digests_to_words, int_to_leaf, int_to_node, MerkleTree, SimpleSmt},
Felt, Word, ONE, WORD_SIZE, ZERO,
};
@@ -15,8 +13,6 @@ use super::{Deserializable, Serializable};
#[cfg(feature = "std")]
use std::error::Error;
use seq_macro::seq;
// TEST DATA
// ================================================================================================
@@ -107,7 +103,7 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
"node 3 must be the same for both MerkleTree and MerkleStore"
);
// STORE MERKLE PATH MATCHES ==============================================================
// STORE MERKLE PATH MATCHS ==============================================================
// assert the merkle path returned by the store is the same as the one in the tree
let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 0)).unwrap();
assert_eq!(
@@ -177,12 +173,12 @@ fn test_leaf_paths_for_empty_trees() -> Result<(), MerkleError> {
// Starts at 1 because leafs are not included in the store.
// Ends at 64 because it is not possible to represent an index of a depth greater than 64,
// because a u64 is used to index the leaf.
seq!(DEPTH in 1_u8..64_u8 {
let smt = SimpleSmt::<DEPTH>::new()?;
for depth in 1..64 {
let smt = SimpleSmt::new(depth)?;
let index = NodeIndex::make(DEPTH, 0);
let index = NodeIndex::make(depth, 0);
let store_path = store.get_path(smt.root(), index)?;
let smt_path = smt.open(&LeafIndex::<DEPTH>::new(0)?).path;
let smt_path = smt.get_path(index)?;
assert_eq!(
store_path.value,
RpoDigest::default(),
@@ -193,12 +189,11 @@ fn test_leaf_paths_for_empty_trees() -> Result<(), MerkleError> {
"the returned merkle path does not match the computed values"
);
assert_eq!(
store_path.path.compute_root(DEPTH.into(), RpoDigest::default()).unwrap(),
store_path.path.compute_root(depth.into(), RpoDigest::default()).unwrap(),
smt.root(),
"computed root from the path must match the empty tree root"
);
});
}
Ok(())
}
@@ -215,7 +210,7 @@ fn test_get_invalid_node() {
fn test_add_sparse_merkle_tree_one_level() -> Result<(), MerkleError> {
let keys2: [u64; 2] = [0, 1];
let leaves2: [Word; 2] = [int_to_leaf(1), int_to_leaf(2)];
let smt = SimpleSmt::<1>::with_leaves(keys2.into_iter().zip(leaves2)).unwrap();
let smt = SimpleSmt::with_leaves(1, keys2.into_iter().zip(leaves2.into_iter())).unwrap();
let store = MerkleStore::from(&smt);
let idx = NodeIndex::make(1, 0);
@@ -231,36 +226,38 @@ fn test_add_sparse_merkle_tree_one_level() -> Result<(), MerkleError> {
#[test]
fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
let smt =
SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4)))
.unwrap();
let smt = SimpleSmt::with_leaves(
SimpleSmt::MAX_DEPTH,
KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()),
)
.unwrap();
let store = MerkleStore::from(&smt);
// STORE LEAVES ARE CORRECT ==============================================================
// checks the leaves in the store corresponds to the expected values
assert_eq!(
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)),
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 0)),
Ok(VALUES4[0]),
"node 0 must be in the tree"
);
assert_eq!(
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)),
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 1)),
Ok(VALUES4[1]),
"node 1 must be in the tree"
);
assert_eq!(
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)),
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 2)),
Ok(VALUES4[2]),
"node 2 must be in the tree"
);
assert_eq!(
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)),
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 3)),
Ok(VALUES4[3]),
"node 3 must be in the tree"
);
assert_eq!(
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)),
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 4)),
Ok(RpoDigest::default()),
"unmodified node 4 must be ZERO"
);
@@ -268,86 +265,86 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
// STORE LEAVES MATCH TREE ===============================================================
// sanity check the values returned by the store and the tree
assert_eq!(
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 0)),
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)),
smt.get_node(NodeIndex::make(smt.depth(), 0)),
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 0)),
"node 0 must be the same for both SparseMerkleTree and MerkleStore"
);
assert_eq!(
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 1)),
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)),
smt.get_node(NodeIndex::make(smt.depth(), 1)),
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 1)),
"node 1 must be the same for both SparseMerkleTree and MerkleStore"
);
assert_eq!(
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 2)),
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)),
smt.get_node(NodeIndex::make(smt.depth(), 2)),
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 2)),
"node 2 must be the same for both SparseMerkleTree and MerkleStore"
);
assert_eq!(
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 3)),
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)),
smt.get_node(NodeIndex::make(smt.depth(), 3)),
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 3)),
"node 3 must be the same for both SparseMerkleTree and MerkleStore"
);
assert_eq!(
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 4)),
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)),
smt.get_node(NodeIndex::make(smt.depth(), 4)),
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 4)),
"node 4 must be the same for both SparseMerkleTree and MerkleStore"
);
// STORE MERKLE PATH MATCHES ==============================================================
// STORE MERKLE PATH MATCHS ==============================================================
// assert the merkle path returned by the store is the same as the one in the tree
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap();
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 0)).unwrap();
assert_eq!(
VALUES4[0], result.value,
"Value for merkle path at index 0 must match leaf value"
);
assert_eq!(
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(0).unwrap()).path,
result.path,
smt.get_path(NodeIndex::make(smt.depth(), 0)),
Ok(result.path),
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
);
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap();
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 1)).unwrap();
assert_eq!(
VALUES4[1], result.value,
"Value for merkle path at index 1 must match leaf value"
);
assert_eq!(
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(1).unwrap()).path,
result.path,
smt.get_path(NodeIndex::make(smt.depth(), 1)),
Ok(result.path),
"merkle path for index 1 must be the same for the MerkleTree and MerkleStore"
);
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap();
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 2)).unwrap();
assert_eq!(
VALUES4[2], result.value,
"Value for merkle path at index 2 must match leaf value"
);
assert_eq!(
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(2).unwrap()).path,
result.path,
smt.get_path(NodeIndex::make(smt.depth(), 2)),
Ok(result.path),
"merkle path for index 2 must be the same for the MerkleTree and MerkleStore"
);
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap();
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 3)).unwrap();
assert_eq!(
VALUES4[3], result.value,
"Value for merkle path at index 3 must match leaf value"
);
assert_eq!(
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(3).unwrap()).path,
result.path,
smt.get_path(NodeIndex::make(smt.depth(), 3)),
Ok(result.path),
"merkle path for index 3 must be the same for the MerkleTree and MerkleStore"
);
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap();
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 4)).unwrap();
assert_eq!(
RpoDigest::default(),
result.value,
"Value for merkle path at index 4 must match leaf value"
);
assert_eq!(
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(4).unwrap()).path,
result.path,
smt.get_path(NodeIndex::make(smt.depth(), 4)),
Ok(result.path),
"merkle path for index 4 must be the same for the MerkleTree and MerkleStore"
);
@@ -428,7 +425,7 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> {
"node 3 must be the same for both PartialMerkleTree and MerkleStore"
);
// STORE MERKLE PATH MATCHES ==============================================================
// STORE MERKLE PATH MATCHS ==============================================================
// assert the merkle path returned by the store is the same as the one in the pmt
let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)).unwrap();
assert_eq!(
@@ -555,15 +552,19 @@ fn test_constructors() -> Result<(), MerkleError> {
assert_eq!(mtree.get_path(index)?, value_path.path);
}
const DEPTH: u8 = 32;
let smt =
SimpleSmt::<DEPTH>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4))).unwrap();
let depth = 32;
let smt = SimpleSmt::with_leaves(
depth,
KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()),
)
.unwrap();
let store = MerkleStore::from(&smt);
let depth = smt.depth();
for key in KEYS4 {
let index = NodeIndex::make(DEPTH, key);
let index = NodeIndex::make(depth, key);
let value_path = store.get_path(smt.root(), index)?;
assert_eq!(smt.open(&LeafIndex::<DEPTH>::new(key).unwrap()).path, value_path.path);
assert_eq!(smt.get_path(index)?, value_path.path);
}
let d = 2;
@@ -651,7 +652,7 @@ fn get_leaf_depth_works_depth_64() {
let index = NodeIndex::new(64, k).unwrap();
// assert the leaf doesn't exist before the insert. the returned depth should always
// increment with the paths count of the set, as they are intersecting one another up to
// increment with the paths count of the set, as they are insersecting one another up to
// the first bits of the used key.
assert_eq!(d, store.get_leaf_depth(root, 64, k).unwrap());
@@ -882,9 +883,8 @@ fn test_serialization() -> Result<(), Box<dyn Error>> {
fn test_recorder() {
// instantiate recorder from MerkleTree and SimpleSmt
let mtree = MerkleTree::new(digests_to_words(&VALUES4)).unwrap();
const TREE_DEPTH: u8 = 64;
let smtree = SimpleSmt::<TREE_DEPTH>::with_leaves(
let smtree = SimpleSmt::with_leaves(
64,
KEYS8.into_iter().zip(VALUES8.into_iter().map(|x| x.into()).rev()),
)
.unwrap();
@@ -897,13 +897,13 @@ fn test_recorder() {
let node = recorder.get_node(mtree.root(), index_0).unwrap();
assert_eq!(node, mtree.get_node(index_0).unwrap());
let index_1 = NodeIndex::new(TREE_DEPTH, 1).unwrap();
let index_1 = NodeIndex::new(smtree.depth(), 1).unwrap();
let node = recorder.get_node(smtree.root(), index_1).unwrap();
assert_eq!(node, smtree.get_node(index_1).unwrap());
// insert a value and assert that when we request it next time it is accurate
let new_value = [ZERO, ZERO, ONE, ONE].into();
let index_2 = NodeIndex::new(TREE_DEPTH, 2).unwrap();
let index_2 = NodeIndex::new(smtree.depth(), 2).unwrap();
let root = recorder.set_node(smtree.root(), index_2, new_value).unwrap().root;
assert_eq!(recorder.get_node(root, index_2).unwrap(), new_value);
@@ -920,13 +920,10 @@ fn test_recorder() {
assert_eq!(node, smtree.get_node(index_1).unwrap());
let node = merkle_store.get_node(smtree.root(), index_2).unwrap();
assert_eq!(
node,
smtree.get_leaf(&LeafIndex::<TREE_DEPTH>::try_from(index_2).unwrap()).into()
);
assert_eq!(node, smtree.get_leaf(index_2.value()).unwrap().into());
// assert that is doesnt contain nodes that were not recorded
let not_recorded_index = NodeIndex::new(TREE_DEPTH, 4).unwrap();
let not_recorded_index = NodeIndex::new(smtree.depth(), 4).unwrap();
assert!(merkle_store.get_node(smtree.root(), not_recorded_index).is_err());
assert!(smtree.get_node(not_recorded_index).is_ok());
}

View File

@@ -0,0 +1,48 @@
use core::fmt::Display;
#[derive(Debug, PartialEq, Eq)]
pub enum TieredSmtProofError {
EntriesEmpty,
EmptyValueNotAllowed,
MismatchedPrefixes(u64, u64),
MultipleEntriesOutsideLastTier,
NotATierPath(u8),
PathTooLong,
}
impl Display for TieredSmtProofError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
TieredSmtProofError::EntriesEmpty => {
write!(f, "Missing entries for tiered sparse merkle tree proof")
}
TieredSmtProofError::EmptyValueNotAllowed => {
write!(
f,
"The empty value [0, 0, 0, 0] is not allowed inside a tiered sparse merkle tree"
)
}
TieredSmtProofError::MismatchedPrefixes(first, second) => {
write!(f, "Not all leaves have the same prefix. First {first} second {second}")
}
TieredSmtProofError::MultipleEntriesOutsideLastTier => {
write!(f, "Multiple entries are only allowed for the last tier (depth 64)")
}
TieredSmtProofError::NotATierPath(got) => {
write!(
f,
"Path length does not correspond to a tier. Got {got} Expected one of 16, 32, 48, 64"
)
}
TieredSmtProofError::PathTooLong => {
write!(
f,
"Path longer than maximum depth of 64 for tiered sparse merkle tree proof"
)
}
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for TieredSmtProofError {}

View File

@@ -0,0 +1,509 @@
use super::{
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex,
Rpo256, RpoDigest, StarkField, Vec, Word,
};
use crate::utils::vec;
use core::{cmp, ops::Deref};
mod nodes;
use nodes::NodeStore;
mod values;
use values::ValueStore;
mod proof;
pub use proof::TieredSmtProof;
mod error;
pub use error::TieredSmtProofError;
#[cfg(test)]
mod tests;
// TIERED SPARSE MERKLE TREE
// ================================================================================================
/// Tiered (compacted) Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and
/// values are represented by 4 field elements.
///
/// Leaves in the tree can exist only on specific depths called "tiers". These depths are: 16, 32,
/// 48, and 64. Initially, when a tree is empty, it is equivalent to an empty Sparse Merkle tree
/// of depth 64 (i.e., leaves at depth 64 are set to [ZERO; 4]). As non-empty values are inserted
/// into the tree they are added to the first available tier.
///
/// For example, when the first key-value pair is inserted, it will be stored in a node at depth
/// 16 such that the 16 most significant bits of the key determine the position of the node at
/// depth 16. If another value with a key sharing the same 16-bit prefix is inserted, both values
/// move into the next tier (depth 32). This process is repeated until values end up at the bottom
/// tier (depth 64). If multiple values have keys with a common 64-bit prefix, such key-value pairs
/// are stored in a sorted list at the bottom tier.
///
/// To differentiate between internal and leaf nodes, node values are computed as follows:
/// - Internal nodes: hash(left_child, right_child).
/// - Leaf node at depths 16, 32, or 64: hash(key, value, domain=depth).
/// - Leaf node at depth 64: hash([key_0, value_0, ..., key_n, value_n], domain=64).
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct TieredSmt {
root: RpoDigest,
nodes: NodeStore,
values: ValueStore,
}
impl TieredSmt {
// CONSTANTS
// --------------------------------------------------------------------------------------------
/// The number of levels between tiers.
pub const TIER_SIZE: u8 = 16;
/// Depths at which leaves can exist in a tiered SMT.
pub const TIER_DEPTHS: [u8; 4] = [16, 32, 48, 64];
/// Maximum node depth. This is also the bottom tier of the tree.
pub const MAX_DEPTH: u8 = 64;
/// Value of an empty leaf.
pub const EMPTY_VALUE: Word = super::EMPTY_WORD;
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Returns a new [TieredSmt] instantiated with the specified key-value pairs.
///
/// # Errors
/// Returns an error if the provided entries contain multiple values for the same key.
pub fn with_entries<R, I>(entries: R) -> Result<Self, MerkleError>
where
R: IntoIterator<IntoIter = I>,
I: Iterator<Item = (RpoDigest, Word)> + ExactSizeIterator,
{
// create an empty tree
let mut tree = Self::default();
// append leaves to the tree returning an error if a duplicate entry for the same key
// is found
let mut empty_entries = BTreeSet::new();
for (key, value) in entries {
let old_value = tree.insert(key, value);
if old_value != Self::EMPTY_VALUE || empty_entries.contains(&key) {
return Err(MerkleError::DuplicateValuesForKey(key));
}
// if we've processed an empty entry, add the key to the set of empty entry keys, and
// if this key was already in the set, return an error
if value == Self::EMPTY_VALUE && !empty_entries.insert(key) {
return Err(MerkleError::DuplicateValuesForKey(key));
}
}
Ok(tree)
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the root of this Merkle tree.
pub const fn root(&self) -> RpoDigest {
self.root
}
/// Returns a node at the specified index.
///
/// # Errors
/// Returns an error if:
/// - The specified index depth is 0 or greater than 64.
/// - The node with the specified index does not exists in the Merkle tree. This is possible
/// when a leaf node with the same index prefix exists at a tier higher than the requested
/// node.
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
self.nodes.get_node(index)
}
/// Returns a Merkle path from the node at the specified index to the root.
///
/// The node itself is not included in the path.
///
/// # Errors
/// Returns an error if:
/// - The specified index depth is 0 or greater than 64.
/// - The node with the specified index does not exists in the Merkle tree. This is possible
/// when a leaf node with the same index prefix exists at a tier higher than the node to
/// which the path is requested.
pub fn get_path(&self, index: NodeIndex) -> Result<MerklePath, MerkleError> {
self.nodes.get_path(index)
}
/// Returns the value associated with the specified key.
///
/// If nothing was inserted into this tree for the specified key, [ZERO; 4] is returned.
pub fn get_value(&self, key: RpoDigest) -> Word {
match self.values.get(&key) {
Some(value) => *value,
None => Self::EMPTY_VALUE,
}
}
/// Returns a proof for a key-value pair defined by the specified key.
///
/// The proof can be used to attest membership of this key-value pair in a Tiered Sparse Merkle
/// Tree defined by the same root as this tree.
pub fn prove(&self, key: RpoDigest) -> TieredSmtProof {
let (path, index, leaf_exists) = self.nodes.get_proof(&key);
let entries = if index.depth() == Self::MAX_DEPTH {
match self.values.get_all(index.value()) {
Some(entries) => entries,
None => vec![(key, Self::EMPTY_VALUE)],
}
} else if leaf_exists {
let entry =
self.values.get_first(index_to_prefix(&index)).expect("leaf entry not found");
debug_assert_eq!(entry.0, key);
vec![*entry]
} else {
vec![(key, Self::EMPTY_VALUE)]
};
TieredSmtProof::new(path, entries).expect("Bug detected, TSMT produced invalid proof")
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Inserts the provided value into the tree under the specified key and returns the value
/// previously stored under this key.
///
/// If the value for the specified key was not previously set, [ZERO; 4] is returned.
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Word {
// if an empty value is being inserted, remove the leaf node to make it look as if the
// value was never inserted
if value == Self::EMPTY_VALUE {
return self.remove_leaf_node(key);
}
// insert the value into the value store, and if the key was already in the store, update
// it with the new value
if let Some(old_value) = self.values.insert(key, value) {
if old_value != value {
// if the new value is different from the old value, determine the location of
// the leaf node for this key, build the node, and update the root
let (index, leaf_exists) = self.nodes.get_leaf_index(&key);
debug_assert!(leaf_exists);
let node = self.build_leaf_node(index, key, value);
self.root = self.nodes.update_leaf_node(index, node);
}
return old_value;
};
// determine the location for the leaf node; this index could have 3 different meanings:
// - it points to a root of an empty subtree or an empty node at depth 64; in this case,
// we can replace the node with the value node immediately.
// - it points to an existing leaf at the bottom tier (i.e., depth = 64); in this case,
// we need to process update the bottom leaf.
// - it points to an existing leaf node for a different key with the same prefix (same
// key case was handled above); in this case, we need to move the leaf to a lower tier
let (index, leaf_exists) = self.nodes.get_leaf_index(&key);
self.root = if leaf_exists && index.depth() == Self::MAX_DEPTH {
// returned index points to a leaf at the bottom tier
let node = self.build_leaf_node(index, key, value);
self.nodes.update_leaf_node(index, node)
} else if leaf_exists {
// returned index points to a leaf for a different key with the same prefix
// get the key-value pair for the key with the same prefix; since the key-value
// pair has already been inserted into the value store, we need to filter it out
// when looking for the other key-value pair
let (other_key, other_value) = self
.values
.get_first_filtered(index_to_prefix(&index), &key)
.expect("other key-value pair not found");
// determine how far down the tree should we move the leaves
let common_prefix_len = get_common_prefix_tier_depth(&key, other_key);
let depth = cmp::min(common_prefix_len + Self::TIER_SIZE, Self::MAX_DEPTH);
// compute node locations for new and existing key-value paris
let new_index = LeafNodeIndex::from_key(&key, depth);
let other_index = LeafNodeIndex::from_key(other_key, depth);
// compute node values for the new and existing key-value pairs
let new_node = self.build_leaf_node(new_index, key, value);
let other_node = self.build_leaf_node(other_index, *other_key, *other_value);
// replace the leaf located at index with a subtree containing nodes for new and
// existing key-value paris
self.nodes.replace_leaf_with_subtree(
index,
[(new_index, new_node), (other_index, other_node)],
)
} else {
// returned index points to an empty subtree or an empty leaf at the bottom tier
let node = self.build_leaf_node(index, key, value);
self.nodes.insert_leaf_node(index, node)
};
Self::EMPTY_VALUE
}
// ITERATORS
// --------------------------------------------------------------------------------------------
/// Returns an iterator over all key-value pairs in this [TieredSmt].
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
self.values.iter()
}
/// Returns an iterator over all inner nodes of this [TieredSmt] (i.e., nodes not at depths 16
/// 32, 48, or 64).
///
/// The iterator order is unspecified.
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
self.nodes.inner_nodes()
}
/// Returns an iterator over upper leaves (i.e., depth = 16, 32, or 48) for this [TieredSmt]
/// where each yielded item is a (node, key, value) tuple.
///
/// The iterator order is unspecified.
pub fn upper_leaves(&self) -> impl Iterator<Item = (RpoDigest, RpoDigest, Word)> + '_ {
self.nodes.upper_leaves().map(|(index, node)| {
let key_prefix = index_to_prefix(index);
let (key, value) = self.values.get_first(key_prefix).expect("upper leaf not found");
debug_assert_eq!(*index, LeafNodeIndex::from_key(key, index.depth()).into());
(*node, *key, *value)
})
}
/// Returns an iterator over upper leaves (i.e., depth = 16, 32, or 48) for this [TieredSmt]
/// where each yielded item is a (node_index, value) tuple.
pub fn upper_leaf_nodes(&self) -> impl Iterator<Item = (&NodeIndex, &RpoDigest)> {
self.nodes.upper_leaves()
}
/// Returns an iterator over bottom leaves (i.e., depth = 64) of this [TieredSmt].
///
/// Each yielded item consists of the hash of the leaf and its contents, where contents is
/// a vector containing key-value pairs of entries storied in this leaf.
///
/// The iterator order is unspecified.
pub fn bottom_leaves(&self) -> impl Iterator<Item = (RpoDigest, Vec<(RpoDigest, Word)>)> + '_ {
self.nodes.bottom_leaves().map(|(&prefix, node)| {
let values = self.values.get_all(prefix).expect("bottom leaf not found");
(*node, values)
})
}
// HELPER METHODS
// --------------------------------------------------------------------------------------------
/// Removes the node holding the key-value pair for the specified key from this tree, and
/// returns the value associated with the specified key.
///
/// If no value was associated with the specified key, [ZERO; 4] is returned.
fn remove_leaf_node(&mut self, key: RpoDigest) -> Word {
// remove the key-value pair from the value store; if no value was associated with the
// specified key, return.
let old_value = match self.values.remove(&key) {
Some(old_value) => old_value,
None => return Self::EMPTY_VALUE,
};
// determine the location of the leaf holding the key-value pair to be removed
let (index, leaf_exists) = self.nodes.get_leaf_index(&key);
debug_assert!(leaf_exists);
// if the leaf is at the bottom tier and after removing the key-value pair from it, the
// leaf is still not empty, we either just update it, or move it up to a higher tier (if
// the leaf doesn't have siblings at lower tiers)
if index.depth() == Self::MAX_DEPTH {
if let Some(entries) = self.values.get_all(index.value()) {
// if there is only one key-value pair left at the bottom leaf, and it can be
// moved up to a higher tier, truncate the branch and return
if entries.len() == 1 {
let new_depth = self.nodes.get_last_single_child_parent_depth(index.value());
if new_depth != Self::MAX_DEPTH {
let node = hash_upper_leaf(entries[0].0, entries[0].1, new_depth);
self.root = self.nodes.truncate_branch(index.value(), new_depth, node);
return old_value;
}
}
// otherwise just recompute the leaf hash and update the leaf node
let node = hash_bottom_leaf(&entries);
self.root = self.nodes.update_leaf_node(index, node);
return old_value;
};
}
// if the removed key-value pair has a lone sibling at the current tier with a root at
// higher tier, we need to move the sibling to a higher tier
if let Some((sib_key, sib_val, new_sib_index)) = self.values.get_lone_sibling(index) {
// determine the current index of the sibling node
let sib_index = LeafNodeIndex::from_key(sib_key, index.depth());
debug_assert!(sib_index.depth() > new_sib_index.depth());
// compute node value for the new location of the sibling leaf and replace the subtree
// with this leaf node
let node = self.build_leaf_node(new_sib_index, *sib_key, *sib_val);
let new_sib_depth = new_sib_index.depth();
self.root = self.nodes.replace_subtree_with_leaf(index, sib_index, new_sib_depth, node);
} else {
// if the removed key-value pair did not have a sibling at the current tier with a
// root at higher tiers, just clear the leaf node
self.root = self.nodes.clear_leaf_node(index);
}
old_value
}
/// Builds and returns a leaf node value for the node located as the specified index.
///
/// This method assumes that the key-value pair for the node has already been inserted into
/// the value store, however, for depths 16, 32, and 48, the node is computed directly from
/// the passed-in values (for depth 64, the value store is queried to get all the key-value
/// pairs located at the specified index).
fn build_leaf_node(&self, index: LeafNodeIndex, key: RpoDigest, value: Word) -> RpoDigest {
let depth = index.depth();
// insert the key into index-key map and compute the new value of the node
if index.depth() == Self::MAX_DEPTH {
// for the bottom tier, we add the key-value pair to the existing leaf, or create a
// new leaf with this key-value pair
let values = self.values.get_all(index.value()).unwrap();
hash_bottom_leaf(&values)
} else {
debug_assert_eq!(self.values.get_first(index_to_prefix(&index)), Some(&(key, value)));
hash_upper_leaf(key, value, depth)
}
}
}
impl Default for TieredSmt {
fn default() -> Self {
let root = EmptySubtreeRoots::empty_hashes(Self::MAX_DEPTH)[0];
Self {
root,
nodes: NodeStore::new(root),
values: ValueStore::default(),
}
}
}
// LEAF NODE INDEX
// ================================================================================================
/// A wrapper around [NodeIndex] to provide type-safe references to nodes at depths 16, 32, 48, and
/// 64.
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub struct LeafNodeIndex(NodeIndex);
impl LeafNodeIndex {
/// Returns a new [LeafNodeIndex] instantiated from the provided [NodeIndex].
///
/// In debug mode, panics if index depth is not 16, 32, 48, or 64.
pub fn new(index: NodeIndex) -> Self {
// check if the depth is 16, 32, 48, or 64; this works because for a valid depth,
// depth - 16, can be 0, 16, 32, or 48 - i.e., the value is either 0 or any of the 4th
// or 5th bits are set. We can test for this by computing a bitwise AND with a value
// which has all but the 4th and 5th bits set (which is !48).
debug_assert_eq!(((index.depth() - 16) & !48), 0, "invalid tier depth {}", index.depth());
Self(index)
}
/// Returns a new [LeafNodeIndex] instantiated from the specified key inserted at the specified
/// depth.
///
/// The value for the key is computed by taking n most significant bits from the most significant
/// element of the key, where n is the specified depth.
pub fn from_key(key: &RpoDigest, depth: u8) -> Self {
let mse = get_key_prefix(key);
Self::new(NodeIndex::new_unchecked(depth, mse >> (TieredSmt::MAX_DEPTH - depth)))
}
/// Returns a new [LeafNodeIndex] instantiated for testing purposes.
#[cfg(test)]
pub fn make(depth: u8, value: u64) -> Self {
Self::new(NodeIndex::make(depth, value))
}
/// Traverses towards the root until the specified depth is reached.
///
/// The new depth must be a valid tier depth - i.e., 16, 32, 48, or 64.
pub fn move_up_to(&mut self, depth: u8) {
debug_assert_eq!(((depth - 16) & !48), 0, "invalid tier depth: {depth}");
self.0.move_up_to(depth);
}
}
impl Deref for LeafNodeIndex {
type Target = NodeIndex;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<NodeIndex> for LeafNodeIndex {
fn from(value: NodeIndex) -> Self {
Self::new(value)
}
}
impl From<LeafNodeIndex> for NodeIndex {
fn from(value: LeafNodeIndex) -> Self {
value.0
}
}
// HELPER FUNCTIONS
// ================================================================================================
/// Returns the value representing the 64 most significant bits of the specified key.
fn get_key_prefix(key: &RpoDigest) -> u64 {
Word::from(key)[3].as_int()
}
/// Returns the index value shifted to be in the most significant bit positions of the returned
/// u64 value.
fn index_to_prefix(index: &NodeIndex) -> u64 {
index.value() << (TieredSmt::MAX_DEPTH - index.depth())
}
/// Returns tiered common prefix length between the most significant elements of the provided keys.
///
/// Specifically:
/// - returns 64 if the most significant elements are equal.
/// - returns 48 if the common prefix is between 48 and 63 bits.
/// - returns 32 if the common prefix is between 32 and 47 bits.
/// - returns 16 if the common prefix is between 16 and 31 bits.
/// - returns 0 if the common prefix is fewer than 16 bits.
fn get_common_prefix_tier_depth(key1: &RpoDigest, key2: &RpoDigest) -> u8 {
let e1 = get_key_prefix(key1);
let e2 = get_key_prefix(key2);
let ex = (e1 ^ e2).leading_zeros() as u8;
(ex / 16) * 16
}
/// Computes node value for leaves at tiers 16, 32, or 48.
///
/// Node value is computed as: hash(key || value, domain = depth).
pub fn hash_upper_leaf(key: RpoDigest, value: Word, depth: u8) -> RpoDigest {
const NUM_UPPER_TIERS: usize = TieredSmt::TIER_DEPTHS.len() - 1;
debug_assert!(TieredSmt::TIER_DEPTHS[..NUM_UPPER_TIERS].contains(&depth));
Rpo256::merge_in_domain(&[key, value.into()], depth.into())
}
/// Computes node value for leaves at the bottom tier (depth 64).
///
/// Node value is computed as: hash([key_0, value_0, ..., key_n, value_n], domain=64).
///
/// TODO: when hashing in domain is implemented for `hash_elements()`, combine this function with
/// `hash_upper_leaf()` function.
pub fn hash_bottom_leaf(values: &[(RpoDigest, Word)]) -> RpoDigest {
let mut elements = Vec::with_capacity(values.len() * 8);
for (key, val) in values.iter() {
elements.extend_from_slice(key.as_elements());
elements.extend_from_slice(val.as_slice());
}
// TODO: hash in domain
Rpo256::hash_elements(&elements)
}

View File

@@ -0,0 +1,419 @@
use super::{
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, LeafNodeIndex, MerkleError, MerklePath,
NodeIndex, Rpo256, RpoDigest, Vec,
};
// CONSTANTS
// ================================================================================================
/// The number of levels between tiers.
const TIER_SIZE: u8 = super::TieredSmt::TIER_SIZE;
/// Depths at which leaves can exist in a tiered SMT.
const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
/// Maximum node depth. This is also the bottom tier of the tree.
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
// NODE STORE
// ================================================================================================
/// A store of nodes for a Tiered Sparse Merkle tree.
///
/// The store contains information about all nodes as well as information about which of the nodes
/// represent leaf nodes in a Tiered Sparse Merkle tree. In the current implementation, [BTreeSet]s
/// are used to determine the position of the leaves in the tree.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct NodeStore {
nodes: BTreeMap<NodeIndex, RpoDigest>,
upper_leaves: BTreeSet<NodeIndex>,
bottom_leaves: BTreeSet<u64>,
}
impl NodeStore {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
/// Returns a new instance of [NodeStore] instantiated with the specified root node.
///
/// Root node is assumed to be a root of an empty sparse Merkle tree.
pub fn new(root_node: RpoDigest) -> Self {
let mut nodes = BTreeMap::default();
nodes.insert(NodeIndex::root(), root_node);
Self {
nodes,
upper_leaves: BTreeSet::default(),
bottom_leaves: BTreeSet::default(),
}
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns a node at the specified index.
///
/// # Errors
/// Returns an error if:
/// - The specified index depth is 0 or greater than 64.
/// - The node with the specified index does not exists in the Merkle tree. This is possible
/// when a leaf node with the same index prefix exists at a tier higher than the requested
/// node.
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
self.validate_node_access(index)?;
Ok(self.get_node_unchecked(&index))
}
/// Returns a Merkle path from the node at the specified index to the root.
///
/// The node itself is not included in the path.
///
/// # Errors
/// Returns an error if:
/// - The specified index depth is 0 or greater than 64.
/// - The node with the specified index does not exists in the Merkle tree. This is possible
/// when a leaf node with the same index prefix exists at a tier higher than the node to
/// which the path is requested.
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
self.validate_node_access(index)?;
let mut path = Vec::with_capacity(index.depth() as usize);
for _ in 0..index.depth() {
let node = self.get_node_unchecked(&index.sibling());
path.push(node);
index.move_up();
}
Ok(path.into())
}
/// Returns a Merkle path to the node specified by the key together with a flag indicating,
/// whether this node is a leaf at depths 16, 32, or 48.
pub fn get_proof(&self, key: &RpoDigest) -> (MerklePath, NodeIndex, bool) {
let (index, leaf_exists) = self.get_leaf_index(key);
let index: NodeIndex = index.into();
let path = self.get_path(index).expect("failed to retrieve Merkle path for a node index");
(path, index, leaf_exists)
}
/// Returns an index at which a leaf node for the specified key should be inserted.
///
/// The second value in the returned tuple is set to true if the node at the returned index
/// is already a leaf node.
pub fn get_leaf_index(&self, key: &RpoDigest) -> (LeafNodeIndex, bool) {
// traverse the tree from the root down checking nodes at tiers 16, 32, and 48. Return if
// a node at any of the tiers is either a leaf or a root of an empty subtree.
const NUM_UPPER_TIERS: usize = TIER_DEPTHS.len() - 1;
for &tier_depth in TIER_DEPTHS[..NUM_UPPER_TIERS].iter() {
let index = LeafNodeIndex::from_key(key, tier_depth);
if self.upper_leaves.contains(&index) {
return (index, true);
} else if !self.nodes.contains_key(&index) {
return (index, false);
}
}
// if we got here, that means all of the nodes checked so far are internal nodes, and
// the new node would need to be inserted in the bottom tier.
let index = LeafNodeIndex::from_key(key, MAX_DEPTH);
(index, self.bottom_leaves.contains(&index.value()))
}
/// Traverses the tree up from the bottom tier starting at the specified leaf index and
/// returns the depth of the first node which hash more than one child. The returned depth
/// is rounded up to the next tier.
pub fn get_last_single_child_parent_depth(&self, leaf_index: u64) -> u8 {
let mut index = NodeIndex::new_unchecked(MAX_DEPTH, leaf_index);
for _ in (TIER_DEPTHS[0]..MAX_DEPTH).rev() {
let sibling_index = index.sibling();
if self.nodes.contains_key(&sibling_index) {
break;
}
index.move_up();
}
let tier = (index.depth() - 1) / TIER_SIZE;
TIER_DEPTHS[tier as usize]
}
// ITERATORS
// --------------------------------------------------------------------------------------------
/// Returns an iterator over all inner nodes of the Tiered Sparse Merkle tree (i.e., nodes not
/// at depths 16 32, 48, or 64).
///
/// The iterator order is unspecified.
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
self.nodes.iter().filter_map(|(index, node)| {
if self.is_internal_node(index) {
Some(InnerNodeInfo {
value: *node,
left: self.get_node_unchecked(&index.left_child()),
right: self.get_node_unchecked(&index.right_child()),
})
} else {
None
}
})
}
/// Returns an iterator over the upper leaves (i.e., leaves with depths 16, 32, 48) of the
/// Tiered Sparse Merkle tree.
pub fn upper_leaves(&self) -> impl Iterator<Item = (&NodeIndex, &RpoDigest)> {
self.upper_leaves.iter().map(|index| (index, &self.nodes[index]))
}
/// Returns an iterator over the bottom leaves (i.e., leaves with depth 64) of the Tiered
/// Sparse Merkle tree.
pub fn bottom_leaves(&self) -> impl Iterator<Item = (&u64, &RpoDigest)> {
self.bottom_leaves.iter().map(|value| {
let index = NodeIndex::new_unchecked(MAX_DEPTH, *value);
(value, &self.nodes[&index])
})
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Replaces the leaf node at the specified index with a tree consisting of two leaves located
/// at the specified indexes. Recomputes and returns the new root.
pub fn replace_leaf_with_subtree(
&mut self,
leaf_index: LeafNodeIndex,
subtree_leaves: [(LeafNodeIndex, RpoDigest); 2],
) -> RpoDigest {
debug_assert!(self.is_non_empty_leaf(&leaf_index));
debug_assert!(!is_empty_root(&subtree_leaves[0].1));
debug_assert!(!is_empty_root(&subtree_leaves[1].1));
debug_assert_eq!(subtree_leaves[0].0.depth(), subtree_leaves[1].0.depth());
debug_assert!(leaf_index.depth() < subtree_leaves[0].0.depth());
self.upper_leaves.remove(&leaf_index);
if subtree_leaves[0].0 == subtree_leaves[1].0 {
// if the subtree is for a single node at depth 64, we only need to insert one node
debug_assert_eq!(subtree_leaves[0].0.depth(), MAX_DEPTH);
debug_assert_eq!(subtree_leaves[0].1, subtree_leaves[1].1);
self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1)
} else {
self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1);
self.insert_leaf_node(subtree_leaves[1].0, subtree_leaves[1].1)
}
}
/// Replaces a subtree containing the retained and the removed leaf nodes, with a leaf node
/// containing the retained leaf.
///
/// This has the effect of deleting the the node at the `removed_leaf` index from the tree,
/// moving the node at the `retained_leaf` index up to the tier specified by `new_depth`.
pub fn replace_subtree_with_leaf(
&mut self,
removed_leaf: LeafNodeIndex,
retained_leaf: LeafNodeIndex,
new_depth: u8,
node: RpoDigest,
) -> RpoDigest {
debug_assert!(!is_empty_root(&node));
debug_assert!(self.is_non_empty_leaf(&removed_leaf));
debug_assert!(self.is_non_empty_leaf(&retained_leaf));
debug_assert_eq!(removed_leaf.depth(), retained_leaf.depth());
debug_assert!(removed_leaf.depth() > new_depth);
// remove the branches leading up to the tier to which the retained leaf is to be moved
self.remove_branch(removed_leaf, new_depth);
self.remove_branch(retained_leaf, new_depth);
// compute the index of the common root for retained and removed leaves
let mut new_index = retained_leaf;
new_index.move_up_to(new_depth);
// insert the node at the root index
self.insert_leaf_node(new_index, node)
}
/// Inserts the specified node at the specified index; recomputes and returns the new root
/// of the Tiered Sparse Merkle tree.
///
/// This method assumes that the provided node is a non-empty value, and that there is no node
/// at the specified index.
pub fn insert_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest {
debug_assert!(!is_empty_root(&node));
debug_assert_eq!(self.nodes.get(&index), None);
// mark the node as the leaf
if index.depth() == MAX_DEPTH {
self.bottom_leaves.insert(index.value());
} else {
self.upper_leaves.insert(index.into());
};
// insert the node and update the path from the node to the root
let mut index: NodeIndex = index.into();
for _ in 0..index.depth() {
self.nodes.insert(index, node);
let sibling = self.get_node_unchecked(&index.sibling());
node = Rpo256::merge(&index.build_node(node, sibling));
index.move_up();
}
// update the root
self.nodes.insert(NodeIndex::root(), node);
node
}
/// Updates the node at the specified index with the specified node value; recomputes and
/// returns the new root of the Tiered Sparse Merkle tree.
///
/// This method can accept `node` as either an empty or a non-empty value.
pub fn update_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest {
debug_assert!(self.is_non_empty_leaf(&index));
// if the value we are updating the node to is a root of an empty tree, clear the leaf
// flag for this node
if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] {
if index.depth() == MAX_DEPTH {
self.bottom_leaves.remove(&index.value());
} else {
self.upper_leaves.remove(&index);
}
} else {
debug_assert!(!is_empty_root(&node));
}
// update the path from the node to the root
let mut index: NodeIndex = index.into();
for _ in 0..index.depth() {
if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] {
self.nodes.remove(&index);
} else {
self.nodes.insert(index, node);
}
let sibling = self.get_node_unchecked(&index.sibling());
node = Rpo256::merge(&index.build_node(node, sibling));
index.move_up();
}
// update the root
self.nodes.insert(NodeIndex::root(), node);
node
}
/// Replaces the leaf node at the specified index with a root of an empty subtree; recomputes
/// and returns the new root of the Tiered Sparse Merkle tree.
pub fn clear_leaf_node(&mut self, index: LeafNodeIndex) -> RpoDigest {
debug_assert!(self.is_non_empty_leaf(&index));
let node = EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize];
self.update_leaf_node(index, node)
}
/// Truncates a branch starting with specified leaf at the bottom tier to new depth.
///
/// This involves removing the part of the branch below the new depth, and then inserting a new
/// // node at the new depth.
pub fn truncate_branch(
&mut self,
leaf_index: u64,
new_depth: u8,
node: RpoDigest,
) -> RpoDigest {
debug_assert!(self.bottom_leaves.contains(&leaf_index));
let mut leaf_index = LeafNodeIndex::new(NodeIndex::new_unchecked(MAX_DEPTH, leaf_index));
self.remove_branch(leaf_index, new_depth);
leaf_index.move_up_to(new_depth);
self.insert_leaf_node(leaf_index, node)
}
// HELPER METHODS
// --------------------------------------------------------------------------------------------
/// Returns true if the node at the specified index is a leaf node.
fn is_non_empty_leaf(&self, index: &LeafNodeIndex) -> bool {
if index.depth() == MAX_DEPTH {
self.bottom_leaves.contains(&index.value())
} else {
self.upper_leaves.contains(index)
}
}
/// Returns true if the node at the specified index is an internal node - i.e., there is
/// no leaf at that node and the node does not belong to the bottom tier.
fn is_internal_node(&self, index: &NodeIndex) -> bool {
if index.depth() == MAX_DEPTH {
false
} else {
!self.upper_leaves.contains(index)
}
}
/// Checks if the specified index is valid in the context of this Merkle tree.
///
/// # Errors
/// Returns an error if:
/// - The specified index depth is 0 or greater than 64.
/// - The node for the specified index does not exists in the Merkle tree. This is possible
/// when an ancestors of the specified index is a leaf node.
fn validate_node_access(&self, index: NodeIndex) -> Result<(), MerkleError> {
if index.is_root() {
return Err(MerkleError::DepthTooSmall(index.depth()));
} else if index.depth() > MAX_DEPTH {
return Err(MerkleError::DepthTooBig(index.depth() as u64));
} else {
// make sure that there are no leaf nodes in the ancestors of the index; since leaf
// nodes can live at specific depth, we just need to check these depths.
let tier = ((index.depth() - 1) / TIER_SIZE) as usize;
let mut tier_index = index;
for &depth in TIER_DEPTHS[..tier].iter().rev() {
tier_index.move_up_to(depth);
if self.upper_leaves.contains(&tier_index) {
return Err(MerkleError::NodeNotInSet(index));
}
}
}
Ok(())
}
/// Returns a node at the specified index. If the node does not exist at this index, a root
/// for an empty subtree at the index's depth is returned.
///
/// Unlike [NodeStore::get_node()] this does not perform any checks to verify that the
/// returned node is valid in the context of this tree.
fn get_node_unchecked(&self, index: &NodeIndex) -> RpoDigest {
match self.nodes.get(index) {
Some(node) => *node,
None => EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize],
}
}
/// Removes a sequence of nodes starting at the specified index and traversing the tree up to
/// the specified depth. The node at the `end_depth` is also removed, and the appropriate leaf
/// flag is cleared.
///
/// This method does not update any other nodes and does not recompute the tree root.
fn remove_branch(&mut self, index: LeafNodeIndex, end_depth: u8) {
if index.depth() == MAX_DEPTH {
self.bottom_leaves.remove(&index.value());
} else {
self.upper_leaves.remove(&index);
}
let mut index: NodeIndex = index.into();
assert!(index.depth() > end_depth);
for _ in 0..(index.depth() - end_depth + 1) {
self.nodes.remove(&index);
index.move_up()
}
}
}
// HELPER FUNCTIONS
// ================================================================================================
/// Returns true if the specified node is a root of an empty tree or an empty value ([ZERO; 4]).
fn is_empty_root(node: &RpoDigest) -> bool {
EmptySubtreeRoots::empty_hashes(MAX_DEPTH).contains(node)
}

View File

@@ -0,0 +1,170 @@
use super::{
get_common_prefix_tier_depth, get_key_prefix, hash_bottom_leaf, hash_upper_leaf,
EmptySubtreeRoots, LeafNodeIndex, MerklePath, RpoDigest, TieredSmtProofError, Vec, Word,
};
// CONSTANTS
// ================================================================================================
/// Maximum node depth. This is also the bottom tier of the tree.
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
/// Value of an empty leaf.
pub const EMPTY_VALUE: Word = super::TieredSmt::EMPTY_VALUE;
/// Depths at which leaves can exist in a tiered SMT.
pub const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
// TIERED SPARSE MERKLE TREE PROOF
// ================================================================================================
/// A proof which can be used to assert membership (or non-membership) of key-value pairs in a
/// Tiered Sparse Merkle tree.
///
/// The proof consists of a Merkle path and one or more key-value entries which describe the node
/// located at the base of the path. If the node at the base of the path resolves to [ZERO; 4],
/// the entries will contain a single item with value set to [ZERO; 4].
#[derive(PartialEq, Eq, Debug)]
pub struct TieredSmtProof {
path: MerklePath,
entries: Vec<(RpoDigest, Word)>,
}
impl TieredSmtProof {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
/// Returns a new instance of [TieredSmtProof] instantiated from the specified path and entries.
///
/// # Panics
/// Panics if:
/// - The length of the path is greater than 64.
/// - Entries is an empty vector.
/// - Entries contains more than 1 item, but the length of the path is not 64.
/// - Entries contains more than 1 item, and one of the items has value set to [ZERO; 4].
/// - Entries contains multiple items with keys which don't share the same 64-bit prefix.
pub fn new<I>(path: MerklePath, entries: I) -> Result<Self, TieredSmtProofError>
where
I: IntoIterator<Item = (RpoDigest, Word)>,
{
let entries: Vec<(RpoDigest, Word)> = entries.into_iter().collect();
if !TIER_DEPTHS.into_iter().any(|e| e == path.depth()) {
return Err(TieredSmtProofError::NotATierPath(path.depth()));
}
if entries.is_empty() {
return Err(TieredSmtProofError::EntriesEmpty);
}
if entries.len() > 1 {
if path.depth() != MAX_DEPTH {
return Err(TieredSmtProofError::MultipleEntriesOutsideLastTier);
}
let prefix = get_key_prefix(&entries[0].0);
for entry in entries.iter().skip(1) {
if entry.1 == EMPTY_VALUE {
return Err(TieredSmtProofError::EmptyValueNotAllowed);
}
let current = get_key_prefix(&entry.0);
if prefix != current {
return Err(TieredSmtProofError::MismatchedPrefixes(prefix, current));
}
}
}
Ok(Self { path, entries })
}
// PROOF VERIFIER
// --------------------------------------------------------------------------------------------
/// Returns true if a Tiered Sparse Merkle tree with the specified root contains the provided
/// key-value pair.
///
/// Note: this method cannot be used to assert non-membership. That is, if false is returned,
/// it does not mean that the provided key-value pair is not in the tree.
pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool {
// Handles the following scenarios:
// - the value is set
// - empty leaf, there is an explicit entry for the key with the empty value
// - shared 64-bit prefix, the target key is not included in the entries list, the value is implicitly the empty word
let v = match self.entries.iter().find(|(k, _)| k == key) {
Some((_, v)) => v,
None => &EMPTY_VALUE,
};
// The value must match for the proof to be valid
if v != value {
return false;
}
// If the proof is for an empty value, we can verify it against any key which has a common
// prefix with the key storied in entries, but the prefix must be greater than the path
// length
if self.is_value_empty()
&& get_common_prefix_tier_depth(key, &self.entries[0].0) < self.path.depth()
{
return false;
}
// make sure the Merkle path resolves to the correct root
root == &self.compute_root()
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the value associated with the specific key according to this proof, or None if
/// this proof does not contain a value for the specified key.
///
/// A key-value pair generated by using this method should pass the `verify_membership()` check.
pub fn get(&self, key: &RpoDigest) -> Option<Word> {
if self.is_value_empty() {
let common_prefix_tier = get_common_prefix_tier_depth(key, &self.entries[0].0);
if common_prefix_tier < self.path.depth() {
None
} else {
Some(EMPTY_VALUE)
}
} else {
self.entries.iter().find(|(k, _)| k == key).map(|(_, value)| *value)
}
}
/// Computes the root of a Tiered Sparse Merkle tree to which this proof resolve.
pub fn compute_root(&self) -> RpoDigest {
let node = self.build_node();
let index = LeafNodeIndex::from_key(&self.entries[0].0, self.path.depth());
self.path
.compute_root(index.value(), node)
.expect("failed to compute Merkle path root")
}
/// Consume the proof and returns its parts.
pub fn into_parts(self) -> (MerklePath, Vec<(RpoDigest, Word)>) {
(self.path, self.entries)
}
// HELPER METHODS
// --------------------------------------------------------------------------------------------
/// Returns true if the proof is for an empty value.
fn is_value_empty(&self) -> bool {
self.entries[0].1 == EMPTY_VALUE
}
/// Converts the entries contained in this proof into a node value for node at the base of the
/// path contained in this proof.
fn build_node(&self) -> RpoDigest {
let depth = self.path.depth();
if self.is_value_empty() {
EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[depth as usize]
} else if depth == MAX_DEPTH {
hash_bottom_leaf(&self.entries)
} else {
let (key, value) = self.entries[0];
hash_upper_leaf(key, value, depth)
}
}
}

View File

@@ -0,0 +1,968 @@
use super::{
super::{super::ONE, super::WORD_SIZE, Felt, MerkleStore, EMPTY_WORD, ZERO},
EmptySubtreeRoots, InnerNodeInfo, NodeIndex, Rpo256, RpoDigest, TieredSmt, Vec, Word,
};
// INSERTION TESTS
// ================================================================================================
#[test]
fn tsmt_insert_one() {
let mut smt = TieredSmt::default();
let mut store = MerkleStore::default();
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let value = [ONE; WORD_SIZE];
// since the tree is empty, the first node will be inserted at depth 16 and the index will be
// 16 most significant bits of the key
let index = NodeIndex::make(16, raw >> 48);
let leaf_node = build_leaf_node(key, value, 16);
let tree_root = store.set_node(smt.root(), index, leaf_node).unwrap().root;
smt.insert(key, value);
assert_eq!(smt.root(), tree_root);
// make sure the value was inserted, and the node is at the expected index
assert_eq!(smt.get_value(key), value);
assert_eq!(smt.get_node(index).unwrap(), leaf_node);
// make sure the paths we get from the store and the tree match
let expected_path = store.get_path(tree_root, index).unwrap();
assert_eq!(smt.get_path(index).unwrap(), expected_path.path);
// make sure inner nodes match
let expected_nodes = get_non_empty_nodes(&store);
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
assert_eq!(actual_nodes.len(), expected_nodes.len());
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
// make sure leaves are returned correctly
let mut leaves = smt.upper_leaves();
assert_eq!(leaves.next(), Some((leaf_node, key, value)));
assert_eq!(leaves.next(), None);
}
#[test]
fn tsmt_insert_two_16() {
let mut smt = TieredSmt::default();
let mut store = MerkleStore::default();
// --- insert the first value ---------------------------------------------
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
let val_a = [ONE; WORD_SIZE];
smt.insert(key_a, val_a);
// --- insert the second value --------------------------------------------
// the key for this value has the same 16-bit prefix as the key for the first value,
// thus, on insertions, both values should be pushed to depth 32 tier
let raw_b = 0b_10101010_10101010_10011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
let val_b = [Felt::new(2); WORD_SIZE];
smt.insert(key_b, val_b);
// --- build Merkle store with equivalent data ----------------------------
let mut tree_root = get_init_root();
let index_a = NodeIndex::make(32, raw_a >> 32);
let leaf_node_a = build_leaf_node(key_a, val_a, 32);
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
let index_b = NodeIndex::make(32, raw_b >> 32);
let leaf_node_b = build_leaf_node(key_b, val_b, 32);
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
// --- verify that data is consistent between store and tree --------------
assert_eq!(smt.root(), tree_root);
assert_eq!(smt.get_value(key_a), val_a);
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
assert_eq!(smt.get_value(key_b), val_b);
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
// make sure inner nodes match - the store contains more entries because it keeps track of
// all prior state - so, we don't check that the number of inner nodes is the same in both
let expected_nodes = get_non_empty_nodes(&store);
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
// make sure leaves are returned correctly
let mut leaves = smt.upper_leaves();
assert_eq!(leaves.next(), Some((leaf_node_a, key_a, val_a)));
assert_eq!(leaves.next(), Some((leaf_node_b, key_b, val_b)));
assert_eq!(leaves.next(), None);
}
#[test]
fn tsmt_insert_two_32() {
let mut smt = TieredSmt::default();
let mut store = MerkleStore::default();
// --- insert the first value ---------------------------------------------
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
let val_a = [ONE; WORD_SIZE];
smt.insert(key_a, val_a);
// --- insert the second value --------------------------------------------
// the key for this value has the same 32-bit prefix as the key for the first value,
// thus, on insertions, both values should be pushed to depth 48 tier
let raw_b = 0b_10101010_10101010_00011111_11111111_00010110_10010011_11100000_00000000_u64;
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
let val_b = [Felt::new(2); WORD_SIZE];
smt.insert(key_b, val_b);
// --- build Merkle store with equivalent data ----------------------------
let mut tree_root = get_init_root();
let index_a = NodeIndex::make(48, raw_a >> 16);
let leaf_node_a = build_leaf_node(key_a, val_a, 48);
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
let index_b = NodeIndex::make(48, raw_b >> 16);
let leaf_node_b = build_leaf_node(key_b, val_b, 48);
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
// --- verify that data is consistent between store and tree --------------
assert_eq!(smt.root(), tree_root);
assert_eq!(smt.get_value(key_a), val_a);
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
assert_eq!(smt.get_value(key_b), val_b);
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
// make sure inner nodes match - the store contains more entries because it keeps track of
// all prior state - so, we don't check that the number of inner nodes is the same in both
let expected_nodes = get_non_empty_nodes(&store);
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
}
#[test]
fn tsmt_insert_three() {
let mut smt = TieredSmt::default();
let mut store = MerkleStore::default();
// --- insert the first value ---------------------------------------------
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
let val_a = [ONE; WORD_SIZE];
smt.insert(key_a, val_a);
// --- insert the second value --------------------------------------------
// the key for this value has the same 16-bit prefix as the key for the first value,
// thus, on insertions, both values should be pushed to depth 32 tier
let raw_b = 0b_10101010_10101010_10011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
let val_b = [Felt::new(2); WORD_SIZE];
smt.insert(key_b, val_b);
// --- insert the third value ---------------------------------------------
// the key for this value has the same 16-bit prefix as the keys for the first two,
// values; thus, on insertions, it will be inserted into depth 32 tier, but will not
// affect locations of the other two values
let raw_c = 0b_10101010_10101010_11011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
let val_c = [Felt::new(3); WORD_SIZE];
smt.insert(key_c, val_c);
// --- build Merkle store with equivalent data ----------------------------
let mut tree_root = get_init_root();
let index_a = NodeIndex::make(32, raw_a >> 32);
let leaf_node_a = build_leaf_node(key_a, val_a, 32);
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
let index_b = NodeIndex::make(32, raw_b >> 32);
let leaf_node_b = build_leaf_node(key_b, val_b, 32);
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
let index_c = NodeIndex::make(32, raw_c >> 32);
let leaf_node_c = build_leaf_node(key_c, val_c, 32);
tree_root = store.set_node(tree_root, index_c, leaf_node_c).unwrap().root;
// --- verify that data is consistent between store and tree --------------
assert_eq!(smt.root(), tree_root);
assert_eq!(smt.get_value(key_a), val_a);
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
assert_eq!(smt.get_value(key_b), val_b);
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
assert_eq!(smt.get_value(key_c), val_c);
assert_eq!(smt.get_node(index_c).unwrap(), leaf_node_c);
let expected_path = store.get_path(tree_root, index_c).unwrap().path;
assert_eq!(smt.get_path(index_c).unwrap(), expected_path);
// make sure inner nodes match - the store contains more entries because it keeps track of
// all prior state - so, we don't check that the number of inner nodes is the same in both
let expected_nodes = get_non_empty_nodes(&store);
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
}
// UPDATE TESTS
// ================================================================================================
#[test]
fn tsmt_update() {
let mut smt = TieredSmt::default();
let mut store = MerkleStore::default();
// --- insert a value into the tree ---------------------------------------
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let value_a = [ONE; WORD_SIZE];
smt.insert(key, value_a);
// --- update the value ---------------------------------------------------
let value_b = [Felt::new(2); WORD_SIZE];
smt.insert(key, value_b);
// --- verify consistency -------------------------------------------------
let mut tree_root = get_init_root();
let index = NodeIndex::make(16, raw >> 48);
let leaf_node = build_leaf_node(key, value_b, 16);
tree_root = store.set_node(tree_root, index, leaf_node).unwrap().root;
assert_eq!(smt.root(), tree_root);
assert_eq!(smt.get_value(key), value_b);
assert_eq!(smt.get_node(index).unwrap(), leaf_node);
let expected_path = store.get_path(tree_root, index).unwrap().path;
assert_eq!(smt.get_path(index).unwrap(), expected_path);
// make sure inner nodes match - the store contains more entries because it keeps track of
// all prior state - so, we don't check that the number of inner nodes is the same in both
let expected_nodes = get_non_empty_nodes(&store);
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
}
// DELETION TESTS
// ================================================================================================
#[test]
fn tsmt_delete_16() {
let mut smt = TieredSmt::default();
// --- insert a value into the tree ---------------------------------------
let smt0 = smt.clone();
let raw_a = 0b_01010101_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
let value_a = [ONE, ONE, ONE, ONE];
smt.insert(key_a, value_a);
// --- insert another value into the tree ---------------------------------
let smt1 = smt.clone();
let raw_b = 0b_01011111_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
let value_b = [ONE, ONE, ONE, ZERO];
smt.insert(key_b, value_b);
// --- delete the last inserted value -------------------------------------
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
assert_eq!(smt, smt1);
// --- delete the first inserted value ------------------------------------
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
assert_eq!(smt, smt0);
}
#[test]
fn tsmt_delete_32() {
let mut smt = TieredSmt::default();
// --- insert a value into the tree ---------------------------------------
let smt0 = smt.clone();
let raw_a = 0b_01010101_01101100_01111111_11111111_10010110_10010011_11100000_00000000_u64;
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
let value_a = [ONE, ONE, ONE, ONE];
smt.insert(key_a, value_a);
// --- insert another with the same 16-bit prefix into the tree -----------
let smt1 = smt.clone();
let raw_b = 0b_01010101_01101100_00111111_11111111_10010110_10010011_11100000_00000000_u64;
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
let value_b = [ONE, ONE, ONE, ZERO];
smt.insert(key_b, value_b);
// --- insert the 3rd value with the same 16-bit prefix into the tree -----
let smt2 = smt.clone();
let raw_c = 0b_01010101_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
let value_c = [ONE, ONE, ZERO, ZERO];
smt.insert(key_c, value_c);
// --- delete the last inserted value -------------------------------------
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
assert_eq!(smt, smt2);
// --- delete the last inserted value -------------------------------------
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
assert_eq!(smt, smt1);
// --- delete the first inserted value ------------------------------------
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
assert_eq!(smt, smt0);
}
#[test]
fn tsmt_delete_48_same_32_bit_prefix() {
let mut smt = TieredSmt::default();
// test the case when all values share the same 32-bit prefix
// --- insert a value into the tree ---------------------------------------
let smt0 = smt.clone();
let raw_a = 0b_01010101_01010101_11111111_11111111_10010110_10010011_11100000_00000000_u64;
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
let value_a = [ONE, ONE, ONE, ONE];
smt.insert(key_a, value_a);
// --- insert another with the same 32-bit prefix into the tree -----------
let smt1 = smt.clone();
let raw_b = 0b_01010101_01010101_11111111_11111111_11010110_10010011_11100000_00000000_u64;
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
let value_b = [ONE, ONE, ONE, ZERO];
smt.insert(key_b, value_b);
// --- insert the 3rd value with the same 32-bit prefix into the tree -----
let smt2 = smt.clone();
let raw_c = 0b_01010101_01010101_11111111_11111111_11110110_10010011_11100000_00000000_u64;
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
let value_c = [ONE, ONE, ZERO, ZERO];
smt.insert(key_c, value_c);
// --- delete the last inserted value -------------------------------------
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
assert_eq!(smt, smt2);
// --- delete the last inserted value -------------------------------------
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
assert_eq!(smt, smt1);
// --- delete the first inserted value ------------------------------------
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
assert_eq!(smt, smt0);
}
#[test]
fn tsmt_delete_48_mixed_prefix() {
let mut smt = TieredSmt::default();
// test the case when some values share a 32-bit prefix and others share a 16-bit prefix
// --- insert a value into the tree ---------------------------------------
let smt0 = smt.clone();
let raw_a = 0b_01010101_01010101_11111111_11111111_10010110_10010011_11100000_00000000_u64;
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
let value_a = [ONE, ONE, ONE, ONE];
smt.insert(key_a, value_a);
// --- insert another with the same 16-bit prefix into the tree -----------
let smt1 = smt.clone();
let raw_b = 0b_01010101_01010101_01111111_11111111_10010110_10010011_11100000_00000000_u64;
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
let value_b = [ONE, ONE, ONE, ZERO];
smt.insert(key_b, value_b);
// --- insert a value with the same 32-bit prefix as the first value -----
let smt2 = smt.clone();
let raw_c = 0b_01010101_01010101_11111111_11111111_11010110_10010011_11100000_00000000_u64;
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
let value_c = [ONE, ONE, ZERO, ZERO];
smt.insert(key_c, value_c);
// --- insert another value with the same 32-bit prefix as the first value
let smt3 = smt.clone();
let raw_d = 0b_01010101_01010101_11111111_11111111_11110110_10010011_11100000_00000000_u64;
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_d)]);
let value_d = [ONE, ZERO, ZERO, ZERO];
smt.insert(key_d, value_d);
// --- delete the inserted values one-by-one ------------------------------
assert_eq!(smt.insert(key_d, EMPTY_WORD), value_d);
assert_eq!(smt, smt3);
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
assert_eq!(smt, smt2);
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
assert_eq!(smt, smt1);
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
assert_eq!(smt, smt0);
}
#[test]
fn tsmt_delete_64() {
let mut smt = TieredSmt::default();
// test the case when all values share the same 48-bit prefix
// --- insert a value into the tree ---------------------------------------
let smt0 = smt.clone();
let raw_a = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
let value_a = [ONE, ONE, ONE, ONE];
smt.insert(key_a, value_a);
// --- insert a value with the same 48-bit prefix into the tree -----------
let smt1 = smt.clone();
let raw_b = 0b_01010101_01010101_11111111_11111111_10110101_10101010_10111100_00000000_u64;
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
let value_b = [ONE, ONE, ONE, ZERO];
smt.insert(key_b, value_b);
// --- insert a value with the same 32-bit prefix into the tree -----------
let smt2 = smt.clone();
let raw_c = 0b_01010101_01010101_11111111_11111111_11111101_10101010_10111100_00000000_u64;
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
let value_c = [ONE, ONE, ZERO, ZERO];
smt.insert(key_c, value_c);
let smt3 = smt.clone();
let raw_d = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
let key_d = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_d)]);
let value_d = [ONE, ZERO, ZERO, ZERO];
smt.insert(key_d, value_d);
// --- delete the last inserted value -------------------------------------
assert_eq!(smt.insert(key_d, EMPTY_WORD), value_d);
assert_eq!(smt, smt3);
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
assert_eq!(smt, smt2);
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
assert_eq!(smt, smt1);
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
assert_eq!(smt, smt0);
}
#[test]
fn tsmt_delete_64_leaf_promotion() {
let mut smt = TieredSmt::default();
// --- delete from bottom tier (no promotion to upper tiers) --------------
// insert a value into the tree
let raw_a = 0b_01010101_01010101_11111111_11111111_10101010_10101010_11111111_00000000_u64;
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
let value_a = [ONE, ONE, ONE, ONE];
smt.insert(key_a, value_a);
// insert another value with a key having the same 64-bit prefix
let key_b = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
let value_b = [ONE, ONE, ONE, ZERO];
smt.insert(key_b, value_b);
// insert a value with a key which shared the same 48-bit prefix
let raw_c = 0b_01010101_01010101_11111111_11111111_10101010_10101010_00111111_00000000_u64;
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
let value_c = [ONE, ONE, ZERO, ZERO];
smt.insert(key_c, value_c);
// delete entry A and compare to the tree which was built from B and C
smt.insert(key_a, EMPTY_WORD);
let mut expected_smt = TieredSmt::default();
expected_smt.insert(key_b, value_b);
expected_smt.insert(key_c, value_c);
assert_eq!(smt, expected_smt);
// entries B and C should stay at depth 64
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 64);
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 64);
// --- delete from bottom tier (promotion to depth 48) --------------------
let mut smt = TieredSmt::default();
smt.insert(key_a, value_a);
smt.insert(key_b, value_b);
// insert a value with a key which shared the same 32-bit prefix
let raw_c = 0b_01010101_01010101_11111111_11111111_11101010_10101010_11111111_00000000_u64;
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
smt.insert(key_c, value_c);
// delete entry A and compare to the tree which was built from B and C
smt.insert(key_a, EMPTY_WORD);
let mut expected_smt = TieredSmt::default();
expected_smt.insert(key_b, value_b);
expected_smt.insert(key_c, value_c);
assert_eq!(smt, expected_smt);
// entry B moves to depth 48, entry C stays at depth 48
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 48);
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 48);
// --- delete from bottom tier (promotion to depth 32) --------------------
let mut smt = TieredSmt::default();
smt.insert(key_a, value_a);
smt.insert(key_b, value_b);
// insert a value with a key which shared the same 16-bit prefix
let raw_c = 0b_01010101_01010101_01111111_11111111_10101010_10101010_11111111_00000000_u64;
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
smt.insert(key_c, value_c);
// delete entry A and compare to the tree which was built from B and C
smt.insert(key_a, EMPTY_WORD);
let mut expected_smt = TieredSmt::default();
expected_smt.insert(key_b, value_b);
expected_smt.insert(key_c, value_c);
assert_eq!(smt, expected_smt);
// entry B moves to depth 32, entry C stays at depth 32
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 32);
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 32);
// --- delete from bottom tier (promotion to depth 16) --------------------
let mut smt = TieredSmt::default();
smt.insert(key_a, value_a);
smt.insert(key_b, value_b);
// insert a value with a key which shared prefix < 16 bits
let raw_c = 0b_01010101_01010100_11111111_11111111_10101010_10101010_11111111_00000000_u64;
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
smt.insert(key_c, value_c);
// delete entry A and compare to the tree which was built from B and C
smt.insert(key_a, EMPTY_WORD);
let mut expected_smt = TieredSmt::default();
expected_smt.insert(key_b, value_b);
expected_smt.insert(key_c, value_c);
assert_eq!(smt, expected_smt);
// entry B moves to depth 16, entry C stays at depth 16
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 16);
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 16);
}
#[test]
fn test_order_sensitivity() {
let raw = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000001_u64;
let value = [ONE; WORD_SIZE];
let key_1 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let key_2 = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw)]);
let mut smt_1 = TieredSmt::default();
smt_1.insert(key_1, value);
smt_1.insert(key_2, value);
smt_1.insert(key_2, EMPTY_WORD);
let mut smt_2 = TieredSmt::default();
smt_2.insert(key_1, value);
assert_eq!(smt_1.root(), smt_2.root());
}
// BOTTOM TIER TESTS
// ================================================================================================
#[test]
fn tsmt_bottom_tier() {
let mut smt = TieredSmt::default();
let mut store = MerkleStore::default();
// common prefix for the keys
let prefix = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
// --- insert the first value ---------------------------------------------
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(prefix)]);
let val_a = [ONE; WORD_SIZE];
smt.insert(key_a, val_a);
// --- insert the second value --------------------------------------------
// this key has the same 64-bit prefix and thus both values should end up in the same
// node at depth 64
let key_b = RpoDigest::from([ZERO, ONE, ONE, Felt::new(prefix)]);
let val_b = [Felt::new(2); WORD_SIZE];
smt.insert(key_b, val_b);
// --- build Merkle store with equivalent data ----------------------------
let index = NodeIndex::make(64, prefix);
// to build bottom leaf we sort by key starting with the least significant element, thus
// key_b is smaller than key_a.
let leaf_node = build_bottom_leaf_node(&[key_b, key_a], &[val_b, val_a]);
let mut tree_root = get_init_root();
tree_root = store.set_node(tree_root, index, leaf_node).unwrap().root;
// --- verify that data is consistent between store and tree --------------
assert_eq!(smt.root(), tree_root);
assert_eq!(smt.get_value(key_a), val_a);
assert_eq!(smt.get_value(key_b), val_b);
assert_eq!(smt.get_node(index).unwrap(), leaf_node);
let expected_path = store.get_path(tree_root, index).unwrap().path;
assert_eq!(smt.get_path(index).unwrap(), expected_path);
// make sure inner nodes match - the store contains more entries because it keeps track of
// all prior state - so, we don't check that the number of inner nodes is the same in both
let expected_nodes = get_non_empty_nodes(&store);
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
// make sure leaves are returned correctly
let smt_clone = smt.clone();
let mut leaves = smt_clone.bottom_leaves();
assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a)])));
assert_eq!(leaves.next(), None);
// --- update a leaf at the bottom tier -------------------------------------------------------
let val_a2 = [Felt::new(3); WORD_SIZE];
assert_eq!(smt.insert(key_a, val_a2), val_a);
let leaf_node = build_bottom_leaf_node(&[key_b, key_a], &[val_b, val_a2]);
store.set_node(tree_root, index, leaf_node).unwrap();
let expected_nodes = get_non_empty_nodes(&store);
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
let mut leaves = smt.bottom_leaves();
assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a2)])));
assert_eq!(leaves.next(), None);
}
#[test]
fn tsmt_bottom_tier_two() {
let mut smt = TieredSmt::default();
let mut store = MerkleStore::default();
// --- insert the first value ---------------------------------------------
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
let val_a = [ONE; WORD_SIZE];
smt.insert(key_a, val_a);
// --- insert the second value --------------------------------------------
// the key for this value has the same 48-bit prefix as the key for the first value,
// thus, on insertions, both should end up in different nodes at depth 64
let raw_b = 0b_10101010_10101010_00011111_11111111_10010110_10010011_01100000_00000000_u64;
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
let val_b = [Felt::new(2); WORD_SIZE];
smt.insert(key_b, val_b);
// --- build Merkle store with equivalent data ----------------------------
let mut tree_root = get_init_root();
let index_a = NodeIndex::make(64, raw_a);
let leaf_node_a = build_bottom_leaf_node(&[key_a], &[val_a]);
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
let index_b = NodeIndex::make(64, raw_b);
let leaf_node_b = build_bottom_leaf_node(&[key_b], &[val_b]);
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
// --- verify that data is consistent between store and tree --------------
assert_eq!(smt.root(), tree_root);
assert_eq!(smt.get_value(key_a), val_a);
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
assert_eq!(smt.get_value(key_b), val_b);
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
// make sure inner nodes match - the store contains more entries because it keeps track of
// all prior state - so, we don't check that the number of inner nodes is the same in both
let expected_nodes = get_non_empty_nodes(&store);
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
// make sure leaves are returned correctly
let mut leaves = smt.bottom_leaves();
assert_eq!(leaves.next(), Some((leaf_node_b, vec![(key_b, val_b)])));
assert_eq!(leaves.next(), Some((leaf_node_a, vec![(key_a, val_a)])));
assert_eq!(leaves.next(), None);
}
// GET PROOF TESTS
// ================================================================================================
/// Tests the membership and non-membership proof for a single at depth 64
#[test]
fn tsmt_get_proof_single_element_64() {
let mut smt = TieredSmt::default();
let raw_a = 0b_00000000_00000001_00000000_00000001_00000000_00000001_00000000_00000001_u64;
let key_a = [ONE, ONE, ONE, raw_a.into()].into();
let value_a = [ONE, ONE, ONE, ONE];
smt.insert(key_a, value_a);
// push element `a` to depth 64, by inserting another value that shares the 48-bit prefix
let raw_b = 0b_00000000_00000001_00000000_00000001_00000000_00000001_00000000_00000000_u64;
let key_b = [ONE, ONE, ONE, raw_b.into()].into();
smt.insert(key_b, [ONE, ONE, ONE, ONE]);
// verify the proof for element `a`
let proof = smt.prove(key_a);
assert!(proof.verify_membership(&key_a, &value_a, &smt.root()));
// check that a value that is not inserted in the tree produces a valid membership proof for the
// empty word
let key = [ZERO, ZERO, ZERO, ZERO].into();
let proof = smt.prove(key);
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
// check that a key that shared the 64-bit prefix with `a`, but is not inserted, also has a
// valid membership proof for the empty word
let key = [ONE, ONE, ZERO, raw_a.into()].into();
let proof = smt.prove(key);
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
}
#[test]
fn tsmt_get_proof() {
let mut smt = TieredSmt::default();
// --- insert a value into the tree ---------------------------------------
let raw_a = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
let value_a = [ONE, ONE, ONE, ONE];
smt.insert(key_a, value_a);
// --- insert a value with the same 48-bit prefix into the tree -----------
let raw_b = 0b_01010101_01010101_11111111_11111111_10110101_10101010_10111100_00000000_u64;
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
let value_b = [ONE, ONE, ONE, ZERO];
smt.insert(key_b, value_b);
let smt_alt = smt.clone();
// --- insert a value with the same 32-bit prefix into the tree -----------
let raw_c = 0b_01010101_01010101_11111111_11111111_11111101_10101010_10111100_00000000_u64;
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
let value_c = [ONE, ONE, ZERO, ZERO];
smt.insert(key_c, value_c);
// --- insert a value with the same 64-bit prefix as A into the tree ------
let raw_d = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
let key_d = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_d)]);
let value_d = [ONE, ZERO, ZERO, ZERO];
smt.insert(key_d, value_d);
// at this point the tree looks as follows:
// - A and D are located in the same node at depth 64.
// - B is located at depth 64 and shares the same 48-bit prefix with A and D.
// - C is located at depth 48 and shares the same 32-bit prefix with A, B, and D.
// --- generate proof for key A and test that it verifies correctly -------
let proof = smt.prove(key_a);
assert!(proof.verify_membership(&key_a, &value_a, &smt.root()));
assert!(!proof.verify_membership(&key_a, &value_b, &smt.root()));
assert!(!proof.verify_membership(&key_a, &EMPTY_WORD, &smt.root()));
assert!(!proof.verify_membership(&key_b, &value_a, &smt.root()));
assert!(!proof.verify_membership(&key_a, &value_a, &smt_alt.root()));
assert_eq!(proof.get(&key_a), Some(value_a));
assert_eq!(proof.get(&key_b), None);
// since A and D are stored in the same node, we should be able to use the proof to verify
// membership of D
assert!(proof.verify_membership(&key_d, &value_d, &smt.root()));
assert_eq!(proof.get(&key_d), Some(value_d));
// --- generate proof for key B and test that it verifies correctly -------
let proof = smt.prove(key_b);
assert!(proof.verify_membership(&key_b, &value_b, &smt.root()));
assert!(!proof.verify_membership(&key_b, &value_a, &smt.root()));
assert!(!proof.verify_membership(&key_b, &EMPTY_WORD, &smt.root()));
assert!(!proof.verify_membership(&key_a, &value_b, &smt.root()));
assert!(!proof.verify_membership(&key_b, &value_b, &smt_alt.root()));
assert_eq!(proof.get(&key_b), Some(value_b));
assert_eq!(proof.get(&key_a), None);
// --- generate proof for key C and test that it verifies correctly -------
let proof = smt.prove(key_c);
assert!(proof.verify_membership(&key_c, &value_c, &smt.root()));
assert!(!proof.verify_membership(&key_c, &value_a, &smt.root()));
assert!(!proof.verify_membership(&key_c, &EMPTY_WORD, &smt.root()));
assert!(!proof.verify_membership(&key_a, &value_c, &smt.root()));
assert!(!proof.verify_membership(&key_c, &value_c, &smt_alt.root()));
assert_eq!(proof.get(&key_c), Some(value_c));
assert_eq!(proof.get(&key_b), None);
// --- generate proof for key D and test that it verifies correctly -------
let proof = smt.prove(key_d);
assert!(proof.verify_membership(&key_d, &value_d, &smt.root()));
assert!(!proof.verify_membership(&key_d, &value_b, &smt.root()));
assert!(!proof.verify_membership(&key_d, &EMPTY_WORD, &smt.root()));
assert!(!proof.verify_membership(&key_b, &value_d, &smt.root()));
assert!(!proof.verify_membership(&key_d, &value_d, &smt_alt.root()));
assert_eq!(proof.get(&key_d), Some(value_d));
assert_eq!(proof.get(&key_b), None);
// since A and D are stored in the same node, we should be able to use the proof to verify
// membership of A
assert!(proof.verify_membership(&key_a, &value_a, &smt.root()));
assert_eq!(proof.get(&key_a), Some(value_a));
// --- generate proof for an empty key at depth 64 ------------------------
// this key has the same 48-bit prefix as A but is different from B
let raw = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000011_u64;
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let proof = smt.prove(key);
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
assert!(!proof.verify_membership(&key, &value_a, &smt.root()));
assert!(!proof.verify_membership(&key, &EMPTY_WORD, &smt_alt.root()));
assert_eq!(proof.get(&key), Some(EMPTY_WORD));
assert_eq!(proof.get(&key_b), None);
// the same proof should verify against any key with the same 64-bit prefix
let key2 = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw)]);
assert!(proof.verify_membership(&key2, &EMPTY_WORD, &smt.root()));
assert_eq!(proof.get(&key2), Some(EMPTY_WORD));
// but verifying if against a key with the same 63-bit prefix (or smaller) should fail
let raw3 = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000010_u64;
let key3 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw3)]);
assert!(!proof.verify_membership(&key3, &EMPTY_WORD, &smt.root()));
assert_eq!(proof.get(&key3), None);
// --- generate proof for an empty key at depth 48 ------------------------
// this key has the same 32-prefix as A, B, C, and D, but is different from C
let raw = 0b_01010101_01010101_11111111_11111111_00110101_10101010_11111100_00000000_u64;
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let proof = smt.prove(key);
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
assert!(!proof.verify_membership(&key, &value_a, &smt.root()));
assert!(!proof.verify_membership(&key, &EMPTY_WORD, &smt_alt.root()));
assert_eq!(proof.get(&key), Some(EMPTY_WORD));
assert_eq!(proof.get(&key_b), None);
// the same proof should verify against any key with the same 48-bit prefix
let raw2 = 0b_01010101_01010101_11111111_11111111_00110101_10101010_01111100_00000000_u64;
let key2 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw2)]);
assert!(proof.verify_membership(&key2, &EMPTY_WORD, &smt.root()));
assert_eq!(proof.get(&key2), Some(EMPTY_WORD));
// but verifying against a key with the same 47-bit prefix (or smaller) should fail
let raw3 = 0b_01010101_01010101_11111111_11111111_00110101_10101011_11111100_00000000_u64;
let key3 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw3)]);
assert!(!proof.verify_membership(&key3, &EMPTY_WORD, &smt.root()));
assert_eq!(proof.get(&key3), None);
}
// ERROR TESTS
// ================================================================================================
#[test]
fn tsmt_node_not_available() {
let mut smt = TieredSmt::default();
let raw = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let value = [ONE; WORD_SIZE];
// build an index which is just below the inserted leaf node
let index = NodeIndex::make(17, raw >> 47);
// since we haven't inserted the node yet, we should be able to get node and path to this index
assert!(smt.get_node(index).is_ok());
assert!(smt.get_path(index).is_ok());
smt.insert(key, value);
// but once the node is inserted, everything under it should be unavailable
assert!(smt.get_node(index).is_err());
assert!(smt.get_path(index).is_err());
let index = NodeIndex::make(32, raw >> 32);
assert!(smt.get_node(index).is_err());
assert!(smt.get_path(index).is_err());
let index = NodeIndex::make(34, raw >> 30);
assert!(smt.get_node(index).is_err());
assert!(smt.get_path(index).is_err());
let index = NodeIndex::make(50, raw >> 14);
assert!(smt.get_node(index).is_err());
assert!(smt.get_path(index).is_err());
let index = NodeIndex::make(64, raw);
assert!(smt.get_node(index).is_err());
assert!(smt.get_path(index).is_err());
}
// HELPER FUNCTIONS
// ================================================================================================
fn get_init_root() -> RpoDigest {
EmptySubtreeRoots::empty_hashes(64)[0]
}
fn build_leaf_node(key: RpoDigest, value: Word, depth: u8) -> RpoDigest {
Rpo256::merge_in_domain(&[key, value.into()], depth.into())
}
fn build_bottom_leaf_node(keys: &[RpoDigest], values: &[Word]) -> RpoDigest {
assert_eq!(keys.len(), values.len());
let mut elements = Vec::with_capacity(keys.len());
for (key, val) in keys.iter().zip(values.iter()) {
elements.extend_from_slice(key.as_elements());
elements.extend_from_slice(val.as_slice());
}
Rpo256::hash_elements(&elements)
}
fn get_non_empty_nodes(store: &MerkleStore) -> Vec<InnerNodeInfo> {
store
.inner_nodes()
.filter(|node| !is_empty_subtree(&node.value))
.collect::<Vec<_>>()
}
fn is_empty_subtree(node: &RpoDigest) -> bool {
EmptySubtreeRoots::empty_hashes(255).contains(node)
}

View File

@@ -0,0 +1,584 @@
use super::{get_key_prefix, BTreeMap, LeafNodeIndex, RpoDigest, StarkField, Vec, Word};
use crate::utils::vec;
use core::{
cmp::{Ord, Ordering},
ops::RangeBounds,
};
use winter_utils::collections::btree_map::Entry;
// CONSTANTS
// ================================================================================================
/// Depths at which leaves can exist in a tiered SMT.
const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
/// Maximum node depth. This is also the bottom tier of the tree.
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
// VALUE STORE
// ================================================================================================
/// A store for key-value pairs for a Tiered Sparse Merkle tree.
///
/// The store is organized in a [BTreeMap] where keys are 64 most significant bits of a key, and
/// the values are the corresponding key-value pairs (or a list of key-value pairs if more that
/// a single key-value pair shares the same 64-bit prefix).
///
/// The store supports lookup by the full key (i.e. [RpoDigest]) as well as by the 64-bit key
/// prefix.
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct ValueStore {
values: BTreeMap<u64, StoreEntry>,
}
impl ValueStore {
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns a reference to the value stored under the specified key, or None if there is no
/// value associated with the specified key.
pub fn get(&self, key: &RpoDigest) -> Option<&Word> {
let prefix = get_key_prefix(key);
self.values.get(&prefix).and_then(|entry| entry.get(key))
}
/// Returns the first key-value pair such that the key prefix is greater than or equal to the
/// specified prefix.
pub fn get_first(&self, prefix: u64) -> Option<&(RpoDigest, Word)> {
self.range(prefix..).next()
}
/// Returns the first key-value pair such that the key prefix is greater than or equal to the
/// specified prefix and the key value is not equal to the exclude_key value.
pub fn get_first_filtered(
&self,
prefix: u64,
exclude_key: &RpoDigest,
) -> Option<&(RpoDigest, Word)> {
self.range(prefix..).find(|(key, _)| key != exclude_key)
}
/// Returns a vector with key-value pairs for all keys with the specified 64-bit prefix, or
/// None if no keys with the specified prefix are present in this store.
pub fn get_all(&self, prefix: u64) -> Option<Vec<(RpoDigest, Word)>> {
self.values.get(&prefix).map(|entry| match entry {
StoreEntry::Single(kv_pair) => vec![*kv_pair],
StoreEntry::List(kv_pairs) => kv_pairs.clone(),
})
}
/// Returns information about a sibling of a leaf node with the specified index, but only if
/// this is the only sibling the leaf has in some subtree starting at the first tier.
///
/// For example, if `index` is an index at depth 32, and there is a leaf node at depth 32 with
/// the same root at depth 16 as `index`, we say that this leaf is a lone sibling.
///
/// The returned tuple contains: they key-value pair of the sibling as well as the index of
/// the node for the root of the common subtree in which both nodes are leaves.
///
/// This method assumes that the key-value pair for the specified index has already been
/// removed from the store.
pub fn get_lone_sibling(
&self,
index: LeafNodeIndex,
) -> Option<(&RpoDigest, &Word, LeafNodeIndex)> {
// iterate over tiers from top to bottom, looking at the tiers which are strictly above
// the depth of the index. This implies that only tiers at depth 32 and 48 will be
// considered. For each tier, check if the parent of the index at the higher tier
// contains a single node. The fist tier (depth 16) is excluded because we cannot move
// nodes at depth 16 to a higher tier. This implies that nodes at the first tier will
// never have "lone siblings".
for &tier_depth in TIER_DEPTHS.iter().filter(|&t| index.depth() > *t) {
// compute the index of the root at a higher tier
let mut parent_index = index;
parent_index.move_up_to(tier_depth);
// find the lone sibling, if any; we need to handle the "last node" at a given tier
// separately specify the bounds for the search correctly.
let start_prefix = parent_index.value() << (MAX_DEPTH - tier_depth);
let sibling = if start_prefix.leading_ones() as u8 == tier_depth {
let mut iter = self.range(start_prefix..);
iter.next().filter(|_| iter.next().is_none())
} else {
let end_prefix = (parent_index.value() + 1) << (MAX_DEPTH - tier_depth);
let mut iter = self.range(start_prefix..end_prefix);
iter.next().filter(|_| iter.next().is_none())
};
if let Some((key, value)) = sibling {
return Some((key, value, parent_index));
}
}
None
}
/// Returns an iterator over all key-value pairs in this store.
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
self.values.iter().flat_map(|(_, entry)| entry.iter())
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Inserts the specified key-value pair into this store and returns the value previously
/// associated with the specified key.
///
/// If no value was previously associated with the specified key, None is returned.
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
let prefix = get_key_prefix(&key);
match self.values.entry(prefix) {
Entry::Occupied(mut entry) => entry.get_mut().insert(key, value),
Entry::Vacant(entry) => {
entry.insert(StoreEntry::new(key, value));
None
}
}
}
/// Removes the key-value pair for the specified key from this store and returns the value
/// associated with this key.
///
/// If no value was associated with the specified key, None is returned.
pub fn remove(&mut self, key: &RpoDigest) -> Option<Word> {
let prefix = get_key_prefix(key);
match self.values.entry(prefix) {
Entry::Occupied(mut entry) => {
let (value, remove_entry) = entry.get_mut().remove(key);
if remove_entry {
entry.remove_entry();
}
value
}
Entry::Vacant(_) => None,
}
}
// HELPER METHODS
// --------------------------------------------------------------------------------------------
/// Returns an iterator over all key-value pairs contained in this store such that the most
/// significant 64 bits of the key lay within the specified bounds.
///
/// The order of iteration is from the smallest to the largest key.
fn range<R: RangeBounds<u64>>(&self, bounds: R) -> impl Iterator<Item = &(RpoDigest, Word)> {
self.values.range(bounds).flat_map(|(_, entry)| entry.iter())
}
}
// VALUE NODE
// ================================================================================================
/// An entry in the [ValueStore].
///
/// An entry can contain either a single key-value pair or a vector of key-value pairs sorted by
/// key.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum StoreEntry {
Single((RpoDigest, Word)),
List(Vec<(RpoDigest, Word)>),
}
impl StoreEntry {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
/// Returns a new [StoreEntry] instantiated with a single key-value pair.
pub fn new(key: RpoDigest, value: Word) -> Self {
Self::Single((key, value))
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the value associated with the specified key, or None if this entry does not contain
/// a value associated with the specified key.
pub fn get(&self, key: &RpoDigest) -> Option<&Word> {
match self {
StoreEntry::Single(kv_pair) => {
if kv_pair.0 == *key {
Some(&kv_pair.1)
} else {
None
}
}
StoreEntry::List(kv_pairs) => {
match kv_pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, key)) {
Ok(pos) => Some(&kv_pairs[pos].1),
Err(_) => None,
}
}
}
}
/// Returns an iterator over all key-value pairs in this entry.
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
EntryIterator { entry: self, pos: 0 }
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Inserts the specified key-value pair into this entry and returns the value previously
/// associated with the specified key, or None if no value was associated with the specified
/// key.
///
/// If a new key is inserted, this will also transform a `SingleEntry` into a `ListEntry`.
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
match self {
StoreEntry::Single(kv_pair) => {
// if the key is already in this entry, update the value and return
if kv_pair.0 == key {
let old_value = kv_pair.1;
kv_pair.1 = value;
return Some(old_value);
}
// transform the entry into a list entry, and make sure the key-value pairs
// are sorted by key
let mut pairs = vec![*kv_pair, (key, value)];
pairs.sort_by(|a, b| cmp_digests(&a.0, &b.0));
*self = StoreEntry::List(pairs);
None
}
StoreEntry::List(pairs) => {
match pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, &key)) {
Ok(pos) => {
let old_value = pairs[pos].1;
pairs[pos].1 = value;
Some(old_value)
}
Err(pos) => {
pairs.insert(pos, (key, value));
None
}
}
}
}
}
/// Removes the key-value pair with the specified key from this entry, and returns the value
/// of the removed pair. If the entry did not contain a key-value pair for the specified key,
/// None is returned.
///
/// If the last last key-value pair was removed from the entry, the second tuple value will
/// be set to true.
pub fn remove(&mut self, key: &RpoDigest) -> (Option<Word>, bool) {
match self {
StoreEntry::Single(kv_pair) => {
if kv_pair.0 == *key {
(Some(kv_pair.1), true)
} else {
(None, false)
}
}
StoreEntry::List(kv_pairs) => {
match kv_pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, key)) {
Ok(pos) => {
let kv_pair = kv_pairs.remove(pos);
if kv_pairs.len() == 1 {
*self = StoreEntry::Single(kv_pairs[0]);
}
(Some(kv_pair.1), false)
}
Err(_) => (None, false),
}
}
}
}
}
/// A custom iterator over key-value pairs of a [StoreEntry].
///
/// For a `SingleEntry` this returns only one value, but for `ListEntry`, this iterates over the
/// entire list of key-value pairs.
pub struct EntryIterator<'a> {
entry: &'a StoreEntry,
pos: usize,
}
impl<'a> Iterator for EntryIterator<'a> {
type Item = &'a (RpoDigest, Word);
fn next(&mut self) -> Option<Self::Item> {
match self.entry {
StoreEntry::Single(kv_pair) => {
if self.pos == 0 {
self.pos = 1;
Some(kv_pair)
} else {
None
}
}
StoreEntry::List(kv_pairs) => {
if self.pos >= kv_pairs.len() {
None
} else {
let kv_pair = &kv_pairs[self.pos];
self.pos += 1;
Some(kv_pair)
}
}
}
}
}
// HELPER FUNCTIONS
// ================================================================================================
/// Compares two digests element-by-element using their integer representations starting with the
/// most significant element.
fn cmp_digests(d1: &RpoDigest, d2: &RpoDigest) -> Ordering {
let d1 = Word::from(d1);
let d2 = Word::from(d2);
for (v1, v2) in d1.iter().zip(d2.iter()).rev() {
let v1 = v1.as_int();
let v2 = v2.as_int();
if v1 != v2 {
return v1.cmp(&v2);
}
}
Ordering::Equal
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::{LeafNodeIndex, RpoDigest, StoreEntry, ValueStore};
use crate::{Felt, ONE, WORD_SIZE, ZERO};
#[test]
fn test_insert() {
let mut store = ValueStore::default();
// insert the first key-value pair into the store
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
let value_a = [ONE; WORD_SIZE];
assert!(store.insert(key_a, value_a).is_none());
assert_eq!(store.values.len(), 1);
let entry = store.values.get(&raw_a).unwrap();
let expected_entry = StoreEntry::Single((key_a, value_a));
assert_eq!(entry, &expected_entry);
// insert a key-value pair with a different key into the store; since the keys are
// different, another entry is added to the values map
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
let value_b = [ONE, ZERO, ONE, ZERO];
assert!(store.insert(key_b, value_b).is_none());
assert_eq!(store.values.len(), 2);
let entry1 = store.values.get(&raw_a).unwrap();
let expected_entry1 = StoreEntry::Single((key_a, value_a));
assert_eq!(entry1, &expected_entry1);
let entry2 = store.values.get(&raw_b).unwrap();
let expected_entry2 = StoreEntry::Single((key_b, value_b));
assert_eq!(entry2, &expected_entry2);
// insert a key-value pair with the same 64-bit key prefix as the first key; this should
// transform the first entry into a List entry
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
let value_c = [ONE, ONE, ZERO, ZERO];
assert!(store.insert(key_c, value_c).is_none());
assert_eq!(store.values.len(), 2);
let entry1 = store.values.get(&raw_a).unwrap();
let expected_entry1 = StoreEntry::List(vec![(key_c, value_c), (key_a, value_a)]);
assert_eq!(entry1, &expected_entry1);
let entry2 = store.values.get(&raw_b).unwrap();
let expected_entry2 = StoreEntry::Single((key_b, value_b));
assert_eq!(entry2, &expected_entry2);
// replace values for keys a and b
let value_a2 = [ONE, ONE, ONE, ZERO];
let value_b2 = [ZERO, ZERO, ZERO, ONE];
assert_eq!(store.insert(key_a, value_a2), Some(value_a));
assert_eq!(store.values.len(), 2);
assert_eq!(store.insert(key_b, value_b2), Some(value_b));
assert_eq!(store.values.len(), 2);
let entry1 = store.values.get(&raw_a).unwrap();
let expected_entry1 = StoreEntry::List(vec![(key_c, value_c), (key_a, value_a2)]);
assert_eq!(entry1, &expected_entry1);
let entry2 = store.values.get(&raw_b).unwrap();
let expected_entry2 = StoreEntry::Single((key_b, value_b2));
assert_eq!(entry2, &expected_entry2);
// insert one more key-value pair with the same 64-bit key-prefix as the first key
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
let value_d = [ZERO, ONE, ZERO, ZERO];
assert!(store.insert(key_d, value_d).is_none());
assert_eq!(store.values.len(), 2);
let entry1 = store.values.get(&raw_a).unwrap();
let expected_entry1 =
StoreEntry::List(vec![(key_c, value_c), (key_a, value_a2), (key_d, value_d)]);
assert_eq!(entry1, &expected_entry1);
let entry2 = store.values.get(&raw_b).unwrap();
let expected_entry2 = StoreEntry::Single((key_b, value_b2));
assert_eq!(entry2, &expected_entry2);
}
#[test]
fn test_remove() {
// populate the value store
let mut store = ValueStore::default();
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
let value_a = [ONE; WORD_SIZE];
store.insert(key_a, value_a);
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
let value_b = [ONE, ZERO, ONE, ZERO];
store.insert(key_b, value_b);
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
let value_c = [ONE, ONE, ZERO, ZERO];
store.insert(key_c, value_c);
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
let value_d = [ZERO, ONE, ZERO, ZERO];
store.insert(key_d, value_d);
assert_eq!(store.values.len(), 2);
let entry1 = store.values.get(&raw_a).unwrap();
let expected_entry1 =
StoreEntry::List(vec![(key_c, value_c), (key_a, value_a), (key_d, value_d)]);
assert_eq!(entry1, &expected_entry1);
let entry2 = store.values.get(&raw_b).unwrap();
let expected_entry2 = StoreEntry::Single((key_b, value_b));
assert_eq!(entry2, &expected_entry2);
// remove non-existent keys
let key_e = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_a)]);
assert!(store.remove(&key_e).is_none());
let raw_f = 0b_11111110_11111111_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_f = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_f)]);
assert!(store.remove(&key_f).is_none());
// remove keys from the list entry
assert_eq!(store.remove(&key_c).unwrap(), value_c);
let entry1 = store.values.get(&raw_a).unwrap();
let expected_entry1 = StoreEntry::List(vec![(key_a, value_a), (key_d, value_d)]);
assert_eq!(entry1, &expected_entry1);
assert_eq!(store.remove(&key_a).unwrap(), value_a);
let entry1 = store.values.get(&raw_a).unwrap();
let expected_entry1 = StoreEntry::Single((key_d, value_d));
assert_eq!(entry1, &expected_entry1);
assert_eq!(store.remove(&key_d).unwrap(), value_d);
assert!(store.values.get(&raw_a).is_none());
assert_eq!(store.values.len(), 1);
// remove a key from a single entry
assert_eq!(store.remove(&key_b).unwrap(), value_b);
assert!(store.values.get(&raw_b).is_none());
assert_eq!(store.values.len(), 0);
}
#[test]
fn test_range() {
// populate the value store
let mut store = ValueStore::default();
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
let value_a = [ONE; WORD_SIZE];
store.insert(key_a, value_a);
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
let value_b = [ONE, ZERO, ONE, ZERO];
store.insert(key_b, value_b);
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
let value_c = [ONE, ONE, ZERO, ZERO];
store.insert(key_c, value_c);
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
let value_d = [ZERO, ONE, ZERO, ZERO];
store.insert(key_d, value_d);
let raw_e = 0b_10101000_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_e = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_e)]);
let value_e = [ZERO, ZERO, ZERO, ONE];
store.insert(key_e, value_e);
// check the entire range
let mut iter = store.range(..u64::MAX);
assert_eq!(iter.next(), Some(&(key_e, value_e)));
assert_eq!(iter.next(), Some(&(key_c, value_c)));
assert_eq!(iter.next(), Some(&(key_a, value_a)));
assert_eq!(iter.next(), Some(&(key_d, value_d)));
assert_eq!(iter.next(), Some(&(key_b, value_b)));
assert_eq!(iter.next(), None);
// check all but e
let mut iter = store.range(raw_a..u64::MAX);
assert_eq!(iter.next(), Some(&(key_c, value_c)));
assert_eq!(iter.next(), Some(&(key_a, value_a)));
assert_eq!(iter.next(), Some(&(key_d, value_d)));
assert_eq!(iter.next(), Some(&(key_b, value_b)));
assert_eq!(iter.next(), None);
// check all but e and b
let mut iter = store.range(raw_a..raw_b);
assert_eq!(iter.next(), Some(&(key_c, value_c)));
assert_eq!(iter.next(), Some(&(key_a, value_a)));
assert_eq!(iter.next(), Some(&(key_d, value_d)));
assert_eq!(iter.next(), None);
}
#[test]
fn test_get_lone_sibling() {
// populate the value store
let mut store = ValueStore::default();
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
let value_a = [ONE; WORD_SIZE];
store.insert(key_a, value_a);
let raw_b = 0b_11111111_11111111_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
let value_b = [ONE, ZERO, ONE, ZERO];
store.insert(key_b, value_b);
// check sibling node for `a`
let index = LeafNodeIndex::make(32, 0b_10101010_10101010_00011111_11111110);
let parent_index = LeafNodeIndex::make(16, 0b_10101010_10101010);
assert_eq!(store.get_lone_sibling(index), Some((&key_a, &value_a, parent_index)));
// check sibling node for `b`
let index = LeafNodeIndex::make(32, 0b_11111111_11111111_00011111_11111111);
let parent_index = LeafNodeIndex::make(16, 0b_11111111_11111111);
assert_eq!(store.get_lone_sibling(index), Some((&key_b, &value_b, parent_index)));
// check some other sibling for some other index
let index = LeafNodeIndex::make(32, 0b_11101010_10101010);
assert_eq!(store.get_lone_sibling(index), None);
}
}

View File

@@ -2,7 +2,7 @@
pub use winter_crypto::{DefaultRandomCoin as WinterRandomCoin, RandomCoin, RandomCoinError};
use crate::{Felt, FieldElement, Word, ZERO};
use crate::{Felt, FieldElement, StarkField, Word, ZERO};
mod rpo;
pub use rpo::RpoRandomCoin;

View File

@@ -1,4 +1,4 @@
use super::{Felt, FeltRng, FieldElement, Word, ZERO};
use super::{Felt, FeltRng, FieldElement, StarkField, Word, ZERO};
use crate::{
hash::rpo::{Rpo256, RpoDigest},
utils::{
@@ -19,7 +19,7 @@ const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.star
// RPO RANDOM COIN
// ================================================================================================
/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
/// described in <https://eprint.iacr.org/2011/499.pdf>.
/// described in https://eprint.iacr.org/2011/499.pdf.
///
/// The simplification is related to the following facts:
/// 1. A call to the reseed method implies one and only one call to the permutation function.

31
src/utils/diff.rs Normal file
View File

@@ -0,0 +1,31 @@
/// A trait for computing the difference between two objects.
pub trait Diff<K: Ord + Clone, V: Clone> {
/// The type that describes the difference between two objects.
type DiffType;
/// Returns a [Self::DiffType] object that represents the difference between this object and
/// other.
fn diff(&self, other: &Self) -> Self::DiffType;
}
/// A trait for applying the difference between two objects.
pub trait ApplyDiff<K: Ord + Clone, V: Clone> {
/// The type that describes the difference between two objects.
type DiffType;
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
fn apply(&mut self, diff: Self::DiffType);
}
/// A trait for applying the difference between two objects with the possibility of failure.
pub trait TryApplyDiff<K: Ord + Clone, V: Clone> {
/// The type that describes the difference between two objects.
type DiffType;
/// An error type that can be returned if the changes cannot be applied.
type Error;
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
/// Returns an error if the changes cannot be applied.
fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), Self::Error>;
}

View File

@@ -1,3 +1,4 @@
use super::{collections::ApplyDiff, diff::Diff};
use core::cell::RefCell;
use winter_utils::{
collections::{btree_map::IntoIter, BTreeMap, BTreeSet},
@@ -208,6 +209,74 @@ impl<K: Clone + Ord, V: Clone> IntoIterator for RecordingMap<K, V> {
}
}
// KV MAP DIFF
// ================================================================================================
/// [KvMapDiff] stores the difference between two key-value maps.
///
/// The [KvMapDiff] is composed of two parts:
/// - `updates` - a map of key-value pairs that were updated in the second map compared to the
/// first map. This includes new key-value pairs.
/// - `removed` - a set of keys that were removed from the second map compared to the first map.
#[derive(Debug, Clone)]
pub struct KvMapDiff<K, V> {
pub updated: BTreeMap<K, V>,
pub removed: BTreeSet<K>,
}
impl<K, V> KvMapDiff<K, V> {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
/// Creates a new [KvMapDiff] instance.
pub fn new() -> Self {
KvMapDiff {
updated: BTreeMap::new(),
removed: BTreeSet::new(),
}
}
}
impl<K, V> Default for KvMapDiff<K, V> {
fn default() -> Self {
Self::new()
}
}
impl<K: Ord + Clone, V: Clone + PartialEq, T: KvMap<K, V>> Diff<K, V> for T {
type DiffType = KvMapDiff<K, V>;
fn diff(&self, other: &T) -> Self::DiffType {
let mut diff = KvMapDiff::default();
for (k, v) in self.iter() {
if let Some(other_value) = other.get(k) {
if v != other_value {
diff.updated.insert(k.clone(), other_value.clone());
}
} else {
diff.removed.insert(k.clone());
}
}
for (k, v) in other.iter() {
if self.get(k).is_none() {
diff.updated.insert(k.clone(), v.clone());
}
}
diff
}
}
impl<K: Ord + Clone, V: Clone, T: KvMap<K, V>> ApplyDiff<K, V> for T {
type DiffType = KvMapDiff<K, V>;
fn apply(&mut self, diff: Self::DiffType) {
for (k, v) in diff.updated {
self.insert(k, v);
}
for k in diff.removed {
self.remove(&k);
}
}
}
// TESTS
// ================================================================================================
@@ -401,4 +470,35 @@ mod tests {
}
}
}
#[test]
fn test_kv_map_diff() {
let mut initial_state = ITEMS.into_iter().collect::<BTreeMap<_, _>>();
let mut map = RecordingMap::new(initial_state.clone());
// remove an item that exists
let key = 0;
let _value = map.remove(&key).unwrap();
// add a new item
let key = 100;
let value = 100;
map.insert(key, value);
// update an existing item
let key = 1;
let value = 100;
map.insert(key, value);
// compute a diff
let diff = initial_state.diff(map.inner());
assert!(diff.updated.len() == 2);
assert!(diff.updated.iter().all(|(k, v)| [(100, 100), (1, 100)].contains(&(*k, *v))));
assert!(diff.removed.len() == 1);
assert!(diff.removed.first() == Some(&0));
// apply the diff to the initial state and assert the contents are the same as the map
initial_state.apply(diff);
assert!(initial_state.iter().eq(map.iter()));
}
}

View File

@@ -9,6 +9,7 @@ pub use alloc::{format, vec};
#[cfg(feature = "std")]
pub use std::{format, vec};
mod diff;
mod kv_map;
// RE-EXPORTS
@@ -19,6 +20,7 @@ pub use winter_utils::{
};
pub mod collections {
pub use super::diff::*;
pub use super::kv_map::*;
pub use winter_utils::collections::*;
}