4 Commits

Author SHA1 Message Date
Al-Kindi-0
d2a6739605 feat: derandomize RPO-STARK DSA (#358) 2025-01-08 11:42:23 -08:00
Al-Kindi-0
cae87a2790 chore: add signatures benchmarks (#354) 2024-12-12 19:58:33 -08:00
Al-Kindi-0
335c50f54d feat: implement RPO STARK-based signature DSA (with zero knowledge) (#349) 2024-12-12 19:33:24 -08:00
Qyriad
b151773b0d feat: implement concurrent Smt construction (#341)
* merkle: add parent() helper function on NodeIndex
* smt: add pairs_to_leaf() to trait
* smt: add sorted_pairs_to_leaves() and test for it
* smt: implement single subtree-8 hashing, w/ benchmarks & tests

This will be composed into depth-8-subtree-based computation of entire
sparse Merkle trees.

* merkle: add a benchmark for constructing 256-balanced trees

This is intended for comparison with the benchmarks from the previous
commit. This benchmark represents the theoretical perfect-efficiency
performance we could possibly (but impractically) get for computing
depth-8 sparse Merkle subtrees.

* smt: test that SparseMerkleTree::build_subtree() is composable

* smt: test that subtree logic can correctly construct an entire tree

This commit ensures that `SparseMerkleTree::build_subtree()` can
correctly compose into building an entire sparse Merkle tree, without
yet getting into potential complications concurrency introduces.

* smt: implement test for basic parallelized subtree computation w/ rayon

Building on the previous commit, this commit implements a test proving
that `SparseMerkleTree::build_subtree()` can be composed into itself not
just concurrently, but in parallel, without issue.

* smt: add from_raw_parts() to trait interface

This commit adds a new required method to the SparseMerkleTree trait,
to allow generic construction from pre-computed parts.

This will be used to add a generic version of `with_entries()` in a
later commit.

* smt: add parallel constructors to Smt and SimpleSmt

What the previous few commits have been leading up to: SparseMerkleTree
now has a function to construct the tree from existing data in parallel.
This is significantly faster than the singlethreaded equivalent.
Benchmarks incoming!

---------

Co-authored-by: krushimir <krushimir@reilabs.co>
Co-authored-by: krushimir <kresimir.grofelnik@reilabs.io>
2024-12-04 10:54:41 -08:00
33 changed files with 2283 additions and 517 deletions

View File

@@ -1,17 +1,16 @@
## 0.13.1 (2024-12-26)
- Generate reverse mutations set on applying of mutations set, implemented serialization of `MutationsSet` (#355).
## 0.13.0 (2024-11-24) ## 0.13.0 (2024-11-24)
- Fixed a bug in the implementation of `draw_integers` for `RpoRandomCoin` (#343). - Fixed a bug in the implementation of `draw_integers` for `RpoRandomCoin` (#343).
- [BREAKING] Refactor error messages and use `thiserror` to derive errors (#344). - [BREAKING] Refactor error messages and use `thiserror` to derive errors (#344).
- [BREAKING] Updated Winterfell dependency to v0.11 (#346). - [BREAKING] Updated Winterfell dependency to v0.11 (#346).
- Added RPO-STARK based DSA (#349).
- Added benchmarks for DSA implementations (#354).
- Implemented deterministic RPO-STARK based DSA (#358).
## 0.12.0 (2024-10-30) ## 0.12.0 (2024-10-30)
- [BREAKING] Updated Winterfell dependency to v0.10 (#338). - [BREAKING] Updated Winterfell dependency to v0.10 (#338).
- Added parallel implementation of `Smt::with_entries()` with significantly better performance when the `concurrent` feature is enabled (#341).
## 0.11.0 (2024-10-17) ## 0.11.0 (2024-10-17)

164
Cargo.lock generated
View File

@@ -92,18 +92,18 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]] [[package]]
name = "bit-set" name = "bit-set"
version = "0.8.0" version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
dependencies = [ dependencies = [
"bit-vec", "bit-vec",
] ]
[[package]] [[package]]
name = "bit-vec" name = "bit-vec"
version = "0.8.0" version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]] [[package]]
name = "bitflags" name = "bitflags"
@@ -153,9 +153,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.2.5" version = "1.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c31a0499c1dc64f458ad13872de75c0eb7e3fdb0e67964610c914b034fc5956e" checksum = "27f657647bcff5394bf56c7317665bbf790a137a50eaaa5c6bfbb9e27a518f2d"
dependencies = [ dependencies = [
"jobserver", "jobserver",
"libc", "libc",
@@ -294,9 +294,9 @@ dependencies = [
[[package]] [[package]]
name = "crossbeam-deque" name = "crossbeam-deque"
version = "0.8.6" version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
dependencies = [ dependencies = [
"crossbeam-epoch", "crossbeam-epoch",
"crossbeam-utils", "crossbeam-utils",
@@ -313,9 +313,9 @@ dependencies = [
[[package]] [[package]]
name = "crossbeam-utils" name = "crossbeam-utils"
version = "0.8.21" version = "0.8.20"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
[[package]] [[package]]
name = "crunchy" name = "crunchy"
@@ -496,9 +496,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.169" version = "0.2.168"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" checksum = "5aaeb2981e0606ca11d79718f8bb01164f1d6ed75080182d3abf017e6d244b6d"
[[package]] [[package]]
name = "libm" name = "libm"
@@ -526,7 +526,7 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]] [[package]]
name = "miden-crypto" name = "miden-crypto"
version = "0.13.1" version = "0.14.0"
dependencies = [ dependencies = [
"assert_matches", "assert_matches",
"blake3", "blake3",
@@ -542,14 +542,18 @@ dependencies = [
"rand", "rand",
"rand_chacha", "rand_chacha",
"rand_core", "rand_core",
"rayon",
"seq-macro", "seq-macro",
"serde", "serde",
"sha3", "sha3",
"thiserror", "thiserror",
"winter-air",
"winter-crypto", "winter-crypto",
"winter-math", "winter-math",
"winter-prover",
"winter-rand-utils", "winter-rand-utils",
"winter-utils", "winter-utils",
"winter-verifier",
] ]
[[package]] [[package]]
@@ -638,6 +642,12 @@ version = "11.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9"
[[package]]
name = "pin-project-lite"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff"
[[package]] [[package]]
name = "plotters" name = "plotters"
version = "0.3.7" version = "0.3.7"
@@ -686,9 +696,9 @@ dependencies = [
[[package]] [[package]]
name = "proptest" name = "proptest"
version = "1.6.0" version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14cae93065090804185d3b75f0bf93b8eeda30c7a9b4a33d3bdb3988d6229e50" checksum = "b4c2511913b88df1637da85cc8d96ec8e43a3f8bb8ccb71ee1ac240d6f3df58d"
dependencies = [ dependencies = [
"bit-set", "bit-set",
"bit-vec", "bit-vec",
@@ -712,9 +722,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.38" version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
] ]
@@ -875,9 +885,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.134" version = "1.0.133"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d" checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377"
dependencies = [ dependencies = [
"itoa", "itoa",
"memchr", "memchr",
@@ -909,9 +919,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.92" version = "2.0.90"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70ae51629bf965c5c098cc9e87908a3df5301051a9e087d6f9bef5c9771ed126" checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@@ -933,18 +943,18 @@ dependencies = [
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "2.0.9" version = "2.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc" checksum = "8fec2a1820ebd077e2b90c4df007bebf344cd394098a13c563957d0afc83ea47"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl",
] ]
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "2.0.9" version = "2.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4" checksum = "d65750cab40f4ff1929fb1ba509e9914eb756131cef4210da8d5d700d26f6312"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@@ -961,6 +971,34 @@ dependencies = [
"serde_json", "serde_json",
] ]
[[package]]
name = "tracing"
version = "0.1.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
dependencies = [
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tracing-core"
version = "0.1.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c"
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.17.0" version = "1.17.0"
@@ -1171,33 +1209,82 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winter-air"
version = "0.11.0"
source = "git+https://github.com/Al-Kindi-0/winterfell?branch=al-zk#5bafedbc2ba00cf85c6182725754547f6cddafc3"
dependencies = [
"libm",
"winter-crypto",
"winter-fri",
"winter-math",
"winter-utils",
]
[[package]] [[package]]
name = "winter-crypto" name = "winter-crypto"
version = "0.11.0" version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/Al-Kindi-0/winterfell?branch=al-zk#5bafedbc2ba00cf85c6182725754547f6cddafc3"
checksum = "67c57748fd2da77742be601f03eda639ff6046879738fd1faae86e80018263cb"
dependencies = [ dependencies = [
"blake3", "blake3",
"rand",
"rand_chacha",
"sha3", "sha3",
"winter-math", "winter-math",
"winter-utils", "winter-utils",
] ]
[[package]]
name = "winter-fri"
version = "0.11.0"
source = "git+https://github.com/Al-Kindi-0/winterfell?branch=al-zk#5bafedbc2ba00cf85c6182725754547f6cddafc3"
dependencies = [
"rand",
"rand_chacha",
"winter-crypto",
"winter-math",
"winter-utils",
]
[[package]] [[package]]
name = "winter-math" name = "winter-math"
version = "0.11.0" version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/Al-Kindi-0/winterfell?branch=al-zk#5bafedbc2ba00cf85c6182725754547f6cddafc3"
checksum = "6020c17839fa107ce4a7cc178e407ebbc24adfac1980f4fa2111198e052700ab"
dependencies = [ dependencies = [
"serde", "serde",
"winter-utils", "winter-utils",
] ]
[[package]]
name = "winter-maybe-async"
version = "0.11.0"
source = "git+https://github.com/Al-Kindi-0/winterfell?branch=al-zk#5bafedbc2ba00cf85c6182725754547f6cddafc3"
dependencies = [
"quote",
"syn",
]
[[package]]
name = "winter-prover"
version = "0.11.0"
source = "git+https://github.com/Al-Kindi-0/winterfell?branch=al-zk#5bafedbc2ba00cf85c6182725754547f6cddafc3"
dependencies = [
"rand",
"rand_chacha",
"tracing",
"winter-air",
"winter-crypto",
"winter-fri",
"winter-math",
"winter-maybe-async",
"winter-rand-utils",
"winter-utils",
]
[[package]] [[package]]
name = "winter-rand-utils" name = "winter-rand-utils"
version = "0.11.0" version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/Al-Kindi-0/winterfell?branch=al-zk#5bafedbc2ba00cf85c6182725754547f6cddafc3"
checksum = "226e4c455f6eb72f64ac6eeb7642df25e21ff2280a4f6b09db75392ad6b390ef"
dependencies = [ dependencies = [
"rand", "rand",
"winter-utils", "winter-utils",
@@ -1206,8 +1293,19 @@ dependencies = [
[[package]] [[package]]
name = "winter-utils" name = "winter-utils"
version = "0.11.0" version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/Al-Kindi-0/winterfell?branch=al-zk#5bafedbc2ba00cf85c6182725754547f6cddafc3"
checksum = "1507ef312ea5569d54c2c7446a18b82143eb2a2e21f5c3ec7cfbe8200c03bd7c"
[[package]]
name = "winter-verifier"
version = "0.11.0"
source = "git+https://github.com/Al-Kindi-0/winterfell?branch=al-zk#5bafedbc2ba00cf85c6182725754547f6cddafc3"
dependencies = [
"winter-air",
"winter-crypto",
"winter-fri",
"winter-math",
"winter-utils",
]
[[package]] [[package]]
name = "zerocopy" name = "zerocopy"

View File

@@ -1,12 +1,12 @@
[package] [package]
name = "miden-crypto" name = "miden-crypto"
version = "0.13.1" version = "0.14.0"
description = "Miden Cryptographic primitives" description = "Miden Cryptographic primitives"
authors = ["miden contributors"] authors = ["miden contributors"]
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"
repository = "https://github.com/0xPolygonMiden/crypto" repository = "https://github.com/0xPolygonMiden/crypto"
documentation = "https://docs.rs/miden-crypto/0.13.1" documentation = "https://docs.rs/miden-crypto/0.14.0"
categories = ["cryptography", "no-std"] categories = ["cryptography", "no-std"]
keywords = ["miden", "crypto", "hash", "merkle"] keywords = ["miden", "crypto", "hash", "merkle"]
edition = "2021" edition = "2021"
@@ -19,6 +19,10 @@ bench = false
doctest = false doctest = false
required-features = ["executable"] required-features = ["executable"]
[[bench]]
name = "dsa"
harness = false
[[bench]] [[bench]]
name = "hash" name = "hash"
harness = false harness = false
@@ -27,13 +31,28 @@ harness = false
name = "smt" name = "smt"
harness = false harness = false
[[bench]]
name = "smt-subtree"
harness = false
required-features = ["internal"]
[[bench]]
name = "merkle"
harness = false
[[bench]]
name = "smt-with-entries"
harness = false
[[bench]] [[bench]]
name = "store" name = "store"
harness = false harness = false
[features] [features]
default = ["std"] concurrent = ["dep:rayon"]
default = ["std", "concurrent"]
executable = ["dep:clap", "dep:rand-utils", "std"] executable = ["dep:clap", "dep:rand-utils", "std"]
internal = []
serde = ["dep:serde", "serde?/alloc", "winter-math/serde"] serde = ["dep:serde", "serde?/alloc", "winter-math/serde"]
std = [ std = [
"blake3/std", "blake3/std",
@@ -48,26 +67,30 @@ std = [
[dependencies] [dependencies]
blake3 = { version = "1.5", default-features = false } blake3 = { version = "1.5", default-features = false }
clap = { version = "4.5", optional = true, features = ["derive"] } clap = { version = "4.5", optional = true, features = ["derive"] }
getrandom = { version = "0.2", features = ["js"] }
num = { version = "0.4", default-features = false, features = ["alloc", "libm"] } num = { version = "0.4", default-features = false, features = ["alloc", "libm"] }
num-complex = { version = "0.4", default-features = false } num-complex = { version = "0.4", default-features = false }
rand = { version = "0.8", default-features = false } rand = { version = "0.8", default-features = false }
rand_chacha = { version = "0.3", default-features = false }
rand_core = { version = "0.6", default-features = false } rand_core = { version = "0.6", default-features = false }
rand-utils = { version = "0.11", package = "winter-rand-utils", optional = true } rand-utils = {git = 'https://github.com/Al-Kindi-0/winterfell', package = "winter-rand-utils" , branch = 'al-zk', optional = true }
rayon = { version = "1.10", optional = true }
serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] } serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] }
sha3 = { version = "0.10", default-features = false } sha3 = { version = "0.10", default-features = false }
thiserror = { version = "2.0", default-features = false } thiserror = { version = "2.0", default-features = false }
winter-crypto = { version = "0.11", default-features = false } winter-air = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
winter-math = { version = "0.11", default-features = false } winter-crypto = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
winter-utils = { version = "0.11", default-features = false } winter-prover = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
winter-verifier = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
winter-math = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
winter-utils = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
[dev-dependencies] [dev-dependencies]
assert_matches = { version = "1.5", default-features = false } assert_matches = { version = "1.5", default-features = false }
criterion = { version = "0.5", features = ["html_reports"] } criterion = { version = "0.5", features = ["html_reports"] }
getrandom = { version = "0.2", features = ["js"] }
hex = { version = "0.4", default-features = false, features = ["alloc"] } hex = { version = "0.4", default-features = false, features = ["alloc"] }
proptest = "1.6" proptest = "1.5"
rand_chacha = { version = "0.3", default-features = false } rand-utils = {git = 'https://github.com/Al-Kindi-0/winterfell', package = "winter-rand-utils" , branch = 'al-zk' }
rand-utils = { version = "0.11", package = "winter-rand-utils" }
seq-macro = { version = "0.3" } seq-macro = { version = "0.3" }
[build-dependencies] [build-dependencies]

View File

@@ -81,10 +81,6 @@ build-sve: ## Build with sve support
# --- benchmarking -------------------------------------------------------------------------------- # --- benchmarking --------------------------------------------------------------------------------
.PHONY: bench .PHONY: bench-tx
bench: ## Run crypto benchmarks bench-tx: ## Run crypto benchmarks
cargo bench cargo bench --features="concurrent"
.PHONY: bench-smt-concurrent
bench-smt-concurrent: ## Run SMT benchmarks with concurrent feature
cargo run --release --features executable -- --size 1000000

View File

@@ -60,10 +60,11 @@ make
This crate can be compiled with the following features: This crate can be compiled with the following features:
- `concurrent`- enabled by default; enables multi-threaded implementation of `Smt::with_entries()` which significantly improves performance on multi-core CPUs.
- `std` - enabled by default and relies on the Rust standard library. - `std` - enabled by default and relies on the Rust standard library.
- `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly. - `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly.
Both of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/) to support heap-allocated collections. All of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/) to support heap-allocated collections.
To compile with `no_std`, disable default features via `--no-default-features` flag or using the following command: To compile with `no_std`, disable default features via `--no-default-features` flag or using the following command:

View File

@@ -1,4 +1,6 @@
# Miden VM Hash Functions # Benchmarks
## Miden VM Hash Functions
In the Miden VM, we make use of different hash functions. Some of these are "traditional" hash functions, like `BLAKE3`, which are optimized for out-of-STARK performance, while others are algebraic hash functions, like `Rescue Prime`, and are more optimized for a better performance inside the STARK. In what follows, we benchmark several such hash functions and compare against other constructions that are used by other proving systems. More precisely, we benchmark: In the Miden VM, we make use of different hash functions. Some of these are "traditional" hash functions, like `BLAKE3`, which are optimized for out-of-STARK performance, while others are algebraic hash functions, like `Rescue Prime`, and are more optimized for a better performance inside the STARK. In what follows, we benchmark several such hash functions and compare against other constructions that are used by other proving systems. More precisely, we benchmark:
* **BLAKE3** as specified [here](https://github.com/BLAKE3-team/BLAKE3-specs/blob/master/blake3.pdf) and implemented [here](https://github.com/BLAKE3-team/BLAKE3) (with a wrapper exposed via this crate). * **BLAKE3** as specified [here](https://github.com/BLAKE3-team/BLAKE3-specs/blob/master/blake3.pdf) and implemented [here](https://github.com/BLAKE3-team/BLAKE3) (with a wrapper exposed via this crate).
@@ -8,13 +10,13 @@ In the Miden VM, we make use of different hash functions. Some of these are "tra
* **Rescue Prime Optimized (RPO)** as specified [here](https://eprint.iacr.org/2022/1577) and implemented in this crate. * **Rescue Prime Optimized (RPO)** as specified [here](https://eprint.iacr.org/2022/1577) and implemented in this crate.
* **Rescue Prime Extended (RPX)** a variant of the [xHash](https://eprint.iacr.org/2023/1045) hash function as implemented in this crate. * **Rescue Prime Extended (RPX)** a variant of the [xHash](https://eprint.iacr.org/2023/1045) hash function as implemented in this crate.
## Comparison and Instructions ### Comparison and Instructions
### Comparison #### Comparison
We benchmark the above hash functions using two scenarios. The first is a 2-to-1 $(a,b)\mapsto h(a,b)$ hashing where both $a$, $b$ and $h(a,b)$ are the digests corresponding to each of the hash functions. We benchmark the above hash functions using two scenarios. The first is a 2-to-1 $(a,b)\mapsto h(a,b)$ hashing where both $a$, $b$ and $h(a,b)$ are the digests corresponding to each of the hash functions.
The second scenario is that of sequential hashing where we take a sequence of length $100$ field elements and hash these to produce a single digest. The digests are $4$ field elements in a prime field with modulus $2^{64} - 2^{32} + 1$ (i.e., 32 bytes) for Poseidon, Rescue Prime and RPO, and an array `[u8; 32]` for SHA3 and BLAKE3. The second scenario is that of sequential hashing where we take a sequence of length $100$ field elements and hash these to produce a single digest. The digests are $4$ field elements in a prime field with modulus $2^{64} - 2^{32} + 1$ (i.e., 32 bytes) for Poseidon, Rescue Prime and RPO, and an array `[u8; 32]` for SHA3 and BLAKE3.
#### Scenario 1: 2-to-1 hashing `h(a,b)` ##### Scenario 1: 2-to-1 hashing `h(a,b)`
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 | | Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
| ------------------- | ------ | ------- | --------- | --------- | ------- | ------- | | ------------------- | ------ | ------- | --------- | --------- | ------- | ------- |
@@ -26,7 +28,7 @@ The second scenario is that of sequential hashing where we take a sequence of le
| Intel Core i5-8279U | 68 ns | 536 ns | 2.0 µs | 13.6 µs | 8.5 µs | 4.4 µs | | Intel Core i5-8279U | 68 ns | 536 ns | 2.0 µs | 13.6 µs | 8.5 µs | 4.4 µs |
| Intel Xeon 8375C | 67 ns | | | | 8.2 µs | | | Intel Xeon 8375C | 67 ns | | | | 8.2 µs | |
#### Scenario 2: Sequential hashing of 100 elements `h([a_0,...,a_99])` ##### Scenario 2: Sequential hashing of 100 elements `h([a_0,...,a_99])`
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 | | Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
| ------------------- | -------| ------- | --------- | --------- | ------- | ------- | | ------------------- | -------| ------- | --------- | --------- | ------- | ------- |
@@ -42,7 +44,7 @@ Notes:
- On Graviton 3, RPO256 and RPX256 are run with SVE acceleration enabled. - On Graviton 3, RPO256 and RPX256 are run with SVE acceleration enabled.
- On AMD EPYC 9R14, RPO256 and RPX256 are run with AVX2 acceleration enabled. - On AMD EPYC 9R14, RPO256 and RPX256 are run with AVX2 acceleration enabled.
### Instructions #### Instructions
Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks for RPO and BLAKE3, clone the current repository, and from the root directory of the repo run the following: Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks for RPO and BLAKE3, clone the current repository, and from the root directory of the repo run the following:
``` ```
@@ -54,3 +56,47 @@ To run the benchmarks for Rescue Prime, Poseidon and SHA3, clone the following [
``` ```
cargo bench hash cargo bench hash
``` ```
## Miden VM DSA
We make use of the following digital signature algorithms (DSA) in the Miden VM:
* **RPO-Falcon512** as specified [here](https://falcon-sign.info/falcon.pdf) with the one difference being the use of the RPO hash function for the hash-to-point algorithm (Algorithm 3 in the previous reference) instead of SHAKE256.
* **RPO-STARK** as specified [here](https://eprint.iacr.org/2024/1553), where the parameters are the ones for the unique-decoding regime (UDR) with the two differences:
* We rely on Conjecture 1 in the [ethSTARK](https://eprint.iacr.org/2021/582) paper.
* The number of FRI queries is $30$ and the grinding factor is $12$ bits. Thus using the previous point we can argue that the modified version achieves at least $102$ bits of average-case existential unforgeability security against $2^{113}$-query bound adversaries that can obtain up to $2^{64}$ signatures under the same public key.
### Comparison and Instructions
#### Comparison
##### Key Generation
| DSA | RPO-Falcon512 | RPO-STARK |
| ------------------- | :-----------: | :-------: |
| Apple M1 Pro | 590 ms | 6 µs |
| Intel Core i5-8279U | 585 ms | 10 µs |
##### Signature Generation
| DSA | RPO-Falcon512 | RPO-STARK |
| ------------------- | :-----------: | :-------: |
| Apple M1 Pro | 1.5 ms | 78 ms |
| Intel Core i5-8279U | 1.8 ms | 130 ms |
##### Signature Verification
| DSA | RPO-Falcon512 | RPO-STARK |
| ------------------- | :-----------: | :-------: |
| Apple M1 Pro | 0.7 ms | 4.5 ms |
| Intel Core i5-8279U | 1.2 ms | 7.9 ms |
#### Instructions
Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks, clone the current repository, and from the root directory of the repo run the following:
```
cargo bench --bench dsa
```

88
benches/dsa.rs Normal file
View File

@@ -0,0 +1,88 @@
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use miden_crypto::dsa::{
rpo_falcon512::SecretKey as FalconSecretKey, rpo_stark::SecretKey as RpoStarkSecretKey,
};
use rand_utils::rand_array;
fn key_gen_falcon(c: &mut Criterion) {
c.bench_function("Falcon public key generation", |bench| {
bench.iter_batched(|| FalconSecretKey::new(), |sk| sk.public_key(), BatchSize::SmallInput)
});
c.bench_function("Falcon secret key generation", |bench| {
bench.iter_batched(|| {}, |_| FalconSecretKey::new(), BatchSize::SmallInput)
});
}
fn key_gen_rpo_stark(c: &mut Criterion) {
c.bench_function("RPO-STARK public key generation", |bench| {
bench.iter_batched(
|| RpoStarkSecretKey::random(),
|sk| sk.public_key(),
BatchSize::SmallInput,
)
});
c.bench_function("RPO-STARK secret key generation", |bench| {
bench.iter_batched(|| {}, |_| RpoStarkSecretKey::random(), BatchSize::SmallInput)
});
}
fn signature_gen_falcon(c: &mut Criterion) {
c.bench_function("Falcon signature generation", |bench| {
bench.iter_batched(
|| (FalconSecretKey::new(), rand_array().into()),
|(sk, msg)| sk.sign(msg),
BatchSize::SmallInput,
)
});
}
fn signature_gen_rpo_stark(c: &mut Criterion) {
c.bench_function("RPO-STARK signature generation", |bench| {
bench.iter_batched(
|| (RpoStarkSecretKey::random(), rand_array().into()),
|(sk, msg)| sk.sign(msg),
BatchSize::SmallInput,
)
});
}
fn signature_ver_falcon(c: &mut Criterion) {
c.bench_function("Falcon signature verification", |bench| {
bench.iter_batched(
|| {
let sk = FalconSecretKey::new();
let msg = rand_array().into();
(sk.public_key(), msg, sk.sign(msg))
},
|(pk, msg, sig)| pk.verify(msg, &sig),
BatchSize::SmallInput,
)
});
}
fn signature_ver_rpo_stark(c: &mut Criterion) {
c.bench_function("RPO-STARK signature verification", |bench| {
bench.iter_batched(
|| {
let sk = RpoStarkSecretKey::random();
let msg = rand_array().into();
(sk.public_key(), msg, sk.sign(msg))
},
|(pk, msg, sig)| pk.verify(msg, &sig),
BatchSize::SmallInput,
)
});
}
criterion_group!(
dsa_group,
key_gen_falcon,
key_gen_rpo_stark,
signature_gen_falcon,
signature_gen_rpo_stark,
signature_ver_falcon,
signature_ver_rpo_stark
);
criterion_main!(dsa_group);

66
benches/merkle.rs Normal file
View File

@@ -0,0 +1,66 @@
//! Benchmark for building a [`miden_crypto::merkle::MerkleTree`]. This is intended to be compared
//! with the results from `benches/smt-subtree.rs`, as building a fully balanced Merkle tree with
//! 256 leaves should indicate the *absolute best* performance we could *possibly* get for building
//! a depth-8 sparse Merkle subtree, though practically speaking building a fully balanced Merkle
//! tree will perform better than the sparse version. At the time of this writing (2024/11/24), this
//! benchmark is about four times more efficient than the equivalent benchmark in
//! `benches/smt-subtree.rs`.
use std::{hint, mem, time::Duration};
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use miden_crypto::{merkle::MerkleTree, Felt, Word, ONE};
use rand_utils::prng_array;
fn balanced_merkle_even(c: &mut Criterion) {
c.bench_function("balanced-merkle-even", |b| {
b.iter_batched(
|| {
let entries: Vec<Word> =
(0..256).map(|i| [Felt::new(i), ONE, ONE, Felt::new(i)]).collect();
assert_eq!(entries.len(), 256);
entries
},
|leaves| {
let tree = MerkleTree::new(hint::black_box(leaves)).unwrap();
assert_eq!(tree.depth(), 8);
},
BatchSize::SmallInput,
);
});
}
fn balanced_merkle_rand(c: &mut Criterion) {
let mut seed = [0u8; 32];
c.bench_function("balanced-merkle-rand", |b| {
b.iter_batched(
|| {
let entries: Vec<Word> = (0..256).map(|_| generate_word(&mut seed)).collect();
assert_eq!(entries.len(), 256);
entries
},
|leaves| {
let tree = MerkleTree::new(hint::black_box(leaves)).unwrap();
assert_eq!(tree.depth(), 8);
},
BatchSize::SmallInput,
);
});
}
criterion_group! {
name = smt_subtree_group;
config = Criterion::default()
.measurement_time(Duration::from_secs(20))
.configure_from_args();
targets = balanced_merkle_even, balanced_merkle_rand
}
criterion_main!(smt_subtree_group);
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn generate_word(seed: &mut [u8; 32]) -> Word {
mem::swap(seed, &mut prng_array(*seed));
let nums: [u64; 4] = prng_array(*seed);
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
}

142
benches/smt-subtree.rs Normal file
View File

@@ -0,0 +1,142 @@
use std::{fmt::Debug, hint, mem, time::Duration};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use miden_crypto::{
hash::rpo::RpoDigest,
merkle::{build_subtree_for_bench, NodeIndex, SmtLeaf, SubtreeLeaf, SMT_DEPTH},
Felt, Word, ONE,
};
use rand_utils::prng_array;
use winter_utils::Randomizable;
const PAIR_COUNTS: [u64; 5] = [1, 64, 128, 192, 256];
fn smt_subtree_even(c: &mut Criterion) {
let mut seed = [0u8; 32];
let mut group = c.benchmark_group("subtree8-even");
for pair_count in PAIR_COUNTS {
let bench_id = BenchmarkId::from_parameter(pair_count);
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
b.iter_batched(
|| {
// Setup.
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
.map(|n| {
// A single depth-8 subtree can have a maximum of 255 leaves.
let leaf_index = ((n as f64 / pair_count as f64) * 255.0) as u64;
let key = RpoDigest::new([
generate_value(&mut seed),
ONE,
Felt::new(n),
Felt::new(leaf_index),
]);
let value = generate_word(&mut seed);
(key, value)
})
.collect();
let mut leaves: Vec<_> = entries
.iter()
.map(|(key, value)| {
let leaf = SmtLeaf::new_single(*key, *value);
let col = NodeIndex::from(leaf.index()).value();
let hash = leaf.hash();
SubtreeLeaf { col, hash }
})
.collect();
leaves.sort();
leaves.dedup_by_key(|leaf| leaf.col);
leaves
},
|leaves| {
// Benchmarked function.
let (subtree, _) = build_subtree_for_bench(
hint::black_box(leaves),
hint::black_box(SMT_DEPTH),
hint::black_box(SMT_DEPTH),
);
assert!(!subtree.is_empty());
},
BatchSize::SmallInput,
);
});
}
}
fn smt_subtree_random(c: &mut Criterion) {
let mut seed = [0u8; 32];
let mut group = c.benchmark_group("subtree8-rand");
for pair_count in PAIR_COUNTS {
let bench_id = BenchmarkId::from_parameter(pair_count);
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
b.iter_batched(
|| {
// Setup.
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
.map(|i| {
let leaf_index: u8 = generate_value(&mut seed);
let key = RpoDigest::new([
ONE,
ONE,
Felt::new(i),
Felt::new(leaf_index as u64),
]);
let value = generate_word(&mut seed);
(key, value)
})
.collect();
let mut leaves: Vec<_> = entries
.iter()
.map(|(key, value)| {
let leaf = SmtLeaf::new_single(*key, *value);
let col = NodeIndex::from(leaf.index()).value();
let hash = leaf.hash();
SubtreeLeaf { col, hash }
})
.collect();
leaves.sort();
leaves
},
|leaves| {
let (subtree, _) = build_subtree_for_bench(
hint::black_box(leaves),
hint::black_box(SMT_DEPTH),
hint::black_box(SMT_DEPTH),
);
assert!(!subtree.is_empty());
},
BatchSize::SmallInput,
);
});
}
}
criterion_group! {
name = smt_subtree_group;
config = Criterion::default()
.measurement_time(Duration::from_secs(40))
.sample_size(60)
.configure_from_args();
targets = smt_subtree_even, smt_subtree_random
}
criterion_main!(smt_subtree_group);
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn generate_value<T: Copy + Debug + Randomizable>(seed: &mut [u8; 32]) -> T {
mem::swap(seed, &mut prng_array(*seed));
let value: [T; 1] = rand_utils::prng_array(*seed);
value[0]
}
fn generate_word(seed: &mut [u8; 32]) -> Word {
mem::swap(seed, &mut prng_array(*seed));
let nums: [u64; 4] = prng_array(*seed);
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
}

View File

@@ -0,0 +1,71 @@
use std::{fmt::Debug, hint, mem, time::Duration};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use miden_crypto::{hash::rpo::RpoDigest, merkle::Smt, Felt, Word, ONE};
use rand_utils::prng_array;
use winter_utils::Randomizable;
// 2^0, 2^4, 2^8, 2^12, 2^16
const PAIR_COUNTS: [u64; 6] = [1, 16, 256, 4096, 65536, 1_048_576];
fn smt_with_entries(c: &mut Criterion) {
let mut seed = [0u8; 32];
let mut group = c.benchmark_group("smt-with-entries");
for pair_count in PAIR_COUNTS {
let bench_id = BenchmarkId::from_parameter(pair_count);
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
b.iter_batched(
|| {
// Setup.
prepare_entries(pair_count, &mut seed)
},
|entries| {
// Benchmarked function.
Smt::with_entries(hint::black_box(entries)).unwrap();
},
BatchSize::SmallInput,
);
});
}
}
criterion_group! {
name = smt_with_entries_group;
config = Criterion::default()
//.measurement_time(Duration::from_secs(960))
.measurement_time(Duration::from_secs(60))
.sample_size(10)
.configure_from_args();
targets = smt_with_entries
}
criterion_main!(smt_with_entries_group);
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn prepare_entries(pair_count: u64, seed: &mut [u8; 32]) -> Vec<(RpoDigest, [Felt; 4])> {
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
.map(|i| {
let count = pair_count as f64;
let idx = ((i as f64 / count) * (count)) as u64;
let key = RpoDigest::new([generate_value(seed), ONE, Felt::new(i), Felt::new(idx)]);
let value = generate_word(seed);
(key, value)
})
.collect();
entries
}
fn generate_value<T: Copy + Debug + Randomizable>(seed: &mut [u8; 32]) -> T {
mem::swap(seed, &mut prng_array(*seed));
let value: [T; 1] = rand_utils::prng_array(*seed);
value[0]
}
fn generate_word(seed: &mut [u8; 32]) -> Word {
mem::swap(seed, &mut prng_array(*seed));
let nums: [u64; 4] = prng_array(*seed);
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
}

View File

@@ -13,7 +13,7 @@ else
if git diff --exit-code "origin/${BASE_REF}" -- "${CHANGELOG_FILE}"; then if git diff --exit-code "origin/${BASE_REF}" -- "${CHANGELOG_FILE}"; then
>&2 echo "Changes should come with an entry in the \"CHANGELOG.md\" file. This behavior >&2 echo "Changes should come with an entry in the \"CHANGELOG.md\" file. This behavior
can be overridden by using the \"no changelog\" label, which is used for changes can be overridden by using the \"no changelog\" label, which is used for changes
that are trivial / explicitly stated not to require a changelog entry." that are trivial / explicitely stated not to require a changelog entry."
exit 1 exit 1
fi fi

View File

@@ -1,3 +1,5 @@
//! Digital signature schemes supported by default in the Miden VM. //! Digital signature schemes supported by default in the Miden VM.
pub mod rpo_falcon512; pub mod rpo_falcon512;
pub mod rpo_stark;

24
src/dsa/rpo_stark/mod.rs Normal file
View File

@@ -0,0 +1,24 @@
mod signature;
pub use signature::{PublicKey, SecretKey, Signature};
mod stark;
pub use stark::{PublicInputs, RescueAir};
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::SecretKey;
use crate::Word;
#[test]
fn test_signature() {
let sk = SecretKey::new(Word::default());
let message = Word::default();
let signature = sk.sign(message);
let pk = sk.public_key();
assert!(pk.verify(message, &signature))
}
}

View File

@@ -0,0 +1,173 @@
use rand::{distributions::Uniform, prelude::Distribution, Rng};
use winter_air::{FieldExtension, ProofOptions};
use winter_math::{fields::f64::BaseElement, FieldElement};
use winter_prover::Proof;
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use crate::{
dsa::rpo_stark::stark::RpoSignatureScheme,
hash::{rpo::Rpo256, DIGEST_SIZE},
StarkField, Word, ZERO,
};
// CONSTANTS
// ================================================================================================
/// Specifies the parameters of the STARK underlying the signature scheme. These parameters provide
/// at least 102 bits of security under the conjectured security of the toy protocol in
/// the ethSTARK paper [1].
///
/// [1]: https://eprint.iacr.org/2021/582
pub const PROOF_OPTIONS: ProofOptions =
ProofOptions::new(30, 8, 12, FieldExtension::Quadratic, 4, 7, true);
// PUBLIC KEY
// ================================================================================================
/// A public key for verifying signatures.
///
/// The public key is a [Word] (i.e., 4 field elements) that is the hash of the secret key.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PublicKey(Word);
impl PublicKey {
/// Returns the [Word] defining the public key.
pub fn inner(&self) -> Word {
self.0
}
}
impl PublicKey {
/// Verifies the provided signature against provided message and this public key.
pub fn verify(&self, message: Word, signature: &Signature) -> bool {
signature.verify(message, *self)
}
}
impl Serializable for PublicKey {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.0.write_into(target);
}
}
impl Deserializable for PublicKey {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let pk = <Word>::read_from(source)?;
Ok(Self(pk))
}
}
// SECRET KEY
// ================================================================================================
/// A secret key for generating signatures.
///
/// The secret key is a [Word] (i.e., 4 field elements).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SecretKey(Word);
impl SecretKey {
/// Generates a secret key from OS-provided randomness.
pub fn new(word: Word) -> Self {
Self(word)
}
/// Generates a secret key from a [Word].
#[cfg(feature = "std")]
pub fn random() -> Self {
use rand::{rngs::StdRng, SeedableRng};
let mut rng = StdRng::from_entropy();
Self::with_rng(&mut rng)
}
/// Generates a secret_key using the provided random number generator `Rng`.
pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
let mut sk = [ZERO; 4];
let uni_dist = Uniform::from(0..BaseElement::MODULUS);
for s in sk.iter_mut() {
let sampled_integer = uni_dist.sample(rng);
*s = BaseElement::new(sampled_integer);
}
Self(sk)
}
/// Computes the public key corresponding to this secret key.
pub fn public_key(&self) -> PublicKey {
let mut elements = [BaseElement::ZERO; 8];
elements[..DIGEST_SIZE].copy_from_slice(&self.0);
let pk = Rpo256::hash_elements(&elements);
PublicKey(pk.into())
}
/// Signs a message with this secret key.
pub fn sign(&self, message: Word) -> Signature {
let signature: RpoSignatureScheme<Rpo256> = RpoSignatureScheme::new(PROOF_OPTIONS);
let proof = signature.sign(self.0, message);
Signature { proof }
}
}
impl Serializable for SecretKey {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.0.write_into(target);
}
}
impl Deserializable for SecretKey {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let sk = <Word>::read_from(source)?;
Ok(Self(sk))
}
}
// SIGNATURE
// ================================================================================================
/// An RPO STARK-based signature over a message.
///
/// The signature is a STARK proof of knowledge of a pre-image given an image where the map is
/// the RPO permutation, the pre-image is the secret key and the image is the public key.
/// The current implementation follows the description in [1] but relies on the conjectured security
/// of the toy protocol in the ethSTARK paper [2], which gives us using the parameter set
/// given in `PROOF_OPTIONS` a signature with $102$ bits of average-case existential unforgeability
/// security against $2^{113}$-query bound adversaries that can obtain up to $2^{64}$ signatures
/// under the same public key.
///
/// [1]: https://eprint.iacr.org/2024/1553
/// [2]: https://eprint.iacr.org/2021/582
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Signature {
proof: Proof,
}
impl Signature {
/// Returns the STARK proof constituting the signature.
pub fn inner(&self) -> Proof {
self.proof.clone()
}
/// Returns true if this signature is a valid signature for the specified message generated
/// against the secret key matching the specified public key.
pub fn verify(&self, message: Word, pk: PublicKey) -> bool {
let signature: RpoSignatureScheme<Rpo256> = RpoSignatureScheme::new(PROOF_OPTIONS);
let res = signature.verify(pk.inner(), message, self.proof.clone());
res.is_ok()
}
}
impl Serializable for Signature {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.proof.write_into(target);
}
}
impl Deserializable for Signature {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let proof = Proof::read_from(source)?;
Ok(Self { proof })
}
}

View File

@@ -0,0 +1,198 @@
use alloc::vec::Vec;
use winter_math::{fields::f64::BaseElement, FieldElement, ToElements};
use winter_prover::{
Air, AirContext, Assertion, EvaluationFrame, ProofOptions, TraceInfo,
TransitionConstraintDegree,
};
use crate::{
hash::{ARK1, ARK2, MDS, STATE_WIDTH},
Word, ZERO,
};
// CONSTANTS
// ================================================================================================
pub const HASH_CYCLE_LEN: usize = 8;
// AIR
// ================================================================================================
pub struct RescueAir {
context: AirContext<BaseElement>,
pub_key: Word,
}
impl Air for RescueAir {
type BaseField = BaseElement;
type PublicInputs = PublicInputs;
type GkrProof = ();
type GkrVerifier = ();
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
fn new(trace_info: TraceInfo, pub_inputs: PublicInputs, options: ProofOptions) -> Self {
let degrees = vec![
// Apply RPO rounds.
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
];
assert_eq!(STATE_WIDTH, trace_info.width());
let context = AirContext::new(trace_info, degrees, 12, options);
let context = context.set_num_transition_exemptions(1);
RescueAir { context, pub_key: pub_inputs.pub_key }
}
fn context(&self) -> &AirContext<Self::BaseField> {
&self.context
}
fn evaluate_transition<E: FieldElement + From<Self::BaseField>>(
&self,
frame: &EvaluationFrame<E>,
periodic_values: &[E],
result: &mut [E],
) {
let current = frame.current();
let next = frame.next();
// expected state width is 12 field elements
debug_assert_eq!(STATE_WIDTH, current.len());
debug_assert_eq!(STATE_WIDTH, next.len());
enforce_rpo_round(frame, result, periodic_values);
}
fn get_assertions(&self) -> Vec<Assertion<Self::BaseField>> {
let initial_step = 0;
let last_step = self.trace_length() - 1;
vec![
// Assert that the capacity as well as the second half of the rate portion of the state
// are initialized to `ZERO`.The first half of the rate is unconstrained as it will
// contain the secret key
Assertion::single(0, initial_step, Self::BaseField::ZERO),
Assertion::single(1, initial_step, Self::BaseField::ZERO),
Assertion::single(2, initial_step, Self::BaseField::ZERO),
Assertion::single(3, initial_step, Self::BaseField::ZERO),
Assertion::single(8, initial_step, Self::BaseField::ZERO),
Assertion::single(9, initial_step, Self::BaseField::ZERO),
Assertion::single(10, initial_step, Self::BaseField::ZERO),
Assertion::single(11, initial_step, Self::BaseField::ZERO),
// Assert that the public key is the correct one
Assertion::single(4, last_step, self.pub_key[0]),
Assertion::single(5, last_step, self.pub_key[1]),
Assertion::single(6, last_step, self.pub_key[2]),
Assertion::single(7, last_step, self.pub_key[3]),
]
}
fn get_periodic_column_values(&self) -> Vec<Vec<Self::BaseField>> {
get_round_constants()
}
}
pub struct PublicInputs {
pub(crate) pub_key: Word,
pub(crate) msg: Word,
}
impl PublicInputs {
pub fn new(pub_key: Word, msg: Word) -> Self {
Self { pub_key, msg }
}
}
impl ToElements<BaseElement> for PublicInputs {
fn to_elements(&self) -> Vec<BaseElement> {
let mut res = self.pub_key.to_vec();
res.extend_from_slice(self.msg.as_ref());
res
}
}
// HELPER EVALUATORS
// ------------------------------------------------------------------------------------------------
/// Enforces constraints for a single round of the Rescue Prime Optimized hash functions.
pub fn enforce_rpo_round<E: FieldElement + From<BaseElement>>(
frame: &EvaluationFrame<E>,
result: &mut [E],
ark: &[E],
) {
// compute the state that should result from applying the first 5 operations of the RPO round to
// the current hash state.
let mut step1 = [E::ZERO; STATE_WIDTH];
step1.copy_from_slice(frame.current());
apply_mds(&mut step1);
// add constants
for i in 0..STATE_WIDTH {
step1[i] += ark[i];
}
apply_sbox(&mut step1);
apply_mds(&mut step1);
// add constants
for i in 0..STATE_WIDTH {
step1[i] += ark[STATE_WIDTH + i];
}
// compute the state that should result from applying the inverse of the last operation of the
// RPO round to the next step of the computation.
let mut step2 = [E::ZERO; STATE_WIDTH];
step2.copy_from_slice(frame.next());
apply_sbox(&mut step2);
// make sure that the results are equal.
for i in 0..STATE_WIDTH {
result[i] = step2[i] - step1[i]
}
}
#[inline(always)]
fn apply_sbox<E: FieldElement + From<BaseElement>>(state: &mut [E; STATE_WIDTH]) {
state.iter_mut().for_each(|v| {
let t2 = v.square();
let t4 = t2.square();
*v *= t2 * t4;
});
}
#[inline(always)]
fn apply_mds<E: FieldElement + From<BaseElement>>(state: &mut [E; STATE_WIDTH]) {
let mut result = [E::ZERO; STATE_WIDTH];
result.iter_mut().zip(MDS).for_each(|(r, mds_row)| {
state.iter().zip(mds_row).for_each(|(&s, m)| {
*r += E::from(m) * s;
});
});
*state = result
}
/// Returns RPO round constants arranged in column-major form.
pub fn get_round_constants() -> Vec<Vec<BaseElement>> {
let mut constants = Vec::new();
for _ in 0..(STATE_WIDTH * 2) {
constants.push(vec![ZERO; HASH_CYCLE_LEN]);
}
#[allow(clippy::needless_range_loop)]
for i in 0..HASH_CYCLE_LEN - 1 {
for j in 0..STATE_WIDTH {
constants[j][i] = ARK1[i][j];
constants[j + STATE_WIDTH][i] = ARK2[i][j];
}
}
constants
}

View File

@@ -0,0 +1,98 @@
use alloc::vec::Vec;
use core::marker::PhantomData;
use prover::RpoSignatureProver;
use rand_chacha::ChaCha20Rng;
use winter_crypto::{ElementHasher, SaltedMerkleTree};
use winter_math::fields::f64::BaseElement;
use winter_prover::{Proof, ProofOptions, Prover};
use winter_utils::Serializable;
use winter_verifier::{verify, AcceptableOptions, VerifierError};
use crate::{
hash::{rpo::Rpo256, DIGEST_SIZE},
rand::RpoRandomCoin,
};
mod air;
pub use air::{PublicInputs, RescueAir};
mod prover;
/// Represents an abstract STARK-based signature scheme with knowledge of RPO pre-image as
/// the hard relation.
pub struct RpoSignatureScheme<H: ElementHasher> {
options: ProofOptions,
_h: PhantomData<H>,
}
impl<H: ElementHasher<BaseField = BaseElement> + Sync> RpoSignatureScheme<H> {
pub fn new(options: ProofOptions) -> Self {
RpoSignatureScheme { options, _h: PhantomData }
}
pub fn sign(&self, sk: [BaseElement; DIGEST_SIZE], msg: [BaseElement; DIGEST_SIZE]) -> Proof {
// create a prover
let prover = RpoSignatureProver::<H>::new(msg, self.options.clone());
// generate execution trace
let trace = prover.build_trace(sk);
// generate the initial seed for the PRNG used for zero-knowledge
let seed: [u8; 32] = generate_seed(sk, msg);
// generate the proof
prover.prove(trace, Some(seed)).expect("failed to generate the signature")
}
pub fn verify(
&self,
pub_key: [BaseElement; DIGEST_SIZE],
msg: [BaseElement; DIGEST_SIZE],
proof: Proof,
) -> Result<(), VerifierError> {
// we make sure that the parameters used in generating the proof match the expected ones
if *proof.options() != self.options {
return Err(VerifierError::UnacceptableProofOptions);
}
let pub_inputs = PublicInputs { pub_key, msg };
let acceptable_options = AcceptableOptions::OptionSet(vec![proof.options().clone()]);
verify::<RescueAir, Rpo256, RpoRandomCoin, SaltedMerkleTree<Rpo256, ChaCha20Rng>>(
proof,
pub_inputs,
&acceptable_options,
)
}
}
/// Deterministically generates a seed for seeding the PRNG used for zero-knowledge.
///
/// This uses the argument described in [RFC 6979](https://datatracker.ietf.org/doc/html/rfc6979#section-3.5)
/// § 3.5 where the concatenation of the private key and the hashed message, i.e., sk || H(m), is
/// used in order to construct the initial seed of a PRNG.
///
/// Note that we hash in also a context string in order to domain separate between different
/// instantiations of the signature scheme.
#[inline]
pub fn generate_seed(sk: [BaseElement; DIGEST_SIZE], msg: [BaseElement; DIGEST_SIZE]) -> [u8; 32] {
let context_bytes = "
Seed for PRNG used for Zero-knowledge in RPO-STARK signature scheme:
1. Version: Conjectured security
2. FRI queries: 30
3. Blowup factor: 8
4. Grinding bits: 12
5. Field extension degree: 2
6. FRI folding factor: 4
7. FRI remainder polynomial max degree: 7
"
.to_bytes();
let sk_bytes = sk.to_bytes();
let msg_bytes = msg.to_bytes();
let total_length = context_bytes.len() + sk_bytes.len() + msg_bytes.len();
let mut buffer = Vec::with_capacity(total_length);
buffer.extend_from_slice(&context_bytes);
buffer.extend_from_slice(&sk_bytes);
buffer.extend_from_slice(&msg_bytes);
blake3::hash(&buffer).into()
}

View File

@@ -0,0 +1,148 @@
use core::marker::PhantomData;
use rand_chacha::ChaCha20Rng;
use winter_air::{
AuxRandElements, ConstraintCompositionCoefficients, PartitionOptions, ZkParameters,
};
use winter_crypto::{ElementHasher, SaltedMerkleTree};
use winter_math::{fields::f64::BaseElement, FieldElement};
use winter_prover::{
matrix::ColMatrix, CompositionPoly, CompositionPolyTrace, DefaultConstraintCommitment,
DefaultConstraintEvaluator, DefaultTraceLde, ProofOptions, Prover, StarkDomain, Trace,
TraceInfo, TracePolyTable, TraceTable,
};
use super::air::{PublicInputs, RescueAir, HASH_CYCLE_LEN};
use crate::{
hash::{rpo::Rpo256, STATE_WIDTH},
rand::RpoRandomCoin,
Word, ZERO,
};
// PROVER
// ================================================================================================
/// A prover for the RPO STARK-based signature scheme.
///
/// The signature is based on the the one-wayness of the RPO hash function but it is generic over
/// the hash function used for instantiating the random oracle for the BCS transform.
pub(crate) struct RpoSignatureProver<H: ElementHasher + Sync> {
message: Word,
options: ProofOptions,
_hasher: PhantomData<H>,
}
impl<H: ElementHasher + Sync> RpoSignatureProver<H> {
pub(crate) fn new(message: Word, options: ProofOptions) -> Self {
Self { message, options, _hasher: PhantomData }
}
pub(crate) fn build_trace(&self, sk: Word) -> TraceTable<BaseElement> {
let mut trace = TraceTable::new(STATE_WIDTH, HASH_CYCLE_LEN);
trace.fill(
|state| {
// initialize first half of the rate portion of the state with the secret key
state[0] = ZERO;
state[1] = ZERO;
state[2] = ZERO;
state[3] = ZERO;
state[4] = sk[0];
state[5] = sk[1];
state[6] = sk[2];
state[7] = sk[3];
state[8] = ZERO;
state[9] = ZERO;
state[10] = ZERO;
state[11] = ZERO;
},
|step, state| {
Rpo256::apply_round(
state.try_into().expect("should not fail given the size of the array"),
step,
);
},
);
trace
}
}
impl<H: ElementHasher> Prover for RpoSignatureProver<H>
where
H: ElementHasher<BaseField = BaseElement> + Sync,
{
type BaseField = BaseElement;
type Air = RescueAir;
type Trace = TraceTable<BaseElement>;
type HashFn = Rpo256;
type VC = SaltedMerkleTree<Self::HashFn, Self::ZkPrng>;
type RandomCoin = RpoRandomCoin;
type TraceLde<E: FieldElement<BaseField = Self::BaseField>> =
DefaultTraceLde<E, Self::HashFn, Self::VC>;
type ConstraintCommitment<E: FieldElement<BaseField = Self::BaseField>> =
DefaultConstraintCommitment<E, Self::HashFn, Self::ZkPrng, Self::VC>;
type ConstraintEvaluator<'a, E: FieldElement<BaseField = Self::BaseField>> =
DefaultConstraintEvaluator<'a, Self::Air, E>;
type ZkPrng = ChaCha20Rng;
fn get_pub_inputs(&self, trace: &Self::Trace) -> PublicInputs {
let last_step = trace.length() - 1;
// Note that the message is not part of the execution trace but is part of the public
// inputs. This is explained in the reference description of the DSA and intuitively
// it is done in order to make sure that the message is part of the Fiat-Shamir
// transcript and hence binds the proof/signature to the message
PublicInputs {
pub_key: [
trace.get(4, last_step),
trace.get(5, last_step),
trace.get(6, last_step),
trace.get(7, last_step),
],
msg: self.message,
}
}
fn options(&self) -> &ProofOptions {
&self.options
}
fn new_trace_lde<E: FieldElement<BaseField = Self::BaseField>>(
&self,
trace_info: &TraceInfo,
main_trace: &ColMatrix<Self::BaseField>,
domain: &StarkDomain<Self::BaseField>,
partition_option: PartitionOptions,
zk_parameters: Option<ZkParameters>,
prng: &mut Option<Self::ZkPrng>,
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
DefaultTraceLde::new(trace_info, main_trace, domain, partition_option, zk_parameters, prng)
}
fn new_evaluator<'a, E: FieldElement<BaseField = Self::BaseField>>(
&self,
air: &'a Self::Air,
aux_rand_elements: Option<AuxRandElements<E>>,
composition_coefficients: ConstraintCompositionCoefficients<E>,
) -> Self::ConstraintEvaluator<'a, E> {
DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients)
}
fn build_constraint_commitment<E: FieldElement<BaseField = Self::BaseField>>(
&self,
composition_poly_trace: CompositionPolyTrace<E>,
num_constraint_composition_columns: usize,
domain: &StarkDomain<Self::BaseField>,
partition_options: PartitionOptions,
zk_parameters: Option<ZkParameters>,
prng: &mut Option<Self::ZkPrng>,
) -> (Self::ConstraintCommitment<E>, CompositionPoly<E>) {
DefaultConstraintCommitment::new(
composition_poly_trace,
num_constraint_composition_columns,
domain,
partition_options,
zk_parameters,
prng,
)
}
}

View File

@@ -5,6 +5,7 @@ use super::{CubeExtension, Felt, FieldElement, StarkField, ZERO};
pub mod blake; pub mod blake;
mod rescue; mod rescue;
pub(crate) use rescue::{ARK1, ARK2, DIGEST_SIZE, MDS, STATE_WIDTH};
pub mod rpo { pub mod rpo {
pub use super::rescue::{Rpo256, RpoDigest, RpoDigestError}; pub use super::rescue::{Rpo256, RpoDigest, RpoDigestError};
} }

View File

@@ -6,7 +6,7 @@ mod arch;
pub use arch::optimized::{add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox}; pub use arch::optimized::{add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox};
mod mds; mod mds;
use mds::{apply_mds, MDS}; pub(crate) use mds::{apply_mds, MDS};
mod rpo; mod rpo;
pub use rpo::{Rpo256, RpoDigest, RpoDigestError}; pub use rpo::{Rpo256, RpoDigest, RpoDigestError};
@@ -26,7 +26,7 @@ const NUM_ROUNDS: usize = 7;
/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and /// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and
/// the remaining 4 elements are reserved for capacity. /// the remaining 4 elements are reserved for capacity.
const STATE_WIDTH: usize = 12; pub(crate) const STATE_WIDTH: usize = 12;
/// The rate portion of the state is located in elements 4 through 11. /// The rate portion of the state is located in elements 4 through 11.
const RATE_RANGE: Range<usize> = 4..12; const RATE_RANGE: Range<usize> = 4..12;
@@ -42,8 +42,8 @@ const CAPACITY_RANGE: Range<usize> = 0..4;
/// ///
/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the /// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the
/// rate portion). /// rate portion).
const DIGEST_RANGE: Range<usize> = 4..8; pub(crate) const DIGEST_RANGE: Range<usize> = 4..8;
const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start; pub(crate) const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start;
/// The number of bytes needed to encoded a digest /// The number of bytes needed to encoded a digest
const DIGEST_BYTES: usize = 32; const DIGEST_BYTES: usize = 32;
@@ -144,7 +144,7 @@ fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) {
/// ///
/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the /// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the
/// first half of RPO round, and ARK2 contains constants for the second half of RPO round. /// first half of RPO round, and ARK2 contains constants for the second half of RPO round.
const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [ pub(crate) const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
[ [
Felt::new(5789762306288267392), Felt::new(5789762306288267392),
Felt::new(6522564764413701783), Felt::new(6522564764413701783),
@@ -245,7 +245,7 @@ const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
], ],
]; ];
const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [ pub(crate) const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
[ [
Felt::new(6077062762357204287), Felt::new(6077062762357204287),
Felt::new(15277620170502011191), Felt::new(15277620170502011191),

View File

@@ -1,6 +1,10 @@
use alloc::string::String; use alloc::string::String;
use core::{cmp::Ordering, fmt::Display, ops::Deref, slice}; use core::{cmp::Ordering, fmt::Display, ops::Deref, slice};
use rand::{
distributions::{Standard, Uniform},
prelude::Distribution,
};
use thiserror::Error; use thiserror::Error;
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO}; use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
@@ -126,6 +130,18 @@ impl Randomizable for RpoDigest {
} }
} }
impl Distribution<RpoDigest> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> RpoDigest {
let mut res = [ZERO; DIGEST_SIZE];
let uni_dist = Uniform::from(0..Felt::MODULUS);
for r in res.iter_mut() {
let sampled_integer = uni_dist.sample(rng);
*r = Felt::new(sampled_integer);
}
RpoDigest::new(res)
}
}
// CONVERSIONS: FROM RPO DIGEST // CONVERSIONS: FROM RPO DIGEST
// ================================================================================================ // ================================================================================================

View File

@@ -4,9 +4,8 @@ use clap::Parser;
use miden_crypto::{ use miden_crypto::{
hash::rpo::{Rpo256, RpoDigest}, hash::rpo::{Rpo256, RpoDigest},
merkle::{MerkleError, Smt}, merkle::{MerkleError, Smt},
Felt, Word, EMPTY_WORD, ONE, Felt, Word, ONE,
}; };
use rand::{prelude::IteratorRandom, thread_rng, Rng};
use rand_utils::rand_value; use rand_utils::rand_value;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@@ -14,7 +13,7 @@ use rand_utils::rand_value;
pub struct BenchmarkCmd { pub struct BenchmarkCmd {
/// Size of the tree /// Size of the tree
#[clap(short = 's', long = "size")] #[clap(short = 's', long = "size")]
size: usize, size: u64,
} }
fn main() { fn main() {
@@ -30,153 +29,101 @@ pub fn benchmark_smt() {
let mut entries = Vec::new(); let mut entries = Vec::new();
for i in 0..tree_size { for i in 0..tree_size {
let key = rand_value::<RpoDigest>(); let key = rand_value::<RpoDigest>();
let value = [ONE, ONE, ONE, Felt::new(i as u64)]; let value = [ONE, ONE, ONE, Felt::new(i)];
entries.push((key, value)); entries.push((key, value));
} }
let mut tree = construction(entries.clone(), tree_size).unwrap(); let mut tree = construction(entries, tree_size).unwrap();
insertion(&mut tree).unwrap(); insertion(&mut tree, tree_size).unwrap();
batched_insertion(&mut tree).unwrap(); batched_insertion(&mut tree, tree_size).unwrap();
batched_update(&mut tree, entries).unwrap(); proof_generation(&mut tree, tree_size).unwrap();
proof_generation(&mut tree).unwrap();
} }
/// Runs the construction benchmark for [`Smt`], returning the constructed tree. /// Runs the construction benchmark for [`Smt`], returning the constructed tree.
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: usize) -> Result<Smt, MerkleError> { pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result<Smt, MerkleError> {
println!("Running a construction benchmark:"); println!("Running a construction benchmark:");
let now = Instant::now(); let now = Instant::now();
let tree = Smt::with_entries(entries)?; let tree = Smt::with_entries(entries)?;
let elapsed = now.elapsed().as_secs_f32(); let elapsed = now.elapsed();
println!(
"Constructed a SMT with {} key-value pairs in {:.3} seconds",
size,
elapsed.as_secs_f32(),
);
println!("Constructed a SMT with {size} key-value pairs in {elapsed:.1} seconds");
println!("Number of leaf nodes: {}\n", tree.leaves().count()); println!("Number of leaf nodes: {}\n", tree.leaves().count());
Ok(tree) Ok(tree)
} }
/// Runs the insertion benchmark for the [`Smt`]. /// Runs the insertion benchmark for the [`Smt`].
pub fn insertion(tree: &mut Smt) -> Result<(), MerkleError> { pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
const NUM_INSERTIONS: usize = 1_000;
println!("Running an insertion benchmark:"); println!("Running an insertion benchmark:");
let size = tree.num_leaves();
let mut insertion_times = Vec::new(); let mut insertion_times = Vec::new();
for i in 0..NUM_INSERTIONS { for i in 0..20 {
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes()); let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)]; let test_value = [ONE, ONE, ONE, Felt::new(size + i)];
let now = Instant::now(); let now = Instant::now();
tree.insert(test_key, test_value); tree.insert(test_key, test_value);
let elapsed = now.elapsed(); let elapsed = now.elapsed();
insertion_times.push(elapsed.as_micros()); insertion_times.push(elapsed.as_secs_f32());
} }
println!( println!(
"An average insertion time measured by {NUM_INSERTIONS} inserts into an SMT with {size} leaves is {:.0} μs\n", "An average insertion time measured by 20 inserts into a SMT with {} key-value pairs is {:.3} milliseconds\n",
// calculate the average size,
insertion_times.iter().sum::<u128>() as f64 / (NUM_INSERTIONS as f64), // calculate the average by dividing by 20 and convert to milliseconds by multiplying by
// 1000. As a result, we can only multiply by 50
insertion_times.iter().sum::<f32>() * 50f32,
); );
Ok(()) Ok(())
} }
pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> { pub fn batched_insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
const NUM_INSERTIONS: usize = 1_000;
println!("Running a batched insertion benchmark:"); println!("Running a batched insertion benchmark:");
let size = tree.num_leaves(); let new_pairs: Vec<(RpoDigest, Word)> = (0..1000)
let new_pairs: Vec<(RpoDigest, Word)> = (0..NUM_INSERTIONS)
.map(|i| { .map(|i| {
let key = Rpo256::hash(&rand_value::<u64>().to_be_bytes()); let key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let value = [ONE, ONE, ONE, Felt::new((size + i) as u64)]; let value = [ONE, ONE, ONE, Felt::new(size + i)];
(key, value) (key, value)
}) })
.collect(); .collect();
let now = Instant::now(); let now = Instant::now();
let mutations = tree.compute_mutations(new_pairs); let mutations = tree.compute_mutations(new_pairs);
let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms let compute_elapsed = now.elapsed();
let now = Instant::now(); let now = Instant::now();
tree.apply_mutations(mutations)?; tree.apply_mutations(mutations).unwrap();
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms let apply_elapsed = now.elapsed();
println!( println!(
"An average insert-batch computation time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", "An average batch computation time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds",
compute_elapsed, size,
compute_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs compute_elapsed.as_secs_f32() * 1000f32,
// Dividing by the number of iterations, 1000, and then multiplying by 1000 to get
// milliseconds, cancels out.
compute_elapsed.as_secs_f32(),
); );
println!( println!(
"An average insert-batch application time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", "An average batch application time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds",
apply_elapsed, size,
apply_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs apply_elapsed.as_secs_f32() * 1000f32,
// Dividing by the number of iterations, 1000, and then multiplying by 1000 to get
// milliseconds, cancels out.
apply_elapsed.as_secs_f32(),
); );
println!( println!(
"An average batch insertion time measured by a 1k-batch into an SMT with {size} leaves totals to {:.1} ms", "An average batch insertion time measured by a 1k-batch into an SMT with {} key-value pairs totals to {:.3} milliseconds",
(compute_elapsed + apply_elapsed), size,
); (compute_elapsed + apply_elapsed).as_secs_f32() * 1000f32,
println!();
Ok(())
}
pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result<(), MerkleError> {
const NUM_UPDATES: usize = 1_000;
const REMOVAL_PROBABILITY: f64 = 0.2;
println!("Running a batched update benchmark:");
let size = tree.num_leaves();
let mut rng = thread_rng();
let new_pairs =
entries
.into_iter()
.choose_multiple(&mut rng, NUM_UPDATES)
.into_iter()
.map(|(key, _)| {
let value = if rng.gen_bool(REMOVAL_PROBABILITY) {
EMPTY_WORD
} else {
[ONE, ONE, ONE, Felt::new(rng.gen())]
};
(key, value)
});
assert_eq!(new_pairs.len(), NUM_UPDATES);
let now = Instant::now();
let mutations = tree.compute_mutations(new_pairs);
let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
let now = Instant::now();
tree.apply_mutations(mutations)?;
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
println!(
"An average update-batch computation time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
compute_elapsed,
compute_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs
);
println!(
"An average update-batch application time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
apply_elapsed,
apply_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs
);
println!(
"An average batch update time measured by a 1k-batch into an SMT with {size} leaves totals to {:.1} ms",
(compute_elapsed + apply_elapsed),
); );
println!(); println!();
@@ -185,29 +132,28 @@ pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result
} }
/// Runs the proof generation benchmark for the [`Smt`]. /// Runs the proof generation benchmark for the [`Smt`].
pub fn proof_generation(tree: &mut Smt) -> Result<(), MerkleError> { pub fn proof_generation(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
const NUM_PROOFS: usize = 100;
println!("Running a proof generation benchmark:"); println!("Running a proof generation benchmark:");
let mut insertion_times = Vec::new(); let mut insertion_times = Vec::new();
let size = tree.num_leaves(); for i in 0..20 {
for i in 0..NUM_PROOFS {
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes()); let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)]; let test_value = [ONE, ONE, ONE, Felt::new(size + i)];
tree.insert(test_key, test_value); tree.insert(test_key, test_value);
let now = Instant::now(); let now = Instant::now();
let _proof = tree.open(&test_key); let _proof = tree.open(&test_key);
insertion_times.push(now.elapsed().as_micros()); let elapsed = now.elapsed();
insertion_times.push(elapsed.as_secs_f32());
} }
println!( println!(
"An average proving time measured by {NUM_PROOFS} value proofs in an SMT with {size} leaves in {:.0} μs", "An average proving time measured by 20 value proofs in a SMT with {} key-value pairs in {:.3} microseconds",
// calculate the average size,
insertion_times.iter().sum::<u128>() as f64 / (NUM_PROOFS as f64), // calculate the average by dividing by 20 and convert to microseconds by multiplying by
// 1000000. As a result, we can only multiply by 50000
insertion_times.iter().sum::<f32>() * 50000f32,
); );
Ok(()) Ok(())

View File

@@ -97,6 +97,14 @@ impl NodeIndex {
self self
} }
/// Returns the parent of the current node. This is the same as [`Self::move_up()`], but returns
/// a new value instead of mutating `self`.
pub const fn parent(mut self) -> Self {
self.depth = self.depth.saturating_sub(1);
self.value >>= 1;
self
}
// PROVIDERS // PROVIDERS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@@ -128,7 +136,7 @@ impl NodeIndex {
self.value self.value
} }
/// Returns `true` if the current instance points to a right sibling node. /// Returns true if the current instance points to a right sibling node.
pub const fn is_value_odd(&self) -> bool { pub const fn is_value_odd(&self) -> bool {
(self.value & 1) == 1 (self.value & 1) == 1
} }

View File

@@ -303,7 +303,7 @@ impl PartialMmr {
if leaf_pos + 1 == self.forest if leaf_pos + 1 == self.forest
&& path.depth() == 0 && path.depth() == 0
&& self.peaks.last().is_some_and(|v| *v == leaf) && self.peaks.last().map_or(false, |v| *v == leaf)
{ {
self.track_latest = true; self.track_latest = true;
return Ok(()); return Ok(());

View File

@@ -21,9 +21,11 @@ mod path;
pub use path::{MerklePath, RootPath, ValuePath}; pub use path::{MerklePath, RootPath, ValuePath};
mod smt; mod smt;
#[cfg(feature = "internal")]
pub use smt::build_subtree_for_bench;
pub use smt::{ pub use smt::{
LeafIndex, MutationSet, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, LeafIndex, MutationSet, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError,
SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH, SubtreeLeaf, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
}; };
mod mmr; mod mmr;

View File

@@ -70,7 +70,7 @@ impl SmtLeaf {
Self::Single((key, value)) Self::Single((key, value))
} }
/// Returns a new multiple leaf with the specified entries. The leaf index is derived from the /// Returns a new single leaf with the specified entry. The leaf index is derived from the
/// entries' keys. /// entries' keys.
/// ///
/// # Errors /// # Errors

View File

@@ -71,12 +71,51 @@ impl Smt {
/// Returns a new [Smt] instantiated with leaves set as specified by the provided entries. /// Returns a new [Smt] instantiated with leaves set as specified by the provided entries.
/// ///
/// If the `concurrent` feature is enabled, this function uses a parallel implementation to
/// process the entries efficiently, otherwise it defaults to the sequential implementation.
///
/// All leaves omitted from the entries list are set to [Self::EMPTY_VALUE]. /// All leaves omitted from the entries list are set to [Self::EMPTY_VALUE].
/// ///
/// # Errors /// # Errors
/// Returns an error if the provided entries contain multiple values for the same key. /// Returns an error if the provided entries contain multiple values for the same key.
pub fn with_entries( pub fn with_entries(
entries: impl IntoIterator<Item = (RpoDigest, Word)>, entries: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> Result<Self, MerkleError> {
#[cfg(feature = "concurrent")]
{
let mut seen_keys = BTreeSet::new();
let entries: Vec<_> = entries
.into_iter()
.map(|(key, value)| {
if seen_keys.insert(key) {
Ok((key, value))
} else {
Err(MerkleError::DuplicateValuesForIndex(
LeafIndex::<SMT_DEPTH>::from(key).value(),
))
}
})
.collect::<Result<_, _>>()?;
if entries.is_empty() {
return Ok(Self::default());
}
<Self as SparseMerkleTree<SMT_DEPTH>>::with_entries_par(entries)
}
#[cfg(not(feature = "concurrent"))]
{
Self::with_entries_sequential(entries)
}
}
/// Returns a new [Smt] instantiated with leaves set as specified by the provided entries.
///
/// This sequential implementation processes entries one at a time to build the tree.
/// All leaves omitted from the entries list are set to [Self::EMPTY_VALUE].
///
/// # Errors
/// Returns an error if the provided entries contain multiple values for the same key.
pub fn with_entries_sequential(
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> Result<Self, MerkleError> { ) -> Result<Self, MerkleError> {
// create an empty tree // create an empty tree
let mut tree = Self::new(); let mut tree = Self::new();
@@ -101,6 +140,23 @@ impl Smt {
Ok(tree) Ok(tree)
} }
/// Returns a new [`Smt`] instantiated from already computed leaves and nodes.
///
/// This function performs minimal consistency checking. It is the caller's responsibility to
/// ensure the passed arguments are correct and consistent with each other.
///
/// # Panics
/// With debug assertions on, this function panics if `root` does not match the root node in
/// `inner_nodes`.
pub fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, SmtLeaf>,
root: RpoDigest,
) -> Self {
// Our particular implementation of `from_raw_parts()` never returns `Err`.
<Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
}
// PUBLIC ACCESSORS // PUBLIC ACCESSORS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@@ -114,11 +170,6 @@ impl Smt {
<Self as SparseMerkleTree<SMT_DEPTH>>::root(self) <Self as SparseMerkleTree<SMT_DEPTH>>::root(self)
} }
/// Returns the number of non-empty leaves in this tree.
pub fn num_leaves(&self) -> usize {
self.leaves.len()
}
/// Returns the leaf to which `key` maps /// Returns the leaf to which `key` maps
pub fn get_leaf(&self, key: &RpoDigest) -> SmtLeaf { pub fn get_leaf(&self, key: &RpoDigest) -> SmtLeaf {
<Self as SparseMerkleTree<SMT_DEPTH>>::get_leaf(self, key) <Self as SparseMerkleTree<SMT_DEPTH>>::get_leaf(self, key)
@@ -205,7 +256,7 @@ impl Smt {
<Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations(self, kv_pairs) <Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
} }
/// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree. /// Apply the prospective mutations computed with [`Smt::compute_mutations()`] to this tree.
/// ///
/// # Errors /// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns /// If `mutations` was computed on a tree with a different root than this one, returns
@@ -219,23 +270,6 @@ impl Smt {
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations(self, mutations) <Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations(self, mutations)
} }
/// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree
/// and returns the reverse mutation set.
///
/// Applying the reverse mutation sets to the updated tree will revert the changes.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
pub fn apply_mutations_with_reversion(
&mut self,
mutations: MutationSet<SMT_DEPTH, RpoDigest, Word>,
) -> Result<MutationSet<SMT_DEPTH, RpoDigest, Word>, MerkleError> {
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations_with_reversion(self, mutations)
}
// HELPERS // HELPERS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@@ -282,6 +316,19 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
const EMPTY_VALUE: Self::Value = EMPTY_WORD; const EMPTY_VALUE: Self::Value = EMPTY_WORD;
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0); const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, SmtLeaf>,
root: RpoDigest,
) -> Result<Self, MerkleError> {
if cfg!(debug_assertions) {
let root_node = inner_nodes.get(&NodeIndex::root()).unwrap();
assert_eq!(root_node.hash(), root);
}
Ok(Self { root, inner_nodes, leaves })
}
fn root(&self) -> RpoDigest { fn root(&self) -> RpoDigest {
self.root self.root
} }
@@ -297,12 +344,12 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth())) .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()))
} }
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> { fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
self.inner_nodes.insert(index, inner_node) self.inner_nodes.insert(index, inner_node);
} }
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> { fn remove_inner_node(&mut self, index: NodeIndex) {
self.inner_nodes.remove(&index) let _ = self.inner_nodes.remove(&index);
} }
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> { fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> {
@@ -366,6 +413,23 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
fn path_and_leaf_to_opening(path: MerklePath, leaf: SmtLeaf) -> SmtProof { fn path_and_leaf_to_opening(path: MerklePath, leaf: SmtLeaf) -> SmtProof {
SmtProof::new_unchecked(path, leaf) SmtProof::new_unchecked(path, leaf)
} }
fn pairs_to_leaf(mut pairs: Vec<(RpoDigest, Word)>) -> SmtLeaf {
assert!(!pairs.is_empty());
if pairs.len() > 1 {
SmtLeaf::new_multiple(pairs).unwrap()
} else {
let (key, value) = pairs.pop().unwrap();
// TODO: should we ever be constructing empty leaves from pairs?
if value == Self::EMPTY_VALUE {
let index = Self::key_to_leaf_index(&key);
SmtLeaf::new_empty(index)
} else {
SmtLeaf::new_single(key, value)
}
}
}
} }
impl Default for Smt { impl Default for Smt {

View File

@@ -1,14 +1,12 @@
use alloc::{collections::BTreeMap, vec::Vec}; use alloc::vec::Vec;
use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH}; use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
use crate::{ use crate::{
merkle::{ merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleStore},
smt::{NodeMutation, SparseMerkleTree},
EmptySubtreeRoots, MerkleStore, MutationSet,
},
utils::{Deserializable, Serializable}, utils::{Deserializable, Serializable},
Word, ONE, WORD_SIZE, Word, ONE, WORD_SIZE,
}; };
// SMT // SMT
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@@ -414,49 +412,21 @@ fn test_prospective_insertion() {
let mutations = smt.compute_mutations(vec![(key_1, value_1)]); let mutations = smt.compute_mutations(vec![(key_1, value_1)]);
assert_eq!(mutations.root(), root_1, "prospective root 1 did not match actual root 1"); assert_eq!(mutations.root(), root_1, "prospective root 1 did not match actual root 1");
let revert = apply_mutations(&mut smt, mutations); smt.apply_mutations(mutations).unwrap();
assert_eq!(smt.root(), root_1, "mutations before and after apply did not match"); assert_eq!(smt.root(), root_1, "mutations before and after apply did not match");
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), root_empty, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_1, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);
assert_eq!(
revert.node_mutations,
smt.inner_nodes.keys().map(|key| (*key, NodeMutation::Removal)).collect(),
"reverse mutations inner nodes did not match"
);
let mutations = smt.compute_mutations(vec![(key_2, value_2)]); let mutations = smt.compute_mutations(vec![(key_2, value_2)]);
assert_eq!(mutations.root(), root_2, "prospective root 2 did not match actual root 2"); assert_eq!(mutations.root(), root_2, "prospective root 2 did not match actual root 2");
let mutations = let mutations =
smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_2, value_2), (key_3, value_3)]); smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_2, value_2), (key_3, value_3)]);
assert_eq!(mutations.root(), root_3, "mutations before and after apply did not match"); assert_eq!(mutations.root(), root_3, "mutations before and after apply did not match");
let old_root = smt.root(); smt.apply_mutations(mutations).unwrap();
let revert = apply_mutations(&mut smt, mutations);
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);
// Edge case: multiple values at the same key, where a later pair restores the original value. // Edge case: multiple values at the same key, where a later pair restores the original value.
let mutations = smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_3, value_3)]); let mutations = smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_3, value_3)]);
assert_eq!(mutations.root(), root_3); assert_eq!(mutations.root(), root_3);
let old_root = smt.root(); smt.apply_mutations(mutations).unwrap();
let revert = apply_mutations(&mut smt, mutations);
assert_eq!(smt.root(), root_3); assert_eq!(smt.root(), root_3);
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_3, value_3)]),
"reverse mutations pairs did not match"
);
// Test batch updates, and that the order doesn't matter. // Test batch updates, and that the order doesn't matter.
let pairs = let pairs =
@@ -467,16 +437,8 @@ fn test_prospective_insertion() {
root_empty, root_empty,
"prospective root for batch removal did not match actual root", "prospective root for batch removal did not match actual root",
); );
let old_root = smt.root(); smt.apply_mutations(mutations).unwrap();
let revert = apply_mutations(&mut smt, mutations);
assert_eq!(smt.root(), root_empty, "mutations before and after apply did not match"); assert_eq!(smt.root(), root_empty, "mutations before and after apply did not match");
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]),
"reverse mutations pairs did not match"
);
let pairs = vec![(key_3, value_3), (key_1, value_1), (key_2, value_2)]; let pairs = vec![(key_3, value_3), (key_1, value_1), (key_2, value_2)];
let mutations = smt.compute_mutations(pairs); let mutations = smt.compute_mutations(pairs);
@@ -485,72 +447,6 @@ fn test_prospective_insertion() {
assert_eq!(smt.root(), root_3); assert_eq!(smt.root(), root_3);
} }
#[test]
fn test_mutations_revert() {
let mut smt = Smt::default();
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(1)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(2)]);
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(3)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3 = [3_u32.into(); WORD_SIZE];
smt.insert(key_1, value_1);
smt.insert(key_2, value_2);
let mutations =
smt.compute_mutations(vec![(key_1, EMPTY_WORD), (key_2, value_1), (key_3, value_3)]);
let original = smt.clone();
let revert = smt.apply_mutations_with_reversion(mutations).unwrap();
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), original.root(), "reverse mutations new root did not match");
smt.apply_mutations(revert).unwrap();
assert_eq!(smt, original, "SMT with applied revert mutations did not match original SMT");
}
#[test]
fn test_mutation_set_serialization() {
let mut smt = Smt::default();
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(1)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(2)]);
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(3)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3 = [3_u32.into(); WORD_SIZE];
smt.insert(key_1, value_1);
smt.insert(key_2, value_2);
let mutations =
smt.compute_mutations(vec![(key_1, EMPTY_WORD), (key_2, value_1), (key_3, value_3)]);
let serialized = mutations.to_bytes();
let deserialized =
MutationSet::<SMT_DEPTH, RpoDigest, Word>::read_from_bytes(&serialized).unwrap();
assert_eq!(deserialized, mutations, "deserialized mutations did not match original");
let revert = smt.apply_mutations_with_reversion(mutations).unwrap();
let serialized = revert.to_bytes();
let deserialized =
MutationSet::<SMT_DEPTH, RpoDigest, Word>::read_from_bytes(&serialized).unwrap();
assert_eq!(deserialized, revert, "deserialized mutations did not match original");
}
/// Tests that 2 key-value pairs stored in the same leaf have the same path /// Tests that 2 key-value pairs stored in the same leaf have the same path
#[test] #[test]
fn test_smt_path_to_keys_in_same_leaf_are_equal() { fn test_smt_path_to_keys_in_same_leaf_are_equal() {
@@ -706,19 +602,3 @@ fn build_multiple_leaf_node(kv_pairs: &[(RpoDigest, Word)]) -> RpoDigest {
Rpo256::hash_elements(&elements) Rpo256::hash_elements(&elements)
} }
/// Applies mutations with and without reversion to the given SMT, comparing resulting SMTs,
/// returning mutation set for reversion.
fn apply_mutations(
smt: &mut Smt,
mutation_set: MutationSet<SMT_DEPTH, RpoDigest, Word>,
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
let mut smt2 = smt.clone();
let reversion = smt.apply_mutations_with_reversion(mutation_set.clone()).unwrap();
smt2.apply_mutations(mutation_set).unwrap();
assert_eq!(&smt2, smt);
reversion
}

View File

@@ -1,6 +1,7 @@
use alloc::{collections::BTreeMap, vec::Vec}; use alloc::{collections::BTreeMap, vec::Vec};
use core::mem;
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use num::Integer;
use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex}; use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
use crate::{ use crate::{
@@ -42,7 +43,7 @@ pub const SMT_MAX_DEPTH: u8 = 64;
/// Every key maps to one leaf. If there are as many keys as there are leaves, then /// Every key maps to one leaf. If there are as many keys as there are leaves, then
/// [Self::Leaf] should be the same type as [Self::Value], as is the case with /// [Self::Leaf] should be the same type as [Self::Value], as is the case with
/// [crate::merkle::SimpleSmt]. However, if there are more keys than leaves, then [`Self::Leaf`] /// [crate::merkle::SimpleSmt]. However, if there are more keys than leaves, then [`Self::Leaf`]
/// must accommodate all keys that map to the same leaf. /// must accomodate all keys that map to the same leaf.
/// ///
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs. /// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
pub(crate) trait SparseMerkleTree<const DEPTH: u8> { pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
@@ -64,6 +65,17 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
// PROVIDED METHODS // PROVIDED METHODS
// --------------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------------
/// Creates a new sparse Merkle tree from an existing set of key-value pairs, in parallel.
#[cfg(feature = "concurrent")]
fn with_entries_par(entries: Vec<(Self::Key, Self::Value)>) -> Result<Self, MerkleError>
where
Self: Sized,
{
let (inner_nodes, leaves) = Self::build_subtrees(entries);
let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash();
Self::from_raw_parts(inner_nodes, leaves, root)
}
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
/// path to the leaf, as well as the leaf itself. /// path to the leaf, as well as the leaf itself.
fn open(&self, key: &Self::Key) -> Self::Opening { fn open(&self, key: &Self::Key) -> Self::Opening {
@@ -135,9 +147,9 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
node_hash = Rpo256::merge(&[left, right]); node_hash = Rpo256::merge(&[left, right]);
if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) { if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
// If a subtree is empty, then can remove the inner node, since it's equal to the // If a subtree is empty, when can remove the inner node, since it's equal to the
// default value // default value
self.remove_inner_node(index); self.remove_inner_node(index)
} else { } else {
self.insert_inner_node(index, InnerNode { left, right }); self.insert_inner_node(index, InnerNode { left, right });
} }
@@ -243,7 +255,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
} }
} }
/// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to /// Apply the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
/// this tree. /// this tree.
/// ///
/// # Errors /// # Errors
@@ -277,12 +289,8 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
for (index, mutation) in node_mutations { for (index, mutation) in node_mutations {
match mutation { match mutation {
Removal => { Removal => self.remove_inner_node(index),
self.remove_inner_node(index); Addition(node) => self.insert_inner_node(index, node),
},
Addition(node) => {
self.insert_inner_node(index, node);
},
} }
} }
@@ -295,79 +303,19 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
Ok(()) Ok(())
} }
/// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
/// this tree and returns the reverse mutation set. Applying the reverse mutation sets to the
/// updated tree will revert the changes.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
fn apply_mutations_with_reversion(
&mut self,
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
) -> Result<MutationSet<DEPTH, Self::Key, Self::Value>, MerkleError>
where
Self: Sized,
{
use NodeMutation::*;
let MutationSet {
old_root,
node_mutations,
new_pairs,
new_root,
} = mutations;
// Guard against accidentally trying to apply mutations that were computed against a
// different tree, including a stale version of this tree.
if old_root != self.root() {
return Err(MerkleError::ConflictingRoots {
expected_root: self.root(),
actual_root: old_root,
});
}
let mut reverse_mutations = BTreeMap::new();
for (index, mutation) in node_mutations {
match mutation {
Removal => {
if let Some(node) = self.remove_inner_node(index) {
reverse_mutations.insert(index, Addition(node));
}
},
Addition(node) => {
if let Some(old_node) = self.insert_inner_node(index, node) {
reverse_mutations.insert(index, Addition(old_node));
} else {
reverse_mutations.insert(index, Removal);
}
},
}
}
let mut reverse_pairs = BTreeMap::new();
for (key, value) in new_pairs {
if let Some(old_value) = self.insert_value(key.clone(), value) {
reverse_pairs.insert(key, old_value);
} else {
reverse_pairs.insert(key, Self::EMPTY_VALUE);
}
}
self.set_root(new_root);
Ok(MutationSet {
old_root: new_root,
node_mutations: reverse_mutations,
new_pairs: reverse_pairs,
new_root: old_root,
})
}
// REQUIRED METHODS // REQUIRED METHODS
// --------------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------------
/// Construct this type from already computed leaves and nodes. The caller ensures passed
/// arguments are correct and consistent with each other.
fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, Self::Leaf>,
root: RpoDigest,
) -> Result<Self, MerkleError>
where
Self: Sized;
/// The root of the tree /// The root of the tree
fn root(&self) -> RpoDigest; fn root(&self) -> RpoDigest;
@@ -378,10 +326,10 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
fn get_inner_node(&self, index: NodeIndex) -> InnerNode; fn get_inner_node(&self, index: NodeIndex) -> InnerNode;
/// Inserts an inner node at the given index /// Inserts an inner node at the given index
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode>; fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode);
/// Removes an inner node at the given index /// Removes an inner node at the given index
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode>; fn remove_inner_node(&mut self, index: NodeIndex);
/// Inserts a leaf node, and returns the value at the key if already exists /// Inserts a leaf node, and returns the value at the key if already exists
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>; fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>;
@@ -417,15 +365,134 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// Maps a key to a leaf index /// Maps a key to a leaf index
fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>; fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
/// Constructs a single leaf from an arbitrary amount of key-value pairs.
/// Those pairs must all have the same leaf index.
fn pairs_to_leaf(pairs: Vec<(Self::Key, Self::Value)>) -> Self::Leaf;
/// Maps a (MerklePath, Self::Leaf) to an opening. /// Maps a (MerklePath, Self::Leaf) to an opening.
/// ///
/// The length `path` is guaranteed to be equal to `DEPTH` /// The length `path` is guaranteed to be equal to `DEPTH`
fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening; fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening;
/// Performs the initial transforms for constructing a [`SparseMerkleTree`] by composing
/// subtrees. In other words, this function takes the key-value inputs to the tree, and produces
/// the inputs to feed into [`build_subtree()`].
///
/// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If
/// `pairs` is not correctly sorted, the returned computations will be incorrect.
///
/// # Panics
/// With debug assertions on, this function panics if it detects that `pairs` is not correctly
/// sorted. Without debug assertions, the returned computations will be incorrect.
fn sorted_pairs_to_leaves(
pairs: Vec<(Self::Key, Self::Value)>,
) -> PairComputations<u64, Self::Leaf> {
debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value()));
let mut accumulator: PairComputations<u64, Self::Leaf> = Default::default();
let mut accumulated_leaves: Vec<SubtreeLeaf> = Vec::with_capacity(pairs.len() / 2);
// As we iterate, we'll keep track of the kv-pairs we've seen so far that correspond to a
// single leaf. When we see a pair that's in a different leaf, we'll swap these pairs
// out and store them in our accumulated leaves.
let mut current_leaf_buffer: Vec<(Self::Key, Self::Value)> = Default::default();
let mut iter = pairs.into_iter().peekable();
while let Some((key, value)) = iter.next() {
let col = Self::key_to_leaf_index(&key).index.value();
let peeked_col = iter.peek().map(|(key, _v)| {
let index = Self::key_to_leaf_index(key);
let next_col = index.index.value();
// We panic if `pairs` is not sorted by column.
debug_assert!(next_col >= col);
next_col
});
current_leaf_buffer.push((key, value));
// If the next pair is the same column as this one, then we're done after adding this
// pair to the buffer.
if peeked_col == Some(col) {
continue;
}
// Otherwise, the next pair is a different column, or there is no next pair. Either way
// it's time to swap out our buffer.
let leaf_pairs = mem::take(&mut current_leaf_buffer);
let leaf = Self::pairs_to_leaf(leaf_pairs);
let hash = Self::hash_leaf(&leaf);
accumulator.nodes.insert(col, leaf);
accumulated_leaves.push(SubtreeLeaf { col, hash });
debug_assert!(current_leaf_buffer.is_empty());
}
// TODO: determine is there is any notable performance difference between computing
// subtree boundaries after the fact as an iterator adapter (like this), versus computing
// subtree boundaries as we go. Either way this function is only used at the beginning of a
// parallel construction, so it should not be a critical path.
accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect();
accumulator
}
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
///
/// `entries` need not be sorted. This function will sort them.
#[cfg(feature = "concurrent")]
fn build_subtrees(
mut entries: Vec<(Self::Key, Self::Value)>,
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<u64, Self::Leaf>) {
entries.sort_by_key(|item| {
let index = Self::key_to_leaf_index(&item.0);
index.value()
});
Self::build_subtrees_from_sorted_entries(entries)
}
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
///
/// This function is mostly an implementation detail of
/// [`SparseMerkleTree::with_entries_par()`].
#[cfg(feature = "concurrent")]
fn build_subtrees_from_sorted_entries(
entries: Vec<(Self::Key, Self::Value)>,
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<u64, Self::Leaf>) {
use rayon::prelude::*;
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
nodes: initial_leaves,
} = Self::sorted_pairs_to_leaves(entries);
for current_depth in (SUBTREE_DEPTH..=DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
let (nodes, mut subtree_roots): (Vec<BTreeMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
.into_par_iter()
.map(|subtree| {
debug_assert!(subtree.is_sorted());
debug_assert!(!subtree.is_empty());
let (nodes, subtree_root) = build_subtree(subtree, DEPTH, current_depth);
(nodes, subtree_root)
})
.unzip();
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
accumulated_nodes.extend(nodes.into_iter().flatten());
debug_assert!(!leaf_subtrees.is_empty());
}
(accumulated_nodes, initial_leaves)
}
} }
// INNER NODE // INNER NODE
// ================================================================================================ // ================================================================================================
/// This struct is public so functions returning it can be used in `benches/`, but is otherwise not
/// part of the public API.
#[doc(hidden)]
#[derive(Debug, Default, Clone, PartialEq, Eq)] #[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct InnerNode { pub struct InnerNode {
@@ -499,7 +566,7 @@ impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
/// [`MutationSet`] stores this type in relation to a [`NodeIndex`] to keep track of what changes /// [`MutationSet`] stores this type in relation to a [`NodeIndex`] to keep track of what changes
/// need to occur at which node indices. /// need to occur at which node indices.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum NodeMutation { pub(crate) enum NodeMutation {
/// Corresponds to [`SparseMerkleTree::remove_inner_node()`]. /// Corresponds to [`SparseMerkleTree::remove_inner_node()`].
Removal, Removal,
/// Corresponds to [`SparseMerkleTree::insert_inner_node()`]. /// Corresponds to [`SparseMerkleTree::insert_inner_node()`].
@@ -532,94 +599,203 @@ pub struct MutationSet<const DEPTH: u8, K, V> {
} }
impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> { impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
/// Returns the SMT root that was calculated during `SparseMerkleTree::compute_mutations()`. See /// Queries the root that was calculated during `SparseMerkleTree::compute_mutations()`. See
/// that method for more information. /// that method for more information.
pub fn root(&self) -> RpoDigest { pub fn root(&self) -> RpoDigest {
self.new_root self.new_root
} }
/// Returns the SMT root before the mutations were applied.
pub fn old_root(&self) -> RpoDigest {
self.old_root
} }
/// Returns the set of inner nodes that need to be removed or added. // SUBTREES
pub fn node_mutations(&self) -> &BTreeMap<NodeIndex, NodeMutation> { // ================================================================================================
&self.node_mutations /// A subtree is of depth 8.
const SUBTREE_DEPTH: u8 = 8;
/// A depth-8 subtree contains 256 "columns" that can possibly be occupied.
const COLS_PER_SUBTREE: u64 = u64::pow(2, SUBTREE_DEPTH as u32);
/// Helper struct for organizing the data we care about when computing Merkle subtrees.
///
/// Note that these represet "conceptual" leaves of some subtree, not necessarily
/// the leaf type for the sparse Merkle tree.
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
pub struct SubtreeLeaf {
/// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known.
pub col: u64,
/// The hash of the node this `SubtreeLeaf` represents.
pub hash: RpoDigest,
} }
/// Returns the set of top-level key-value pairs that need to be added, updated or deleted /// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`].
/// (i.e. set to `EMPTY_WORD`). #[derive(Debug, Clone, PartialEq, Eq)]
pub fn new_pairs(&self) -> &BTreeMap<K, V> { pub(crate) struct PairComputations<K, L> {
&self.new_pairs /// Literal leaves to be added to the sparse Merkle tree's internal mapping.
pub nodes: BTreeMap<K, L>,
/// "Conceptual" leaves that will be used for computations.
pub leaves: Vec<Vec<SubtreeLeaf>>,
}
// Derive requires `L` to impl Default, even though we don't actually need that.
impl<K, L> Default for PairComputations<K, L> {
fn default() -> Self {
Self {
nodes: Default::default(),
leaves: Default::default(),
}
} }
} }
// SERIALIZATION #[derive(Debug)]
struct SubtreeLeavesIter<'s> {
leaves: core::iter::Peekable<alloc::vec::Drain<'s, SubtreeLeaf>>,
}
impl<'s> SubtreeLeavesIter<'s> {
fn from_leaves(leaves: &'s mut Vec<SubtreeLeaf>) -> Self {
// TODO: determine if there is any notable performance difference between taking a Vec,
// which many need flattening first, vs storing a `Box<dyn Iterator<Item = SubtreeLeaf>>`.
// The latter may have self-referential properties that are impossible to express in purely
// safe Rust Rust.
Self { leaves: leaves.drain(..).peekable() }
}
}
impl core::iter::Iterator for SubtreeLeavesIter<'_> {
type Item = Vec<SubtreeLeaf>;
/// Each `next()` collects an entire subtree.
fn next(&mut self) -> Option<Vec<SubtreeLeaf>> {
let mut subtree: Vec<SubtreeLeaf> = Default::default();
let mut last_subtree_col = 0;
while let Some(leaf) = self.leaves.peek() {
last_subtree_col = u64::max(1, last_subtree_col);
let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE);
let next_subtree_col = if is_exact_multiple {
u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE)
} else {
last_subtree_col.next_multiple_of(COLS_PER_SUBTREE)
};
last_subtree_col = leaf.col;
if leaf.col < next_subtree_col {
subtree.push(self.leaves.next().unwrap());
} else if subtree.is_empty() {
continue;
} else {
break;
}
}
if subtree.is_empty() {
debug_assert!(self.leaves.peek().is_none());
return None;
}
Some(subtree)
}
}
// HELPER FUNCTIONS
// ================================================================================================ // ================================================================================================
impl Serializable for InnerNode { /// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and
fn write_into<W: ByteWriter>(&self, target: &mut W) { /// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and
self.left.write_into(target); /// `leaves` must not contain more than one depth-8 subtree's worth of leaves.
self.right.write_into(target); ///
} /// This function will then calculate the inner nodes above each leaf for 8 layers, as well as
} /// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into
/// itself.
impl Deserializable for InnerNode { ///
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> { /// # Panics
let left = source.read()?; /// With debug assertions on, this function panics under invalid inputs: if `leaves` contains
let right = source.read()?; /// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to
/// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified
Ok(Self { left, right }) /// maximum depth (`DEPTH`), or if `leaves` is not sorted.
} fn build_subtree(
} mut leaves: Vec<SubtreeLeaf>,
tree_depth: u8,
impl Serializable for NodeMutation { bottom_depth: u8,
fn write_into<W: ByteWriter>(&self, target: &mut W) { ) -> (BTreeMap<NodeIndex, InnerNode>, SubtreeLeaf) {
match self { debug_assert!(bottom_depth <= tree_depth);
NodeMutation::Removal => target.write_bool(false), debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH));
NodeMutation::Addition(inner_node) => { debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32));
target.write_bool(true); let subtree_root = bottom_depth - SUBTREE_DEPTH;
inner_node.write_into(target); let mut inner_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let mut next_leaves: Vec<SubtreeLeaf> = Vec::with_capacity(leaves.len() / 2);
for next_depth in (subtree_root..bottom_depth).rev() {
debug_assert!(next_depth <= bottom_depth);
// `next_depth` is the stuff we're making.
// `current_depth` is the stuff we have.
let current_depth = next_depth + 1;
let mut iter = leaves.drain(..).peekable();
while let Some(first) = iter.next() {
// On non-continuous iterations, including the first iteration, `first_column` may
// be a left or right node. On subsequent continuous iterations, we will always call
// `iter.next()` twice.
// On non-continuous iterations (including the very first iteration), this column
// could be either on the left or the right. If the next iteration is not
// discontinuous with our right node, then the next iteration's
let is_right = first.col.is_odd();
let (left, right) = if is_right {
// Discontinuous iteration: we have no left node, so it must be empty.
let left = SubtreeLeaf {
col: first.col - 1,
hash: *EmptySubtreeRoots::entry(tree_depth, current_depth),
};
let right = first;
(left, right)
} else {
let left = first;
let right_col = first.col + 1;
let right = match iter.peek().copied() {
Some(SubtreeLeaf { col, .. }) if col == right_col => {
// Our inputs must be sorted.
debug_assert!(left.col <= col);
// The next leaf in the iterator is our sibling. Use it and consume it!
iter.next().unwrap()
}, },
// Otherwise, the leaves don't contain our sibling, so our sibling must be
// empty.
_ => SubtreeLeaf {
col: right_col,
hash: *EmptySubtreeRoots::entry(tree_depth, current_depth),
},
};
(left, right)
};
let index = NodeIndex::new_unchecked(current_depth, left.col).parent();
let node = InnerNode { left: left.hash, right: right.hash };
let hash = node.hash();
let &equivalent_empty_hash = EmptySubtreeRoots::entry(tree_depth, next_depth);
// If this hash is empty, then it doesn't become a new inner node, nor does it count
// as a leaf for the next depth.
if hash != equivalent_empty_hash {
inner_nodes.insert(index, node);
next_leaves.push(SubtreeLeaf { col: index.value(), hash });
} }
} }
// Stop borrowing `leaves`, so we can swap it.
// The iterator is empty at this point anyway.
drop(iter);
// After each depth, consider the stuff we just made the new "leaves", and empty the
// other collection.
mem::swap(&mut leaves, &mut next_leaves);
}
debug_assert_eq!(leaves.len(), 1);
let root = leaves.pop().unwrap();
(inner_nodes, root)
} }
impl Deserializable for NodeMutation { #[cfg(feature = "internal")]
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> { pub fn build_subtree_for_bench(
if source.read_bool()? { leaves: Vec<SubtreeLeaf>,
let inner_node = source.read()?; tree_depth: u8,
return Ok(NodeMutation::Addition(inner_node)); bottom_depth: u8,
) -> (BTreeMap<NodeIndex, InnerNode>, SubtreeLeaf) {
build_subtree(leaves, tree_depth, bottom_depth)
} }
Ok(NodeMutation::Removal) // TESTS
} // ================================================================================================
} #[cfg(test)]
mod tests;
impl<const DEPTH: u8, K: Serializable, V: Serializable> Serializable for MutationSet<DEPTH, K, V> {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.old_root);
target.write(self.new_root);
self.node_mutations.write_into(target);
self.new_pairs.write_into(target);
}
}
impl<const DEPTH: u8, K: Deserializable + Ord, V: Deserializable> Deserializable
for MutationSet<DEPTH, K, V>
{
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let old_root = source.read()?;
let new_root = source.read()?;
let node_mutations = source.read()?;
let new_pairs = source.read()?;
Ok(Self {
old_root,
node_mutations,
new_pairs,
new_root,
})
}
}

View File

@@ -1,4 +1,7 @@
use alloc::collections::{BTreeMap, BTreeSet}; use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};
use super::{ use super::{
super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError,
@@ -97,6 +100,23 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
Ok(tree) Ok(tree)
} }
/// Returns a new [`SimpleSmt`] instantiated from already computed leaves and nodes.
///
/// This function performs minimal consistency checking. It is the caller's responsibility to
/// ensure the passed arguments are correct and consistent with each other.
///
/// # Panics
/// With debug assertions on, this function panics if `root` does not match the root node in
/// `inner_nodes`.
pub fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, Word>,
root: RpoDigest,
) -> Self {
// Our particular implementation of `from_raw_parts()` never returns `Err`.
<Self as SparseMerkleTree<DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
}
/// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices /// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
/// starting at index 0. /// starting at index 0.
pub fn with_contiguous_leaves( pub fn with_contiguous_leaves(
@@ -221,7 +241,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
<Self as SparseMerkleTree<DEPTH>>::compute_mutations(self, kv_pairs) <Self as SparseMerkleTree<DEPTH>>::compute_mutations(self, kv_pairs)
} }
/// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this /// Apply the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this
/// tree. /// tree.
/// ///
/// # Errors /// # Errors
@@ -236,23 +256,6 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
<Self as SparseMerkleTree<DEPTH>>::apply_mutations(self, mutations) <Self as SparseMerkleTree<DEPTH>>::apply_mutations(self, mutations)
} }
/// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to
/// this tree and returns the reverse mutation set.
///
/// Applying the reverse mutation sets to the updated tree will revert the changes.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
/// root hash the `mutations` were computed against, and the second item is the actual
/// current root of this tree.
pub fn apply_mutations_with_reversion(
&mut self,
mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
) -> Result<MutationSet<DEPTH, LeafIndex<DEPTH>, Word>, MerkleError> {
<Self as SparseMerkleTree<DEPTH>>::apply_mutations_with_reversion(self, mutations)
}
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is /// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
/// computed as `DEPTH - SUBTREE_DEPTH`. /// computed as `DEPTH - SUBTREE_DEPTH`.
/// ///
@@ -323,6 +326,19 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
const EMPTY_VALUE: Self::Value = EMPTY_WORD; const EMPTY_VALUE: Self::Value = EMPTY_WORD;
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0); const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0);
fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, Word>,
root: RpoDigest,
) -> Result<Self, MerkleError> {
if cfg!(debug_assertions) {
let root_node = inner_nodes.get(&NodeIndex::root()).unwrap();
assert_eq!(root_node.hash(), root);
}
Ok(Self { root, inner_nodes, leaves })
}
fn root(&self) -> RpoDigest { fn root(&self) -> RpoDigest {
self.root self.root
} }
@@ -338,12 +354,12 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(DEPTH, index.depth())) .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(DEPTH, index.depth()))
} }
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> { fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
self.inner_nodes.insert(index, inner_node) self.inner_nodes.insert(index, inner_node);
} }
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> { fn remove_inner_node(&mut self, index: NodeIndex) {
self.inner_nodes.remove(&index) let _ = self.inner_nodes.remove(&index);
} }
fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> { fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> {
@@ -387,4 +403,11 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
fn path_and_leaf_to_opening(path: MerklePath, leaf: Word) -> ValuePath { fn path_and_leaf_to_opening(path: MerklePath, leaf: Word) -> ValuePath {
(path, leaf).into() (path, leaf).into()
} }
fn pairs_to_leaf(mut pairs: Vec<(LeafIndex<DEPTH>, Word)>) -> Word {
// SimpleSmt can't have more than one value per key.
assert_eq!(pairs.len(), 1);
let (_key, value) = pairs.pop().unwrap();
value
}
} }

417
src/merkle/smt/tests.rs Normal file
View File

@@ -0,0 +1,417 @@
use alloc::{collections::BTreeMap, vec::Vec};
use super::{
build_subtree, InnerNode, LeafIndex, NodeIndex, PairComputations, SmtLeaf, SparseMerkleTree,
SubtreeLeaf, SubtreeLeavesIter, COLS_PER_SUBTREE, SUBTREE_DEPTH,
};
use crate::{
hash::rpo::RpoDigest,
merkle::{Smt, SMT_DEPTH},
Felt, Word, ONE,
};
fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf {
SubtreeLeaf {
col: leaf.index().index.value(),
hash: leaf.hash(),
}
}
#[test]
fn test_sorted_pairs_to_leaves() {
let entries: Vec<(RpoDigest, Word)> = vec![
// Subtree 0.
(RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]),
(RpoDigest::new([ONE, ONE, ONE, Felt::new(17)]), [ONE; 4]),
// Leaf index collision.
(RpoDigest::new([ONE, ONE, Felt::new(10), Felt::new(20)]), [ONE; 4]),
(RpoDigest::new([ONE, ONE, Felt::new(20), Felt::new(20)]), [ONE; 4]),
// Subtree 1. Normal single leaf again.
(RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), // Subtree boundary.
(RpoDigest::new([ONE, ONE, ONE, Felt::new(401)]), [ONE; 4]),
// Subtree 2. Another normal leaf.
(RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]),
];
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
let control_leaves: Vec<SmtLeaf> = {
let mut entries_iter = entries.iter().cloned();
let mut next_entry = || entries_iter.next().unwrap();
let control_leaves = vec![
// Subtree 0.
SmtLeaf::Single(next_entry()),
SmtLeaf::Single(next_entry()),
SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(),
// Subtree 1.
SmtLeaf::Single(next_entry()),
SmtLeaf::Single(next_entry()),
// Subtree 2.
SmtLeaf::Single(next_entry()),
];
assert_eq!(entries_iter.next(), None);
control_leaves
};
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = {
let mut control_leaves_iter = control_leaves.iter();
let mut next_leaf = || control_leaves_iter.next().unwrap();
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = [
// Subtree 0.
vec![next_leaf(), next_leaf(), next_leaf()],
// Subtree 1.
vec![next_leaf(), next_leaf()],
// Subtree 2.
vec![next_leaf()],
]
.map(|subtree| subtree.into_iter().map(smtleaf_to_subtree_leaf).collect())
.to_vec();
assert_eq!(control_leaves_iter.next(), None);
control_subtree_leaves
};
let subtrees: PairComputations<u64, SmtLeaf> = Smt::sorted_pairs_to_leaves(entries);
// This will check that the hashes, columns, and subtree assignments all match.
assert_eq!(subtrees.leaves, control_subtree_leaves);
// Flattening and re-separating out the leaves into subtrees should have the same result.
let mut all_leaves: Vec<SubtreeLeaf> = subtrees.leaves.clone().into_iter().flatten().collect();
let re_grouped: Vec<Vec<_>> = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect();
assert_eq!(subtrees.leaves, re_grouped);
// Then finally we might as well check the computed leaf nodes too.
let control_leaves: BTreeMap<u64, SmtLeaf> = control
.leaves()
.map(|(index, value)| (index.index.value(), value.clone()))
.collect();
for (column, test_leaf) in subtrees.nodes {
if test_leaf.is_empty() {
continue;
}
let control_leaf = control_leaves
.get(&column)
.unwrap_or_else(|| panic!("no leaf node found for column {column}"));
assert_eq!(control_leaf, &test_leaf);
}
}
// Helper for the below tests.
fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> {
(0..pair_count)
.map(|i| {
let leaf_index = ((i as f64 / pair_count as f64) * (pair_count as f64)) as u64;
let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]);
let value = [ONE, ONE, ONE, Felt::new(i)];
(key, value)
})
.collect()
}
#[test]
fn test_single_subtree() {
// A single subtree's worth of leaves.
const PAIR_COUNT: u64 = COLS_PER_SUBTREE;
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
// `entries` should already be sorted by nature of how we constructed it.
let leaves = Smt::sorted_pairs_to_leaves(entries).leaves;
let leaves = leaves.into_iter().next().unwrap();
let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH);
assert!(!first_subtree.is_empty());
// The inner nodes computed from that subtree should match the nodes in our control tree.
for (index, node) in first_subtree.into_iter() {
let control = control.get_inner_node(index);
assert_eq!(
control, node,
"subtree-computed node at index {index:?} does not match control",
);
}
// The root returned should also match the equivalent node in the control tree.
let control_root_index =
NodeIndex::new(SMT_DEPTH - SUBTREE_DEPTH, subtree_root.col).expect("Valid root index");
let control_root_node = control.get_inner_node(control_root_index);
let control_hash = control_root_node.hash();
assert_eq!(
control_hash, subtree_root.hash,
"Subtree-computed root at index {control_root_index:?} does not match control"
);
}
// Test that not just can we compute a subtree correctly, but we can feed the results of one
// subtree into computing another. In other words, test that `build_subtree()` is correctly
// composable.
#[test]
fn test_two_subtrees() {
// Two subtrees' worth of leaves.
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2;
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries);
// With two subtrees' worth of leaves, we should have exactly two subtrees.
let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap();
assert_eq!(first.len() as u64, PAIR_COUNT / 2);
assert_eq!(first.len(), second.len());
let mut current_depth = SMT_DEPTH;
let mut next_leaves: Vec<SubtreeLeaf> = Default::default();
let (first_nodes, first_root) = build_subtree(first, SMT_DEPTH, current_depth);
next_leaves.push(first_root);
let (second_nodes, second_root) = build_subtree(second, SMT_DEPTH, current_depth);
next_leaves.push(second_root);
// All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle.
let total_computed = first_nodes.len() + second_nodes.len() + next_leaves.len();
assert_eq!(total_computed as u64, PAIR_COUNT);
// Verify the computed nodes of both subtrees.
let computed_nodes = first_nodes.clone().into_iter().chain(second_nodes);
for (index, test_node) in computed_nodes {
let control_node = control.get_inner_node(index);
assert_eq!(
control_node, test_node,
"subtree-computed node at index {index:?} does not match control",
);
}
current_depth -= SUBTREE_DEPTH;
let (nodes, root_leaf) = build_subtree(next_leaves, SMT_DEPTH, current_depth);
assert_eq!(nodes.len(), SUBTREE_DEPTH as usize);
assert_eq!(root_leaf.col, 0);
for (index, test_node) in nodes {
let control_node = control.get_inner_node(index);
assert_eq!(
control_node, test_node,
"subtree-computed node at index {index:?} does not match control",
);
}
let index = NodeIndex::new(current_depth - SUBTREE_DEPTH, root_leaf.col).unwrap();
let control_root = control.get_inner_node(index).hash();
assert_eq!(control_root, root_leaf.hash, "Root mismatch");
}
#[test]
fn test_singlethreaded_subtrees() {
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
nodes: test_leaves,
} = Smt::sorted_pairs_to_leaves(entries);
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
// There's no flat_map_unzip(), so this is the best we can do.
let (nodes, mut subtree_roots): (Vec<BTreeMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
.into_iter()
.enumerate()
.map(|(i, subtree)| {
// Pre-assertions.
assert!(
subtree.is_sorted(),
"subtree {i} at bottom-depth {current_depth} is not sorted",
);
assert!(
!subtree.is_empty(),
"subtree {i} at bottom-depth {current_depth} is empty!",
);
// Do actual things.
let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth);
// Post-assertions.
for (&index, test_node) in nodes.iter() {
let control_node = control.get_inner_node(index);
assert_eq!(
test_node, &control_node,
"depth {} subtree {}: test node does not match control at index {:?}",
current_depth, i, index,
);
}
(nodes, subtree_root)
})
.unzip();
// Update state between each depth iteration.
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
accumulated_nodes.extend(nodes.into_iter().flatten());
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
}
// Make sure the true leaves match, first checking length and then checking each individual
// leaf.
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
let control_leaves_len = control_leaves.len();
let test_leaves_len = test_leaves.len();
assert_eq!(test_leaves_len, control_leaves_len);
for (col, ref test_leaf) in test_leaves {
let index = LeafIndex::new_max_depth(col);
let &control_leaf = control_leaves.get(&index).unwrap();
assert_eq!(test_leaf, control_leaf, "test leaf at column {col} does not match control");
}
// Make sure the inner nodes match, checking length first and then each individual leaf.
let control_nodes_len = control.inner_nodes().count();
let test_nodes_len = accumulated_nodes.len();
assert_eq!(test_nodes_len, control_nodes_len);
for (index, test_node) in accumulated_nodes.clone() {
let control_node = control.get_inner_node(index);
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
}
// After the last iteration of the above for loop, we should have the new root node actually
// in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from
// `build_subtree()`. So let's check both!
let control_root = control.get_inner_node(NodeIndex::root());
// That for loop should have left us with only one leaf subtree...
let [leaf_subtree]: [Vec<_>; 1] = leaf_subtrees.try_into().unwrap();
// which itself contains only one 'leaf'...
let [root_leaf]: [SubtreeLeaf; 1] = leaf_subtree.try_into().unwrap();
// which matches the expected root.
assert_eq!(control.root(), root_leaf.hash);
// Likewise `accumulated_nodes` should contain a node at the root index...
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
// and it should match our actual root.
let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap();
assert_eq!(control_root, *test_root);
// And of course the root we got from each place should match.
assert_eq!(control.root(), root_leaf.hash);
}
/// The parallel version of `test_singlethreaded_subtree()`.
#[test]
#[cfg(feature = "concurrent")]
fn test_multithreaded_subtrees() {
use rayon::prelude::*;
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
nodes: test_leaves,
} = Smt::sorted_pairs_to_leaves(entries);
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
let (nodes, mut subtree_roots): (Vec<BTreeMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
.into_par_iter()
.enumerate()
.map(|(i, subtree)| {
// Pre-assertions.
assert!(
subtree.is_sorted(),
"subtree {i} at bottom-depth {current_depth} is not sorted",
);
assert!(
!subtree.is_empty(),
"subtree {i} at bottom-depth {current_depth} is empty!",
);
let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth);
// Post-assertions.
for (&index, test_node) in nodes.iter() {
let control_node = control.get_inner_node(index);
assert_eq!(
test_node, &control_node,
"depth {} subtree {}: test node does not match control at index {:?}",
current_depth, i, index,
);
}
(nodes, subtree_root)
})
.unzip();
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
accumulated_nodes.extend(nodes.into_iter().flatten());
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
}
// Make sure the true leaves match, checking length first and then each individual leaf.
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
let control_leaves_len = control_leaves.len();
let test_leaves_len = test_leaves.len();
assert_eq!(test_leaves_len, control_leaves_len);
for (col, ref test_leaf) in test_leaves {
let index = LeafIndex::new_max_depth(col);
let &control_leaf = control_leaves.get(&index).unwrap();
assert_eq!(test_leaf, control_leaf);
}
// Make sure the inner nodes match, checking length first and then each individual leaf.
let control_nodes_len = control.inner_nodes().count();
let test_nodes_len = accumulated_nodes.len();
assert_eq!(test_nodes_len, control_nodes_len);
for (index, test_node) in accumulated_nodes.clone() {
let control_node = control.get_inner_node(index);
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
}
// After the last iteration of the above for loop, we should have the new root node actually
// in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from
// `build_subtree()`. So let's check both!
let control_root = control.get_inner_node(NodeIndex::root());
// That for loop should have left us with only one leaf subtree...
let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap();
// which itself contains only one 'leaf'...
let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap();
// which matches the expected root.
assert_eq!(control.root(), root_leaf.hash);
// Likewise `accumulated_nodes` should contain a node at the root index...
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
// and it should match our actual root.
let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap();
assert_eq!(control_root, *test_root);
// And of course the root we got from each place should match.
assert_eq!(control.root(), root_leaf.hash);
}
#[test]
#[cfg(feature = "concurrent")]
fn test_with_entries_parallel() {
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
let smt = Smt::with_entries(entries.clone()).unwrap();
assert_eq!(smt.root(), control.root());
assert_eq!(smt, control);
}

View File

@@ -725,7 +725,7 @@ fn get_leaf_depth_works_with_depth_8() {
assert_eq!(8, store.get_leaf_depth(root, 8, k).unwrap()); assert_eq!(8, store.get_leaf_depth(root, 8, k).unwrap());
} }
// flip last bit of a and expect it to return the same depth, but for an empty node // flip last bit of a and expect it to return the the same depth, but for an empty node
assert_eq!(8, store.get_leaf_depth(root, 8, 0b01101000_u64).unwrap()); assert_eq!(8, store.get_leaf_depth(root, 8, 0b01101000_u64).unwrap());
// flip fourth bit of a and expect an empty node on depth 4 // flip fourth bit of a and expect an empty node on depth 4

View File

@@ -174,6 +174,36 @@ impl RandomCoin for RpoRandomCoin {
Ok(values) Ok(values)
} }
fn reseed_with_salt(
&mut self,
data: <Self::Hasher as winter_crypto::Hasher>::Digest,
salt: Option<<Self::Hasher as winter_crypto::Hasher>::Digest>,
) {
// Reset buffer
self.current = RATE_START;
// Add the new seed material to the first half of the rate portion of the RPO state
let data: Word = data.into();
self.state[RATE_START] += data[0];
self.state[RATE_START + 1] += data[1];
self.state[RATE_START + 2] += data[2];
self.state[RATE_START + 3] += data[3];
if let Some(salt) = salt {
// Add the salt to the second half of the rate portion of the RPO state
let data: Word = salt.into();
self.state[RATE_START + 4] += data[0];
self.state[RATE_START + 5] += data[1];
self.state[RATE_START + 6] += data[2];
self.state[RATE_START + 7] += data[3];
}
// Absorb
Rpo256::apply_permutation(&mut self.state);
}
} }
// FELT RNG IMPLEMENTATION // FELT RNG IMPLEMENTATION

View File

@@ -172,6 +172,36 @@ impl RandomCoin for RpxRandomCoin {
Ok(values) Ok(values)
} }
fn reseed_with_salt(
&mut self,
data: <Self::Hasher as winter_crypto::Hasher>::Digest,
salt: Option<<Self::Hasher as winter_crypto::Hasher>::Digest>,
) {
// Reset buffer
self.current = RATE_START;
// Add the new seed material to the first half of the rate portion of the RPO state
let data: Word = data.into();
self.state[RATE_START] += data[0];
self.state[RATE_START + 1] += data[1];
self.state[RATE_START + 2] += data[2];
self.state[RATE_START + 3] += data[3];
if let Some(salt) = salt {
// Add the salt to the second half of the rate portion of the RPO state
let data: Word = salt.into();
self.state[RATE_START + 4] += data[0];
self.state[RATE_START + 5] += data[1];
self.state[RATE_START + 6] += data[2];
self.state[RATE_START + 7] += data[3];
}
// Absorb
Rpx256::apply_permutation(&mut self.state);
}
} }
// FELT RNG IMPLEMENTATION // FELT RNG IMPLEMENTATION