mirror of
https://github.com/arnaucube/miden-crypto.git
synced 2026-01-11 08:31:30 +01:00
Compare commits
77 Commits
al-gkr-bas
...
v0.9.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4885f885a4 | ||
|
|
5a2e917dd5 | ||
|
|
2be17b74fb | ||
|
|
b35e99c390 | ||
|
|
4c8a9809ed | ||
|
|
ce9b45fe77 | ||
|
|
56d014898d | ||
|
|
8e81ccdb68 | ||
|
|
999a64fca6 | ||
|
|
4bc4bea0db | ||
|
|
dbab0e9aa9 | ||
|
|
24f72c986b | ||
|
|
cd4525c7ad | ||
|
|
552d90429b | ||
|
|
119c7e2b6d | ||
|
|
45e7e78118 | ||
|
|
a9475b2a2d | ||
|
|
e55b3ed2ce | ||
|
|
61a0764a61 | ||
|
|
3d71a9b59b | ||
|
|
da12fd258a | ||
|
|
5fcf98669d | ||
|
|
1cdd3dbbfa | ||
|
|
d59ffe274a | ||
|
|
727ed8fb3e | ||
|
|
0acceaa526 | ||
|
|
3882e0f719 | ||
|
|
70e39e7b39 | ||
|
|
5596db7868 | ||
|
|
a933ff2fa0 | ||
|
|
8ea37904e3 | ||
|
|
1004246bfe | ||
|
|
dae9de9068 | ||
|
|
7e9d4a4316 | ||
|
|
c9ab3beccc | ||
|
|
260592f8e7 | ||
|
|
6b5db8a6db | ||
|
|
3ebee98b0f | ||
|
|
457c985a92 | ||
|
|
f894ed9cde | ||
|
|
ac7593a13c | ||
|
|
004a3bc7a8 | ||
|
|
479fe5e649 | ||
|
|
a0f533241f | ||
|
|
05309b19bb | ||
|
|
be1d631630 | ||
|
|
4d0d8d3058 | ||
|
|
59d93cb8ba | ||
|
|
9baddfd138 | ||
|
|
8f92f44a55 | ||
|
|
36d3b8dc46 | ||
|
|
7e13346e04 | ||
|
|
9a18ed6749 | ||
|
|
df2650eb1f | ||
|
|
18310a89f0 | ||
|
|
d719cc2663 | ||
|
|
fa475d1929 | ||
|
|
25b8cb64ba | ||
|
|
389fcb03c2 | ||
|
|
b7cb346e22 | ||
|
|
fd480f827a | ||
|
|
9f95582654 | ||
|
|
1f92d5417a | ||
|
|
9b0ce0810b | ||
|
|
938250453a | ||
|
|
9ccac2baf0 | ||
|
|
525062d023 | ||
|
|
3a5264c428 | ||
|
|
a8acc0b39d | ||
|
|
5f2d170435 | ||
|
|
9d52958f64 | ||
|
|
a2a26e2aba | ||
|
|
3125144445 | ||
|
|
f33a982f29 | ||
|
|
41f03fbe91 | ||
|
|
65495aeb18 | ||
|
|
0a2d440524 |
20
.editorconfig
Normal file
20
.editorconfig
Normal file
@@ -0,0 +1,20 @@
|
||||
# Documentation available at editorconfig.org
|
||||
|
||||
root=true
|
||||
|
||||
[*]
|
||||
ident_style = space
|
||||
ident_size = 4
|
||||
end_of_line = lf
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
[*.rs]
|
||||
max_line_length = 100
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
||||
|
||||
[*.yml]
|
||||
ident_size = 2
|
||||
101
.github/workflows/ci.yml
vendored
101
.github/workflows/ci.yml
vendored
@@ -1,101 +0,0 @@
|
||||
name: CI
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, repoened, synchronize]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.args}}
|
||||
runs-on: ${{matrix.os}}-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable, nightly]
|
||||
os: [ubuntu]
|
||||
target: [wasm32-unknown-unknown]
|
||||
args: [--no-default-features --target wasm32-unknown-unknown]
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
override: true
|
||||
- run: rustup target add ${{matrix.target}}
|
||||
- name: Test
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: build
|
||||
args: ${{matrix.args}}
|
||||
|
||||
test:
|
||||
name: Test ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.features}}
|
||||
runs-on: ${{matrix.os}}-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable, nightly]
|
||||
os: [ubuntu]
|
||||
features: ["--features default,std,serde", --no-default-features]
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
override: true
|
||||
- name: Test
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: ${{matrix.features}}
|
||||
|
||||
clippy:
|
||||
name: Clippy with ${{matrix.features}}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
features: ["--features default,std,serde", --no-default-features]
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install minimal nightly with clippy
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: nightly
|
||||
components: clippy
|
||||
override: true
|
||||
- name: Clippy
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: clippy
|
||||
args: --all ${{matrix.features}} -- -D clippy::all -D warnings
|
||||
|
||||
rustfmt:
|
||||
name: rustfmt
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
- name: Install minimal stable with rustfmt
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
components: rustfmt
|
||||
override: true
|
||||
|
||||
- name: rustfmt
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: fmt
|
||||
args: --all -- --check
|
||||
31
.github/workflows/doc.yml
vendored
Normal file
31
.github/workflows/doc.yml
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
# Runs documentation related jobs.
|
||||
|
||||
name: doc
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, reopened, synchronize]
|
||||
|
||||
jobs:
|
||||
docs:
|
||||
name: Verify the docs on ${{matrix.toolchain}}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
override: true
|
||||
- uses: davidB/rust-cargo-make@v1
|
||||
- name: cargo make - doc
|
||||
run: cargo make doc
|
||||
66
.github/workflows/lint.yml
vendored
Normal file
66
.github/workflows/lint.yml
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
# Runs linting related jobs.
|
||||
|
||||
name: lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, reopened, synchronize]
|
||||
|
||||
jobs:
|
||||
version:
|
||||
name: check rust version consistency
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
profile: minimal
|
||||
override: true
|
||||
- name: check rust versions
|
||||
run: ./scripts/check-rust-version.sh
|
||||
|
||||
rustfmt:
|
||||
name: rustfmt ${{matrix.toolchain}} on ${{matrix.os}}
|
||||
runs-on: ${{matrix.os}}-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [nightly]
|
||||
os: [ubuntu]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install minimal Rust with rustfmt
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
components: rustfmt
|
||||
override: true
|
||||
- uses: davidB/rust-cargo-make@v1
|
||||
- name: cargo make - format-check
|
||||
run: cargo make format-check
|
||||
|
||||
clippy:
|
||||
name: clippy ${{matrix.toolchain}} on ${{matrix.os}}
|
||||
runs-on: ${{matrix.os}}-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable]
|
||||
os: [ubuntu]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install minimal Rust with clippy
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
components: clippy
|
||||
override: true
|
||||
- uses: davidB/rust-cargo-make@v1
|
||||
- name: cargo make - clippy
|
||||
run: cargo make clippy
|
||||
32
.github/workflows/no-std.yml
vendored
Normal file
32
.github/workflows/no-std.yml
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
# Runs no-std related jobs.
|
||||
|
||||
name: no-std
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, reopened, synchronize]
|
||||
|
||||
jobs:
|
||||
no-std:
|
||||
name: build ${{matrix.toolchain}} no-std for wasm32-unknown-unknown
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable, nightly]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
override: true
|
||||
- run: rustup target add wasm32-unknown-unknown
|
||||
- uses: davidB/rust-cargo-make@v1
|
||||
- name: cargo make - build-no-std
|
||||
run: cargo make build-no-std
|
||||
34
.github/workflows/test.yml
vendored
Normal file
34
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
# Runs testing related jobs
|
||||
|
||||
name: test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, reopened, synchronize]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: test ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.features}}
|
||||
runs-on: ${{matrix.os}}-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable, nightly]
|
||||
os: [ubuntu]
|
||||
features: ["test", "test-no-default-features"]
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
override: true
|
||||
- uses: davidB/rust-cargo-make@v1
|
||||
- name: cargo make - test
|
||||
run: cargo make ${{matrix.features}}
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -2,12 +2,11 @@
|
||||
# will have compiled files and executables
|
||||
/target/
|
||||
|
||||
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
Cargo.lock
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
# Generated by cmake
|
||||
cmake-build-*
|
||||
|
||||
# VS Code
|
||||
.vscode/
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +0,0 @@
|
||||
[submodule "PQClean"]
|
||||
path = PQClean
|
||||
url = https://github.com/PQClean/PQClean.git
|
||||
|
||||
35
CHANGELOG.md
35
CHANGELOG.md
@@ -1,3 +1,37 @@
|
||||
## 0.9.0 (2024-03-24)
|
||||
|
||||
* [BREAKING] Removed deprecated re-exports from liballoc/libstd (#290).
|
||||
* [BREAKING] Refactored RpoFalcon512 signature to work with pure Rust (#285).
|
||||
* [BREAKING] Added `RngCore` as supertrait for `FeltRng` (#299).
|
||||
|
||||
# 0.8.4 (2024-03-17)
|
||||
|
||||
* Re-added unintentionally removed re-exported liballoc macros (`vec` and `format` macros).
|
||||
|
||||
# 0.8.3 (2024-03-17)
|
||||
|
||||
* Re-added unintentionally removed re-exported liballoc macros (#292).
|
||||
|
||||
# 0.8.2 (2024-03-17)
|
||||
|
||||
* Updated `no-std` approach to be in sync with winterfell v0.8.3 release (#290).
|
||||
|
||||
## 0.8.1 (2024-02-21)
|
||||
|
||||
* Fixed clippy warnings (#280)
|
||||
|
||||
## 0.8.0 (2024-02-14)
|
||||
|
||||
* Implemented the `PartialMmr` data structure (#195).
|
||||
* Implemented RPX hash function (#201).
|
||||
* Added `FeltRng` and `RpoRandomCoin` (#237).
|
||||
* Accelerated RPO/RPX hash functions using AVX512 instructions (#234).
|
||||
* Added `inner_nodes()` method to `PartialMmr` (#238).
|
||||
* Improved `PartialMmr::apply_delta()` (#242).
|
||||
* Refactored `SimpleSmt` struct (#245).
|
||||
* Replaced `TieredSmt` struct with `Smt` struct (#254, #277).
|
||||
* Updated Winterfell dependency to v0.8 (#275).
|
||||
|
||||
## 0.7.1 (2023-10-10)
|
||||
|
||||
* Fixed RPO Falcon signature build on Windows.
|
||||
@@ -12,7 +46,6 @@
|
||||
* Implemented benchmarking for `TieredSmt` (#182).
|
||||
* Added more leaf traversal methods for `MerkleStore` (#185).
|
||||
* Added SVE acceleration for RPO hash function (#189).
|
||||
* Implemented the `PartialMmr` datastructure (#195).
|
||||
|
||||
## 0.6.0 (2023-06-25)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ For example, a new change to the AIR crate might have the following message: `fe
|
||||
// ================================================================================
|
||||
```
|
||||
|
||||
- [Rustfmt](https://github.com/rust-lang/rustfmt) and [Clippy](https://github.com/rust-lang/rust-clippy) linting is included in CI pipeline. Anyways it's prefferable to run linting locally before push:
|
||||
- [Rustfmt](https://github.com/rust-lang/rustfmt) and [Clippy](https://github.com/rust-lang/rust-clippy) linting is included in CI pipeline. Anyways it's preferable to run linting locally before push:
|
||||
```
|
||||
cargo fix --allow-staged --allow-dirty --all-targets --all-features; cargo fmt; cargo clippy --workspace --all-targets --all-features -- -D warnings
|
||||
```
|
||||
|
||||
1164
Cargo.lock
generated
Normal file
1164
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
49
Cargo.toml
49
Cargo.toml
@@ -1,16 +1,16 @@
|
||||
[package]
|
||||
name = "miden-crypto"
|
||||
version = "0.7.1"
|
||||
version = "0.9.0"
|
||||
description = "Miden Cryptographic primitives"
|
||||
authors = ["miden contributors"]
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
repository = "https://github.com/0xPolygonMiden/crypto"
|
||||
documentation = "https://docs.rs/miden-crypto/0.7.1"
|
||||
documentation = "https://docs.rs/miden-crypto/0.9.0"
|
||||
categories = ["cryptography", "no-std"]
|
||||
keywords = ["miden", "crypto", "hash", "merkle"]
|
||||
edition = "2021"
|
||||
rust-version = "1.73"
|
||||
rust-version = "1.75"
|
||||
|
||||
[[bin]]
|
||||
name = "miden-crypto"
|
||||
@@ -33,26 +33,41 @@ harness = false
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
executable = ["dep:clap", "dep:rand_utils", "std"]
|
||||
serde = ["dep:serde", "serde?/alloc", "winter_math/serde"]
|
||||
std = ["blake3/std", "dep:cc", "dep:libc", "winter_crypto/std", "winter_math/std", "winter_utils/std"]
|
||||
sve = ["std"]
|
||||
executable = ["dep:clap", "dep:rand-utils", "std"]
|
||||
serde = ["dep:serde", "serde?/alloc", "winter-math/serde"]
|
||||
std = [
|
||||
"blake3/std",
|
||||
"dep:cc",
|
||||
"rand/std",
|
||||
"rand/std_rng",
|
||||
"winter-crypto/std",
|
||||
"winter-math/std",
|
||||
"winter-utils/std",
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
blake3 = { version = "1.5", default-features = false }
|
||||
clap = { version = "4.4", features = ["derive"], optional = true }
|
||||
libc = { version = "0.2", default-features = false, optional = true }
|
||||
rand_utils = { version = "0.6", package = "winter-rand-utils", optional = true }
|
||||
serde = { version = "1.0", features = [ "derive" ], default-features = false, optional = true }
|
||||
winter_crypto = { version = "0.6", package = "winter-crypto", default-features = false }
|
||||
winter_math = { version = "0.6", package = "winter-math", default-features = false }
|
||||
winter_utils = { version = "0.6", package = "winter-utils", default-features = false }
|
||||
clap = { version = "4.5", optional = true, features = ["derive"] }
|
||||
num = { version = "0.4", default-features = false, features = ["alloc", "libm"] }
|
||||
num-complex = { version = "0.4.4", default-features = false }
|
||||
rand = { version = "0.8", default-features = false }
|
||||
rand_core = { version = "0.6", default-features = false }
|
||||
rand-utils = { version = "0.8", package = "winter-rand-utils", optional = true }
|
||||
serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] }
|
||||
sha3 = { version = "0.10", default-features = false }
|
||||
winter-crypto = { version = "0.8", default-features = false }
|
||||
winter-math = { version = "0.8", default-features = false }
|
||||
winter-utils = { version = "0.8", default-features = false }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
proptest = "1.3"
|
||||
rand_utils = { version = "0.6", package = "winter-rand-utils" }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
hex = { version = "0.4", default-features = false, features = ["alloc"] }
|
||||
proptest = "1.4"
|
||||
rand_chacha = { version = "0.3", default-features = false }
|
||||
rand-utils = { version = "0.8", package = "winter-rand-utils" }
|
||||
seq-macro = { version = "0.3" }
|
||||
|
||||
[build-dependencies]
|
||||
cc = { version = "1.0", features = ["parallel"], optional = true }
|
||||
cc = { version = "1.0", optional = true, features = ["parallel"] }
|
||||
glob = "0.3"
|
||||
|
||||
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Polygon Miden
|
||||
Copyright (c) 2024 Polygon Miden
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
86
Makefile.toml
Normal file
86
Makefile.toml
Normal file
@@ -0,0 +1,86 @@
|
||||
# Cargo Makefile
|
||||
|
||||
# -- linting --------------------------------------------------------------------------------------
|
||||
[tasks.format]
|
||||
toolchain = "nightly"
|
||||
command = "cargo"
|
||||
args = ["fmt", "--all"]
|
||||
|
||||
[tasks.format-check]
|
||||
toolchain = "nightly"
|
||||
command = "cargo"
|
||||
args = ["fmt", "--all", "--", "--check"]
|
||||
|
||||
[tasks.clippy-default]
|
||||
command = "cargo"
|
||||
args = ["clippy","--workspace", "--all-targets", "--", "-D", "clippy::all", "-D", "warnings"]
|
||||
|
||||
[tasks.clippy-all-features]
|
||||
command = "cargo"
|
||||
args = ["clippy","--workspace", "--all-targets", "--all-features", "--", "-D", "clippy::all", "-D", "warnings"]
|
||||
|
||||
[tasks.clippy]
|
||||
dependencies = [
|
||||
"clippy-default",
|
||||
"clippy-all-features"
|
||||
]
|
||||
|
||||
[tasks.fix]
|
||||
description = "Runs Fix"
|
||||
command = "cargo"
|
||||
toolchain = "nightly"
|
||||
args = ["fix", "--allow-staged", "--allow-dirty", "--all-targets", "--all-features"]
|
||||
|
||||
[tasks.lint]
|
||||
description = "Runs all linting tasks (Clippy, fixing, formatting)"
|
||||
run_task = { name = ["format", "format-check", "clippy", "docs"] }
|
||||
|
||||
# --- docs ----------------------------------------------------------------------------------------
|
||||
[tasks.doc]
|
||||
env = { "RUSTDOCFLAGS" = "-D warnings" }
|
||||
command = "cargo"
|
||||
args = ["doc", "--all-features", "--keep-going", "--release"]
|
||||
|
||||
# --- testing -------------------------------------------------------------------------------------
|
||||
[tasks.test]
|
||||
description = "Run tests with default features"
|
||||
env = { "RUSTFLAGS" = "-C debug-assertions -C overflow-checks -C debuginfo=2" }
|
||||
workspace = false
|
||||
command = "cargo"
|
||||
args = ["test", "--release"]
|
||||
|
||||
[tasks.test-no-default-features]
|
||||
description = "Run tests with no-default-features"
|
||||
env = { "RUSTFLAGS" = "-C debug-assertions -C overflow-checks -C debuginfo=2" }
|
||||
workspace = false
|
||||
command = "cargo"
|
||||
args = ["test", "--release", "--no-default-features"]
|
||||
|
||||
[tasks.test-all]
|
||||
description = "Run all tests"
|
||||
workspace = false
|
||||
run_task = { name = ["test", "test-no-default-features"], parallel = true }
|
||||
|
||||
# --- building ------------------------------------------------------------------------------------
|
||||
[tasks.build]
|
||||
description = "Build in release mode"
|
||||
command = "cargo"
|
||||
args = ["build", "--release"]
|
||||
|
||||
[tasks.build-no-std]
|
||||
description = "Build using no-std"
|
||||
command = "cargo"
|
||||
args = ["build", "--release", "--no-default-features", "--target", "wasm32-unknown-unknown"]
|
||||
|
||||
[tasks.build-avx2]
|
||||
description = "Build using AVX2 acceleration"
|
||||
env = { "RUSTFLAGS" = "-C target-feature=+avx2" }
|
||||
command = "cargo"
|
||||
args = ["build", "--release"]
|
||||
|
||||
[tasks.build-sve]
|
||||
description = "Build with SVE acceleration"
|
||||
env = { "RUSTFLAGS" = "-C target-feature=+sve" }
|
||||
command = "cargo"
|
||||
args = ["build", "--release"]
|
||||
|
||||
1
PQClean
1
PQClean
Submodule PQClean deleted from c3abebf4ab
41
README.md
41
README.md
@@ -1,4 +1,11 @@
|
||||
# Miden Crypto
|
||||
|
||||
[](https://github.com/0xPolygonMiden/crypto/blob/main/LICENSE)
|
||||
[](https://github.com/0xPolygonMiden/crypto/actions/workflows/test.yml)
|
||||
[](https://github.com/0xPolygonMiden/crypto/actions/workflows/no-std.yml)
|
||||
[]()
|
||||
[](https://crates.io/crates/miden-crypto)
|
||||
|
||||
This crate contains cryptographic primitives used in Polygon Miden.
|
||||
|
||||
## Hash
|
||||
@@ -6,6 +13,7 @@ This crate contains cryptographic primitives used in Polygon Miden.
|
||||
|
||||
* [BLAKE3](https://github.com/BLAKE3-team/BLAKE3) hash function with 256-bit, 192-bit, or 160-bit output. The 192-bit and 160-bit outputs are obtained by truncating the 256-bit output of the standard BLAKE3.
|
||||
* [RPO](https://eprint.iacr.org/2022/1577) hash function with 256-bit output. This hash function is an algebraic hash function suitable for recursive STARKs.
|
||||
* [RPX](https://eprint.iacr.org/2023/1045) hash function with 256-bit output. Similar to RPO, this hash function is suitable for recursive STARKs but it is about 2x faster as compared to RPO.
|
||||
|
||||
For performance benchmarks of these hash functions and their comparison to other popular hash functions please see [here](./benches/).
|
||||
|
||||
@@ -16,18 +24,25 @@ For performance benchmarks of these hash functions and their comparison to other
|
||||
* `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64.
|
||||
* `Mmr`: a Merkle mountain range structure designed to function as an append-only log.
|
||||
* `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64.
|
||||
* `PartialMmr`: a partial view of a Merkle mountain range structure.
|
||||
* `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values.
|
||||
* `TieredSmt`: a Sparse Merkle tree (with compaction), mapping 4-element keys to 4-element values.
|
||||
* `Smt`: a Sparse Merkle tree (with compaction at depth 64), mapping 4-element keys to 4-element values.
|
||||
|
||||
The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state.
|
||||
|
||||
## Signatures
|
||||
[DAS module](./src/dsa) provides a set of digital signature schemes supported by default in the Miden VM. Currently, these schemes are:
|
||||
[DSA module](./src/dsa) provides a set of digital signature schemes supported by default in the Miden VM. Currently, these schemes are:
|
||||
|
||||
* `RPO Falcon512`: a variant of the [Falcon](https://falcon-sign.info/) signature scheme. This variant differs from the standard in that instead of using SHAKE256 hash function in the *hash-to-point* algorithm we use RPO256. This makes the signature more efficient to verify in Miden VM.
|
||||
|
||||
For the above signatures, key generation and signing is available only in the `std` context (see [crate features](#crate-features) below), while signature verification is available in `no_std` context as well.
|
||||
For the above signatures, key generation, signing, and signature verification are available for both `std` and `no_std` contexts (see [crate features](#crate-features) below). However, in `no_std` context, the user is responsible for supplying the key generation and signing procedures with a random number generator.
|
||||
|
||||
## Pseudo-Random Element Generator
|
||||
[Pseudo random element generator module](./src/rand/) provides a set of traits and data structures that facilitate generating pseudo-random elements in the context of Miden VM and Miden rollup. The module currently includes:
|
||||
|
||||
* `FeltRng`: a trait for generating random field elements and random 4 field elements.
|
||||
* `RpoRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait.
|
||||
|
||||
## Crate features
|
||||
This crate can be compiled with the following features:
|
||||
|
||||
@@ -38,23 +53,29 @@ Both of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/
|
||||
|
||||
To compile with `no_std`, disable default features via `--no-default-features` flag.
|
||||
|
||||
### SVE acceleration
|
||||
On platforms with [SVE](https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)) support, RPO hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` feature enabled. This feature has an effect only if the platform exposes `target-feature=sve` flag. On some platforms (e.g., Graviton 3), for this flag to be set, the compilation must be done in "native" mode. For example, to enable SVE acceleration on Graviton 3, we can execute the following:
|
||||
### AVX2 acceleration
|
||||
On platforms with [AVX2](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable AVX2 acceleration, the code needs to be compiled with the `avx2` target feature enabled. For example:
|
||||
```shell
|
||||
RUSTFLAGS="-C target-cpu=native" cargo build --release --features sve
|
||||
cargo make build-avx2
|
||||
```
|
||||
|
||||
### SVE acceleration
|
||||
On platforms with [SVE](https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` target feature enabled. For example:
|
||||
```shell
|
||||
cargo make build-sve
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
You can use cargo defaults to test the library:
|
||||
The best way to test the library is using our `Makefile.toml` and [cargo-make](https://github.com/sagiegurari/cargo-make), this will enable you to use our pre-defined optimized testing commands:
|
||||
|
||||
```shell
|
||||
cargo test
|
||||
cargo make test-all
|
||||
```
|
||||
|
||||
However, some of the functions are heavy and might take a while for the tests to complete. In order to test in release mode, we have to replicate the test conditions of the development mode so all debug assertions can be verified.
|
||||
For example, some of the functions are heavy and might take a while for the tests to complete if using simply `cargo test`. In order to test in release and optimized mode, we have to replicate the test conditions of the development mode so all debug assertions can be verified.
|
||||
|
||||
We do that by enabling some special [flags](https://doc.rust-lang.org/cargo/reference/profiles.html) for the compilation.
|
||||
We do that by enabling some special [flags](https://doc.rust-lang.org/cargo/reference/profiles.html) for the compilation (which we have set as a default in our [Makefile.toml](Makefile.toml)):
|
||||
|
||||
```shell
|
||||
RUSTFLAGS="-C debug-assertions -C overflow-checks -C debuginfo=2" cargo test --release
|
||||
|
||||
@@ -6,6 +6,7 @@ In the Miden VM, we make use of different hash functions. Some of these are "tra
|
||||
* **Poseidon** as specified [here](https://eprint.iacr.org/2019/458.pdf) and implemented [here](https://github.com/mir-protocol/plonky2/blob/806b88d7d6e69a30dc0b4775f7ba275c45e8b63b/plonky2/src/hash/poseidon_goldilocks.rs) (but in pure Rust, without vectorized instructions).
|
||||
* **Rescue Prime (RP)** as specified [here](https://eprint.iacr.org/2020/1143) and implemented [here](https://github.com/novifinancial/winterfell/blob/46dce1adf0/crypto/src/hash/rescue/rp64_256/mod.rs).
|
||||
* **Rescue Prime Optimized (RPO)** as specified [here](https://eprint.iacr.org/2022/1577) and implemented in this crate.
|
||||
* **Rescue Prime Extended (RPX)** a variant of the [xHash](https://eprint.iacr.org/2023/1045) hash function as implemented in this crate.
|
||||
|
||||
## Comparison and Instructions
|
||||
|
||||
@@ -15,28 +16,31 @@ The second scenario is that of sequential hashing where we take a sequence of le
|
||||
|
||||
#### Scenario 1: 2-to-1 hashing `h(a,b)`
|
||||
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 |
|
||||
| ------------------- | ------ | --------| --------- | --------- | ------- |
|
||||
| Apple M1 Pro | 80 ns | 245 ns | 1.5 us | 9.1 us | 5.4 us |
|
||||
| Apple M2 | 76 ns | 233 ns | 1.3 us | 7.9 us | 5.0 us |
|
||||
| Amazon Graviton 3 | 108 ns | | | | 5.3 us |
|
||||
| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 us | 9.1 us | 5.5 us |
|
||||
| Intel Core i5-8279U | 80 ns | | | | 8.7 us |
|
||||
| Intel Xeon 8375C | 67 ns | | | | 8.2 us |
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
|
||||
| ------------------- | ------ | ------- | --------- | --------- | ------- | ------- |
|
||||
| Apple M1 Pro | 76 ns | 245 ns | 1.5 µs | 9.1 µs | 5.2 µs | 2.7 µs |
|
||||
| Apple M2 Max | 71 ns | 233 ns | 1.3 µs | 7.9 µs | 4.6 µs | 2.4 µs |
|
||||
| Amazon Graviton 3 | 108 ns | | | | 5.3 µs | 3.1 µs |
|
||||
| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 µs | 9.1 µs | 5.5 µs | |
|
||||
| AMD EPYC 9R14 | 83 ns | | | | 4.3 µs | 2.4 µs |
|
||||
| Intel Core i5-8279U | 68 ns | 536 ns | 2.0 µs | 13.6 µs | 8.5 µs | 4.4 µs |
|
||||
| Intel Xeon 8375C | 67 ns | | | | 8.2 µs | |
|
||||
|
||||
#### Scenario 2: Sequential hashing of 100 elements `h([a_0,...,a_99])`
|
||||
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 |
|
||||
| ------------------- | -------| ------- | --------- | --------- | ------- |
|
||||
| Apple M1 Pro | 1.0 us | 1.5 us | 19.4 us | 118 us | 70 us |
|
||||
| Apple M2 | 1.0 us | 1.5 us | 17.4 us | 103 us | 65 us |
|
||||
| Amazon Graviton 3 | 1.4 us | | | | 69 us |
|
||||
| AMD Ryzen 9 5950X | 0.8 us | 1.7 us | 15.7 us | 120 us | 72 us |
|
||||
| Intel Core i5-8279U | 1.0 us | | | | 116 us |
|
||||
| Intel Xeon 8375C | 0.8 ns | | | | 110 us |
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
|
||||
| ------------------- | -------| ------- | --------- | --------- | ------- | ------- |
|
||||
| Apple M1 Pro | 1.0 µs | 1.5 µs | 19.4 µs | 118 µs | 69 µs | 35 µs |
|
||||
| Apple M2 Max | 0.9 µs | 1.5 µs | 17.4 µs | 103 µs | 60 µs | 31 µs |
|
||||
| Amazon Graviton 3 | 1.4 µs | | | | 69 µs | 41 µs |
|
||||
| AMD Ryzen 9 5950X | 0.8 µs | 1.7 µs | 15.7 µs | 120 µs | 72 µs | |
|
||||
| AMD EPYC 9R14 | 0.9 µs | | | | 56 µs | 32 µs |
|
||||
| Intel Core i5-8279U | 0.9 µs | | | | 107 µs | 56 µs |
|
||||
| Intel Xeon 8375C | 0.8 µs | | | | 110 µs | |
|
||||
|
||||
Notes:
|
||||
- On Graviton 3, RPO256 is run with SVE acceleration enabled.
|
||||
- On Graviton 3, RPO256 and RPX256 are run with SVE acceleration enabled.
|
||||
- On AMD EPYC 9R14, RPO256 and RPX256 are run with AVX2 acceleration enabled.
|
||||
|
||||
### Instructions
|
||||
Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks for RPO and BLAKE3, clone the current repository, and from the root directory of the repo run the following:
|
||||
|
||||
@@ -3,6 +3,7 @@ use miden_crypto::{
|
||||
hash::{
|
||||
blake::Blake3_256,
|
||||
rpo::{Rpo256, RpoDigest},
|
||||
rpx::{Rpx256, RpxDigest},
|
||||
},
|
||||
Felt,
|
||||
};
|
||||
@@ -31,7 +32,6 @@ fn rpo256_2to1(c: &mut Criterion) {
|
||||
|
||||
fn rpo256_sequential(c: &mut Criterion) {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.into_iter()
|
||||
.map(Felt::new)
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
@@ -44,7 +44,6 @@ fn rpo256_sequential(c: &mut Criterion) {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.into_iter()
|
||||
.map(|_| Felt::new(rand_value()))
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
@@ -57,6 +56,52 @@ fn rpo256_sequential(c: &mut Criterion) {
|
||||
});
|
||||
}
|
||||
|
||||
fn rpx256_2to1(c: &mut Criterion) {
|
||||
let v: [RpxDigest; 2] = [Rpx256::hash(&[1_u8]), Rpx256::hash(&[2_u8])];
|
||||
c.bench_function("RPX256 2-to-1 hashing (cached)", |bench| {
|
||||
bench.iter(|| Rpx256::merge(black_box(&v)))
|
||||
});
|
||||
|
||||
c.bench_function("RPX256 2-to-1 hashing (random)", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
[
|
||||
Rpx256::hash(&rand_value::<u64>().to_le_bytes()),
|
||||
Rpx256::hash(&rand_value::<u64>().to_le_bytes()),
|
||||
]
|
||||
},
|
||||
|state| Rpx256::merge(&state),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn rpx256_sequential(c: &mut Criterion) {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.map(Felt::new)
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
.expect("should not fail");
|
||||
c.bench_function("RPX256 sequential hashing (cached)", |bench| {
|
||||
bench.iter(|| Rpx256::hash_elements(black_box(&v)))
|
||||
});
|
||||
|
||||
c.bench_function("RPX256 sequential hashing (random)", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.map(|_| Felt::new(rand_value()))
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
.expect("should not fail");
|
||||
v
|
||||
},
|
||||
|state| Rpx256::hash_elements(&state),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn blake3_2to1(c: &mut Criterion) {
|
||||
let v: [<Blake3_256 as Hasher>::Digest; 2] =
|
||||
[Blake3_256::hash(&[1_u8]), Blake3_256::hash(&[2_u8])];
|
||||
@@ -80,7 +125,6 @@ fn blake3_2to1(c: &mut Criterion) {
|
||||
|
||||
fn blake3_sequential(c: &mut Criterion) {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.into_iter()
|
||||
.map(Felt::new)
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
@@ -93,7 +137,6 @@ fn blake3_sequential(c: &mut Criterion) {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.into_iter()
|
||||
.map(|_| Felt::new(rand_value()))
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
@@ -106,5 +149,13 @@ fn blake3_sequential(c: &mut Criterion) {
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(hash_group, rpo256_2to1, rpo256_sequential, blake3_2to1, blake3_sequential);
|
||||
criterion_group!(
|
||||
hash_group,
|
||||
rpx256_2to1,
|
||||
rpx256_sequential,
|
||||
rpo256_2to1,
|
||||
rpo256_sequential,
|
||||
blake3_2to1,
|
||||
blake3_sequential
|
||||
);
|
||||
criterion_main!(hash_group);
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
use core::mem::swap;
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use miden_crypto::{merkle::SimpleSmt, Felt, Word};
|
||||
use miden_crypto::{
|
||||
merkle::{LeafIndex, SimpleSmt},
|
||||
Felt, Word,
|
||||
};
|
||||
use rand_utils::prng_array;
|
||||
use seq_macro::seq;
|
||||
|
||||
fn smt_rpo(c: &mut Criterion) {
|
||||
// setup trees
|
||||
|
||||
let mut seed = [0u8; 32];
|
||||
let mut trees = vec![];
|
||||
let leaf = generate_word(&mut seed);
|
||||
|
||||
for depth in 14..=20 {
|
||||
let leaves = ((1 << depth) - 1) as u64;
|
||||
seq!(DEPTH in 14..=20 {
|
||||
let leaves = ((1 << DEPTH) - 1) as u64;
|
||||
for count in [1, leaves / 2, leaves] {
|
||||
let entries: Vec<_> = (0..count)
|
||||
.map(|i| {
|
||||
@@ -18,50 +23,45 @@ fn smt_rpo(c: &mut Criterion) {
|
||||
(i, word)
|
||||
})
|
||||
.collect();
|
||||
let tree = SimpleSmt::with_leaves(depth, entries).unwrap();
|
||||
trees.push((tree, count));
|
||||
let mut tree = SimpleSmt::<DEPTH>::with_leaves(entries).unwrap();
|
||||
|
||||
// benchmark 1
|
||||
let mut insert = c.benchmark_group("smt update_leaf".to_string());
|
||||
{
|
||||
let depth = DEPTH;
|
||||
let key = count >> 2;
|
||||
insert.bench_with_input(
|
||||
format!("simple smt(depth:{depth},count:{count})"),
|
||||
&(key, leaf),
|
||||
|b, (key, leaf)| {
|
||||
b.iter(|| {
|
||||
tree.insert(black_box(LeafIndex::<DEPTH>::new(*key).unwrap()), black_box(*leaf));
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
}
|
||||
insert.finish();
|
||||
|
||||
// benchmark 2
|
||||
let mut path = c.benchmark_group("smt get_leaf_path".to_string());
|
||||
{
|
||||
let depth = DEPTH;
|
||||
let key = count >> 2;
|
||||
path.bench_with_input(
|
||||
format!("simple smt(depth:{depth},count:{count})"),
|
||||
&key,
|
||||
|b, key| {
|
||||
b.iter(|| {
|
||||
tree.open(black_box(&LeafIndex::<DEPTH>::new(*key).unwrap()));
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
}
|
||||
path.finish();
|
||||
}
|
||||
}
|
||||
|
||||
let leaf = generate_word(&mut seed);
|
||||
|
||||
// benchmarks
|
||||
|
||||
let mut insert = c.benchmark_group(format!("smt update_leaf"));
|
||||
|
||||
for (tree, count) in trees.iter_mut() {
|
||||
let depth = tree.depth();
|
||||
let key = *count >> 2;
|
||||
insert.bench_with_input(
|
||||
format!("simple smt(depth:{depth},count:{count})"),
|
||||
&(key, leaf),
|
||||
|b, (key, leaf)| {
|
||||
b.iter(|| {
|
||||
tree.update_leaf(black_box(*key), black_box(*leaf)).unwrap();
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
insert.finish();
|
||||
|
||||
let mut path = c.benchmark_group(format!("smt get_leaf_path"));
|
||||
|
||||
for (tree, count) in trees.iter_mut() {
|
||||
let depth = tree.depth();
|
||||
let key = *count >> 2;
|
||||
path.bench_with_input(
|
||||
format!("simple smt(depth:{depth},count:{count})"),
|
||||
&key,
|
||||
|b, key| {
|
||||
b.iter(|| {
|
||||
tree.get_leaf_path(black_box(*key)).unwrap();
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
path.finish();
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(smt_group, smt_rpo);
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
|
||||
use miden_crypto::merkle::{DefaultMerkleStore as MerkleStore, MerkleTree, NodeIndex, SimpleSmt};
|
||||
use miden_crypto::Word;
|
||||
use miden_crypto::{hash::rpo::RpoDigest, Felt};
|
||||
use miden_crypto::{
|
||||
hash::rpo::RpoDigest,
|
||||
merkle::{
|
||||
DefaultMerkleStore as MerkleStore, LeafIndex, MerkleTree, NodeIndex, SimpleSmt,
|
||||
SMT_MAX_DEPTH,
|
||||
},
|
||||
Felt, Word,
|
||||
};
|
||||
use rand_utils::{rand_array, rand_value};
|
||||
|
||||
/// Since MerkleTree can only be created when a power-of-two number of elements is used, the sample
|
||||
@@ -15,7 +20,7 @@ fn random_rpo_digest() -> RpoDigest {
|
||||
|
||||
/// Generates a random `Word`.
|
||||
fn random_word() -> Word {
|
||||
rand_array::<Felt, 4>().into()
|
||||
rand_array::<Felt, 4>()
|
||||
}
|
||||
|
||||
/// Generates an index at the specified depth in `0..range`.
|
||||
@@ -28,26 +33,26 @@ fn random_index(range: u64, depth: u8) -> NodeIndex {
|
||||
fn get_empty_leaf_simplesmt(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("get_empty_leaf_simplesmt");
|
||||
|
||||
let depth = SimpleSmt::MAX_DEPTH;
|
||||
const DEPTH: u8 = SMT_MAX_DEPTH;
|
||||
let size = u64::MAX;
|
||||
|
||||
// both SMT and the store are pre-populated with empty hashes, accessing these values is what is
|
||||
// being benchmarked here, so no values are inserted into the backends
|
||||
let smt = SimpleSmt::new(depth).unwrap();
|
||||
let smt = SimpleSmt::<DEPTH>::new().unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let root = smt.root();
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", depth), |b| {
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", DEPTH), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size, depth),
|
||||
|| random_index(size, DEPTH),
|
||||
|index| black_box(smt.get_node(index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", depth), |b| {
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", DEPTH), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size, depth),
|
||||
|| random_index(size, DEPTH),
|
||||
|index| black_box(store.get_node(root, index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
@@ -104,15 +109,14 @@ fn get_leaf_simplesmt(c: &mut Criterion) {
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
|
||||
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let depth = smt.depth();
|
||||
let root = smt.root();
|
||||
let size_u64 = size as u64;
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, depth),
|
||||
|| random_index(size_u64, SMT_MAX_DEPTH),
|
||||
|index| black_box(smt.get_node(index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
@@ -120,7 +124,7 @@ fn get_leaf_simplesmt(c: &mut Criterion) {
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, depth),
|
||||
|| random_index(size_u64, SMT_MAX_DEPTH),
|
||||
|index| black_box(store.get_node(root, index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
@@ -132,18 +136,18 @@ fn get_leaf_simplesmt(c: &mut Criterion) {
|
||||
fn get_node_of_empty_simplesmt(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("get_node_of_empty_simplesmt");
|
||||
|
||||
let depth = SimpleSmt::MAX_DEPTH;
|
||||
const DEPTH: u8 = SMT_MAX_DEPTH;
|
||||
|
||||
// both SMT and the store are pre-populated with the empty hashes, accessing the internal nodes
|
||||
// of these values is what is being benchmarked here, so no values are inserted into the
|
||||
// backends.
|
||||
let smt = SimpleSmt::new(depth).unwrap();
|
||||
let smt = SimpleSmt::<DEPTH>::new().unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let root = smt.root();
|
||||
let half_depth = depth / 2;
|
||||
let half_depth = DEPTH / 2;
|
||||
let half_size = 2_u64.pow(half_depth as u32);
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", depth), |b| {
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", DEPTH), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(half_size, half_depth),
|
||||
|index| black_box(smt.get_node(index)),
|
||||
@@ -151,7 +155,7 @@ fn get_node_of_empty_simplesmt(c: &mut Criterion) {
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", depth), |b| {
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", DEPTH), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(half_size, half_depth),
|
||||
|index| black_box(store.get_node(root, index)),
|
||||
@@ -212,10 +216,10 @@ fn get_node_simplesmt(c: &mut Criterion) {
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
|
||||
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let root = smt.root();
|
||||
let half_depth = smt.depth() / 2;
|
||||
let half_depth = SMT_MAX_DEPTH / 2;
|
||||
let half_size = 2_u64.pow(half_depth as u32);
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
|
||||
@@ -286,23 +290,24 @@ fn get_leaf_path_simplesmt(c: &mut Criterion) {
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
|
||||
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let depth = smt.depth();
|
||||
let root = smt.root();
|
||||
let size_u64 = size as u64;
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, depth),
|
||||
|index| black_box(smt.get_path(index)),
|
||||
|| random_index(size_u64, SMT_MAX_DEPTH),
|
||||
|index| {
|
||||
black_box(smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(index.value()).unwrap()))
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, depth),
|
||||
|| random_index(size_u64, SMT_MAX_DEPTH),
|
||||
|index| black_box(store.get_path(root, index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
@@ -352,7 +357,7 @@ fn new(c: &mut Criterion) {
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>()
|
||||
},
|
||||
|l| black_box(SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, l)),
|
||||
|l| black_box(SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(l)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
@@ -367,7 +372,7 @@ fn new(c: &mut Criterion) {
|
||||
.collect::<Vec<(u64, Word)>>()
|
||||
},
|
||||
|l| {
|
||||
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, l).unwrap();
|
||||
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(l).unwrap();
|
||||
black_box(MerkleStore::from(&smt));
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
@@ -433,16 +438,17 @@ fn update_leaf_simplesmt(c: &mut Criterion) {
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let mut smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
|
||||
let mut smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
|
||||
let mut store = MerkleStore::from(&smt);
|
||||
let depth = smt.depth();
|
||||
let root = smt.root();
|
||||
let size_u64 = size as u64;
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSMT", size), |b| {
|
||||
b.iter_batched(
|
||||
|| (rand_value::<u64>() % size_u64, random_word()),
|
||||
|(index, value)| black_box(smt.update_leaf(index, value)),
|
||||
|(index, value)| {
|
||||
black_box(smt.insert(LeafIndex::<SMT_MAX_DEPTH>::new(index).unwrap(), value))
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
@@ -450,7 +456,7 @@ fn update_leaf_simplesmt(c: &mut Criterion) {
|
||||
let mut store_root = root;
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
|
||||
b.iter_batched(
|
||||
|| (random_index(size_u64, depth), random_word()),
|
||||
|| (random_index(size_u64, SMT_MAX_DEPTH), random_word()),
|
||||
|(index, value)| {
|
||||
// The MerkleTree automatically updates its internal root, the Store maintains
|
||||
// the old root and adds the new one. Here we update the root to have a fair
|
||||
|
||||
35
build.rs
35
build.rs
@@ -1,40 +1,9 @@
|
||||
fn main() {
|
||||
#[cfg(feature = "std")]
|
||||
compile_rpo_falcon();
|
||||
|
||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
#[cfg(target_feature = "sve")]
|
||||
compile_arch_arm64_sve();
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
fn compile_rpo_falcon() {
|
||||
use std::path::PathBuf;
|
||||
|
||||
const RPO_FALCON_PATH: &str = "src/dsa/rpo_falcon512/falcon_c";
|
||||
|
||||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/falcon.h");
|
||||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/falcon.c");
|
||||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/rpo.h");
|
||||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/rpo.c");
|
||||
|
||||
let target_dir: PathBuf = ["PQClean", "crypto_sign", "falcon-512", "clean"].iter().collect();
|
||||
let common_dir: PathBuf = ["PQClean", "common"].iter().collect();
|
||||
|
||||
let scheme_files = glob::glob(target_dir.join("*.c").to_str().unwrap()).unwrap();
|
||||
let common_files = glob::glob(common_dir.join("*.c").to_str().unwrap()).unwrap();
|
||||
|
||||
cc::Build::new()
|
||||
.include(&common_dir)
|
||||
.include(target_dir)
|
||||
.files(scheme_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned()))
|
||||
.files(common_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned()))
|
||||
.file(format!("{RPO_FALCON_PATH}/falcon.c"))
|
||||
.file(format!("{RPO_FALCON_PATH}/rpo.c"))
|
||||
.flag("-O3")
|
||||
.compile("rpo_falcon512");
|
||||
}
|
||||
|
||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
#[cfg(target_feature = "sve")]
|
||||
fn compile_arch_arm64_sve() {
|
||||
const RPO_SVE_PATH: &str = "arch/arm64-sve/rpo";
|
||||
|
||||
|
||||
1
rust-toolchain
Normal file
1
rust-toolchain
Normal file
@@ -0,0 +1 @@
|
||||
1.75
|
||||
13
scripts/check-rust-version.sh
Executable file
13
scripts/check-rust-version.sh
Executable file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Check rust-toolchain file
|
||||
TOOLCHAIN_VERSION=$(cat rust-toolchain)
|
||||
|
||||
# Check workspace Cargo.toml file
|
||||
CARGO_VERSION=$(cat Cargo.toml | grep "rust-version" | cut -d '"' -f 2)
|
||||
if [ "$CARGO_VERSION" != "$TOOLCHAIN_VERSION" ]; then
|
||||
echo "Mismatch in Cargo.toml: Expected $TOOLCHAIN_VERSION, found $CARGO_VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Rust versions match ✅"
|
||||
@@ -1,55 +0,0 @@
|
||||
use super::{LOG_N, MODULUS, PK_LEN};
|
||||
use core::fmt;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum FalconError {
|
||||
KeyGenerationFailed,
|
||||
PubKeyDecodingExtraData,
|
||||
PubKeyDecodingInvalidCoefficient(u32),
|
||||
PubKeyDecodingInvalidLength(usize),
|
||||
PubKeyDecodingInvalidTag(u8),
|
||||
SigDecodingTooBigHighBits(u32),
|
||||
SigDecodingInvalidRemainder,
|
||||
SigDecodingNonZeroUnusedBitsLastByte,
|
||||
SigDecodingMinusZero,
|
||||
SigDecodingIncorrectEncodingAlgorithm,
|
||||
SigDecodingNotSupportedDegree(u8),
|
||||
SigGenerationFailed,
|
||||
}
|
||||
|
||||
impl fmt::Display for FalconError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use FalconError::*;
|
||||
match self {
|
||||
KeyGenerationFailed => write!(f, "Failed to generate a private-public key pair"),
|
||||
PubKeyDecodingExtraData => {
|
||||
write!(f, "Failed to decode public key: input not fully consumed")
|
||||
}
|
||||
PubKeyDecodingInvalidCoefficient(val) => {
|
||||
write!(f, "Failed to decode public key: coefficient {val} is greater than or equal to the field modulus {MODULUS}")
|
||||
}
|
||||
PubKeyDecodingInvalidLength(len) => {
|
||||
write!(f, "Failed to decode public key: expected {PK_LEN} bytes but received {len}")
|
||||
}
|
||||
PubKeyDecodingInvalidTag(byte) => {
|
||||
write!(f, "Failed to decode public key: expected the first byte to be {LOG_N} but was {byte}")
|
||||
}
|
||||
SigDecodingTooBigHighBits(m) => {
|
||||
write!(f, "Failed to decode signature: high bits {m} exceed 2048")
|
||||
}
|
||||
SigDecodingInvalidRemainder => {
|
||||
write!(f, "Failed to decode signature: incorrect remaining data")
|
||||
}
|
||||
SigDecodingNonZeroUnusedBitsLastByte => {
|
||||
write!(f, "Failed to decode signature: Non-zero unused bits in the last byte")
|
||||
}
|
||||
SigDecodingMinusZero => write!(f, "Failed to decode signature: -0 is forbidden"),
|
||||
SigDecodingIncorrectEncodingAlgorithm => write!(f, "Failed to decode signature: not supported encoding algorithm"),
|
||||
SigDecodingNotSupportedDegree(log_n) => write!(f, "Failed to decode signature: only supported irreducible polynomial degree is 512, 2^{log_n} was provided"),
|
||||
SigGenerationFailed => write!(f, "Failed to generate a signature"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for FalconError {}
|
||||
@@ -1,402 +0,0 @@
|
||||
/*
|
||||
* Wrapper for implementing the PQClean API.
|
||||
*/
|
||||
|
||||
#include <string.h>
|
||||
#include "randombytes.h"
|
||||
#include "falcon.h"
|
||||
#include "inner.h"
|
||||
#include "rpo.h"
|
||||
|
||||
#define NONCELEN 40
|
||||
|
||||
/*
|
||||
* Encoding formats (nnnn = log of degree, 9 for Falcon-512, 10 for Falcon-1024)
|
||||
*
|
||||
* private key:
|
||||
* header byte: 0101nnnn
|
||||
* private f (6 or 5 bits by element, depending on degree)
|
||||
* private g (6 or 5 bits by element, depending on degree)
|
||||
* private F (8 bits by element)
|
||||
*
|
||||
* public key:
|
||||
* header byte: 0000nnnn
|
||||
* public h (14 bits by element)
|
||||
*
|
||||
* signature:
|
||||
* header byte: 0011nnnn
|
||||
* nonce 40 bytes
|
||||
* value (12 bits by element)
|
||||
*
|
||||
* message + signature:
|
||||
* signature length (2 bytes, big-endian)
|
||||
* nonce 40 bytes
|
||||
* message
|
||||
* header byte: 0010nnnn
|
||||
* value (12 bits by element)
|
||||
* (signature length is 1+len(value), not counting the nonce)
|
||||
*/
|
||||
|
||||
/* see falcon.h */
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
uint8_t *pk,
|
||||
uint8_t *sk,
|
||||
unsigned char *seed
|
||||
) {
|
||||
union
|
||||
{
|
||||
uint8_t b[FALCON_KEYGEN_TEMP_9];
|
||||
uint64_t dummy_u64;
|
||||
fpr dummy_fpr;
|
||||
} tmp;
|
||||
int8_t f[512], g[512], F[512];
|
||||
uint16_t h[512];
|
||||
inner_shake256_context rng;
|
||||
size_t u, v;
|
||||
|
||||
/*
|
||||
* Generate key pair.
|
||||
*/
|
||||
inner_shake256_init(&rng);
|
||||
inner_shake256_inject(&rng, seed, sizeof seed);
|
||||
inner_shake256_flip(&rng);
|
||||
PQCLEAN_FALCON512_CLEAN_keygen(&rng, f, g, F, NULL, h, 9, tmp.b);
|
||||
inner_shake256_ctx_release(&rng);
|
||||
|
||||
/*
|
||||
* Encode private key.
|
||||
*/
|
||||
sk[0] = 0x50 + 9;
|
||||
u = 1;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode(
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u,
|
||||
f, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9]);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode(
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u,
|
||||
g, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9]);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode(
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u,
|
||||
F, 9, PQCLEAN_FALCON512_CLEAN_max_FG_bits[9]);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
if (u != PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
/*
|
||||
* Encode public key.
|
||||
*/
|
||||
pk[0] = 0x00 + 9;
|
||||
v = PQCLEAN_FALCON512_CLEAN_modq_encode(
|
||||
pk + 1, PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1,
|
||||
h, 9);
|
||||
if (v != PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
||||
uint8_t *pk,
|
||||
uint8_t *sk
|
||||
) {
|
||||
unsigned char seed[48];
|
||||
|
||||
/*
|
||||
* Generate a random seed.
|
||||
*/
|
||||
randombytes(seed, sizeof seed);
|
||||
|
||||
return PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(pk, sk, seed);
|
||||
}
|
||||
|
||||
/*
|
||||
* Compute the signature. nonce[] receives the nonce and must have length
|
||||
* NONCELEN bytes. sigbuf[] receives the signature value (without nonce
|
||||
* or header byte), with *sigbuflen providing the maximum value length and
|
||||
* receiving the actual value length.
|
||||
*
|
||||
* If a signature could be computed but not encoded because it would
|
||||
* exceed the output buffer size, then a new signature is computed. If
|
||||
* the provided buffer size is too low, this could loop indefinitely, so
|
||||
* the caller must provide a size that can accommodate signatures with a
|
||||
* large enough probability.
|
||||
*
|
||||
* Return value: 0 on success, -1 on error.
|
||||
*/
|
||||
static int do_sign(
|
||||
uint8_t *nonce,
|
||||
uint8_t *sigbuf,
|
||||
size_t *sigbuflen,
|
||||
const uint8_t *m,
|
||||
size_t mlen,
|
||||
const uint8_t *sk
|
||||
) {
|
||||
union
|
||||
{
|
||||
uint8_t b[72 * 512];
|
||||
uint64_t dummy_u64;
|
||||
fpr dummy_fpr;
|
||||
} tmp;
|
||||
int8_t f[512], g[512], F[512], G[512];
|
||||
struct
|
||||
{
|
||||
int16_t sig[512];
|
||||
uint16_t hm[512];
|
||||
} r;
|
||||
unsigned char seed[48];
|
||||
inner_shake256_context sc;
|
||||
rpo128_context rc;
|
||||
size_t u, v;
|
||||
|
||||
/*
|
||||
* Decode the private key.
|
||||
*/
|
||||
if (sk[0] != 0x50 + 9)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u = 1;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode(
|
||||
f, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9],
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode(
|
||||
g, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9],
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode(
|
||||
F, 9, PQCLEAN_FALCON512_CLEAN_max_FG_bits[9],
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
if (u != PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
if (!PQCLEAN_FALCON512_CLEAN_complete_private(G, f, g, F, 9, tmp.b))
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
/*
|
||||
* Create a random nonce (40 bytes).
|
||||
*/
|
||||
randombytes(nonce, NONCELEN);
|
||||
|
||||
/* ==== Start: Deviation from the reference implementation ================================= */
|
||||
|
||||
// Transform the nonce into 8 chunks each of size 5 bytes. We do this in order to be sure that
|
||||
// the conversion to field elements succeeds
|
||||
uint8_t buffer[64];
|
||||
memset(buffer, 0, 64);
|
||||
for (size_t i = 0; i < 8; i++)
|
||||
{
|
||||
buffer[8 * i] = nonce[5 * i];
|
||||
buffer[8 * i + 1] = nonce[5 * i + 1];
|
||||
buffer[8 * i + 2] = nonce[5 * i + 2];
|
||||
buffer[8 * i + 3] = nonce[5 * i + 3];
|
||||
buffer[8 * i + 4] = nonce[5 * i + 4];
|
||||
}
|
||||
|
||||
/*
|
||||
* Hash message nonce + message into a vector.
|
||||
*/
|
||||
rpo128_init(&rc);
|
||||
rpo128_absorb(&rc, buffer, NONCELEN + 24);
|
||||
rpo128_absorb(&rc, m, mlen);
|
||||
rpo128_finalize(&rc);
|
||||
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(&rc, r.hm, 9);
|
||||
rpo128_release(&rc);
|
||||
|
||||
/* ==== End: Deviation from the reference implementation =================================== */
|
||||
|
||||
/*
|
||||
* Initialize a RNG.
|
||||
*/
|
||||
randombytes(seed, sizeof seed);
|
||||
inner_shake256_init(&sc);
|
||||
inner_shake256_inject(&sc, seed, sizeof seed);
|
||||
inner_shake256_flip(&sc);
|
||||
|
||||
/*
|
||||
* Compute and return the signature. This loops until a signature
|
||||
* value is found that fits in the provided buffer.
|
||||
*/
|
||||
for (;;)
|
||||
{
|
||||
PQCLEAN_FALCON512_CLEAN_sign_dyn(r.sig, &sc, f, g, F, G, r.hm, 9, tmp.b);
|
||||
v = PQCLEAN_FALCON512_CLEAN_comp_encode(sigbuf, *sigbuflen, r.sig, 9);
|
||||
if (v != 0)
|
||||
{
|
||||
inner_shake256_ctx_release(&sc);
|
||||
*sigbuflen = v;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Verify a signature. The nonce has size NONCELEN bytes. sigbuf[]
|
||||
* (of size sigbuflen) contains the signature value, not including the
|
||||
* header byte or nonce. Return value is 0 on success, -1 on error.
|
||||
*/
|
||||
static int do_verify(
|
||||
const uint8_t *nonce,
|
||||
const uint8_t *sigbuf,
|
||||
size_t sigbuflen,
|
||||
const uint8_t *m,
|
||||
size_t mlen,
|
||||
const uint8_t *pk
|
||||
) {
|
||||
union
|
||||
{
|
||||
uint8_t b[2 * 512];
|
||||
uint64_t dummy_u64;
|
||||
fpr dummy_fpr;
|
||||
} tmp;
|
||||
uint16_t h[512], hm[512];
|
||||
int16_t sig[512];
|
||||
rpo128_context rc;
|
||||
|
||||
/*
|
||||
* Decode public key.
|
||||
*/
|
||||
if (pk[0] != 0x00 + 9)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
if (PQCLEAN_FALCON512_CLEAN_modq_decode(h, 9,
|
||||
pk + 1, PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1)
|
||||
!= PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
PQCLEAN_FALCON512_CLEAN_to_ntt_monty(h, 9);
|
||||
|
||||
/*
|
||||
* Decode signature.
|
||||
*/
|
||||
if (sigbuflen == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
if (PQCLEAN_FALCON512_CLEAN_comp_decode(sig, 9, sigbuf, sigbuflen) != sigbuflen)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
/* ==== Start: Deviation from the reference implementation ================================= */
|
||||
|
||||
/*
|
||||
* Hash nonce + message into a vector.
|
||||
*/
|
||||
|
||||
// Transform the nonce into 8 chunks each of size 5 bytes. We do this in order to be sure that
|
||||
// the conversion to field elements succeeds
|
||||
uint8_t buffer[64];
|
||||
memset(buffer, 0, 64);
|
||||
for (size_t i = 0; i < 8; i++)
|
||||
{
|
||||
buffer[8 * i] = nonce[5 * i];
|
||||
buffer[8 * i + 1] = nonce[5 * i + 1];
|
||||
buffer[8 * i + 2] = nonce[5 * i + 2];
|
||||
buffer[8 * i + 3] = nonce[5 * i + 3];
|
||||
buffer[8 * i + 4] = nonce[5 * i + 4];
|
||||
}
|
||||
|
||||
rpo128_init(&rc);
|
||||
rpo128_absorb(&rc, buffer, NONCELEN + 24);
|
||||
rpo128_absorb(&rc, m, mlen);
|
||||
rpo128_finalize(&rc);
|
||||
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(&rc, hm, 9);
|
||||
rpo128_release(&rc);
|
||||
|
||||
/* === End: Deviation from the reference implementation ==================================== */
|
||||
|
||||
/*
|
||||
* Verify signature.
|
||||
*/
|
||||
if (!PQCLEAN_FALCON512_CLEAN_verify_raw(hm, sig, h, 9, tmp.b))
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* see falcon.h */
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
uint8_t *sig,
|
||||
size_t *siglen,
|
||||
const uint8_t *m,
|
||||
size_t mlen,
|
||||
const uint8_t *sk
|
||||
) {
|
||||
/*
|
||||
* The PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES constant is used for
|
||||
* the signed message object (as produced by crypto_sign())
|
||||
* and includes a two-byte length value, so we take care here
|
||||
* to only generate signatures that are two bytes shorter than
|
||||
* the maximum. This is done to ensure that crypto_sign()
|
||||
* and crypto_sign_signature() produce the exact same signature
|
||||
* value, if used on the same message, with the same private key,
|
||||
* and using the same output from randombytes() (this is for
|
||||
* reproducibility of tests).
|
||||
*/
|
||||
size_t vlen;
|
||||
|
||||
vlen = PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES - NONCELEN - 3;
|
||||
if (do_sign(sig + 1, sig + 1 + NONCELEN, &vlen, m, mlen, sk) < 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
sig[0] = 0x30 + 9;
|
||||
*siglen = 1 + NONCELEN + vlen;
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* see falcon.h */
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
const uint8_t *sig,
|
||||
size_t siglen,
|
||||
const uint8_t *m,
|
||||
size_t mlen,
|
||||
const uint8_t *pk
|
||||
) {
|
||||
if (siglen < 1 + NONCELEN)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
if (sig[0] != 0x30 + 9)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
return do_verify(sig + 1, sig + 1 + NONCELEN, siglen - 1 - NONCELEN, m, mlen, pk);
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES 1281
|
||||
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES 897
|
||||
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES 666
|
||||
|
||||
/*
|
||||
* Generate a new key pair. Public key goes into pk[], private key in sk[].
|
||||
* Key sizes are exact (in bytes):
|
||||
* public (pk): PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES
|
||||
* private (sk): PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES
|
||||
*
|
||||
* Return value: 0 on success, -1 on error.
|
||||
*
|
||||
* Note: This implementation follows the reference implementation in PQClean
|
||||
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512
|
||||
* verbatim except for the sections that are marked otherwise.
|
||||
*/
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
||||
uint8_t *pk, uint8_t *sk);
|
||||
|
||||
/*
|
||||
* Generate a new key pair from seed. Public key goes into pk[], private key in sk[].
|
||||
* Key sizes are exact (in bytes):
|
||||
* public (pk): PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES
|
||||
* private (sk): PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES
|
||||
*
|
||||
* Return value: 0 on success, -1 on error.
|
||||
*/
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
uint8_t *pk, uint8_t *sk, unsigned char *seed);
|
||||
|
||||
/*
|
||||
* Compute a signature on a provided message (m, mlen), with a given
|
||||
* private key (sk). Signature is written in sig[], with length written
|
||||
* into *siglen. Signature length is variable; maximum signature length
|
||||
* (in bytes) is PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES.
|
||||
*
|
||||
* sig[], m[] and sk[] may overlap each other arbitrarily.
|
||||
*
|
||||
* Return value: 0 on success, -1 on error.
|
||||
*
|
||||
* Note: This implementation follows the reference implementation in PQClean
|
||||
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512
|
||||
* verbatim except for the sections that are marked otherwise.
|
||||
*/
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
uint8_t *sig, size_t *siglen,
|
||||
const uint8_t *m, size_t mlen, const uint8_t *sk);
|
||||
|
||||
/*
|
||||
* Verify a signature (sig, siglen) on a message (m, mlen) with a given
|
||||
* public key (pk).
|
||||
*
|
||||
* sig[], m[] and pk[] may overlap each other arbitrarily.
|
||||
*
|
||||
* Return value: 0 on success, -1 on error.
|
||||
*
|
||||
* Note: This implementation follows the reference implementation in PQClean
|
||||
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512
|
||||
* verbatim except for the sections that are marked otherwise.
|
||||
*/
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
const uint8_t *sig, size_t siglen,
|
||||
const uint8_t *m, size_t mlen, const uint8_t *pk);
|
||||
@@ -1,582 +0,0 @@
|
||||
/*
|
||||
* RPO implementation.
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
/* ================================================================================================
|
||||
* Modular Arithmetic
|
||||
*/
|
||||
|
||||
#define P 0xFFFFFFFF00000001
|
||||
#define M 12289
|
||||
|
||||
// From https://github.com/ncw/iprime/blob/master/mod_math_noasm.go
|
||||
static uint64_t add_mod_p(uint64_t a, uint64_t b)
|
||||
{
|
||||
a = P - a;
|
||||
uint64_t res = b - a;
|
||||
if (b < a)
|
||||
res += P;
|
||||
return res;
|
||||
}
|
||||
|
||||
static uint64_t sub_mod_p(uint64_t a, uint64_t b)
|
||||
{
|
||||
uint64_t r = a - b;
|
||||
if (a < b)
|
||||
r += P;
|
||||
return r;
|
||||
}
|
||||
|
||||
static uint64_t reduce_mod_p(uint64_t b, uint64_t a)
|
||||
{
|
||||
uint32_t d = b >> 32,
|
||||
c = b;
|
||||
if (a >= P)
|
||||
a -= P;
|
||||
a = sub_mod_p(a, c);
|
||||
a = sub_mod_p(a, d);
|
||||
a = add_mod_p(a, ((uint64_t)c) << 32);
|
||||
return a;
|
||||
}
|
||||
|
||||
static uint64_t mult_mod_p(uint64_t x, uint64_t y)
|
||||
{
|
||||
uint32_t a = x,
|
||||
b = x >> 32,
|
||||
c = y,
|
||||
d = y >> 32;
|
||||
|
||||
/* first synthesize the product using 32*32 -> 64 bit multiplies */
|
||||
x = b * (uint64_t)c; /* b*c */
|
||||
y = a * (uint64_t)d; /* a*d */
|
||||
uint64_t e = a * (uint64_t)c, /* a*c */
|
||||
f = b * (uint64_t)d, /* b*d */
|
||||
t;
|
||||
|
||||
x += y; /* b*c + a*d */
|
||||
/* carry? */
|
||||
if (x < y)
|
||||
f += 1LL << 32; /* carry into upper 32 bits - can't overflow */
|
||||
|
||||
t = x << 32;
|
||||
e += t; /* a*c + LSW(b*c + a*d) */
|
||||
/* carry? */
|
||||
if (e < t)
|
||||
f += 1; /* carry into upper 64 bits - can't overflow*/
|
||||
t = x >> 32;
|
||||
f += t; /* b*d + MSW(b*c + a*d) */
|
||||
/* can't overflow */
|
||||
|
||||
/* now reduce: (b*d + MSW(b*c + a*d), a*c + LSW(b*c + a*d)) */
|
||||
return reduce_mod_p(f, e);
|
||||
}
|
||||
|
||||
/* ================================================================================================
|
||||
* RPO128 Permutation
|
||||
*/
|
||||
|
||||
#define STATE_WIDTH 12
|
||||
#define NUM_ROUNDS 7
|
||||
|
||||
/*
|
||||
* MDS matrix
|
||||
*/
|
||||
static const uint64_t MDS[12][12] = {
|
||||
{ 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8 },
|
||||
{ 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21 },
|
||||
{ 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22 },
|
||||
{ 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6 },
|
||||
{ 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7 },
|
||||
{ 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9 },
|
||||
{ 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10 },
|
||||
{ 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13 },
|
||||
{ 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26 },
|
||||
{ 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8 },
|
||||
{ 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23 },
|
||||
{ 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7 },
|
||||
};
|
||||
|
||||
/*
|
||||
* Round constants.
|
||||
*/
|
||||
static const uint64_t ARK1[7][12] = {
|
||||
{
|
||||
5789762306288267392ULL,
|
||||
6522564764413701783ULL,
|
||||
17809893479458208203ULL,
|
||||
107145243989736508ULL,
|
||||
6388978042437517382ULL,
|
||||
15844067734406016715ULL,
|
||||
9975000513555218239ULL,
|
||||
3344984123768313364ULL,
|
||||
9959189626657347191ULL,
|
||||
12960773468763563665ULL,
|
||||
9602914297752488475ULL,
|
||||
16657542370200465908ULL,
|
||||
},
|
||||
{
|
||||
12987190162843096997ULL,
|
||||
653957632802705281ULL,
|
||||
4441654670647621225ULL,
|
||||
4038207883745915761ULL,
|
||||
5613464648874830118ULL,
|
||||
13222989726778338773ULL,
|
||||
3037761201230264149ULL,
|
||||
16683759727265180203ULL,
|
||||
8337364536491240715ULL,
|
||||
3227397518293416448ULL,
|
||||
8110510111539674682ULL,
|
||||
2872078294163232137ULL,
|
||||
},
|
||||
{
|
||||
18072785500942327487ULL,
|
||||
6200974112677013481ULL,
|
||||
17682092219085884187ULL,
|
||||
10599526828986756440ULL,
|
||||
975003873302957338ULL,
|
||||
8264241093196931281ULL,
|
||||
10065763900435475170ULL,
|
||||
2181131744534710197ULL,
|
||||
6317303992309418647ULL,
|
||||
1401440938888741532ULL,
|
||||
8884468225181997494ULL,
|
||||
13066900325715521532ULL,
|
||||
},
|
||||
{
|
||||
5674685213610121970ULL,
|
||||
5759084860419474071ULL,
|
||||
13943282657648897737ULL,
|
||||
1352748651966375394ULL,
|
||||
17110913224029905221ULL,
|
||||
1003883795902368422ULL,
|
||||
4141870621881018291ULL,
|
||||
8121410972417424656ULL,
|
||||
14300518605864919529ULL,
|
||||
13712227150607670181ULL,
|
||||
17021852944633065291ULL,
|
||||
6252096473787587650ULL,
|
||||
},
|
||||
{
|
||||
4887609836208846458ULL,
|
||||
3027115137917284492ULL,
|
||||
9595098600469470675ULL,
|
||||
10528569829048484079ULL,
|
||||
7864689113198939815ULL,
|
||||
17533723827845969040ULL,
|
||||
5781638039037710951ULL,
|
||||
17024078752430719006ULL,
|
||||
109659393484013511ULL,
|
||||
7158933660534805869ULL,
|
||||
2955076958026921730ULL,
|
||||
7433723648458773977ULL,
|
||||
},
|
||||
{
|
||||
16308865189192447297ULL,
|
||||
11977192855656444890ULL,
|
||||
12532242556065780287ULL,
|
||||
14594890931430968898ULL,
|
||||
7291784239689209784ULL,
|
||||
5514718540551361949ULL,
|
||||
10025733853830934803ULL,
|
||||
7293794580341021693ULL,
|
||||
6728552937464861756ULL,
|
||||
6332385040983343262ULL,
|
||||
13277683694236792804ULL,
|
||||
2600778905124452676ULL,
|
||||
},
|
||||
{
|
||||
7123075680859040534ULL,
|
||||
1034205548717903090ULL,
|
||||
7717824418247931797ULL,
|
||||
3019070937878604058ULL,
|
||||
11403792746066867460ULL,
|
||||
10280580802233112374ULL,
|
||||
337153209462421218ULL,
|
||||
13333398568519923717ULL,
|
||||
3596153696935337464ULL,
|
||||
8104208463525993784ULL,
|
||||
14345062289456085693ULL,
|
||||
17036731477169661256ULL,
|
||||
}};
|
||||
|
||||
const uint64_t ARK2[7][12] = {
|
||||
{
|
||||
6077062762357204287ULL,
|
||||
15277620170502011191ULL,
|
||||
5358738125714196705ULL,
|
||||
14233283787297595718ULL,
|
||||
13792579614346651365ULL,
|
||||
11614812331536767105ULL,
|
||||
14871063686742261166ULL,
|
||||
10148237148793043499ULL,
|
||||
4457428952329675767ULL,
|
||||
15590786458219172475ULL,
|
||||
10063319113072092615ULL,
|
||||
14200078843431360086ULL,
|
||||
},
|
||||
{
|
||||
6202948458916099932ULL,
|
||||
17690140365333231091ULL,
|
||||
3595001575307484651ULL,
|
||||
373995945117666487ULL,
|
||||
1235734395091296013ULL,
|
||||
14172757457833931602ULL,
|
||||
707573103686350224ULL,
|
||||
15453217512188187135ULL,
|
||||
219777875004506018ULL,
|
||||
17876696346199469008ULL,
|
||||
17731621626449383378ULL,
|
||||
2897136237748376248ULL,
|
||||
},
|
||||
{
|
||||
8023374565629191455ULL,
|
||||
15013690343205953430ULL,
|
||||
4485500052507912973ULL,
|
||||
12489737547229155153ULL,
|
||||
9500452585969030576ULL,
|
||||
2054001340201038870ULL,
|
||||
12420704059284934186ULL,
|
||||
355990932618543755ULL,
|
||||
9071225051243523860ULL,
|
||||
12766199826003448536ULL,
|
||||
9045979173463556963ULL,
|
||||
12934431667190679898ULL,
|
||||
},
|
||||
{
|
||||
18389244934624494276ULL,
|
||||
16731736864863925227ULL,
|
||||
4440209734760478192ULL,
|
||||
17208448209698888938ULL,
|
||||
8739495587021565984ULL,
|
||||
17000774922218161967ULL,
|
||||
13533282547195532087ULL,
|
||||
525402848358706231ULL,
|
||||
16987541523062161972ULL,
|
||||
5466806524462797102ULL,
|
||||
14512769585918244983ULL,
|
||||
10973956031244051118ULL,
|
||||
},
|
||||
{
|
||||
6982293561042362913ULL,
|
||||
14065426295947720331ULL,
|
||||
16451845770444974180ULL,
|
||||
7139138592091306727ULL,
|
||||
9012006439959783127ULL,
|
||||
14619614108529063361ULL,
|
||||
1394813199588124371ULL,
|
||||
4635111139507788575ULL,
|
||||
16217473952264203365ULL,
|
||||
10782018226466330683ULL,
|
||||
6844229992533662050ULL,
|
||||
7446486531695178711ULL,
|
||||
},
|
||||
{
|
||||
3736792340494631448ULL,
|
||||
577852220195055341ULL,
|
||||
6689998335515779805ULL,
|
||||
13886063479078013492ULL,
|
||||
14358505101923202168ULL,
|
||||
7744142531772274164ULL,
|
||||
16135070735728404443ULL,
|
||||
12290902521256031137ULL,
|
||||
12059913662657709804ULL,
|
||||
16456018495793751911ULL,
|
||||
4571485474751953524ULL,
|
||||
17200392109565783176ULL,
|
||||
},
|
||||
{
|
||||
17130398059294018733ULL,
|
||||
519782857322261988ULL,
|
||||
9625384390925085478ULL,
|
||||
1664893052631119222ULL,
|
||||
7629576092524553570ULL,
|
||||
3485239601103661425ULL,
|
||||
9755891797164033838ULL,
|
||||
15218148195153269027ULL,
|
||||
16460604813734957368ULL,
|
||||
9643968136937729763ULL,
|
||||
3611348709641382851ULL,
|
||||
18256379591337759196ULL,
|
||||
},
|
||||
};
|
||||
|
||||
static void apply_sbox(uint64_t *const state)
|
||||
{
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
uint64_t t2 = mult_mod_p(*(state + i), *(state + i));
|
||||
uint64_t t4 = mult_mod_p(t2, t2);
|
||||
|
||||
*(state + i) = mult_mod_p(*(state + i), mult_mod_p(t2, t4));
|
||||
}
|
||||
}
|
||||
|
||||
static void apply_mds(uint64_t *state)
|
||||
{
|
||||
uint64_t res[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
res[i] = 0;
|
||||
}
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
for (uint64_t j = 0; j < STATE_WIDTH; j++)
|
||||
{
|
||||
res[i] = add_mod_p(res[i], mult_mod_p(MDS[i][j], *(state + j)));
|
||||
}
|
||||
}
|
||||
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
*(state + i) = res[i];
|
||||
}
|
||||
}
|
||||
|
||||
static void apply_constants(uint64_t *const state, const uint64_t *ark)
|
||||
{
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
*(state + i) = add_mod_p(*(state + i), *(ark + i));
|
||||
}
|
||||
}
|
||||
|
||||
static void exp_acc(const uint64_t m, const uint64_t *base, const uint64_t *tail, uint64_t *const res)
|
||||
{
|
||||
for (uint64_t i = 0; i < m; i++)
|
||||
{
|
||||
for (uint64_t j = 0; j < STATE_WIDTH; j++)
|
||||
{
|
||||
if (i == 0)
|
||||
{
|
||||
*(res + j) = mult_mod_p(*(base + j), *(base + j));
|
||||
}
|
||||
else
|
||||
{
|
||||
*(res + j) = mult_mod_p(*(res + j), *(res + j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
*(res + i) = mult_mod_p(*(res + i), *(tail + i));
|
||||
}
|
||||
}
|
||||
|
||||
static void apply_inv_sbox(uint64_t *const state)
|
||||
{
|
||||
uint64_t t1[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t1[i] = 0;
|
||||
}
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t1[i] = mult_mod_p(*(state + i), *(state + i));
|
||||
}
|
||||
|
||||
uint64_t t2[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t2[i] = 0;
|
||||
}
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t2[i] = mult_mod_p(t1[i], t1[i]);
|
||||
}
|
||||
|
||||
uint64_t t3[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t3[i] = 0;
|
||||
}
|
||||
exp_acc(3, t2, t2, t3);
|
||||
|
||||
uint64_t t4[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t4[i] = 0;
|
||||
}
|
||||
exp_acc(6, t3, t3, t4);
|
||||
|
||||
uint64_t tmp[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
tmp[i] = 0;
|
||||
}
|
||||
exp_acc(12, t4, t4, tmp);
|
||||
|
||||
uint64_t t5[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t5[i] = 0;
|
||||
}
|
||||
exp_acc(6, tmp, t3, t5);
|
||||
|
||||
uint64_t t6[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t6[i] = 0;
|
||||
}
|
||||
exp_acc(31, t5, t5, t6);
|
||||
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
uint64_t a = mult_mod_p(mult_mod_p(t6[i], t6[i]), t5[i]);
|
||||
a = mult_mod_p(a, a);
|
||||
a = mult_mod_p(a, a);
|
||||
uint64_t b = mult_mod_p(mult_mod_p(t1[i], t2[i]), *(state + i));
|
||||
|
||||
*(state + i) = mult_mod_p(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
static void apply_round(uint64_t *const state, const uint64_t round)
|
||||
{
|
||||
apply_mds(state);
|
||||
apply_constants(state, ARK1[round]);
|
||||
apply_sbox(state);
|
||||
|
||||
apply_mds(state);
|
||||
apply_constants(state, ARK2[round]);
|
||||
apply_inv_sbox(state);
|
||||
}
|
||||
|
||||
static void apply_permutation(uint64_t *state)
|
||||
{
|
||||
for (uint64_t i = 0; i < NUM_ROUNDS; i++)
|
||||
{
|
||||
apply_round(state, i);
|
||||
}
|
||||
}
|
||||
|
||||
/* ================================================================================================
|
||||
* RPO128 implementation. This is supposed to substitute SHAKE256 in the hash-to-point algorithm.
|
||||
*/
|
||||
|
||||
#include "rpo.h"
|
||||
|
||||
void rpo128_init(rpo128_context *rc)
|
||||
{
|
||||
rc->dptr = 32;
|
||||
|
||||
memset(rc->st.A, 0, sizeof rc->st.A);
|
||||
}
|
||||
|
||||
void rpo128_absorb(rpo128_context *rc, const uint8_t *in, size_t len)
|
||||
{
|
||||
size_t dptr;
|
||||
|
||||
dptr = (size_t)rc->dptr;
|
||||
while (len > 0)
|
||||
{
|
||||
size_t clen, u;
|
||||
|
||||
/* 136 * 8 = 1088 bit for the rate portion in the case of SHAKE256
|
||||
* For RPO, this is 64 * 8 = 512 bits
|
||||
* The capacity for SHAKE256 is at the end while for RPO128 it is at the beginning
|
||||
*/
|
||||
clen = 96 - dptr;
|
||||
if (clen > len)
|
||||
{
|
||||
clen = len;
|
||||
}
|
||||
|
||||
for (u = 0; u < clen; u++)
|
||||
{
|
||||
rc->st.dbuf[dptr + u] = in[u];
|
||||
}
|
||||
|
||||
dptr += clen;
|
||||
in += clen;
|
||||
len -= clen;
|
||||
if (dptr == 96)
|
||||
{
|
||||
apply_permutation(rc->st.A);
|
||||
dptr = 32;
|
||||
}
|
||||
}
|
||||
rc->dptr = dptr;
|
||||
}
|
||||
|
||||
void rpo128_finalize(rpo128_context *rc)
|
||||
{
|
||||
// Set dptr to the end of the buffer, so that first call to extract will call the permutation.
|
||||
rc->dptr = 96;
|
||||
}
|
||||
|
||||
void rpo128_squeeze(rpo128_context *rc, uint8_t *out, size_t len)
|
||||
{
|
||||
size_t dptr;
|
||||
|
||||
dptr = (size_t)rc->dptr;
|
||||
while (len > 0)
|
||||
{
|
||||
size_t clen;
|
||||
|
||||
if (dptr == 96)
|
||||
{
|
||||
apply_permutation(rc->st.A);
|
||||
dptr = 32;
|
||||
}
|
||||
clen = 96 - dptr;
|
||||
if (clen > len)
|
||||
{
|
||||
clen = len;
|
||||
}
|
||||
len -= clen;
|
||||
|
||||
memcpy(out, rc->st.dbuf + dptr, clen);
|
||||
dptr += clen;
|
||||
out += clen;
|
||||
}
|
||||
rc->dptr = dptr;
|
||||
}
|
||||
|
||||
void rpo128_release(rpo128_context *rc)
|
||||
{
|
||||
memset(rc->st.A, 0, sizeof rc->st.A);
|
||||
rc->dptr = 32;
|
||||
}
|
||||
|
||||
/* ================================================================================================
|
||||
* Hash-to-Point algorithm implementation based on RPO128
|
||||
*/
|
||||
|
||||
void PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(rpo128_context *rc, uint16_t *x, unsigned logn)
|
||||
{
|
||||
/*
|
||||
* This implementation avoids the rejection sampling step needed in the
|
||||
* per-the-spec implementation. It uses a remark in https://falcon-sign.info/falcon.pdf
|
||||
* page 31, which argues that the current variant is secure for the parameters set by NIST.
|
||||
* Avoiding the rejection-sampling step leads to an implementation that is constant-time.
|
||||
* TODO: Check that the current implementation is indeed constant-time.
|
||||
*/
|
||||
size_t n;
|
||||
|
||||
n = (size_t)1 << logn;
|
||||
while (n > 0)
|
||||
{
|
||||
uint8_t buf[8];
|
||||
uint64_t w;
|
||||
|
||||
rpo128_squeeze(rc, (void *)buf, sizeof buf);
|
||||
w = ((uint64_t)(buf[7]) << 56) |
|
||||
((uint64_t)(buf[6]) << 48) |
|
||||
((uint64_t)(buf[5]) << 40) |
|
||||
((uint64_t)(buf[4]) << 32) |
|
||||
((uint64_t)(buf[3]) << 24) |
|
||||
((uint64_t)(buf[2]) << 16) |
|
||||
((uint64_t)(buf[1]) << 8) |
|
||||
((uint64_t)(buf[0]));
|
||||
|
||||
w %= M;
|
||||
|
||||
*x++ = (uint16_t)w;
|
||||
n--;
|
||||
}
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
/* ================================================================================================
|
||||
* RPO hashing algorithm related structs and methods.
|
||||
*/
|
||||
|
||||
/*
|
||||
* RPO128 context.
|
||||
*
|
||||
* This structure is used by the hashing API. It is composed of an internal state that can be
|
||||
* viewed as either:
|
||||
* 1. 12 field elements in the Miden VM.
|
||||
* 2. 96 bytes.
|
||||
*
|
||||
* The first view is used for the internal state in the context of the RPO hashing algorithm. The
|
||||
* second view is used for the buffer used to absorb the data to be hashed.
|
||||
*
|
||||
* The pointer to the buffer is updated as the data is absorbed.
|
||||
*
|
||||
* 'rpo128_context' must be initialized with rpo128_init() before first use.
|
||||
*/
|
||||
typedef struct
|
||||
{
|
||||
union
|
||||
{
|
||||
uint64_t A[12];
|
||||
uint8_t dbuf[96];
|
||||
} st;
|
||||
uint64_t dptr;
|
||||
} rpo128_context;
|
||||
|
||||
/*
|
||||
* Initializes an RPO state
|
||||
*/
|
||||
void rpo128_init(rpo128_context *rc);
|
||||
|
||||
/*
|
||||
* Absorbs an array of bytes of length 'len' into the state.
|
||||
*/
|
||||
void rpo128_absorb(rpo128_context *rc, const uint8_t *in, size_t len);
|
||||
|
||||
/*
|
||||
* Squeezes an array of bytes of length 'len' from the state.
|
||||
*/
|
||||
void rpo128_squeeze(rpo128_context *rc, uint8_t *out, size_t len);
|
||||
|
||||
/*
|
||||
* Finalizes the state in preparation for squeezing.
|
||||
*
|
||||
* This function should be called after all the data has been absorbed.
|
||||
*
|
||||
* Note that the current implementation does not perform any sort of padding for domain separation
|
||||
* purposes. The reason being that, for our purposes, we always perform the following sequence:
|
||||
* 1. Absorb a Nonce (which is always 40 bytes packed as 8 field elements).
|
||||
* 2. Absorb the message (which is always 4 field elements).
|
||||
* 3. Call finalize.
|
||||
* 4. Squeeze the output.
|
||||
* 5. Call release.
|
||||
*/
|
||||
void rpo128_finalize(rpo128_context *rc);
|
||||
|
||||
/*
|
||||
* Releases the state.
|
||||
*
|
||||
* This function should be called after the squeeze operation is finished.
|
||||
*/
|
||||
void rpo128_release(rpo128_context *rc);
|
||||
|
||||
/* ================================================================================================
|
||||
* Hash-to-Point algorithm for signature generation and signature verification.
|
||||
*/
|
||||
|
||||
/*
|
||||
* Hash-to-Point algorithm.
|
||||
*
|
||||
* This function generates a point in Z_q[x]/(phi) from a given message.
|
||||
*
|
||||
* It takes a finalized rpo128_context as input and it generates the coefficients of the polynomial
|
||||
* representing the point. The coefficients are stored in the array 'x'. The number of coefficients
|
||||
* is given by 'logn', which must in our case is 512.
|
||||
*/
|
||||
void PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(rpo128_context *rc, uint16_t *x, unsigned logn);
|
||||
@@ -1,189 +0,0 @@
|
||||
use libc::c_int;
|
||||
|
||||
// C IMPLEMENTATION INTERFACE
|
||||
// ================================================================================================
|
||||
|
||||
#[link(name = "rpo_falcon512", kind = "static")]
|
||||
extern "C" {
|
||||
/// Generate a new key pair. Public key goes into pk[], private key in sk[].
|
||||
/// Key sizes are exact (in bytes):
|
||||
/// - public (pk): 897
|
||||
/// - private (sk): 1281
|
||||
///
|
||||
/// Return value: 0 on success, -1 on error.
|
||||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(pk: *mut u8, sk: *mut u8) -> c_int;
|
||||
|
||||
/// Generate a new key pair from seed. Public key goes into pk[], private key in sk[].
|
||||
/// Key sizes are exact (in bytes):
|
||||
/// - public (pk): 897
|
||||
/// - private (sk): 1281
|
||||
///
|
||||
/// Return value: 0 on success, -1 on error.
|
||||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
pk: *mut u8,
|
||||
sk: *mut u8,
|
||||
seed: *const u8,
|
||||
) -> c_int;
|
||||
|
||||
/// Compute a signature on a provided message (m, mlen), with a given private key (sk).
|
||||
/// Signature is written in sig[], with length written into *siglen. Signature length is
|
||||
/// variable; maximum signature length (in bytes) is 666.
|
||||
///
|
||||
/// sig[], m[] and sk[] may overlap each other arbitrarily.
|
||||
///
|
||||
/// Return value: 0 on success, -1 on error.
|
||||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
sig: *mut u8,
|
||||
siglen: *mut usize,
|
||||
m: *const u8,
|
||||
mlen: usize,
|
||||
sk: *const u8,
|
||||
) -> c_int;
|
||||
|
||||
// TEST HELPERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Verify a signature (sig, siglen) on a message (m, mlen) with a given public key (pk).
|
||||
///
|
||||
/// sig[], m[] and pk[] may overlap each other arbitrarily.
|
||||
///
|
||||
/// Return value: 0 on success, -1 on error.
|
||||
#[cfg(test)]
|
||||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
sig: *const u8,
|
||||
siglen: usize,
|
||||
m: *const u8,
|
||||
mlen: usize,
|
||||
pk: *const u8,
|
||||
) -> c_int;
|
||||
|
||||
/// Hash-to-Point algorithm.
|
||||
///
|
||||
/// This function generates a point in Z_q[x]/(phi) from a given message.
|
||||
///
|
||||
/// It takes a finalized rpo128_context as input and it generates the coefficients of the polynomial
|
||||
/// representing the point. The coefficients are stored in the array 'x'. The number of coefficients
|
||||
/// is given by 'logn', which must in our case is 512.
|
||||
#[cfg(test)]
|
||||
pub fn PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(
|
||||
rc: *mut Rpo128Context,
|
||||
x: *mut u16,
|
||||
logn: usize,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn rpo128_init(sc: *mut Rpo128Context);
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn rpo128_absorb(
|
||||
sc: *mut Rpo128Context,
|
||||
data: *const ::std::os::raw::c_void,
|
||||
len: libc::size_t,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn rpo128_finalize(sc: *mut Rpo128Context);
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[cfg(test)]
|
||||
pub struct Rpo128Context {
|
||||
pub content: [u64; 13usize],
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::dsa::rpo_falcon512::{NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
|
||||
use rand_utils::{rand_array, rand_value, rand_vector};
|
||||
|
||||
#[test]
|
||||
fn falcon_ffi() {
|
||||
unsafe {
|
||||
//let mut rng = rand::thread_rng();
|
||||
|
||||
// --- generate a key pair from a seed ----------------------------
|
||||
|
||||
let mut pk = [0u8; PK_LEN];
|
||||
let mut sk = [0u8; SK_LEN];
|
||||
let seed: [u8; NONCE_LEN] = rand_array();
|
||||
|
||||
assert_eq!(
|
||||
0,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
pk.as_mut_ptr(),
|
||||
sk.as_mut_ptr(),
|
||||
seed.as_ptr()
|
||||
)
|
||||
);
|
||||
|
||||
// --- sign a message and make sure it verifies -------------------
|
||||
|
||||
let mlen: usize = rand_value::<u16>() as usize;
|
||||
let msg: Vec<u8> = rand_vector(mlen);
|
||||
let mut detached_sig = [0u8; NONCE_LEN + SIG_LEN];
|
||||
let mut siglen = 0;
|
||||
|
||||
assert_eq!(
|
||||
0,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
detached_sig.as_mut_ptr(),
|
||||
&mut siglen as *mut usize,
|
||||
msg.as_ptr(),
|
||||
msg.len(),
|
||||
sk.as_ptr()
|
||||
)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
0,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
detached_sig.as_ptr(),
|
||||
siglen,
|
||||
msg.as_ptr(),
|
||||
msg.len(),
|
||||
pk.as_ptr()
|
||||
)
|
||||
);
|
||||
|
||||
// --- check verification of different signature ------------------
|
||||
|
||||
assert_eq!(
|
||||
-1,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
detached_sig.as_ptr(),
|
||||
siglen,
|
||||
msg.as_ptr(),
|
||||
msg.len() - 1,
|
||||
pk.as_ptr()
|
||||
)
|
||||
);
|
||||
|
||||
// --- check verification against a different pub key -------------
|
||||
|
||||
let mut pk_alt = [0u8; PK_LEN];
|
||||
let mut sk_alt = [0u8; SK_LEN];
|
||||
assert_eq!(
|
||||
0,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
||||
pk_alt.as_mut_ptr(),
|
||||
sk_alt.as_mut_ptr()
|
||||
)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
-1,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
detached_sig.as_ptr(),
|
||||
siglen,
|
||||
msg.as_ptr(),
|
||||
msg.len(),
|
||||
pk_alt.as_ptr()
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
68
src/dsa/rpo_falcon512/hash_to_point.rs
Normal file
68
src/dsa/rpo_falcon512/hash_to_point.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
use super::{math::FalconFelt, Nonce, Polynomial, Rpo256, Word, MODULUS, N, ZERO};
|
||||
use alloc::vec::Vec;
|
||||
use num::Zero;
|
||||
|
||||
// HASH-TO-POINT FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
|
||||
/// nonce using RPO256.
|
||||
pub fn hash_to_point_rpo256(message: Word, nonce: &Nonce) -> Polynomial<FalconFelt> {
|
||||
let mut state = [ZERO; Rpo256::STATE_WIDTH];
|
||||
|
||||
// absorb the nonce into the state
|
||||
let nonce_elements = nonce.to_elements();
|
||||
for (&n, s) in nonce_elements.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
||||
*s = n;
|
||||
}
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
|
||||
// absorb message into the state
|
||||
for (&m, s) in message.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
||||
*s = m;
|
||||
}
|
||||
|
||||
// squeeze the coefficients of the polynomial
|
||||
let mut i = 0;
|
||||
let mut res = [FalconFelt::zero(); N];
|
||||
for _ in 0..64 {
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
for a in &state[Rpo256::RATE_RANGE] {
|
||||
res[i] = FalconFelt::new((a.as_int() % MODULUS as u64) as i16);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Polynomial::new(res.to_vec())
|
||||
}
|
||||
|
||||
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
|
||||
/// nonce using SHAKE256. This is the hash-to-point algorithm used in the reference implementation.
|
||||
#[allow(dead_code)]
|
||||
pub fn hash_to_point_shake256(message: &[u8], nonce: &Nonce) -> Polynomial<FalconFelt> {
|
||||
use sha3::{
|
||||
digest::{ExtendableOutput, Update, XofReader},
|
||||
Shake256,
|
||||
};
|
||||
|
||||
let mut data = vec![];
|
||||
data.extend_from_slice(nonce.as_bytes());
|
||||
data.extend_from_slice(message);
|
||||
const K: u32 = (1u32 << 16) / MODULUS as u32;
|
||||
|
||||
let mut hasher = Shake256::default();
|
||||
hasher.update(&data);
|
||||
let mut reader = hasher.finalize_xof();
|
||||
|
||||
let mut coefficients: Vec<FalconFelt> = Vec::with_capacity(N);
|
||||
while coefficients.len() != N {
|
||||
let mut randomness = [0u8; 2];
|
||||
reader.read(&mut randomness);
|
||||
let t = ((randomness[0] as u32) << 8) | (randomness[1] as u32);
|
||||
if t < K * MODULUS as u32 {
|
||||
coefficients.push(FalconFelt::new((t % MODULUS as u32) as i16));
|
||||
}
|
||||
}
|
||||
|
||||
Polynomial { coefficients }
|
||||
}
|
||||
@@ -1,227 +0,0 @@
|
||||
use super::{
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, FalconError, Polynomial,
|
||||
PublicKeyBytes, Rpo256, SecretKeyBytes, Serializable, Signature, Word,
|
||||
};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use super::{ffi, NonceBytes, StarkField, NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
|
||||
|
||||
// PUBLIC KEY
|
||||
// ================================================================================================
|
||||
|
||||
/// A public key for verifying signatures.
|
||||
///
|
||||
/// The public key is a [Word] (i.e., 4 field elements) that is the hash of the coefficients of
|
||||
/// the polynomial representing the raw bytes of the expanded public key.
|
||||
///
|
||||
/// For Falcon-512, the first byte of the expanded public key is always equal to log2(512) i.e., 9.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct PublicKey(Word);
|
||||
|
||||
impl PublicKey {
|
||||
/// Returns a new [PublicKey] which is a commitment to the provided expanded public key.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the decoding of the public key fails.
|
||||
pub fn new(pk: PublicKeyBytes) -> Result<Self, FalconError> {
|
||||
let h = Polynomial::from_pub_key(&pk)?;
|
||||
let pk_felts = h.to_elements();
|
||||
let pk_digest = Rpo256::hash_elements(&pk_felts).into();
|
||||
Ok(Self(pk_digest))
|
||||
}
|
||||
|
||||
/// Verifies the provided signature against provided message and this public key.
|
||||
pub fn verify(&self, message: Word, signature: &Signature) -> bool {
|
||||
signature.verify(message, self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PublicKey> for Word {
|
||||
fn from(key: PublicKey) -> Self {
|
||||
key.0
|
||||
}
|
||||
}
|
||||
|
||||
// KEY PAIR
|
||||
// ================================================================================================
|
||||
|
||||
/// A key pair (public and secret keys) for signing messages.
|
||||
///
|
||||
/// The secret key is a byte array of length [PK_LEN].
|
||||
/// The public key is a byte array of length [SK_LEN].
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct KeyPair {
|
||||
public_key: PublicKeyBytes,
|
||||
secret_key: SecretKeyBytes,
|
||||
}
|
||||
|
||||
#[allow(clippy::new_without_default)]
|
||||
impl KeyPair {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Generates a (public_key, secret_key) key pair from OS-provided randomness.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if key generation fails.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn new() -> Result<Self, FalconError> {
|
||||
let mut public_key = [0u8; PK_LEN];
|
||||
let mut secret_key = [0u8; SK_LEN];
|
||||
|
||||
let res = unsafe {
|
||||
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
||||
public_key.as_mut_ptr(),
|
||||
secret_key.as_mut_ptr(),
|
||||
)
|
||||
};
|
||||
|
||||
if res == 0 {
|
||||
Ok(Self { public_key, secret_key })
|
||||
} else {
|
||||
Err(FalconError::KeyGenerationFailed)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a (public_key, secret_key) key pair from the provided seed.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if key generation fails.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn from_seed(seed: &NonceBytes) -> Result<Self, FalconError> {
|
||||
let mut public_key = [0u8; PK_LEN];
|
||||
let mut secret_key = [0u8; SK_LEN];
|
||||
|
||||
let res = unsafe {
|
||||
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
public_key.as_mut_ptr(),
|
||||
secret_key.as_mut_ptr(),
|
||||
seed.as_ptr(),
|
||||
)
|
||||
};
|
||||
|
||||
if res == 0 {
|
||||
Ok(Self { public_key, secret_key })
|
||||
} else {
|
||||
Err(FalconError::KeyGenerationFailed)
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the public key corresponding to this key pair.
|
||||
pub fn public_key(&self) -> PublicKey {
|
||||
// TODO: memoize public key commitment as computing it requires quite a bit of hashing.
|
||||
// expect() is fine here because we assume that the key pair was constructed correctly.
|
||||
PublicKey::new(self.public_key).expect("invalid key pair")
|
||||
}
|
||||
|
||||
/// Returns the expanded public key corresponding to this key pair.
|
||||
pub fn expanded_public_key(&self) -> PublicKeyBytes {
|
||||
self.public_key
|
||||
}
|
||||
|
||||
// SIGNATURE GENERATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Signs a message with a secret key and a seed.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error of signature generation fails.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn sign(&self, message: Word) -> Result<Signature, FalconError> {
|
||||
let msg = message.iter().flat_map(|e| e.as_int().to_le_bytes()).collect::<Vec<_>>();
|
||||
let msg_len = msg.len();
|
||||
let mut sig = [0_u8; SIG_LEN + NONCE_LEN];
|
||||
let mut sig_len: usize = 0;
|
||||
|
||||
let res = unsafe {
|
||||
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
sig.as_mut_ptr(),
|
||||
&mut sig_len as *mut usize,
|
||||
msg.as_ptr(),
|
||||
msg_len,
|
||||
self.secret_key.as_ptr(),
|
||||
)
|
||||
};
|
||||
|
||||
if res == 0 {
|
||||
Ok(Signature { sig, pk: self.public_key })
|
||||
} else {
|
||||
Err(FalconError::SigGenerationFailed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for KeyPair {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.public_key);
|
||||
target.write_bytes(&self.secret_key);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for KeyPair {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let public_key: PublicKeyBytes = source.read_array()?;
|
||||
let secret_key: SecretKeyBytes = source.read_array()?;
|
||||
Ok(Self { public_key, secret_key })
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use super::{super::Felt, KeyPair, NonceBytes, Word};
|
||||
use rand_utils::{rand_array, rand_vector};
|
||||
|
||||
#[test]
|
||||
fn test_falcon_verification() {
|
||||
// generate random keys
|
||||
let keys = KeyPair::new().unwrap();
|
||||
let pk = keys.public_key();
|
||||
|
||||
// sign a random message
|
||||
let message: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
let signature = keys.sign(message);
|
||||
|
||||
// make sure the signature verifies correctly
|
||||
assert!(pk.verify(message, signature.as_ref().unwrap()));
|
||||
|
||||
// a signature should not verify against a wrong message
|
||||
let message2: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
assert!(!pk.verify(message2, signature.as_ref().unwrap()));
|
||||
|
||||
// a signature should not verify against a wrong public key
|
||||
let keys2 = KeyPair::new().unwrap();
|
||||
assert!(!keys2.public_key().verify(message, signature.as_ref().unwrap()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_falcon_verification_from_seed() {
|
||||
// generate keys from a random seed
|
||||
let seed: NonceBytes = rand_array();
|
||||
let keys = KeyPair::from_seed(&seed).unwrap();
|
||||
let pk = keys.public_key();
|
||||
|
||||
// sign a random message
|
||||
let message: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
let signature = keys.sign(message);
|
||||
|
||||
// make sure the signature verifies correctly
|
||||
assert!(pk.verify(message, signature.as_ref().unwrap()));
|
||||
|
||||
// a signature should not verify against a wrong message
|
||||
let message2: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
assert!(!pk.verify(message2, signature.as_ref().unwrap()));
|
||||
|
||||
// a signature should not verify against a wrong public key
|
||||
let keys2 = KeyPair::new().unwrap();
|
||||
assert!(!keys2.public_key().verify(message, signature.as_ref().unwrap()))
|
||||
}
|
||||
}
|
||||
53
src/dsa/rpo_falcon512/keys/mod.rs
Normal file
53
src/dsa/rpo_falcon512/keys/mod.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use super::{
|
||||
math::{FalconFelt, Polynomial},
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, Felt, Serializable, Signature,
|
||||
Word,
|
||||
};
|
||||
|
||||
mod public_key;
|
||||
pub use public_key::{PubKeyPoly, PublicKey};
|
||||
|
||||
mod secret_key;
|
||||
pub use secret_key::SecretKey;
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{dsa::rpo_falcon512::SecretKey, Word, ONE};
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
use winter_math::FieldElement;
|
||||
use winter_utils::{Deserializable, Serializable};
|
||||
|
||||
#[test]
|
||||
fn test_falcon_verification() {
|
||||
let seed = [0_u8; 32];
|
||||
let mut rng = ChaCha20Rng::from_seed(seed);
|
||||
|
||||
// generate random keys
|
||||
let sk = SecretKey::with_rng(&mut rng);
|
||||
let pk = sk.public_key();
|
||||
|
||||
// test secret key serialization/deserialization
|
||||
let mut buffer = vec![];
|
||||
sk.write_into(&mut buffer);
|
||||
let sk = SecretKey::read_from_bytes(&buffer).unwrap();
|
||||
|
||||
// sign a random message
|
||||
let message: Word = [ONE; 4];
|
||||
let signature = sk.sign_with_rng(message, &mut rng);
|
||||
|
||||
// make sure the signature verifies correctly
|
||||
assert!(pk.verify(message, &signature));
|
||||
|
||||
// a signature should not verify against a wrong message
|
||||
let message2: Word = [ONE.double(); 4];
|
||||
assert!(!pk.verify(message2, &signature));
|
||||
|
||||
// a signature should not verify against a wrong public key
|
||||
let sk2 = SecretKey::with_rng(&mut rng);
|
||||
assert!(!sk2.public_key().verify(message, &signature))
|
||||
}
|
||||
}
|
||||
138
src/dsa/rpo_falcon512/keys/public_key.rs
Normal file
138
src/dsa/rpo_falcon512/keys/public_key.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
use crate::dsa::rpo_falcon512::FALCON_ENCODING_BITS;
|
||||
|
||||
use super::{
|
||||
super::{Rpo256, LOG_N, N, PK_LEN},
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, FalconFelt, Felt, Polynomial,
|
||||
Serializable, Signature, Word,
|
||||
};
|
||||
use alloc::string::ToString;
|
||||
use core::ops::Deref;
|
||||
use num::Zero;
|
||||
|
||||
// PUBLIC KEY
|
||||
// ================================================================================================
|
||||
|
||||
/// A public key for verifying signatures.
|
||||
///
|
||||
/// The public key is a [Word] (i.e., 4 field elements) that is the hash of the coefficients of
|
||||
/// the polynomial representing the raw bytes of the expanded public key. The hash is computed
|
||||
/// using Rpo256.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct PublicKey(Word);
|
||||
|
||||
impl PublicKey {
|
||||
/// Returns a new [PublicKey] which is a commitment to the provided expanded public key.
|
||||
pub fn new(pub_key: Word) -> Self {
|
||||
Self(pub_key)
|
||||
}
|
||||
|
||||
/// Verifies the provided signature against provided message and this public key.
|
||||
pub fn verify(&self, message: Word, signature: &Signature) -> bool {
|
||||
signature.verify(message, self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PubKeyPoly> for PublicKey {
|
||||
fn from(pk_poly: PubKeyPoly) -> Self {
|
||||
let pk_felts: Polynomial<Felt> = pk_poly.0.into();
|
||||
let pk_digest = Rpo256::hash_elements(&pk_felts.coefficients).into();
|
||||
Self(pk_digest)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PublicKey> for Word {
|
||||
fn from(key: PublicKey) -> Self {
|
||||
key.0
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC KEY POLYNOMIAL
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PubKeyPoly(pub Polynomial<FalconFelt>);
|
||||
|
||||
impl Deref for PubKeyPoly {
|
||||
type Target = Polynomial<FalconFelt>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Polynomial<FalconFelt>> for PubKeyPoly {
|
||||
fn from(pk_poly: Polynomial<FalconFelt>) -> Self {
|
||||
Self(pk_poly)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for &PubKeyPoly {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
let mut buf = [0_u8; PK_LEN];
|
||||
buf[0] = LOG_N;
|
||||
|
||||
let mut acc = 0_u32;
|
||||
let mut acc_len: u32 = 0;
|
||||
|
||||
let mut input_pos = 1;
|
||||
for c in self.0.coefficients.iter() {
|
||||
let c = c.value();
|
||||
acc = (acc << FALCON_ENCODING_BITS) | c as u32;
|
||||
acc_len += FALCON_ENCODING_BITS;
|
||||
while acc_len >= 8 {
|
||||
acc_len -= 8;
|
||||
buf[input_pos] = (acc >> acc_len) as u8;
|
||||
input_pos += 1;
|
||||
}
|
||||
}
|
||||
if acc_len > 0 {
|
||||
buf[input_pos] = (acc >> (8 - acc_len)) as u8;
|
||||
}
|
||||
|
||||
target.write(buf);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for PubKeyPoly {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let buf = source.read_array::<PK_LEN>()?;
|
||||
|
||||
if buf[0] != LOG_N {
|
||||
return Err(DeserializationError::InvalidValue(format!(
|
||||
"Failed to decode public key: expected the first byte to be {LOG_N} but was {}",
|
||||
buf[0]
|
||||
)));
|
||||
}
|
||||
|
||||
let mut acc = 0_u32;
|
||||
let mut acc_len = 0;
|
||||
|
||||
let mut output = [FalconFelt::zero(); N];
|
||||
let mut output_idx = 0;
|
||||
|
||||
for &byte in buf.iter().skip(1) {
|
||||
acc = (acc << 8) | (byte as u32);
|
||||
acc_len += 8;
|
||||
|
||||
if acc_len >= FALCON_ENCODING_BITS {
|
||||
acc_len -= FALCON_ENCODING_BITS;
|
||||
let w = (acc >> acc_len) & 0x3FFF;
|
||||
let element = w.try_into().map_err(|err| {
|
||||
DeserializationError::InvalidValue(format!(
|
||||
"Failed to decode public key: {err}"
|
||||
))
|
||||
})?;
|
||||
output[output_idx] = element;
|
||||
output_idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (acc & ((1u32 << acc_len) - 1)) == 0 {
|
||||
Ok(Polynomial::new(output.to_vec()).into())
|
||||
} else {
|
||||
Err(DeserializationError::InvalidValue(
|
||||
"Failed to decode public key: input not fully consumed".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
386
src/dsa/rpo_falcon512/keys/secret_key.rs
Normal file
386
src/dsa/rpo_falcon512/keys/secret_key.rs
Normal file
@@ -0,0 +1,386 @@
|
||||
use super::{
|
||||
super::{
|
||||
math::{ffldl, ffsampling, gram, normalize_tree, FalconFelt, FastFft, LdlTree, Polynomial},
|
||||
signature::SignaturePoly,
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, Nonce, Serializable,
|
||||
ShortLatticeBasis, Signature, Word, MODULUS, N, SIGMA, SIG_L2_BOUND,
|
||||
},
|
||||
PubKeyPoly, PublicKey,
|
||||
};
|
||||
use crate::dsa::rpo_falcon512::{
|
||||
hash_to_point::hash_to_point_rpo256, math::ntru_gen, SIG_NONCE_LEN, SK_LEN,
|
||||
};
|
||||
use alloc::{string::ToString, vec::Vec};
|
||||
use num::Complex;
|
||||
use num_complex::Complex64;
|
||||
use rand::Rng;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num::Float;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
const WIDTH_BIG_POLY_COEFFICIENT: usize = 8;
|
||||
const WIDTH_SMALL_POLY_COEFFICIENT: usize = 6;
|
||||
|
||||
// SECRET KEY
|
||||
// ================================================================================================
|
||||
|
||||
/// The secret key is a quadruple [[g, -f], [G, -F]] of polynomials with integer coefficients. Each
|
||||
/// polynomial is of degree at most N = 512 and computations with these polynomials is done modulo
|
||||
/// the monic irreducible polynomial ϕ = x^N + 1. The secret key is a basis for a lattice and has
|
||||
/// the property of being short with respect to a certain norm and an upper bound appropriate for
|
||||
/// a given security parameter. The public key on the other hand is another basis for the same
|
||||
/// lattice and can be described by a single polynomial h with integer coefficients modulo ϕ.
|
||||
/// The two keys are related by the following relation:
|
||||
///
|
||||
/// 1. h = g /f [mod ϕ][mod p]
|
||||
/// 2. f.G - g.F = p [mod ϕ]
|
||||
///
|
||||
/// where p = 12289 is the Falcon prime. Equation 2 is called the NTRU equation.
|
||||
/// The secret key is generated by first sampling a random pair (f, g) of polynomials using
|
||||
/// an appropriate distribution that yields short but not too short polynomials with integer
|
||||
/// coefficients modulo ϕ. The NTRU equation is then used to find a matching pair (F, G).
|
||||
/// The public key is then derived from the secret key using equation 1.
|
||||
///
|
||||
/// To allow for fast signature generation, the secret key is pre-processed into a more suitable
|
||||
/// form, called the LDL tree, and this allows for fast sampling of short vectors in the lattice
|
||||
/// using Fast Fourier sampling during signature generation (ffSampling algorithm 11 in [1]).
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SecretKey {
|
||||
secret_key: ShortLatticeBasis,
|
||||
tree: LdlTree,
|
||||
}
|
||||
|
||||
#[allow(clippy::new_without_default)]
|
||||
impl SecretKey {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Generates a secret key from OS-provided randomness.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn new() -> Self {
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
let mut rng = StdRng::from_entropy();
|
||||
Self::with_rng(&mut rng)
|
||||
}
|
||||
|
||||
/// Generates a secret_key using the provided random number generator `Rng`.
|
||||
pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
|
||||
let basis = ntru_gen(N, rng);
|
||||
Self::from_short_lattice_basis(basis)
|
||||
}
|
||||
|
||||
/// Given a short basis [[g, -f], [G, -F]], computes the normalized LDL tree i.e., Falcon tree.
|
||||
fn from_short_lattice_basis(basis: ShortLatticeBasis) -> SecretKey {
|
||||
// FFT each polynomial of the short basis.
|
||||
let basis_fft = to_complex_fft(&basis);
|
||||
// compute the Gram matrix.
|
||||
let gram_fft = gram(basis_fft);
|
||||
// construct the LDL tree of the Gram matrix.
|
||||
let mut tree = ffldl(gram_fft);
|
||||
// normalize the leaves of the LDL tree.
|
||||
normalize_tree(&mut tree, SIGMA);
|
||||
Self { secret_key: basis, tree }
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the polynomials of the short lattice basis of this secret key.
|
||||
pub fn short_lattice_basis(&self) -> &ShortLatticeBasis {
|
||||
&self.secret_key
|
||||
}
|
||||
|
||||
/// Returns the public key corresponding to this secret key.
|
||||
pub fn public_key(&self) -> PublicKey {
|
||||
self.compute_pub_key_poly().into()
|
||||
}
|
||||
|
||||
/// Returns the LDL tree associated to this secret key.
|
||||
pub fn tree(&self) -> &LdlTree {
|
||||
&self.tree
|
||||
}
|
||||
|
||||
// SIGNATURE GENERATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Signs a message with this secret key.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn sign(&self, message: Word) -> Signature {
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
let mut rng = StdRng::from_entropy();
|
||||
self.sign_with_rng(message, &mut rng)
|
||||
}
|
||||
|
||||
/// Signs a message with the secret key relying on the provided randomness generator.
|
||||
pub fn sign_with_rng<R: Rng>(&self, message: Word, rng: &mut R) -> Signature {
|
||||
let mut nonce_bytes = [0u8; SIG_NONCE_LEN];
|
||||
rng.fill_bytes(&mut nonce_bytes);
|
||||
let nonce = Nonce::new(nonce_bytes);
|
||||
|
||||
let h = self.compute_pub_key_poly();
|
||||
let c = hash_to_point_rpo256(message, &nonce);
|
||||
let s2 = self.sign_helper(c, rng);
|
||||
|
||||
Signature::new(nonce, h, s2)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Derives the public key corresponding to this secret key using h = g /f [mod ϕ][mod p].
|
||||
pub fn compute_pub_key_poly(&self) -> PubKeyPoly {
|
||||
let g: Polynomial<FalconFelt> = self.secret_key[0].clone().into();
|
||||
let g_fft = g.fft();
|
||||
let minus_f: Polynomial<FalconFelt> = self.secret_key[1].clone().into();
|
||||
let f = -minus_f;
|
||||
let f_fft = f.fft();
|
||||
let h_fft = g_fft.hadamard_div(&f_fft);
|
||||
h_fft.ifft().into()
|
||||
}
|
||||
|
||||
/// Signs a message polynomial with the secret key.
|
||||
///
|
||||
/// Takes a randomness generator implementing `Rng` and message polynomial representing `c`
|
||||
/// the hash-to-point of the message to be signed. It outputs a signature polynomial `s2`.
|
||||
fn sign_helper<R: Rng>(&self, c: Polynomial<FalconFelt>, rng: &mut R) -> SignaturePoly {
|
||||
let one_over_q = 1.0 / (MODULUS as f64);
|
||||
let c_over_q_fft = c.map(|cc| Complex::new(one_over_q * cc.value() as f64, 0.0)).fft();
|
||||
|
||||
// B = [[FFT(g), -FFT(f)], [FFT(G), -FFT(F)]]
|
||||
let [g_fft, minus_f_fft, big_g_fft, minus_big_f_fft] = to_complex_fft(&self.secret_key);
|
||||
let t0 = c_over_q_fft.hadamard_mul(&minus_big_f_fft);
|
||||
let t1 = -c_over_q_fft.hadamard_mul(&minus_f_fft);
|
||||
|
||||
loop {
|
||||
let bold_s = loop {
|
||||
let z = ffsampling(&(t0.clone(), t1.clone()), &self.tree, rng);
|
||||
let t0_min_z0 = t0.clone() - z.0;
|
||||
let t1_min_z1 = t1.clone() - z.1;
|
||||
|
||||
// s = (t-z) * B
|
||||
let s0 = t0_min_z0.hadamard_mul(&g_fft) + t1_min_z1.hadamard_mul(&big_g_fft);
|
||||
let s1 =
|
||||
t0_min_z0.hadamard_mul(&minus_f_fft) + t1_min_z1.hadamard_mul(&minus_big_f_fft);
|
||||
|
||||
// compute the norm of (s0||s1) and note that they are in FFT representation
|
||||
let length_squared: f64 =
|
||||
(s0.coefficients.iter().map(|a| (a * a.conj()).re).sum::<f64>()
|
||||
+ s1.coefficients.iter().map(|a| (a * a.conj()).re).sum::<f64>())
|
||||
/ (N as f64);
|
||||
|
||||
if length_squared > (SIG_L2_BOUND as f64) {
|
||||
continue;
|
||||
}
|
||||
|
||||
break [-s0, s1];
|
||||
};
|
||||
|
||||
let s2 = bold_s[1].ifft();
|
||||
let s2_coef: [i16; N] = s2
|
||||
.coefficients
|
||||
.iter()
|
||||
.map(|a| a.re.round() as i16)
|
||||
.collect::<Vec<i16>>()
|
||||
.try_into()
|
||||
.expect("The number of coefficients should be equal to N");
|
||||
|
||||
if let Ok(s2) = SignaturePoly::try_from(&s2_coef) {
|
||||
return s2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for SecretKey {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
let basis = &self.secret_key;
|
||||
|
||||
// header
|
||||
let n = basis[0].coefficients.len();
|
||||
let l = n.checked_ilog2().unwrap() as u8;
|
||||
let header: u8 = (5 << 4) | l;
|
||||
|
||||
let f = &basis[1];
|
||||
let g = &basis[0];
|
||||
let capital_f = &basis[3];
|
||||
|
||||
let mut buffer = Vec::with_capacity(1281);
|
||||
buffer.push(header);
|
||||
|
||||
let f_i8: Vec<i8> = f.coefficients.iter().map(|&a| -a as i8).collect();
|
||||
let f_i8_encoded = encode_i8(&f_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
|
||||
buffer.extend_from_slice(&f_i8_encoded);
|
||||
|
||||
let g_i8: Vec<i8> = g.coefficients.iter().map(|&a| a as i8).collect();
|
||||
let g_i8_encoded = encode_i8(&g_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
|
||||
buffer.extend_from_slice(&g_i8_encoded);
|
||||
|
||||
let big_f_i8: Vec<i8> = capital_f.coefficients.iter().map(|&a| -a as i8).collect();
|
||||
let big_f_i8_encoded = encode_i8(&big_f_i8, WIDTH_BIG_POLY_COEFFICIENT).unwrap();
|
||||
buffer.extend_from_slice(&big_f_i8_encoded);
|
||||
target.write_bytes(&buffer);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SecretKey {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let byte_vector: [u8; SK_LEN] = source.read_array()?;
|
||||
|
||||
// check length
|
||||
if byte_vector.len() < 2 {
|
||||
return Err(DeserializationError::InvalidValue("Invalid encoding length: Failed to decode as length is different from the one expected".to_string()));
|
||||
}
|
||||
|
||||
// read fields
|
||||
let header = byte_vector[0];
|
||||
|
||||
// check fixed bits in header
|
||||
if (header >> 4) != 5 {
|
||||
return Err(DeserializationError::InvalidValue("Invalid header format".to_string()));
|
||||
}
|
||||
|
||||
// check log n
|
||||
let logn = (header & 15) as usize;
|
||||
let n = 1 << logn;
|
||||
|
||||
// match against const variant generic parameter
|
||||
if n != N {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"Unsupported Falcon DSA variant".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if byte_vector.len() != SK_LEN {
|
||||
return Err(DeserializationError::InvalidValue("Invalid encoding length: Failed to decode as length is different from the one expected".to_string()));
|
||||
}
|
||||
|
||||
let chunk_size_f = ((n * WIDTH_SMALL_POLY_COEFFICIENT) + 7) >> 3;
|
||||
let chunk_size_g = ((n * WIDTH_SMALL_POLY_COEFFICIENT) + 7) >> 3;
|
||||
let chunk_size_big_f = ((n * WIDTH_BIG_POLY_COEFFICIENT) + 7) >> 3;
|
||||
|
||||
let f = decode_i8(&byte_vector[1..chunk_size_f + 1], WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
|
||||
let g = decode_i8(
|
||||
&byte_vector[chunk_size_f + 1..(chunk_size_f + chunk_size_g + 1)],
|
||||
WIDTH_SMALL_POLY_COEFFICIENT,
|
||||
)
|
||||
.unwrap();
|
||||
let big_f = decode_i8(
|
||||
&byte_vector[(chunk_size_f + chunk_size_g + 1)
|
||||
..(chunk_size_f + chunk_size_g + chunk_size_big_f + 1)],
|
||||
WIDTH_BIG_POLY_COEFFICIENT,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let f = Polynomial::new(f.iter().map(|&c| FalconFelt::new(c.into())).collect());
|
||||
let g = Polynomial::new(g.iter().map(|&c| FalconFelt::new(c.into())).collect());
|
||||
let big_f = Polynomial::new(big_f.iter().map(|&c| FalconFelt::new(c.into())).collect());
|
||||
|
||||
// big_g * f - g * big_f = p (mod X^n + 1)
|
||||
let big_g = g.fft().hadamard_div(&f.fft()).hadamard_mul(&big_f.fft()).ifft();
|
||||
let basis = [
|
||||
g.map(|f| f.balanced_value()),
|
||||
-f.map(|f| f.balanced_value()),
|
||||
big_g.map(|f| f.balanced_value()),
|
||||
-big_f.map(|f| f.balanced_value()),
|
||||
];
|
||||
Ok(Self::from_short_lattice_basis(basis))
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Computes the complex FFT of the secret key polynomials.
|
||||
fn to_complex_fft(basis: &[Polynomial<i16>; 4]) -> [Polynomial<Complex<f64>>; 4] {
|
||||
let [g, f, big_g, big_f] = basis.clone();
|
||||
let g_fft = g.map(|cc| Complex64::new(*cc as f64, 0.0)).fft();
|
||||
let minus_f_fft = f.map(|cc| -Complex64::new(*cc as f64, 0.0)).fft();
|
||||
let big_g_fft = big_g.map(|cc| Complex64::new(*cc as f64, 0.0)).fft();
|
||||
let minus_big_f_fft = big_f.map(|cc| -Complex64::new(*cc as f64, 0.0)).fft();
|
||||
[g_fft, minus_f_fft, big_g_fft, minus_big_f_fft]
|
||||
}
|
||||
|
||||
/// Encodes a sequence of signed integers such that each integer x satisfies |x| < 2^(bits-1)
|
||||
/// for a given parameter bits. bits can take either the value 6 or 8.
|
||||
pub fn encode_i8(x: &[i8], bits: usize) -> Option<Vec<u8>> {
|
||||
let maxv = (1 << (bits - 1)) - 1_usize;
|
||||
let maxv = maxv as i8;
|
||||
let minv = -maxv;
|
||||
|
||||
for &c in x {
|
||||
if c > maxv || c < minv {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let out_len = ((N * bits) + 7) >> 3;
|
||||
let mut buf = vec![0_u8; out_len];
|
||||
|
||||
let mut acc = 0_u32;
|
||||
let mut acc_len = 0;
|
||||
let mask = ((1_u16 << bits) - 1) as u8;
|
||||
|
||||
let mut input_pos = 0;
|
||||
for &c in x {
|
||||
acc = (acc << bits) | (c as u8 & mask) as u32;
|
||||
acc_len += bits;
|
||||
while acc_len >= 8 {
|
||||
acc_len -= 8;
|
||||
buf[input_pos] = (acc >> acc_len) as u8;
|
||||
input_pos += 1;
|
||||
}
|
||||
}
|
||||
if acc_len > 0 {
|
||||
buf[input_pos] = (acc >> (8 - acc_len)) as u8;
|
||||
}
|
||||
|
||||
Some(buf)
|
||||
}
|
||||
|
||||
/// Decodes a sequence of bytes into a sequence of signed integers such that each integer x
|
||||
/// satisfies |x| < 2^(bits-1) for a given parameter bits. bits can take either the value 6 or 8.
|
||||
pub fn decode_i8(buf: &[u8], bits: usize) -> Option<Vec<i8>> {
|
||||
let mut x = [0_i8; N];
|
||||
|
||||
let mut i = 0;
|
||||
let mut j = 0;
|
||||
let mut acc = 0_u32;
|
||||
let mut acc_len = 0;
|
||||
let mask = (1_u32 << bits) - 1;
|
||||
let a = (1 << bits) as u8;
|
||||
let b = ((1 << (bits - 1)) - 1) as u8;
|
||||
|
||||
while i < N {
|
||||
acc = (acc << 8) | (buf[j] as u32);
|
||||
j += 1;
|
||||
acc_len += 8;
|
||||
|
||||
while acc_len >= bits && i < N {
|
||||
acc_len -= bits;
|
||||
let w = (acc >> acc_len) & mask;
|
||||
|
||||
let w = w as u8;
|
||||
|
||||
let z = if w > b { w as i8 - a as i8 } else { w as i8 };
|
||||
|
||||
x[i] = z;
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (acc & ((1u32 << acc_len) - 1)) == 0 {
|
||||
Some(x.to_vec())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
123
src/dsa/rpo_falcon512/math/ffsampling.rs
Normal file
123
src/dsa/rpo_falcon512/math/ffsampling.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
use super::{fft::FastFft, polynomial::Polynomial, samplerz::sampler_z};
|
||||
use alloc::boxed::Box;
|
||||
use num::{One, Zero};
|
||||
use num_complex::{Complex, Complex64};
|
||||
use rand::Rng;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num::Float;
|
||||
|
||||
const SIGMIN: f64 = 1.2778336969128337;
|
||||
|
||||
/// Computes the Gram matrix. The argument must be a 2x2 matrix
|
||||
/// whose elements are equal-length vectors of complex numbers,
|
||||
/// representing polynomials in FFT domain.
|
||||
pub fn gram(b: [Polynomial<Complex64>; 4]) -> [Polynomial<Complex64>; 4] {
|
||||
const N: usize = 2;
|
||||
let mut g: [Polynomial<Complex<f64>>; 4] =
|
||||
[Polynomial::zero(), Polynomial::zero(), Polynomial::zero(), Polynomial::zero()];
|
||||
for i in 0..N {
|
||||
for j in 0..N {
|
||||
for k in 0..N {
|
||||
g[N * i + j] = g[N * i + j].clone()
|
||||
+ b[N * i + k].hadamard_mul(&b[N * j + k].map(|c| c.conj()));
|
||||
}
|
||||
}
|
||||
}
|
||||
g
|
||||
}
|
||||
|
||||
/// Computes the LDL decomposition of a 2x2 matrix G such that
|
||||
/// L D L* = G
|
||||
/// where D is diagonal, and L is lower-triangular. The elements of the matrices are in FFT domain.
|
||||
pub fn ldl(
|
||||
g: [Polynomial<Complex64>; 4],
|
||||
) -> ([Polynomial<Complex64>; 4], [Polynomial<Complex64>; 4]) {
|
||||
let zero = Polynomial::<Complex64>::one();
|
||||
let one = Polynomial::<Complex64>::zero();
|
||||
|
||||
let l10 = g[2].hadamard_div(&g[0]);
|
||||
let bc = l10.map(|c| c * c.conj());
|
||||
let abc = g[0].hadamard_mul(&bc);
|
||||
let d11 = g[3].clone() - abc;
|
||||
|
||||
let l = [one.clone(), zero.clone(), l10.clone(), one];
|
||||
let d = [g[0].clone(), zero.clone(), zero, d11];
|
||||
(l, d)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LdlTree {
|
||||
Branch(Polynomial<Complex64>, Box<LdlTree>, Box<LdlTree>),
|
||||
Leaf([Complex64; 2]),
|
||||
}
|
||||
|
||||
/// Computes the LDL Tree of G. Corresponds to Algorithm 9 of the specification [1, p.37].
|
||||
/// The argument is a 2x2 matrix of polynomials, given in FFT form.
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
pub fn ffldl(gram_matrix: [Polynomial<Complex64>; 4]) -> LdlTree {
|
||||
let n = gram_matrix[0].coefficients.len();
|
||||
let (l, d) = ldl(gram_matrix);
|
||||
|
||||
if n > 2 {
|
||||
let (d00, d01) = d[0].split_fft();
|
||||
let (d10, d11) = d[3].split_fft();
|
||||
let g0 = [d00.clone(), d01.clone(), d01.map(|c| c.conj()), d00];
|
||||
let g1 = [d10.clone(), d11.clone(), d11.map(|c| c.conj()), d10];
|
||||
LdlTree::Branch(l[2].clone(), Box::new(ffldl(g0)), Box::new(ffldl(g1)))
|
||||
} else {
|
||||
LdlTree::Branch(
|
||||
l[2].clone(),
|
||||
Box::new(LdlTree::Leaf(d[0].clone().coefficients.try_into().unwrap())),
|
||||
Box::new(LdlTree::Leaf(d[3].clone().coefficients.try_into().unwrap())),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalizes the leaves of an LDL tree using a given normalization value `sigma`.
|
||||
pub fn normalize_tree(tree: &mut LdlTree, sigma: f64) {
|
||||
match tree {
|
||||
LdlTree::Branch(_ell, left, right) => {
|
||||
normalize_tree(left, sigma);
|
||||
normalize_tree(right, sigma);
|
||||
}
|
||||
LdlTree::Leaf(vector) => {
|
||||
vector[0] = Complex::new(sigma / vector[0].re.sqrt(), 0.0);
|
||||
vector[1] = Complex64::zero();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Samples short polynomials using a Falcon tree. Algorithm 11 from the spec [1, p.40].
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
pub fn ffsampling<R: Rng>(
|
||||
t: &(Polynomial<Complex64>, Polynomial<Complex64>),
|
||||
tree: &LdlTree,
|
||||
mut rng: &mut R,
|
||||
) -> (Polynomial<Complex64>, Polynomial<Complex64>) {
|
||||
match tree {
|
||||
LdlTree::Branch(ell, left, right) => {
|
||||
let bold_t1 = t.1.split_fft();
|
||||
let bold_z1 = ffsampling(&bold_t1, right, rng);
|
||||
let z1 = Polynomial::<Complex64>::merge_fft(&bold_z1.0, &bold_z1.1);
|
||||
|
||||
// t0' = t0 + (t1 - z1) * l
|
||||
let t0_prime = t.0.clone() + (t.1.clone() - z1.clone()).hadamard_mul(ell);
|
||||
|
||||
let bold_t0 = t0_prime.split_fft();
|
||||
let bold_z0 = ffsampling(&bold_t0, left, rng);
|
||||
let z0 = Polynomial::<Complex64>::merge_fft(&bold_z0.0, &bold_z0.1);
|
||||
|
||||
(z0, z1)
|
||||
}
|
||||
LdlTree::Leaf(value) => {
|
||||
let z0 = sampler_z(t.0.coefficients[0].re, value[0].re, SIGMIN, &mut rng);
|
||||
let z1 = sampler_z(t.1.coefficients[0].re, value[0].re, SIGMIN, &mut rng);
|
||||
(
|
||||
Polynomial::new(vec![Complex64::new(z0 as f64, 0.0)]),
|
||||
Polynomial::new(vec![Complex64::new(z1 as f64, 0.0)]),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
1927
src/dsa/rpo_falcon512/math/fft.rs
Normal file
1927
src/dsa/rpo_falcon512/math/fft.rs
Normal file
File diff suppressed because it is too large
Load Diff
172
src/dsa/rpo_falcon512/math/field.rs
Normal file
172
src/dsa/rpo_falcon512/math/field.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
use super::{fft::CyclotomicFourier, Inverse, MODULUS};
|
||||
use alloc::string::String;
|
||||
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
use num::{One, Zero};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct FalconFelt(u32);
|
||||
|
||||
impl FalconFelt {
|
||||
pub const fn new(value: i16) -> Self {
|
||||
let gtz_bool = value >= 0;
|
||||
let gtz_int = gtz_bool as i16;
|
||||
let gtz_sign = gtz_int - ((!gtz_bool) as i16);
|
||||
let reduced = gtz_sign * (gtz_sign * value) % MODULUS;
|
||||
let canonical_representative = (reduced + MODULUS * (1 - gtz_int)) as u32;
|
||||
FalconFelt(canonical_representative)
|
||||
}
|
||||
|
||||
pub const fn value(&self) -> i16 {
|
||||
self.0 as i16
|
||||
}
|
||||
|
||||
pub fn balanced_value(&self) -> i16 {
|
||||
let value = self.value();
|
||||
let g = (value > ((MODULUS) / 2)) as i16;
|
||||
value - (MODULUS) * g
|
||||
}
|
||||
|
||||
pub const fn multiply(&self, other: Self) -> Self {
|
||||
FalconFelt((self.0 * other.0) % MODULUS as u32)
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for FalconFelt {
|
||||
type Output = Self;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let (s, _) = self.0.overflowing_add(rhs.0);
|
||||
let (d, n) = s.overflowing_sub(MODULUS as u32);
|
||||
let (r, _) = d.overflowing_add(MODULUS as u32 * (n as u32));
|
||||
FalconFelt(r)
|
||||
}
|
||||
}
|
||||
|
||||
impl AddAssign for FalconFelt {
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub for FalconFelt {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
self + -rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl SubAssign for FalconFelt {
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Neg for FalconFelt {
|
||||
type Output = FalconFelt;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
let is_nonzero = self.0 != 0;
|
||||
let r = MODULUS as u32 - self.0;
|
||||
FalconFelt(r * (is_nonzero as u32))
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul for FalconFelt {
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
FalconFelt((self.0 * rhs.0) % MODULUS as u32)
|
||||
}
|
||||
|
||||
type Output = Self;
|
||||
}
|
||||
|
||||
impl MulAssign for FalconFelt {
|
||||
fn mul_assign(&mut self, rhs: Self) {
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Div for FalconFelt {
|
||||
type Output = FalconFelt;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn div(self, rhs: Self) -> Self::Output {
|
||||
self * rhs.inverse_or_zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl DivAssign for FalconFelt {
|
||||
fn div_assign(&mut self, rhs: Self) {
|
||||
*self = *self / rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl Zero for FalconFelt {
|
||||
fn zero() -> Self {
|
||||
FalconFelt::new(0)
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
self.0 == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl One for FalconFelt {
|
||||
fn one() -> Self {
|
||||
FalconFelt::new(1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Inverse for FalconFelt {
|
||||
fn inverse_or_zero(self) -> Self {
|
||||
// q-2 = 0b10 11 11 11 11 11 11
|
||||
let two = self.multiply(self);
|
||||
let three = two.multiply(self);
|
||||
let six = three.multiply(three);
|
||||
let twelve = six.multiply(six);
|
||||
let fifteen = twelve.multiply(three);
|
||||
let thirty = fifteen.multiply(fifteen);
|
||||
let sixty = thirty.multiply(thirty);
|
||||
let sixty_three = sixty.multiply(three);
|
||||
|
||||
let sixty_three_sq = sixty_three.multiply(sixty_three);
|
||||
let sixty_three_qu = sixty_three_sq.multiply(sixty_three_sq);
|
||||
let sixty_three_oc = sixty_three_qu.multiply(sixty_three_qu);
|
||||
let sixty_three_hx = sixty_three_oc.multiply(sixty_three_oc);
|
||||
let sixty_three_tt = sixty_three_hx.multiply(sixty_three_hx);
|
||||
let sixty_three_sf = sixty_three_tt.multiply(sixty_three_tt);
|
||||
|
||||
let all_ones = sixty_three_sf.multiply(sixty_three);
|
||||
let two_e_twelve = all_ones.multiply(self);
|
||||
let two_e_thirteen = two_e_twelve.multiply(two_e_twelve);
|
||||
|
||||
two_e_thirteen.multiply(all_ones)
|
||||
}
|
||||
}
|
||||
|
||||
impl CyclotomicFourier for FalconFelt {
|
||||
fn primitive_root_of_unity(n: usize) -> Self {
|
||||
let log2n = n.ilog2();
|
||||
assert!(log2n <= 12);
|
||||
// and 1331 is a twelfth root of unity
|
||||
let mut a = FalconFelt::new(1331);
|
||||
let num_squarings = 12 - n.ilog2();
|
||||
for _ in 0..num_squarings {
|
||||
a *= a;
|
||||
}
|
||||
a
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for FalconFelt {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(value: u32) -> Result<Self, Self::Error> {
|
||||
if value >= MODULUS as u32 {
|
||||
Err(format!("value {value} is greater than or equal to the field modulus {MODULUS}"))
|
||||
} else {
|
||||
Ok(FalconFelt::new(value as i16))
|
||||
}
|
||||
}
|
||||
}
|
||||
320
src/dsa/rpo_falcon512/math/mod.rs
Normal file
320
src/dsa/rpo_falcon512/math/mod.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
//! Contains different structs and methods related to the Falcon DSA.
|
||||
//!
|
||||
//! It uses and acknowledges the work in:
|
||||
//!
|
||||
//! 1. The [reference](https://falcon-sign.info/impl/README.txt.html) implementation by Thomas Pornin.
|
||||
//! 2. The [Rust](https://github.com/aszepieniec/falcon-rust) implementation by Alan Szepieniec.
|
||||
use super::MODULUS;
|
||||
use alloc::{string::String, vec::Vec};
|
||||
use core::ops::MulAssign;
|
||||
use num::{BigInt, FromPrimitive, One, Zero};
|
||||
use num_complex::Complex64;
|
||||
use rand::Rng;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num::Float;
|
||||
|
||||
mod fft;
|
||||
pub use fft::{CyclotomicFourier, FastFft};
|
||||
|
||||
mod field;
|
||||
pub use field::FalconFelt;
|
||||
|
||||
mod ffsampling;
|
||||
pub use ffsampling::{ffldl, ffsampling, gram, normalize_tree, LdlTree};
|
||||
|
||||
mod samplerz;
|
||||
use self::samplerz::sampler_z;
|
||||
|
||||
mod polynomial;
|
||||
pub use polynomial::Polynomial;
|
||||
|
||||
pub trait Inverse: Copy + Zero + MulAssign + One {
|
||||
/// Gets the inverse of a, or zero if it is zero.
|
||||
fn inverse_or_zero(self) -> Self;
|
||||
|
||||
/// Gets the inverses of a batch of elements, and skip over any that are zero.
|
||||
fn batch_inverse_or_zero(batch: &[Self]) -> Vec<Self> {
|
||||
let mut acc = Self::one();
|
||||
let mut rp: Vec<Self> = Vec::with_capacity(batch.len());
|
||||
for batch_item in batch {
|
||||
if !batch_item.is_zero() {
|
||||
rp.push(acc);
|
||||
acc = *batch_item * acc;
|
||||
} else {
|
||||
rp.push(Self::zero());
|
||||
}
|
||||
}
|
||||
let mut inv = Self::inverse_or_zero(acc);
|
||||
for i in (0..batch.len()).rev() {
|
||||
if !batch[i].is_zero() {
|
||||
rp[i] *= inv;
|
||||
inv *= batch[i];
|
||||
}
|
||||
}
|
||||
rp
|
||||
}
|
||||
}
|
||||
|
||||
impl Inverse for Complex64 {
|
||||
fn inverse_or_zero(self) -> Self {
|
||||
let modulus = self.re * self.re + self.im * self.im;
|
||||
Complex64::new(self.re / modulus, -self.im / modulus)
|
||||
}
|
||||
fn batch_inverse_or_zero(batch: &[Self]) -> Vec<Self> {
|
||||
batch.iter().map(|&c| Complex64::new(1.0, 0.0) / c).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Inverse for f64 {
|
||||
fn inverse_or_zero(self) -> Self {
|
||||
1.0 / self
|
||||
}
|
||||
fn batch_inverse_or_zero(batch: &[Self]) -> Vec<Self> {
|
||||
batch.iter().map(|&c| 1.0 / c).collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Samples 4 small polynomials f, g, F, G such that f * G - g * F = q mod (X^n + 1).
|
||||
/// Algorithm 5 (NTRUgen) of the documentation [1, p.34].
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
pub(crate) fn ntru_gen<R: Rng>(n: usize, rng: &mut R) -> [Polynomial<i16>; 4] {
|
||||
loop {
|
||||
let f = gen_poly(n, rng);
|
||||
let g = gen_poly(n, rng);
|
||||
let f_ntt = f.map(|&i| FalconFelt::new(i)).fft();
|
||||
if f_ntt.coefficients.iter().any(|e| e.is_zero()) {
|
||||
continue;
|
||||
}
|
||||
let gamma = gram_schmidt_norm_squared(&f, &g);
|
||||
if gamma > 1.3689f64 * (MODULUS as f64) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some((capital_f, capital_g)) =
|
||||
ntru_solve(&f.map(|&i| i.into()), &g.map(|&i| i.into()))
|
||||
{
|
||||
return [
|
||||
f,
|
||||
g,
|
||||
capital_f.map(|i| i.try_into().unwrap()),
|
||||
capital_g.map(|i| i.try_into().unwrap()),
|
||||
];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Solves the NTRU equation. Given f, g in ZZ[X], find F, G in ZZ[X] such that:
|
||||
///
|
||||
/// f G - g F = q mod (X^n + 1)
|
||||
///
|
||||
/// Algorithm 6 of the specification [1, p.35].
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
fn ntru_solve(
|
||||
f: &Polynomial<BigInt>,
|
||||
g: &Polynomial<BigInt>,
|
||||
) -> Option<(Polynomial<BigInt>, Polynomial<BigInt>)> {
|
||||
let n = f.coefficients.len();
|
||||
if n == 1 {
|
||||
let (gcd, u, v) = xgcd(&f.coefficients[0], &g.coefficients[0]);
|
||||
if gcd != BigInt::one() {
|
||||
return None;
|
||||
}
|
||||
return Some((
|
||||
(Polynomial::new(vec![-v * BigInt::from_u32(MODULUS as u32).unwrap()])),
|
||||
Polynomial::new(vec![u * BigInt::from_u32(MODULUS as u32).unwrap()]),
|
||||
));
|
||||
}
|
||||
|
||||
let f_prime = f.field_norm();
|
||||
let g_prime = g.field_norm();
|
||||
|
||||
let (capital_f_prime, capital_g_prime) = ntru_solve(&f_prime, &g_prime)?;
|
||||
let capital_f_prime_xsq = capital_f_prime.lift_next_cyclotomic();
|
||||
let capital_g_prime_xsq = capital_g_prime.lift_next_cyclotomic();
|
||||
|
||||
let f_minx = f.galois_adjoint();
|
||||
let g_minx = g.galois_adjoint();
|
||||
|
||||
let mut capital_f = (capital_f_prime_xsq.karatsuba(&g_minx)).reduce_by_cyclotomic(n);
|
||||
let mut capital_g = (capital_g_prime_xsq.karatsuba(&f_minx)).reduce_by_cyclotomic(n);
|
||||
|
||||
match babai_reduce(f, g, &mut capital_f, &mut capital_g) {
|
||||
Ok(_) => Some((capital_f, capital_g)),
|
||||
Err(_e) => {
|
||||
#[cfg(test)]
|
||||
{
|
||||
panic!("{}", _e);
|
||||
}
|
||||
#[cfg(not(test))]
|
||||
{
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a polynomial of degree at most n-1 whose coefficients are distributed according
|
||||
/// to a discrete Gaussian with mu = 0 and sigma = 1.17 * sqrt(Q / (2n)).
|
||||
fn gen_poly<R: Rng>(n: usize, rng: &mut R) -> Polynomial<i16> {
|
||||
let mu = 0.0;
|
||||
let sigma_star = 1.43300980528773;
|
||||
Polynomial {
|
||||
coefficients: (0..4096)
|
||||
.map(|_| sampler_z(mu, sigma_star, sigma_star - 0.001, rng))
|
||||
.collect::<Vec<i16>>()
|
||||
.chunks(4096 / n)
|
||||
.map(|ch| ch.iter().sum())
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the Gram-Schmidt norm of B = [[g, -f], [G, -F]] from f and g.
|
||||
/// Corresponds to line 9 in algorithm 5 of the spec [1, p.34]
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
fn gram_schmidt_norm_squared(f: &Polynomial<i16>, g: &Polynomial<i16>) -> f64 {
|
||||
let n = f.coefficients.len();
|
||||
let norm_f_squared = f.l2_norm_squared();
|
||||
let norm_g_squared = g.l2_norm_squared();
|
||||
let gamma1 = norm_f_squared + norm_g_squared;
|
||||
|
||||
let f_fft = f.map(|i| Complex64::new(*i as f64, 0.0)).fft();
|
||||
let g_fft = g.map(|i| Complex64::new(*i as f64, 0.0)).fft();
|
||||
let f_adj_fft = f_fft.map(|c| c.conj());
|
||||
let g_adj_fft = g_fft.map(|c| c.conj());
|
||||
let ffgg_fft = f_fft.hadamard_mul(&f_adj_fft) + g_fft.hadamard_mul(&g_adj_fft);
|
||||
let ffgg_fft_inverse = ffgg_fft.hadamard_inv();
|
||||
let qf_over_ffgg_fft = f_adj_fft.map(|c| c * (MODULUS as f64)).hadamard_mul(&ffgg_fft_inverse);
|
||||
let qg_over_ffgg_fft = g_adj_fft.map(|c| c * (MODULUS as f64)).hadamard_mul(&ffgg_fft_inverse);
|
||||
let norm_f_over_ffgg_squared =
|
||||
qf_over_ffgg_fft.coefficients.iter().map(|c| (c * c.conj()).re).sum::<f64>() / (n as f64);
|
||||
let norm_g_over_ffgg_squared =
|
||||
qg_over_ffgg_fft.coefficients.iter().map(|c| (c * c.conj()).re).sum::<f64>() / (n as f64);
|
||||
|
||||
let gamma2 = norm_f_over_ffgg_squared + norm_g_over_ffgg_squared;
|
||||
|
||||
f64::max(gamma1, gamma2)
|
||||
}
|
||||
|
||||
/// Reduces the vector (F,G) relative to (f,g). This method follows the python implementation [1].
|
||||
/// Note that this algorithm can end up in an infinite loop. (It's one of the things the author
|
||||
/// would like to fix.) When this happens, control returns an error (hence the return type) and
|
||||
/// generates another keypair with fresh randomness.
|
||||
///
|
||||
/// Algorithm 7 in the spec [2, p.35]
|
||||
///
|
||||
/// [1]: https://github.com/tprest/falcon.py
|
||||
///
|
||||
/// [2]: https://falcon-sign.info/falcon.pdf
|
||||
fn babai_reduce(
|
||||
f: &Polynomial<BigInt>,
|
||||
g: &Polynomial<BigInt>,
|
||||
capital_f: &mut Polynomial<BigInt>,
|
||||
capital_g: &mut Polynomial<BigInt>,
|
||||
) -> Result<(), String> {
|
||||
let bitsize = |bi: &BigInt| (bi.bits() + 7) & (u64::MAX ^ 7);
|
||||
let n = f.coefficients.len();
|
||||
let size = [
|
||||
f.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
|
||||
g.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
|
||||
53,
|
||||
]
|
||||
.into_iter()
|
||||
.max()
|
||||
.unwrap();
|
||||
let shift = (size as i64) - 53;
|
||||
let f_adjusted = f
|
||||
.map(|bi| Complex64::new(i64::try_from(bi >> shift).unwrap() as f64, 0.0))
|
||||
.fft();
|
||||
let g_adjusted = g
|
||||
.map(|bi| Complex64::new(i64::try_from(bi >> shift).unwrap() as f64, 0.0))
|
||||
.fft();
|
||||
|
||||
let f_star_adjusted = f_adjusted.map(|c| c.conj());
|
||||
let g_star_adjusted = g_adjusted.map(|c| c.conj());
|
||||
let denominator_fft =
|
||||
f_adjusted.hadamard_mul(&f_star_adjusted) + g_adjusted.hadamard_mul(&g_star_adjusted);
|
||||
|
||||
let mut counter = 0;
|
||||
loop {
|
||||
let capital_size = [
|
||||
capital_f.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
|
||||
capital_g.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
|
||||
53,
|
||||
]
|
||||
.into_iter()
|
||||
.max()
|
||||
.unwrap();
|
||||
|
||||
if capital_size < size {
|
||||
break;
|
||||
}
|
||||
let capital_shift = (capital_size as i64) - 53;
|
||||
let capital_f_adjusted = capital_f
|
||||
.map(|bi| Complex64::new(i64::try_from(bi >> capital_shift).unwrap() as f64, 0.0))
|
||||
.fft();
|
||||
let capital_g_adjusted = capital_g
|
||||
.map(|bi| Complex64::new(i64::try_from(bi >> capital_shift).unwrap() as f64, 0.0))
|
||||
.fft();
|
||||
|
||||
let numerator = capital_f_adjusted.hadamard_mul(&f_star_adjusted)
|
||||
+ capital_g_adjusted.hadamard_mul(&g_star_adjusted);
|
||||
let quotient = numerator.hadamard_div(&denominator_fft).ifft();
|
||||
|
||||
let k = quotient.map(|f| Into::<BigInt>::into(f.re.round() as i64));
|
||||
|
||||
if k.is_zero() {
|
||||
break;
|
||||
}
|
||||
let kf = (k.clone().karatsuba(f))
|
||||
.reduce_by_cyclotomic(n)
|
||||
.map(|bi| bi << (capital_size - size));
|
||||
let kg = (k.clone().karatsuba(g))
|
||||
.reduce_by_cyclotomic(n)
|
||||
.map(|bi| bi << (capital_size - size));
|
||||
*capital_f -= kf;
|
||||
*capital_g -= kg;
|
||||
|
||||
counter += 1;
|
||||
if counter > 1000 {
|
||||
// If we get here, that means that (with high likelihood) we are in an
|
||||
// infinite loop. We know it happens from time to time -- seldomly, but it
|
||||
// does. It would be nice to fix that! But in order to fix it we need to be
|
||||
// able to reproduce it, and for that we need test vectors. So print them
|
||||
// and hope that one day they circle back to the implementor.
|
||||
return Err(format!("Encountered infinite loop in babai_reduce of falcon-rust.\n\\
|
||||
Please help the developer(s) fix it! You can do this by sending them the inputs to the function that caused the behavior:\n\\
|
||||
f: {:?}\n\\
|
||||
g: {:?}\n\\
|
||||
capital_f: {:?}\n\\
|
||||
capital_g: {:?}\n", f.coefficients, g.coefficients, capital_f.coefficients, capital_g.coefficients));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Extended Euclidean algorithm for computing the greatest common divisor (g) and
|
||||
/// Bézout coefficients (u, v) for the relation
|
||||
///
|
||||
/// $$ u a + v b = g . $$
|
||||
///
|
||||
/// Implementation adapted from Wikipedia [1].
|
||||
///
|
||||
/// [1]: https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Pseudocode
|
||||
fn xgcd(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
|
||||
let (mut old_r, mut r) = (a.clone(), b.clone());
|
||||
let (mut old_s, mut s) = (BigInt::one(), BigInt::zero());
|
||||
let (mut old_t, mut t) = (BigInt::zero(), BigInt::one());
|
||||
|
||||
while r != BigInt::zero() {
|
||||
let quotient = old_r.clone() / r.clone();
|
||||
(old_r, r) = (r.clone(), old_r.clone() - quotient.clone() * r);
|
||||
(old_s, s) = (s.clone(), old_s.clone() - quotient.clone() * s);
|
||||
(old_t, t) = (t.clone(), old_t.clone() - quotient * t);
|
||||
}
|
||||
|
||||
(old_r, old_s, old_t)
|
||||
}
|
||||
616
src/dsa/rpo_falcon512/math/polynomial.rs
Normal file
616
src/dsa/rpo_falcon512/math/polynomial.rs
Normal file
@@ -0,0 +1,616 @@
|
||||
use super::{field::FalconFelt, Inverse};
|
||||
use crate::dsa::rpo_falcon512::{MODULUS, N};
|
||||
use crate::Felt;
|
||||
use alloc::vec::Vec;
|
||||
use core::default::Default;
|
||||
use core::fmt::Debug;
|
||||
use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
use num::{One, Zero};
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Polynomial<F> {
|
||||
pub coefficients: Vec<F>,
|
||||
}
|
||||
|
||||
impl<F> Polynomial<F>
|
||||
where
|
||||
F: Clone,
|
||||
{
|
||||
pub fn new(coefficients: Vec<F>) -> Self {
|
||||
Self { coefficients }
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
F: Mul<Output = F> + Sub<Output = F> + AddAssign + Zero + Div<Output = F> + Clone + Inverse,
|
||||
> Polynomial<F>
|
||||
{
|
||||
pub fn hadamard_mul(&self, other: &Self) -> Self {
|
||||
Polynomial::new(
|
||||
self.coefficients
|
||||
.iter()
|
||||
.zip(other.coefficients.iter())
|
||||
.map(|(a, b)| *a * *b)
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
pub fn hadamard_div(&self, other: &Self) -> Self {
|
||||
let other_coefficients_inverse = F::batch_inverse_or_zero(&other.coefficients);
|
||||
Polynomial::new(
|
||||
self.coefficients
|
||||
.iter()
|
||||
.zip(other_coefficients_inverse.iter())
|
||||
.map(|(a, b)| *a * *b)
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn hadamard_inv(&self) -> Self {
|
||||
let coefficients_inverse = F::batch_inverse_or_zero(&self.coefficients);
|
||||
Polynomial::new(coefficients_inverse)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Zero + PartialEq + Clone> Polynomial<F> {
|
||||
pub fn degree(&self) -> Option<usize> {
|
||||
if self.coefficients.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut max_index = self.coefficients.len() - 1;
|
||||
while self.coefficients[max_index] == F::zero() {
|
||||
if let Some(new_index) = max_index.checked_sub(1) {
|
||||
max_index = new_index;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
Some(max_index)
|
||||
}
|
||||
|
||||
pub fn lc(&self) -> F {
|
||||
match self.degree() {
|
||||
Some(non_negative_degree) => self.coefficients[non_negative_degree].clone(),
|
||||
None => F::zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The following implementations are specific to cyclotomic polynomial rings,
|
||||
/// i.e., F\[ X \] / <X^n + 1>, and are used extensively in Falcon.
|
||||
impl<
|
||||
F: One
|
||||
+ Zero
|
||||
+ Clone
|
||||
+ Neg<Output = F>
|
||||
+ MulAssign
|
||||
+ AddAssign
|
||||
+ Div<Output = F>
|
||||
+ Sub<Output = F>
|
||||
+ PartialEq,
|
||||
> Polynomial<F>
|
||||
{
|
||||
/// Reduce the polynomial by X^n + 1.
|
||||
pub fn reduce_by_cyclotomic(&self, n: usize) -> Self {
|
||||
let mut coefficients = vec![F::zero(); n];
|
||||
let mut sign = -F::one();
|
||||
for (i, c) in self.coefficients.iter().cloned().enumerate() {
|
||||
if i % n == 0 {
|
||||
sign *= -F::one();
|
||||
}
|
||||
coefficients[i % n] += sign.clone() * c;
|
||||
}
|
||||
Polynomial::new(coefficients)
|
||||
}
|
||||
|
||||
/// Computes the field norm of the polynomial as an element of the cyclotomic ring
|
||||
/// F\[ X \] / <X^n + 1 > relative to one of half the size, i.e., F\[ X \] / <X^(n/2) + 1> .
|
||||
///
|
||||
/// Corresponds to formula 3.25 in the spec [1, p.30].
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
pub fn field_norm(&self) -> Self {
|
||||
let n = self.coefficients.len();
|
||||
let mut f0_coefficients = vec![F::zero(); n / 2];
|
||||
let mut f1_coefficients = vec![F::zero(); n / 2];
|
||||
for i in 0..n / 2 {
|
||||
f0_coefficients[i] = self.coefficients[2 * i].clone();
|
||||
f1_coefficients[i] = self.coefficients[2 * i + 1].clone();
|
||||
}
|
||||
let f0 = Polynomial::new(f0_coefficients);
|
||||
let f1 = Polynomial::new(f1_coefficients);
|
||||
let f0_squared = (f0.clone() * f0).reduce_by_cyclotomic(n / 2);
|
||||
let f1_squared = (f1.clone() * f1).reduce_by_cyclotomic(n / 2);
|
||||
let x = Polynomial::new(vec![F::zero(), F::one()]);
|
||||
f0_squared - (x * f1_squared).reduce_by_cyclotomic(n / 2)
|
||||
}
|
||||
|
||||
/// Lifts an element from a cyclotomic polynomial ring to one of double the size.
|
||||
pub fn lift_next_cyclotomic(&self) -> Self {
|
||||
let n = self.coefficients.len();
|
||||
let mut coefficients = vec![F::zero(); n * 2];
|
||||
for i in 0..n {
|
||||
coefficients[2 * i] = self.coefficients[i].clone();
|
||||
}
|
||||
Self::new(coefficients)
|
||||
}
|
||||
|
||||
/// Computes the galois adjoint of the polynomial in the cyclotomic ring F\[ X \] / < X^n + 1 > ,
|
||||
/// which corresponds to f(x^2).
|
||||
pub fn galois_adjoint(&self) -> Self {
|
||||
Self::new(
|
||||
self.coefficients
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, c)| if i % 2 == 0 { c.clone() } else { c.clone().neg() })
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Clone + Into<f64>> Polynomial<F> {
|
||||
pub(crate) fn l2_norm_squared(&self) -> f64 {
|
||||
self.coefficients
|
||||
.iter()
|
||||
.map(|i| Into::<f64>::into(i.clone()))
|
||||
.map(|i| i * i)
|
||||
.sum::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> PartialEq for Polynomial<F>
|
||||
where
|
||||
F: Zero + PartialEq + Clone + AddAssign,
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.is_zero() && other.is_zero() {
|
||||
true
|
||||
} else if self.is_zero() || other.is_zero() {
|
||||
false
|
||||
} else {
|
||||
let self_degree = self.degree().unwrap();
|
||||
let other_degree = other.degree().unwrap();
|
||||
self.coefficients[0..=self_degree] == other.coefficients[0..=other_degree]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Eq for Polynomial<F> where F: Zero + PartialEq + Clone + AddAssign {}
|
||||
|
||||
impl<F> Add for &Polynomial<F>
|
||||
where
|
||||
F: Add<Output = F> + AddAssign + Clone,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let coefficients = if self.coefficients.len() >= rhs.coefficients.len() {
|
||||
let mut coefficients = self.coefficients.clone();
|
||||
for (i, c) in rhs.coefficients.iter().enumerate() {
|
||||
coefficients[i] += c.clone();
|
||||
}
|
||||
coefficients
|
||||
} else {
|
||||
let mut coefficients = rhs.coefficients.clone();
|
||||
for (i, c) in self.coefficients.iter().enumerate() {
|
||||
coefficients[i] += c.clone();
|
||||
}
|
||||
coefficients
|
||||
};
|
||||
Self::Output { coefficients }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Add for Polynomial<F>
|
||||
where
|
||||
F: Add<Output = F> + AddAssign + Clone,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let coefficients = if self.coefficients.len() >= rhs.coefficients.len() {
|
||||
let mut coefficients = self.coefficients.clone();
|
||||
for (i, c) in rhs.coefficients.into_iter().enumerate() {
|
||||
coefficients[i] += c;
|
||||
}
|
||||
coefficients
|
||||
} else {
|
||||
let mut coefficients = rhs.coefficients.clone();
|
||||
for (i, c) in self.coefficients.into_iter().enumerate() {
|
||||
coefficients[i] += c;
|
||||
}
|
||||
coefficients
|
||||
};
|
||||
Self::Output { coefficients }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> AddAssign for Polynomial<F>
|
||||
where
|
||||
F: Add<Output = F> + AddAssign + Clone,
|
||||
{
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
if self.coefficients.len() >= rhs.coefficients.len() {
|
||||
for (i, c) in rhs.coefficients.into_iter().enumerate() {
|
||||
self.coefficients[i] += c;
|
||||
}
|
||||
} else {
|
||||
let mut coefficients = rhs.coefficients.clone();
|
||||
for (i, c) in self.coefficients.iter().enumerate() {
|
||||
coefficients[i] += c.clone();
|
||||
}
|
||||
self.coefficients = coefficients;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Sub for &Polynomial<F>
|
||||
where
|
||||
F: Sub<Output = F> + Clone + Neg<Output = F> + Add<Output = F> + AddAssign,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
self + &(-rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Sub for Polynomial<F>
|
||||
where
|
||||
F: Sub<Output = F> + Clone + Neg<Output = F> + Add<Output = F> + AddAssign,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
self + (-rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> SubAssign for Polynomial<F>
|
||||
where
|
||||
F: Add<Output = F> + Neg<Output = F> + AddAssign + Clone + Sub<Output = F>,
|
||||
{
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
self.coefficients = self.clone().sub(rhs).coefficients;
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Neg<Output = F> + Clone> Neg for &Polynomial<F> {
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
Self::Output {
|
||||
coefficients: self.coefficients.iter().cloned().map(|a| -a).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Neg<Output = F> + Clone> Neg for Polynomial<F> {
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
Self::Output {
|
||||
coefficients: self.coefficients.iter().cloned().map(|a| -a).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Mul for &Polynomial<F>
|
||||
where
|
||||
F: Add + AddAssign + Mul<Output = F> + Sub<Output = F> + Zero + PartialEq + Clone,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn mul(self, other: Self) -> Self::Output {
|
||||
if self.is_zero() || other.is_zero() {
|
||||
return Polynomial::<F>::zero();
|
||||
}
|
||||
let mut coefficients =
|
||||
vec![F::zero(); self.coefficients.len() + other.coefficients.len() - 1];
|
||||
for i in 0..self.coefficients.len() {
|
||||
for j in 0..other.coefficients.len() {
|
||||
coefficients[i + j] += self.coefficients[i].clone() * other.coefficients[j].clone();
|
||||
}
|
||||
}
|
||||
Polynomial { coefficients }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Mul for Polynomial<F>
|
||||
where
|
||||
F: Add + AddAssign + Mul<Output = F> + Zero + PartialEq + Clone,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, other: Self) -> Self::Output {
|
||||
if self.is_zero() || other.is_zero() {
|
||||
return Self::zero();
|
||||
}
|
||||
let mut coefficients =
|
||||
vec![F::zero(); self.coefficients.len() + other.coefficients.len() - 1];
|
||||
for i in 0..self.coefficients.len() {
|
||||
for j in 0..other.coefficients.len() {
|
||||
coefficients[i + j] += self.coefficients[i].clone() * other.coefficients[j].clone();
|
||||
}
|
||||
}
|
||||
Self { coefficients }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Add + Mul<Output = F> + Zero + Clone> Mul<F> for &Polynomial<F> {
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn mul(self, other: F) -> Self::Output {
|
||||
Polynomial {
|
||||
coefficients: self.coefficients.iter().cloned().map(|i| i * other.clone()).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Add + Mul<Output = F> + Zero + Clone> Mul<F> for Polynomial<F> {
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn mul(self, other: F) -> Self::Output {
|
||||
Polynomial {
|
||||
coefficients: self.coefficients.iter().cloned().map(|i| i * other.clone()).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Mul<Output = F> + Sub<Output = F> + AddAssign + Zero + Div<Output = F> + Clone>
|
||||
Polynomial<F>
|
||||
{
|
||||
/// Multiply two polynomials using Karatsuba's divide-and-conquer algorithm.
|
||||
pub fn karatsuba(&self, other: &Self) -> Self {
|
||||
Polynomial::new(vector_karatsuba(&self.coefficients, &other.coefficients))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> One for Polynomial<F>
|
||||
where
|
||||
F: Clone + One + PartialEq + Zero + AddAssign,
|
||||
{
|
||||
fn one() -> Self {
|
||||
Self { coefficients: vec![F::one()] }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Zero for Polynomial<F>
|
||||
where
|
||||
F: Zero + PartialEq + Clone + AddAssign,
|
||||
{
|
||||
fn zero() -> Self {
|
||||
Self { coefficients: vec![] }
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
self.degree().is_none()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Zero + Clone> Polynomial<F> {
|
||||
pub fn shift(&self, shamt: usize) -> Self {
|
||||
Self {
|
||||
coefficients: [vec![F::zero(); shamt], self.coefficients.clone()].concat(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn constant(f: F) -> Self {
|
||||
Self { coefficients: vec![f] }
|
||||
}
|
||||
|
||||
pub fn map<G: Clone, C: FnMut(&F) -> G>(&self, closure: C) -> Polynomial<G> {
|
||||
Polynomial::<G>::new(self.coefficients.iter().map(closure).collect())
|
||||
}
|
||||
|
||||
pub fn fold<G, C: FnMut(G, &F) -> G + Clone>(&self, mut initial_value: G, closure: C) -> G {
|
||||
for c in self.coefficients.iter() {
|
||||
initial_value = (closure.clone())(initial_value, c);
|
||||
}
|
||||
initial_value
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Div<Polynomial<F>> for Polynomial<F>
|
||||
where
|
||||
F: Zero
|
||||
+ One
|
||||
+ PartialEq
|
||||
+ AddAssign
|
||||
+ Clone
|
||||
+ Mul<Output = F>
|
||||
+ MulAssign
|
||||
+ Div<Output = F>
|
||||
+ Neg<Output = F>
|
||||
+ Sub<Output = F>,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn div(self, denominator: Self) -> Self::Output {
|
||||
if denominator.is_zero() {
|
||||
panic!();
|
||||
}
|
||||
if self.is_zero() {
|
||||
Self::zero();
|
||||
}
|
||||
let mut remainder = self.clone();
|
||||
let mut quotient = Polynomial::<F>::zero();
|
||||
while remainder.degree().unwrap() >= denominator.degree().unwrap() {
|
||||
let shift = remainder.degree().unwrap() - denominator.degree().unwrap();
|
||||
let quotient_coefficient = remainder.lc() / denominator.lc();
|
||||
let monomial = Self::constant(quotient_coefficient).shift(shift);
|
||||
quotient += monomial.clone();
|
||||
remainder -= monomial * denominator.clone();
|
||||
if remainder.is_zero() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
quotient
|
||||
}
|
||||
}
|
||||
|
||||
fn vector_karatsuba<
|
||||
F: Zero + AddAssign + Mul<Output = F> + Sub<Output = F> + Div<Output = F> + Clone,
|
||||
>(
|
||||
left: &[F],
|
||||
right: &[F],
|
||||
) -> Vec<F> {
|
||||
let n = left.len();
|
||||
if n <= 8 {
|
||||
let mut product = vec![F::zero(); left.len() + right.len() - 1];
|
||||
for (i, l) in left.iter().enumerate() {
|
||||
for (j, r) in right.iter().enumerate() {
|
||||
product[i + j] += l.clone() * r.clone();
|
||||
}
|
||||
}
|
||||
return product;
|
||||
}
|
||||
let n_over_2 = n / 2;
|
||||
let mut product = vec![F::zero(); 2 * n - 1];
|
||||
let left_lo = &left[0..n_over_2];
|
||||
let right_lo = &right[0..n_over_2];
|
||||
let left_hi = &left[n_over_2..];
|
||||
let right_hi = &right[n_over_2..];
|
||||
let left_sum: Vec<F> =
|
||||
left_lo.iter().zip(left_hi).map(|(a, b)| a.clone() + b.clone()).collect();
|
||||
let right_sum: Vec<F> =
|
||||
right_lo.iter().zip(right_hi).map(|(a, b)| a.clone() + b.clone()).collect();
|
||||
|
||||
let prod_lo = vector_karatsuba(left_lo, right_lo);
|
||||
let prod_hi = vector_karatsuba(left_hi, right_hi);
|
||||
let prod_mid: Vec<F> = vector_karatsuba(&left_sum, &right_sum)
|
||||
.iter()
|
||||
.zip(prod_lo.iter().zip(prod_hi.iter()))
|
||||
.map(|(s, (l, h))| s.clone() - (l.clone() + h.clone()))
|
||||
.collect();
|
||||
|
||||
for (i, l) in prod_lo.into_iter().enumerate() {
|
||||
product[i] = l;
|
||||
}
|
||||
for (i, m) in prod_mid.into_iter().enumerate() {
|
||||
product[i + n_over_2] += m;
|
||||
}
|
||||
for (i, h) in prod_hi.into_iter().enumerate() {
|
||||
product[i + n] += h
|
||||
}
|
||||
product
|
||||
}
|
||||
|
||||
impl From<Polynomial<FalconFelt>> for Polynomial<Felt> {
|
||||
fn from(item: Polynomial<FalconFelt>) -> Self {
|
||||
let res: Vec<Felt> =
|
||||
item.coefficients.iter().map(|a| Felt::from(a.value() as u16)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Polynomial<FalconFelt>> for Polynomial<Felt> {
|
||||
fn from(item: &Polynomial<FalconFelt>) -> Self {
|
||||
let res: Vec<Felt> =
|
||||
item.coefficients.iter().map(|a| Felt::from(a.value() as u16)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Polynomial<i16>> for Polynomial<FalconFelt> {
|
||||
fn from(item: Polynomial<i16>) -> Self {
|
||||
let res: Vec<FalconFelt> = item.coefficients.iter().map(|&a| FalconFelt::new(a)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Polynomial<i16>> for Polynomial<FalconFelt> {
|
||||
fn from(item: &Polynomial<i16>) -> Self {
|
||||
let res: Vec<FalconFelt> = item.coefficients.iter().map(|&a| FalconFelt::new(a)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<i16>> for Polynomial<FalconFelt> {
|
||||
fn from(item: Vec<i16>) -> Self {
|
||||
let res: Vec<FalconFelt> = item.iter().map(|&a| FalconFelt::new(a)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Vec<i16>> for Polynomial<FalconFelt> {
|
||||
fn from(item: &Vec<i16>) -> Self {
|
||||
let res: Vec<FalconFelt> = item.iter().map(|&a| FalconFelt::new(a)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl Polynomial<FalconFelt> {
|
||||
pub fn norm_squared(&self) -> u64 {
|
||||
self.coefficients
|
||||
.iter()
|
||||
.map(|&i| i.balanced_value() as i64)
|
||||
.map(|i| (i * i) as u64)
|
||||
.sum::<u64>()
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the coefficients of this polynomial as field elements.
|
||||
pub fn to_elements(&self) -> Vec<Felt> {
|
||||
self.coefficients.iter().map(|&a| Felt::from(a.value() as u16)).collect()
|
||||
}
|
||||
|
||||
// POLYNOMIAL OPERATIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Multiplies two polynomials over Z_p\[x\] without reducing modulo p. Given that the degrees
|
||||
/// of the input polynomials are less than 512 and their coefficients are less than the modulus
|
||||
/// q equal to 12289, the resulting product polynomial is guaranteed to have coefficients less
|
||||
/// than the Miden prime.
|
||||
///
|
||||
/// Note that this multiplication is not over Z_p\[x\]/(phi).
|
||||
pub fn mul_modulo_p(a: &Self, b: &Self) -> [u64; 1024] {
|
||||
let mut c = [0; 2 * N];
|
||||
for i in 0..N {
|
||||
for j in 0..N {
|
||||
c[i + j] += a.coefficients[i].value() as u64 * b.coefficients[j].value() as u64;
|
||||
}
|
||||
}
|
||||
|
||||
c
|
||||
}
|
||||
|
||||
/// Reduces a polynomial, that is the product of two polynomials over Z_p\[x\], modulo
|
||||
/// the irreducible polynomial phi. This results in an element in Z_p\[x\]/(phi).
|
||||
pub fn reduce_negacyclic(a: &[u64; 1024]) -> Self {
|
||||
let mut c = [FalconFelt::zero(); N];
|
||||
let modulus = MODULUS as u16;
|
||||
for i in 0..N {
|
||||
let ai = a[N + i] % modulus as u64;
|
||||
let neg_ai = (modulus - ai as u16) % modulus;
|
||||
|
||||
let bi = (a[i] % modulus as u64) as u16;
|
||||
c[i] = FalconFelt::new(((neg_ai + bi) % modulus) as i16);
|
||||
}
|
||||
|
||||
Self::new(c.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{FalconFelt, Polynomial, N};
|
||||
|
||||
#[test]
|
||||
fn test_negacyclic_reduction() {
|
||||
let coef1: [u8; N] = rand_utils::rand_array();
|
||||
let coef2: [u8; N] = rand_utils::rand_array();
|
||||
|
||||
let poly1 = Polynomial::new(coef1.iter().map(|&a| FalconFelt::new(a as i16)).collect());
|
||||
let poly2 = Polynomial::new(coef2.iter().map(|&a| FalconFelt::new(a as i16)).collect());
|
||||
let prod = poly1.clone() * poly2.clone();
|
||||
|
||||
assert_eq!(
|
||||
prod.reduce_by_cyclotomic(N),
|
||||
Polynomial::reduce_negacyclic(&Polynomial::mul_modulo_p(&poly1, &poly2))
|
||||
);
|
||||
}
|
||||
}
|
||||
298
src/dsa/rpo_falcon512/math/samplerz.rs
Normal file
298
src/dsa/rpo_falcon512/math/samplerz.rs
Normal file
@@ -0,0 +1,298 @@
|
||||
use core::f64::consts::LN_2;
|
||||
use rand::Rng;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num::Float;
|
||||
|
||||
/// Samples an integer from {0, ..., 18} according to the distribution χ, which is close to
|
||||
/// the half-Gaussian distribution on the natural numbers with mean 0 and standard deviation
|
||||
/// equal to sigma_max.
|
||||
fn base_sampler(bytes: [u8; 9]) -> i16 {
|
||||
const RCDT: [u128; 18] = [
|
||||
3024686241123004913666,
|
||||
1564742784480091954050,
|
||||
636254429462080897535,
|
||||
199560484645026482916,
|
||||
47667343854657281903,
|
||||
8595902006365044063,
|
||||
1163297957344668388,
|
||||
117656387352093658,
|
||||
8867391802663976,
|
||||
496969357462633,
|
||||
20680885154299,
|
||||
638331848991,
|
||||
14602316184,
|
||||
247426747,
|
||||
3104126,
|
||||
28824,
|
||||
198,
|
||||
1,
|
||||
];
|
||||
let u = u128::from_be_bytes([vec![0u8; 7], bytes.to_vec()].concat().try_into().unwrap());
|
||||
RCDT.into_iter().filter(|r| u < *r).count() as i16
|
||||
}
|
||||
|
||||
/// Computes an integer approximation of 2^63 * ccs * exp(-x).
|
||||
fn approx_exp(x: f64, ccs: f64) -> u64 {
|
||||
// The constants C are used to approximate exp(-x); these
|
||||
// constants are taken from FACCT (up to a scaling factor
|
||||
// of 2^63):
|
||||
// https://eprint.iacr.org/2018/1234
|
||||
// https://github.com/raykzhao/gaussian
|
||||
const C: [u64; 13] = [
|
||||
0x00000004741183A3u64,
|
||||
0x00000036548CFC06u64,
|
||||
0x0000024FDCBF140Au64,
|
||||
0x0000171D939DE045u64,
|
||||
0x0000D00CF58F6F84u64,
|
||||
0x000680681CF796E3u64,
|
||||
0x002D82D8305B0FEAu64,
|
||||
0x011111110E066FD0u64,
|
||||
0x0555555555070F00u64,
|
||||
0x155555555581FF00u64,
|
||||
0x400000000002B400u64,
|
||||
0x7FFFFFFFFFFF4800u64,
|
||||
0x8000000000000000u64,
|
||||
];
|
||||
|
||||
let mut z: u64;
|
||||
let mut y: u64;
|
||||
let twoe63 = 1u64 << 63;
|
||||
|
||||
y = C[0];
|
||||
z = f64::floor(x * (twoe63 as f64)) as u64;
|
||||
for cu in C.iter().skip(1) {
|
||||
let zy = (z as u128) * (y as u128);
|
||||
y = cu - ((zy >> 63) as u64);
|
||||
}
|
||||
|
||||
z = f64::floor((twoe63 as f64) * ccs) as u64;
|
||||
|
||||
(((z as u128) * (y as u128)) >> 63) as u64
|
||||
}
|
||||
|
||||
/// A random bool that is true with probability ≈ ccs · exp(-x).
|
||||
fn ber_exp(x: f64, ccs: f64, random_bytes: [u8; 7]) -> bool {
|
||||
// 0.69314718055994530941 = ln(2)
|
||||
let s = f64::floor(x / LN_2) as usize;
|
||||
let r = x - LN_2 * (s as f64);
|
||||
let shamt = usize::min(s, 63);
|
||||
let z = ((((approx_exp(r, ccs) as u128) << 1) - 1) >> shamt) as u64;
|
||||
let mut w = 0i16;
|
||||
for (index, i) in (0..64).step_by(8).rev().enumerate() {
|
||||
let byte = random_bytes[index];
|
||||
w = (byte as i16) - (((z >> i) & 0xff) as i16);
|
||||
if w != 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
w < 0
|
||||
}
|
||||
|
||||
/// Samples an integer from the Gaussian distribution with given mean (mu) and standard deviation
|
||||
/// (sigma).
|
||||
pub(crate) fn sampler_z<R: Rng>(mu: f64, sigma: f64, sigma_min: f64, rng: &mut R) -> i16 {
|
||||
const SIGMA_MAX: f64 = 1.8205;
|
||||
const INV_2SIGMA_MAX_SQ: f64 = 1f64 / (2f64 * SIGMA_MAX * SIGMA_MAX);
|
||||
let isigma = 1f64 / sigma;
|
||||
let dss = 0.5f64 * isigma * isigma;
|
||||
let s = f64::floor(mu);
|
||||
let r = mu - s;
|
||||
let ccs = sigma_min * isigma;
|
||||
loop {
|
||||
let z0 = base_sampler(rng.gen());
|
||||
let random_byte: u8 = rng.gen();
|
||||
let b = (random_byte & 1) as i16;
|
||||
let z = b + ((b << 1) - 1) * z0;
|
||||
let zf_min_r = (z as f64) - r;
|
||||
// x = ((z-r)^2)/(2*sigma^2) - ((z-b)^2)/(2*sigma0^2)
|
||||
let x = zf_min_r * zf_min_r * dss - (z0 * z0) as f64 * INV_2SIGMA_MAX_SQ;
|
||||
if ber_exp(x, ccs, rng.gen()) {
|
||||
return z + (s as i16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod test {
|
||||
use alloc::vec::Vec;
|
||||
use rand::RngCore;
|
||||
use std::{thread::sleep, time::Duration};
|
||||
|
||||
use super::{approx_exp, ber_exp, sampler_z};
|
||||
|
||||
/// RNG used only for testing purposes, whereby the produced
|
||||
/// string of random bytes is equal to the one it is initialized
|
||||
/// with. Whatever you do, do not use this RNG in production.
|
||||
struct UnsafeBufferRng {
|
||||
buffer: Vec<u8>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl UnsafeBufferRng {
|
||||
fn new(buffer: &[u8]) -> Self {
|
||||
Self { buffer: buffer.to_vec(), index: 0 }
|
||||
}
|
||||
|
||||
fn next(&mut self) -> u8 {
|
||||
if self.buffer.len() <= self.index {
|
||||
// panic!("Ran out of buffer.");
|
||||
sleep(Duration::from_millis(10));
|
||||
0u8
|
||||
} else {
|
||||
let return_value = self.buffer[self.index];
|
||||
self.index += 1;
|
||||
return_value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RngCore for UnsafeBufferRng {
|
||||
fn next_u32(&mut self) -> u32 {
|
||||
// let bytes: [u8; 4] = (0..4)
|
||||
// .map(|_| self.next())
|
||||
// .collect_vec()
|
||||
// .try_into()
|
||||
// .unwrap();
|
||||
// u32::from_be_bytes(bytes)
|
||||
u32::from_le_bytes([self.next(), 0, 0, 0])
|
||||
}
|
||||
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
// let bytes: [u8; 8] = (0..8)
|
||||
// .map(|_| self.next())
|
||||
// .collect_vec()
|
||||
// .try_into()
|
||||
// .unwrap();
|
||||
// u64::from_be_bytes(bytes)
|
||||
u64::from_le_bytes([self.next(), 0, 0, 0, 0, 0, 0, 0])
|
||||
}
|
||||
|
||||
fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
for d in dest.iter_mut() {
|
||||
*d = self.next();
|
||||
}
|
||||
}
|
||||
|
||||
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
|
||||
for d in dest.iter_mut() {
|
||||
*d = self.next();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsafe_buffer_rng() {
|
||||
let seed_bytes = hex::decode("7FFECD162AE2").unwrap();
|
||||
let mut rng = UnsafeBufferRng::new(&seed_bytes);
|
||||
let generated_bytes: Vec<u8> = (0..seed_bytes.len()).map(|_| rng.next()).collect();
|
||||
assert_eq!(seed_bytes, generated_bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_approx_exp() {
|
||||
let precision = 1u64 << 14;
|
||||
// known answers were generated with the following sage script:
|
||||
//```sage
|
||||
// num_samples = 10
|
||||
// precision = 200
|
||||
// R = Reals(precision)
|
||||
//
|
||||
// print(f"let kats : [(f64, f64, u64);{num_samples}] = [")
|
||||
// for i in range(num_samples):
|
||||
// x = RDF.random_element(0.0, 0.693147180559945)
|
||||
// ccs = RDF.random_element(0.0, 1.0)
|
||||
// res = round(2^63 * R(ccs) * exp(R(-x)))
|
||||
// print(f"({x}, {ccs}, {res}),")
|
||||
// print("];")
|
||||
// ```
|
||||
let kats: [(f64, f64, u64); 10] = [
|
||||
(0.2314993926072656, 0.8148006314615972, 5962140072160879737),
|
||||
(0.2648875572812225, 0.12769669655309035, 903712282351034505),
|
||||
(0.11251957513682391, 0.9264611470305881, 7635725498677341553),
|
||||
(0.04353439307256617, 0.5306497137523327, 4685877322232397936),
|
||||
(0.41834495299784347, 0.879438856118578, 5338392138535350986),
|
||||
(0.32579398973228557, 0.16513412873289002, 1099603299296456803),
|
||||
(0.5939508073919817, 0.029776019144967303, 151637565622779016),
|
||||
(0.2932367999399056, 0.37123847662857923, 2553827649386670452),
|
||||
(0.5005699297417507, 0.31447208863888976, 1758235618083658825),
|
||||
(0.4876437338498085, 0.6159515298936868, 3488632981903743976),
|
||||
];
|
||||
for (x, ccs, answer) in kats {
|
||||
let difference = (answer as i128) - (approx_exp(x, ccs) as i128);
|
||||
assert!(
|
||||
(difference * difference) as u64 <= precision * precision,
|
||||
"answer: {answer} versus approximation: {}\ndifference: {} whereas precision: {}",
|
||||
approx_exp(x, ccs),
|
||||
difference,
|
||||
precision
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ber_exp() {
|
||||
let kats = [
|
||||
(
|
||||
1.268_314_048_020_498_4,
|
||||
0.749_990_853_267_664_9,
|
||||
hex::decode("ea000000000000").unwrap(),
|
||||
false,
|
||||
),
|
||||
(
|
||||
0.001_563_917_959_143_409_6,
|
||||
0.749_990_853_267_664_9,
|
||||
hex::decode("6c000000000000").unwrap(),
|
||||
true,
|
||||
),
|
||||
(
|
||||
0.017_921_215_753_999_235,
|
||||
0.749_990_853_267_664_9,
|
||||
hex::decode("c2000000000000").unwrap(),
|
||||
false,
|
||||
),
|
||||
(
|
||||
0.776_117_648_844_980_6,
|
||||
0.751_181_554_542_520_8,
|
||||
hex::decode("58000000000000").unwrap(),
|
||||
true,
|
||||
),
|
||||
];
|
||||
for (x, ccs, bytes, answer) in kats {
|
||||
assert_eq!(answer, ber_exp(x, ccs, bytes.try_into().unwrap()));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sampler_z() {
|
||||
let sigma_min = 1.277833697;
|
||||
// known answers from the doc, table 3.2, page 44
|
||||
// https://falcon-sign.info/falcon.pdf
|
||||
// The zeros were added to account for dropped bytes.
|
||||
let kats = [
|
||||
(-91.90471153063714,1.7037990414754918,hex::decode("0fc5442ff043d66e91d1ea000000000000cac64ea5450a22941edc6c").unwrap(),-92),
|
||||
(-8.322564895434937,1.7037990414754918,hex::decode("f4da0f8d8444d1a77265c2000000000000ef6f98bbbb4bee7db8d9b3").unwrap(),-8),
|
||||
(-19.096516109216804,1.7035823083824078,hex::decode("db47f6d7fb9b19f25c36d6000000000000b9334d477a8bc0be68145d").unwrap(),-20),
|
||||
(-11.335543982423326, 1.7035823083824078, hex::decode("ae41b4f5209665c74d00dc000000000000c1a8168a7bb516b3190cb42c1ded26cd52000000000000aed770eca7dd334e0547bcc3c163ce0b").unwrap(), -12),
|
||||
(7.9386734193997555, 1.6984647769450156, hex::decode("31054166c1012780c603ae0000000000009b833cec73f2f41ca5807c000000000000c89c92158834632f9b1555").unwrap(), 8),
|
||||
(-28.990850086867255, 1.6984647769450156, hex::decode("737e9d68a50a06dbbc6477").unwrap(), -30),
|
||||
(-9.071257914091655, 1.6980782114808988, hex::decode("a98ddd14bf0bf22061d632").unwrap(), -10),
|
||||
(-43.88754568839566, 1.6980782114808988, hex::decode("3cbf6818a68f7ab9991514").unwrap(), -41),
|
||||
(-58.17435547946095,1.7010983419195522,hex::decode("6f8633f5bfa5d26848668e0000000000003d5ddd46958e97630410587c").unwrap(),-61),
|
||||
(-43.58664906684732, 1.7010983419195522, hex::decode("272bc6c25f5c5ee53f83c40000000000003a361fbc7cc91dc783e20a").unwrap(), -46),
|
||||
(-34.70565203313315, 1.7009387219711465, hex::decode("45443c59574c2c3b07e2e1000000000000d9071e6d133dbe32754b0a").unwrap(), -34),
|
||||
(-44.36009577368896, 1.7009387219711465, hex::decode("6ac116ed60c258e2cbaeab000000000000728c4823e6da36e18d08da0000000000005d0cc104e21cc7fd1f5ca8000000000000d9dbb675266c928448059e").unwrap(), -44),
|
||||
(-21.783037079346236, 1.6958406126012802, hex::decode("68163bc1e2cbf3e18e7426").unwrap(), -23),
|
||||
(-39.68827784633828, 1.6958406126012802, hex::decode("d6a1b51d76222a705a0259").unwrap(), -40),
|
||||
(-18.488607061056847, 1.6955259305261838, hex::decode("f0523bfaa8a394bf4ea5c10000000000000f842366fde286d6a30803").unwrap(), -22),
|
||||
(-48.39610939101591, 1.6955259305261838, hex::decode("87bd87e63374cee62127fc0000000000006931104aab64f136a0485b").unwrap(), -50),
|
||||
];
|
||||
for (mu, sigma, random_bytes, answer) in kats {
|
||||
assert_eq!(
|
||||
sampler_z(mu, sigma, sigma_min, &mut UnsafeBufferRng::new(&random_bytes)),
|
||||
answer
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,60 +1,103 @@
|
||||
use crate::{
|
||||
hash::rpo::Rpo256,
|
||||
utils::{
|
||||
collections::Vec, ByteReader, ByteWriter, Deserializable, DeserializationError,
|
||||
Serializable,
|
||||
},
|
||||
Felt, StarkField, Word, ZERO,
|
||||
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
|
||||
Felt, Word, ZERO,
|
||||
};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
mod ffi;
|
||||
|
||||
mod error;
|
||||
mod hash_to_point;
|
||||
mod keys;
|
||||
mod polynomial;
|
||||
mod math;
|
||||
mod signature;
|
||||
|
||||
pub use error::FalconError;
|
||||
pub use keys::{KeyPair, PublicKey};
|
||||
pub use polynomial::Polynomial;
|
||||
pub use signature::Signature;
|
||||
pub use self::keys::{PubKeyPoly, PublicKey, SecretKey};
|
||||
pub use self::math::Polynomial;
|
||||
pub use self::signature::{Signature, SignatureHeader, SignaturePoly};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
// The Falcon modulus.
|
||||
const MODULUS: u16 = 12289;
|
||||
const MODULUS_MINUS_1_OVER_TWO: u16 = 6144;
|
||||
// The Falcon modulus p.
|
||||
const MODULUS: i16 = 12289;
|
||||
|
||||
// Number of bits needed to encode an element in the Falcon field.
|
||||
const FALCON_ENCODING_BITS: u32 = 14;
|
||||
|
||||
// The Falcon parameters for Falcon-512. This is the degree of the polynomial `phi := x^N + 1`
|
||||
// defining the ring Z_p[x]/(phi).
|
||||
const N: usize = 512;
|
||||
const LOG_N: usize = 9;
|
||||
const LOG_N: u8 = 9;
|
||||
|
||||
/// Length of nonce used for key-pair generation.
|
||||
const NONCE_LEN: usize = 40;
|
||||
const SIG_NONCE_LEN: usize = 40;
|
||||
|
||||
/// Number of filed elements used to encode a nonce.
|
||||
const NONCE_ELEMENTS: usize = 8;
|
||||
|
||||
/// Public key length as a u8 vector.
|
||||
const PK_LEN: usize = 897;
|
||||
pub const PK_LEN: usize = 897;
|
||||
|
||||
/// Secret key length as a u8 vector.
|
||||
const SK_LEN: usize = 1281;
|
||||
pub const SK_LEN: usize = 1281;
|
||||
|
||||
/// Signature length as a u8 vector.
|
||||
const SIG_LEN: usize = 626;
|
||||
const SIG_POLY_BYTE_LEN: usize = 625;
|
||||
|
||||
/// Bound on the squared-norm of the signature.
|
||||
const SIG_L2_BOUND: u64 = 34034726;
|
||||
|
||||
/// Standard deviation of the Gaussian over the lattice.
|
||||
const SIGMA: f64 = 165.7366171829776;
|
||||
|
||||
// TYPE ALIASES
|
||||
// ================================================================================================
|
||||
|
||||
type SignatureBytes = [u8; NONCE_LEN + SIG_LEN];
|
||||
type PublicKeyBytes = [u8; PK_LEN];
|
||||
type SecretKeyBytes = [u8; SK_LEN];
|
||||
type NonceBytes = [u8; NONCE_LEN];
|
||||
type NonceElements = [Felt; NONCE_ELEMENTS];
|
||||
type ShortLatticeBasis = [Polynomial<i16>; 4];
|
||||
|
||||
// NONCE
|
||||
// ================================================================================================
|
||||
|
||||
/// Nonce of the Falcon signature.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Nonce([u8; SIG_NONCE_LEN]);
|
||||
|
||||
impl Nonce {
|
||||
/// Returns a new [Nonce] instantiated from the provided bytes.
|
||||
pub fn new(bytes: [u8; SIG_NONCE_LEN]) -> Self {
|
||||
Self(bytes)
|
||||
}
|
||||
|
||||
/// Returns the underlying bytes of this nonce.
|
||||
pub fn as_bytes(&self) -> &[u8; SIG_NONCE_LEN] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// Converts byte representation of the nonce into field element representation.
|
||||
///
|
||||
/// Nonce bytes are converted to field elements by taking consecutive 5 byte chunks
|
||||
/// of the nonce and interpreting them as field elements.
|
||||
pub fn to_elements(&self) -> [Felt; NONCE_ELEMENTS] {
|
||||
let mut buffer = [0_u8; 8];
|
||||
let mut result = [ZERO; 8];
|
||||
for (i, bytes) in self.0.chunks(5).enumerate() {
|
||||
buffer[..5].copy_from_slice(bytes);
|
||||
// we can safely (without overflow) create a new Felt from u64 value here since this
|
||||
// value contains at most 5 bytes
|
||||
result[i] = Felt::new(u64::from_le_bytes(buffer));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for &Nonce {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for Nonce {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let bytes = source.read()?;
|
||||
Ok(Self(bytes))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,277 +0,0 @@
|
||||
use super::{FalconError, Felt, Vec, LOG_N, MODULUS, MODULUS_MINUS_1_OVER_TWO, N, PK_LEN};
|
||||
use core::ops::{Add, Mul, Sub};
|
||||
|
||||
// FALCON POLYNOMIAL
|
||||
// ================================================================================================
|
||||
|
||||
/// A polynomial over Z_p[x]/(phi) where phi := x^512 + 1
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub struct Polynomial([u16; N]);
|
||||
|
||||
impl Polynomial {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Constructs a new polynomial from a list of coefficients.
|
||||
///
|
||||
/// # Safety
|
||||
/// This constructor validates that the coefficients are in the valid range only in debug mode.
|
||||
pub unsafe fn new(data: [u16; N]) -> Self {
|
||||
for value in data {
|
||||
debug_assert!(value < MODULUS);
|
||||
}
|
||||
|
||||
Self(data)
|
||||
}
|
||||
|
||||
/// Decodes raw bytes representing a public key into a polynomial in Z_p[x]/(phi).
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The provided input is not exactly 897 bytes long.
|
||||
/// - The first byte of the input is not equal to log2(512) i.e., 9.
|
||||
/// - Any of the coefficients encoded in the provided input is greater than or equal to the
|
||||
/// Falcon field modulus.
|
||||
pub fn from_pub_key(input: &[u8]) -> Result<Self, FalconError> {
|
||||
if input.len() != PK_LEN {
|
||||
return Err(FalconError::PubKeyDecodingInvalidLength(input.len()));
|
||||
}
|
||||
|
||||
if input[0] != LOG_N as u8 {
|
||||
return Err(FalconError::PubKeyDecodingInvalidTag(input[0]));
|
||||
}
|
||||
|
||||
let mut acc = 0_u32;
|
||||
let mut acc_len = 0;
|
||||
|
||||
let mut output = [0_u16; N];
|
||||
let mut output_idx = 0;
|
||||
|
||||
for &byte in input.iter().skip(1) {
|
||||
acc = (acc << 8) | (byte as u32);
|
||||
acc_len += 8;
|
||||
|
||||
if acc_len >= 14 {
|
||||
acc_len -= 14;
|
||||
let w = (acc >> acc_len) & 0x3FFF;
|
||||
if w >= MODULUS as u32 {
|
||||
return Err(FalconError::PubKeyDecodingInvalidCoefficient(w));
|
||||
}
|
||||
output[output_idx] = w as u16;
|
||||
output_idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (acc & ((1u32 << acc_len) - 1)) == 0 {
|
||||
Ok(Self(output))
|
||||
} else {
|
||||
Err(FalconError::PubKeyDecodingExtraData)
|
||||
}
|
||||
}
|
||||
|
||||
/// Decodes the signature into the coefficients of a polynomial in Z_p[x]/(phi). It assumes
|
||||
/// that the signature has been encoded using the uncompressed format.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The signature has been encoded using a different algorithm than the reference compressed
|
||||
/// encoding algorithm.
|
||||
/// - The encoded signature polynomial is in Z_p[x]/(phi') where phi' = x^N' + 1 and N' != 512.
|
||||
/// - While decoding the high bits of a coefficient, the current accumulated value of its
|
||||
/// high bits is larger than 2048.
|
||||
/// - The decoded coefficient is -0.
|
||||
/// - The remaining unused bits in the last byte of `input` are non-zero.
|
||||
pub fn from_signature(input: &[u8]) -> Result<Self, FalconError> {
|
||||
let (encoding, log_n) = (input[0] >> 4, input[0] & 0b00001111);
|
||||
|
||||
if encoding != 0b0011 {
|
||||
return Err(FalconError::SigDecodingIncorrectEncodingAlgorithm);
|
||||
}
|
||||
if log_n != 0b1001 {
|
||||
return Err(FalconError::SigDecodingNotSupportedDegree(log_n));
|
||||
}
|
||||
|
||||
let input = &input[41..];
|
||||
let mut input_idx = 0;
|
||||
let mut acc = 0u32;
|
||||
let mut acc_len = 0;
|
||||
let mut output = [0_u16; N];
|
||||
|
||||
for e in output.iter_mut() {
|
||||
acc = (acc << 8) | (input[input_idx] as u32);
|
||||
input_idx += 1;
|
||||
let b = acc >> acc_len;
|
||||
let s = b & 128;
|
||||
let mut m = b & 127;
|
||||
|
||||
loop {
|
||||
if acc_len == 0 {
|
||||
acc = (acc << 8) | (input[input_idx] as u32);
|
||||
input_idx += 1;
|
||||
acc_len = 8;
|
||||
}
|
||||
acc_len -= 1;
|
||||
if ((acc >> acc_len) & 1) != 0 {
|
||||
break;
|
||||
}
|
||||
m += 128;
|
||||
if m >= 2048 {
|
||||
return Err(FalconError::SigDecodingTooBigHighBits(m));
|
||||
}
|
||||
}
|
||||
if s != 0 && m == 0 {
|
||||
return Err(FalconError::SigDecodingMinusZero);
|
||||
}
|
||||
|
||||
*e = if s != 0 { (MODULUS as u32 - m) as u16 } else { m as u16 };
|
||||
}
|
||||
|
||||
if (acc & ((1 << acc_len) - 1)) != 0 {
|
||||
return Err(FalconError::SigDecodingNonZeroUnusedBitsLastByte);
|
||||
}
|
||||
|
||||
Ok(Self(output))
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the coefficients of this polynomial as integers.
|
||||
pub fn inner(&self) -> [u16; N] {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// Returns the coefficients of this polynomial as field elements.
|
||||
pub fn to_elements(&self) -> Vec<Felt> {
|
||||
self.0.iter().map(|&a| Felt::from(a)).collect()
|
||||
}
|
||||
|
||||
// POLYNOMIAL OPERATIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Multiplies two polynomials over Z_p[x] without reducing modulo p. Given that the degrees
|
||||
/// of the input polynomials are less than 512 and their coefficients are less than the modulus
|
||||
/// q equal to 12289, the resulting product polynomial is guaranteed to have coefficients less
|
||||
/// than the Miden prime.
|
||||
///
|
||||
/// Note that this multiplication is not over Z_p[x]/(phi).
|
||||
pub fn mul_modulo_p(a: &Self, b: &Self) -> [u64; 1024] {
|
||||
let mut c = [0; 2 * N];
|
||||
for i in 0..N {
|
||||
for j in 0..N {
|
||||
c[i + j] += a.0[i] as u64 * b.0[j] as u64;
|
||||
}
|
||||
}
|
||||
|
||||
c
|
||||
}
|
||||
|
||||
/// Reduces a polynomial, that is the product of two polynomials over Z_p[x], modulo
|
||||
/// the irreducible polynomial phi. This results in an element in Z_p[x]/(phi).
|
||||
pub fn reduce_negacyclic(a: &[u64; 1024]) -> Self {
|
||||
let mut c = [0; N];
|
||||
for i in 0..N {
|
||||
let ai = a[N + i] % MODULUS as u64;
|
||||
let neg_ai = (MODULUS - ai as u16) % MODULUS;
|
||||
|
||||
let bi = (a[i] % MODULUS as u64) as u16;
|
||||
c[i] = (neg_ai + bi) % MODULUS;
|
||||
}
|
||||
|
||||
Self(c)
|
||||
}
|
||||
|
||||
/// Computes the norm squared of a polynomial in Z_p[x]/(phi) after normalizing its
|
||||
/// coefficients to be in the interval (-p/2, p/2].
|
||||
pub fn sq_norm(&self) -> u64 {
|
||||
let mut res = 0;
|
||||
for e in self.0 {
|
||||
if e > MODULUS_MINUS_1_OVER_TWO {
|
||||
res += (MODULUS - e) as u64 * (MODULUS - e) as u64
|
||||
} else {
|
||||
res += e as u64 * e as u64
|
||||
}
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a polynomial representing the zero polynomial i.e. default element.
|
||||
impl Default for Polynomial {
|
||||
fn default() -> Self {
|
||||
Self([0_u16; N])
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiplication over Z_p[x]/(phi)
|
||||
impl Mul for Polynomial {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, other: Self) -> <Self as Mul<Self>>::Output {
|
||||
let mut result = [0_u16; N];
|
||||
for j in 0..N {
|
||||
for k in 0..N {
|
||||
let i = (j + k) % N;
|
||||
let a = self.0[j] as usize;
|
||||
let b = other.0[k] as usize;
|
||||
let q = MODULUS as usize;
|
||||
let mut prod = a * b % q;
|
||||
if (N - 1) < (j + k) {
|
||||
prod = (q - prod) % q;
|
||||
}
|
||||
result[i] = ((result[i] as usize + prod) % q) as u16;
|
||||
}
|
||||
}
|
||||
|
||||
Polynomial(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Addition over Z_p[x]/(phi)
|
||||
impl Add for Polynomial {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, other: Self) -> <Self as Add<Self>>::Output {
|
||||
let mut res = self;
|
||||
res.0.iter_mut().zip(other.0.iter()).for_each(|(x, y)| *x = (*x + *y) % MODULUS);
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
/// Subtraction over Z_p[x]/(phi)
|
||||
impl Sub for Polynomial {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, other: Self) -> <Self as Add<Self>>::Output {
|
||||
let mut res = self;
|
||||
res.0
|
||||
.iter_mut()
|
||||
.zip(other.0.iter())
|
||||
.for_each(|(x, y)| *x = (*x + MODULUS - *y) % MODULUS);
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Polynomial, N};
|
||||
|
||||
#[test]
|
||||
fn test_negacyclic_reduction() {
|
||||
let coef1: [u16; N] = rand_utils::rand_array();
|
||||
let coef2: [u16; N] = rand_utils::rand_array();
|
||||
|
||||
let poly1 = Polynomial(coef1);
|
||||
let poly2 = Polynomial(coef2);
|
||||
|
||||
assert_eq!(
|
||||
poly1 * poly2,
|
||||
Polynomial::reduce_negacyclic(&Polynomial::mul_modulo_p(&poly1, &poly2))
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,262 +1,373 @@
|
||||
use alloc::{string::ToString, vec::Vec};
|
||||
use core::ops::Deref;
|
||||
|
||||
use super::{
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, NonceBytes, NonceElements,
|
||||
Polynomial, PublicKeyBytes, Rpo256, Serializable, SignatureBytes, StarkField, Word, MODULUS, N,
|
||||
SIG_L2_BOUND, ZERO,
|
||||
hash_to_point::hash_to_point_rpo256,
|
||||
keys::PubKeyPoly,
|
||||
math::{FalconFelt, FastFft, Polynomial},
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, Felt, Nonce, Rpo256,
|
||||
Serializable, Word, LOG_N, MODULUS, N, SIG_L2_BOUND, SIG_POLY_BYTE_LEN,
|
||||
};
|
||||
use crate::utils::string::ToString;
|
||||
use num::Zero;
|
||||
|
||||
// FALCON SIGNATURE
|
||||
// ================================================================================================
|
||||
|
||||
/// An RPO Falcon512 signature over a message.
|
||||
///
|
||||
/// The signature is a pair of polynomials (s1, s2) in (Z_p[x]/(phi))^2, where:
|
||||
/// The signature is a pair of polynomials (s1, s2) in (Z_p\[x\]/(phi))^2 a nonce `r`, and a public
|
||||
/// key polynomial `h` where:
|
||||
/// - p := 12289
|
||||
/// - phi := x^512 + 1
|
||||
/// - s1 = c - s2 * h
|
||||
/// - h is a polynomial representing the public key and c is a polynomial that is the hash-to-point
|
||||
/// of the message being signed.
|
||||
///
|
||||
/// The signature verifies if and only if:
|
||||
/// The signature verifies against a public key `pk` if and only if:
|
||||
/// 1. s1 = c - s2 * h
|
||||
/// 2. |s1|^2 + |s2|^2 <= SIG_L2_BOUND
|
||||
///
|
||||
/// where |.| is the norm.
|
||||
/// where |.| is the norm and:
|
||||
/// - c = HashToPoint(r || message)
|
||||
/// - pk = Rpo256::hash(h)
|
||||
///
|
||||
/// [Signature] also includes the extended public key which is serialized as:
|
||||
/// Here h is a polynomial representing the public key and pk is its digest using the Rpo256 hash
|
||||
/// function. c is a polynomial that is the hash-to-point of the message being signed.
|
||||
///
|
||||
/// The polynomial h is serialized as:
|
||||
/// 1. 1 byte representing the log2(512) i.e., 9.
|
||||
/// 2. 896 bytes for the public key. This is decoded into the `h` polynomial above.
|
||||
/// 2. 896 bytes for the public key itself.
|
||||
///
|
||||
/// The actual signature is serialized as:
|
||||
/// The signature is serialized as:
|
||||
/// 1. A header byte specifying the algorithm used to encode the coefficients of the `s2` polynomial
|
||||
/// together with the degree of the irreducible polynomial phi.
|
||||
/// The general format of this byte is 0b0cc1nnnn where:
|
||||
/// a. cc is either 01 when the compressed encoding algorithm is used and 10 when the
|
||||
/// uncompressed algorithm is used.
|
||||
/// b. nnnn is log2(N) where N is the degree of the irreducible polynomial phi.
|
||||
/// The current implementation works always with cc equal to 0b01 and nnnn equal to 0b1001 and
|
||||
/// thus the header byte is always equal to 0b00111001.
|
||||
/// together with the degree of the irreducible polynomial phi. For RPO Falcon512, the header
|
||||
/// byte is set to `10111001` which differentiates it from the standardized instantiation of
|
||||
/// the Falcon signature.
|
||||
/// 2. 40 bytes for the nonce.
|
||||
/// 3. 625 bytes encoding the `s2` polynomial above.
|
||||
/// 4. 625 bytes encoding the `s2` polynomial above.
|
||||
///
|
||||
/// The total size of the signature (including the extended public key) is 1563 bytes.
|
||||
/// The total size of the signature is (including the extended public key) is 1563 bytes.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Signature {
|
||||
pub(super) pk: PublicKeyBytes,
|
||||
pub(super) sig: SignatureBytes,
|
||||
header: SignatureHeader,
|
||||
nonce: Nonce,
|
||||
s2: SignaturePoly,
|
||||
h: PubKeyPoly,
|
||||
}
|
||||
|
||||
impl Signature {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
pub fn new(nonce: Nonce, h: PubKeyPoly, s2: SignaturePoly) -> Signature {
|
||||
Self {
|
||||
header: SignatureHeader::default(),
|
||||
nonce,
|
||||
s2,
|
||||
h,
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the public key polynomial h.
|
||||
pub fn pub_key_poly(&self) -> Polynomial {
|
||||
// TODO: memoize
|
||||
// we assume that the signature was constructed with a valid public key, and thus
|
||||
// expect() is OK here.
|
||||
Polynomial::from_pub_key(&self.pk).expect("invalid public key")
|
||||
}
|
||||
|
||||
/// Returns the nonce component of the signature represented as field elements.
|
||||
///
|
||||
/// Nonce bytes are converted to field elements by taking consecutive 5 byte chunks
|
||||
/// of the nonce and interpreting them as field elements.
|
||||
pub fn nonce(&self) -> NonceElements {
|
||||
// we assume that the signature was constructed with a valid signature, and thus
|
||||
// expect() is OK here.
|
||||
let nonce = self.sig[1..41].try_into().expect("invalid signature");
|
||||
decode_nonce(nonce)
|
||||
pub fn pk_poly(&self) -> &PubKeyPoly {
|
||||
&self.h
|
||||
}
|
||||
|
||||
// Returns the polynomial representation of the signature in Z_p[x]/(phi).
|
||||
pub fn sig_poly(&self) -> Polynomial {
|
||||
// TODO: memoize
|
||||
// we assume that the signature was constructed with a valid signature, and thus
|
||||
// expect() is OK here.
|
||||
Polynomial::from_signature(&self.sig).expect("invalid signature")
|
||||
pub fn sig_poly(&self) -> &Polynomial<FalconFelt> {
|
||||
&self.s2
|
||||
}
|
||||
|
||||
// HASH-TO-POINT
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message.
|
||||
pub fn hash_to_point(&self, message: Word) -> Polynomial {
|
||||
hash_to_point(message, &self.nonce())
|
||||
/// Returns the nonce component of the signature.
|
||||
pub fn nonce(&self) -> &Nonce {
|
||||
&self.nonce
|
||||
}
|
||||
|
||||
// SIGNATURE VERIFICATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if this signature is a valid signature for the specified message generated
|
||||
/// against key pair matching the specified public key commitment.
|
||||
/// against the secret key matching the specified public key commitment.
|
||||
pub fn verify(&self, message: Word, pubkey_com: Word) -> bool {
|
||||
// Make sure the expanded public key matches the provided public key commitment
|
||||
let h = self.pub_key_poly();
|
||||
let h_digest: Word = Rpo256::hash_elements(&h.to_elements()).into();
|
||||
// compute the hash of the public key polynomial
|
||||
let h_felt: Polynomial<Felt> = (&**self.pk_poly()).into();
|
||||
let h_digest: Word = Rpo256::hash_elements(&h_felt.coefficients).into();
|
||||
if h_digest != pubkey_com {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Make sure the signature is valid
|
||||
let s2 = self.sig_poly();
|
||||
let c = self.hash_to_point(message);
|
||||
|
||||
let s1 = c - s2 * h;
|
||||
|
||||
let sq_norm = s1.sq_norm() + s2.sq_norm();
|
||||
sq_norm <= SIG_L2_BOUND
|
||||
let c = hash_to_point_rpo256(message, &self.nonce);
|
||||
h_digest == pubkey_com && verify_helper(&c, &self.s2, self.pk_poly())
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for Signature {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.pk);
|
||||
target.write_bytes(&self.sig);
|
||||
target.write(&self.header);
|
||||
target.write(&self.nonce);
|
||||
target.write(&self.s2);
|
||||
target.write(&self.h);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for Signature {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let pk: PublicKeyBytes = source.read_array()?;
|
||||
let sig: SignatureBytes = source.read_array()?;
|
||||
let header = source.read()?;
|
||||
let nonce = source.read()?;
|
||||
let s2 = source.read()?;
|
||||
let h = source.read()?;
|
||||
|
||||
// make sure public key and signature can be decoded correctly
|
||||
Polynomial::from_pub_key(&pk)
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?;
|
||||
Polynomial::from_signature(&sig[41..])
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?;
|
||||
Ok(Self { header, nonce, s2, h })
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self { pk, sig })
|
||||
// SIGNATURE HEADER
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SignatureHeader(u8);
|
||||
|
||||
impl Default for SignatureHeader {
|
||||
/// According to section 3.11.3 in the specification [1], the signature header has the format
|
||||
/// `0cc1nnnn` where:
|
||||
///
|
||||
/// 1. `cc` signifies the encoding method. `01` denotes using the compression encoding method
|
||||
/// and `10` denotes encoding using the uncompressed method.
|
||||
/// 2. `nnnn` encodes `LOG_N`.
|
||||
///
|
||||
/// For RPO Falcon 512 we use compression encoding and N = 512. Moreover, to differentiate the
|
||||
/// RPO Falcon variant from the reference variant using SHAKE256, we flip the first bit in the
|
||||
/// header. Thus, for RPO Falcon 512 the header is `10111001`
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
fn default() -> Self {
|
||||
Self(0b1011_1001)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for &SignatureHeader {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_u8(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SignatureHeader {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let header = source.read_u8()?;
|
||||
let (encoding, log_n) = (header >> 4, header & 0b00001111);
|
||||
if encoding != 0b1011 {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"Failed to decode signature: not supported encoding algorithm".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if log_n != LOG_N {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
format!("Failed to decode signature: only supported irreducible polynomial degree is 512, 2^{log_n} was provided")
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Self(header))
|
||||
}
|
||||
}
|
||||
|
||||
// SIGNATURE POLYNOMIAL
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SignaturePoly(pub Polynomial<FalconFelt>);
|
||||
|
||||
impl Deref for SignaturePoly {
|
||||
type Target = Polynomial<FalconFelt>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Polynomial<FalconFelt>> for SignaturePoly {
|
||||
fn from(pk_poly: Polynomial<FalconFelt>) -> Self {
|
||||
Self(pk_poly)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[i16; N]> for SignaturePoly {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(coefficients: &[i16; N]) -> Result<Self, Self::Error> {
|
||||
if are_coefficients_valid(coefficients) {
|
||||
Ok(Self(coefficients.to_vec().into()))
|
||||
} else {
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for &SignaturePoly {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
let sig_coeff: Vec<i16> = self.0.coefficients.iter().map(|a| a.balanced_value()).collect();
|
||||
let mut sk_bytes = vec![0_u8; SIG_POLY_BYTE_LEN];
|
||||
|
||||
let mut acc = 0;
|
||||
let mut acc_len = 0;
|
||||
let mut v = 0;
|
||||
let mut t;
|
||||
let mut w;
|
||||
|
||||
// For each coefficient of x:
|
||||
// - the sign is encoded on 1 bit
|
||||
// - the 7 lower bits are encoded naively (binary)
|
||||
// - the high bits are encoded in unary encoding
|
||||
//
|
||||
// Algorithm 17 p. 47 of the specification [1].
|
||||
//
|
||||
// [1]: https://falcon-sign.info/falcon.pdf
|
||||
for &c in sig_coeff.iter() {
|
||||
acc <<= 1;
|
||||
t = c;
|
||||
|
||||
if t < 0 {
|
||||
t = -t;
|
||||
acc |= 1;
|
||||
}
|
||||
w = t as u16;
|
||||
|
||||
acc <<= 7;
|
||||
let mask = 127_u32;
|
||||
acc |= (w as u32) & mask;
|
||||
w >>= 7;
|
||||
|
||||
acc_len += 8;
|
||||
|
||||
acc <<= w + 1;
|
||||
acc |= 1;
|
||||
acc_len += w + 1;
|
||||
|
||||
while acc_len >= 8 {
|
||||
acc_len -= 8;
|
||||
|
||||
sk_bytes[v] = (acc >> acc_len) as u8;
|
||||
v += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if acc_len > 0 {
|
||||
sk_bytes[v] = (acc << (8 - acc_len)) as u8;
|
||||
}
|
||||
target.write_bytes(&sk_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SignaturePoly {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let input = source.read_array::<SIG_POLY_BYTE_LEN>()?;
|
||||
|
||||
let mut input_idx = 0;
|
||||
let mut acc = 0u32;
|
||||
let mut acc_len = 0;
|
||||
let mut coefficients = [FalconFelt::zero(); N];
|
||||
|
||||
// Algorithm 18 p. 48 of the specification [1].
|
||||
//
|
||||
// [1]: https://falcon-sign.info/falcon.pdf
|
||||
for c in coefficients.iter_mut() {
|
||||
acc = (acc << 8) | (input[input_idx] as u32);
|
||||
input_idx += 1;
|
||||
let b = acc >> acc_len;
|
||||
let s = b & 128;
|
||||
let mut m = b & 127;
|
||||
|
||||
loop {
|
||||
if acc_len == 0 {
|
||||
acc = (acc << 8) | (input[input_idx] as u32);
|
||||
input_idx += 1;
|
||||
acc_len = 8;
|
||||
}
|
||||
acc_len -= 1;
|
||||
if ((acc >> acc_len) & 1) != 0 {
|
||||
break;
|
||||
}
|
||||
m += 128;
|
||||
if m >= 2048 {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"Failed to decode signature: high bits {m} exceed 2048".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
if s != 0 && m == 0 {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"Failed to decode signature: -0 is forbidden".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let felt = if s != 0 { (MODULUS as u32 - m) as u16 } else { m as u16 };
|
||||
*c = FalconFelt::new(felt as i16);
|
||||
}
|
||||
|
||||
if (acc & ((1 << acc_len) - 1)) != 0 {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"Failed to decode signature: Non-zero unused bits in the last byte".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(Polynomial::new(coefficients.to_vec()).into())
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
|
||||
/// nonce.
|
||||
fn hash_to_point(message: Word, nonce: &NonceElements) -> Polynomial {
|
||||
let mut state = [ZERO; Rpo256::STATE_WIDTH];
|
||||
/// Takes the hash-to-point polynomial `c` of a message, the signature polynomial over
|
||||
/// the message `s2` and a public key polynomial and returns `true` is the signature is a valid
|
||||
/// signature for the given parameters, otherwise it returns `false`.
|
||||
fn verify_helper(c: &Polynomial<FalconFelt>, s2: &SignaturePoly, h: &PubKeyPoly) -> bool {
|
||||
let h_fft = h.fft();
|
||||
let s2_fft = s2.fft();
|
||||
let c_fft = c.fft();
|
||||
|
||||
// absorb the nonce into the state
|
||||
for (&n, s) in nonce.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
||||
*s = n;
|
||||
}
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
// compute the signature polynomial s1 using s1 = c - s2 * h
|
||||
let s1_fft = c_fft - s2_fft.hadamard_mul(&h_fft);
|
||||
let s1 = s1_fft.ifft();
|
||||
|
||||
// absorb message into the state
|
||||
for (&m, s) in message.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
||||
*s = m;
|
||||
// compute the norm squared of (s1, s2)
|
||||
let length_squared_s1 = s1.norm_squared();
|
||||
let length_squared_s2 = s2.norm_squared();
|
||||
let length_squared = length_squared_s1 + length_squared_s2;
|
||||
|
||||
length_squared < SIG_L2_BOUND
|
||||
}
|
||||
|
||||
/// Checks whether a set of coefficients is a valid one for a signature polynomial.
|
||||
fn are_coefficients_valid(x: &[i16]) -> bool {
|
||||
if x.len() != N {
|
||||
return false;
|
||||
}
|
||||
|
||||
// squeeze the coefficients of the polynomial
|
||||
let mut i = 0;
|
||||
let mut res = [0_u16; N];
|
||||
for _ in 0..64 {
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
for a in &state[Rpo256::RATE_RANGE] {
|
||||
res[i] = (a.as_int() % MODULUS as u64) as u16;
|
||||
i += 1;
|
||||
for &c in x {
|
||||
if !(-2047..=2047).contains(&c) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// using the raw constructor is OK here because we reduce all coefficients by the modulus above
|
||||
unsafe { Polynomial::new(res) }
|
||||
}
|
||||
|
||||
/// Converts byte representation of the nonce into field element representation.
|
||||
fn decode_nonce(nonce: &NonceBytes) -> NonceElements {
|
||||
let mut buffer = [0_u8; 8];
|
||||
let mut result = [ZERO; 8];
|
||||
for (i, bytes) in nonce.chunks(5).enumerate() {
|
||||
buffer[..5].copy_from_slice(bytes);
|
||||
result[i] = u64::from_le_bytes(buffer).into();
|
||||
}
|
||||
|
||||
result
|
||||
true
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
super::{ffi::*, Felt},
|
||||
*,
|
||||
};
|
||||
use libc::c_void;
|
||||
use rand_utils::rand_vector;
|
||||
|
||||
// Wrappers for unsafe functions
|
||||
impl Rpo128Context {
|
||||
/// Initializes the RPO state.
|
||||
pub fn init() -> Self {
|
||||
let mut ctx = Rpo128Context { content: [0u64; 13] };
|
||||
unsafe {
|
||||
rpo128_init(&mut ctx as *mut Rpo128Context);
|
||||
}
|
||||
ctx
|
||||
}
|
||||
|
||||
/// Absorbs data into the RPO state.
|
||||
pub fn absorb(&mut self, data: &[u8]) {
|
||||
unsafe {
|
||||
rpo128_absorb(
|
||||
self as *mut Rpo128Context,
|
||||
data.as_ptr() as *const c_void,
|
||||
data.len(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Finalizes the RPO state to prepare for squeezing.
|
||||
pub fn finalize(&mut self) {
|
||||
unsafe { rpo128_finalize(self as *mut Rpo128Context) }
|
||||
}
|
||||
}
|
||||
use super::{super::SecretKey, *};
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
|
||||
#[test]
|
||||
fn test_hash_to_point() {
|
||||
// Create a random message and transform it into a u8 vector
|
||||
let msg_felts: Word = rand_vector::<Felt>(4).try_into().unwrap();
|
||||
let msg_bytes = msg_felts.iter().flat_map(|e| e.as_int().to_le_bytes()).collect::<Vec<_>>();
|
||||
fn test_serialization_round_trip() {
|
||||
let seed = [0_u8; 32];
|
||||
let mut rng = ChaCha20Rng::from_seed(seed);
|
||||
|
||||
// Create a nonce i.e. a [u8; 40] array and pack into a [Felt; 8] array.
|
||||
let nonce: [u8; 40] = rand_vector::<u8>(40).try_into().unwrap();
|
||||
|
||||
let mut buffer = [0_u8; 64];
|
||||
for i in 0..8 {
|
||||
buffer[8 * i] = nonce[5 * i];
|
||||
buffer[8 * i + 1] = nonce[5 * i + 1];
|
||||
buffer[8 * i + 2] = nonce[5 * i + 2];
|
||||
buffer[8 * i + 3] = nonce[5 * i + 3];
|
||||
buffer[8 * i + 4] = nonce[5 * i + 4];
|
||||
}
|
||||
|
||||
// Initialize the RPO state
|
||||
let mut rng = Rpo128Context::init();
|
||||
|
||||
// Absorb the nonce and message into the RPO state
|
||||
rng.absorb(&buffer);
|
||||
rng.absorb(&msg_bytes);
|
||||
rng.finalize();
|
||||
|
||||
// Generate the coefficients of the hash-to-point polynomial.
|
||||
let mut res: [u16; N] = [0; N];
|
||||
|
||||
unsafe {
|
||||
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(
|
||||
&mut rng as *mut Rpo128Context,
|
||||
res.as_mut_ptr(),
|
||||
9,
|
||||
);
|
||||
}
|
||||
|
||||
// Check that the coefficients are correct
|
||||
let nonce = decode_nonce(&nonce);
|
||||
assert_eq!(res, hash_to_point(msg_felts, &nonce).inner());
|
||||
let sk = SecretKey::with_rng(&mut rng);
|
||||
let signature = sk.sign_with_rng(Word::default(), &mut rng);
|
||||
let serialized = signature.to_bytes();
|
||||
let deserialized = Signature::read_from_bytes(&serialized).unwrap();
|
||||
assert_eq!(signature.sig_poly(), deserialized.sig_poly());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField};
|
||||
use crate::utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
};
|
||||
use alloc::string::String;
|
||||
use core::{
|
||||
mem::{size_of, transmute, transmute_copy},
|
||||
ops::Deref,
|
||||
slice::from_raw_parts,
|
||||
};
|
||||
|
||||
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher};
|
||||
use crate::utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use super::*;
|
||||
use crate::utils::collections::Vec;
|
||||
use proptest::prelude::*;
|
||||
use rand_utils::rand_vector;
|
||||
|
||||
use super::*;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
#[test]
|
||||
fn blake3_hash_elements() {
|
||||
// test multiple of 8
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
//! Cryptographic hash functions used by the Miden VM and the Miden rollup.
|
||||
|
||||
use super::{Felt, FieldElement, StarkField, ONE, ZERO};
|
||||
use super::{CubeExtension, Felt, FieldElement, StarkField, ONE, ZERO};
|
||||
|
||||
pub mod blake;
|
||||
pub mod rpo;
|
||||
|
||||
mod rescue;
|
||||
pub mod rpo {
|
||||
pub use super::rescue::{Rpo256, RpoDigest};
|
||||
}
|
||||
|
||||
pub mod rpx {
|
||||
pub use super::rescue::{Rpx256, RpxDigest};
|
||||
}
|
||||
|
||||
// RE-EXPORTS
|
||||
// ================================================================================================
|
||||
|
||||
101
src/hash/rescue/arch/mod.rs
Normal file
101
src/hash/rescue/arch/mod.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
#[cfg(target_feature = "sve")]
|
||||
pub mod optimized {
|
||||
use crate::{hash::rescue::STATE_WIDTH, Felt};
|
||||
|
||||
mod ffi {
|
||||
#[link(name = "rpo_sve", kind = "static")]
|
||||
extern "C" {
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
state: *mut std::ffi::c_ulong,
|
||||
constants: *const std::ffi::c_ulong,
|
||||
) -> bool;
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
state: *mut std::ffi::c_ulong,
|
||||
constants: *const std::ffi::c_ulong,
|
||||
) -> bool;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
unsafe {
|
||||
ffi::add_constants_and_apply_sbox(
|
||||
state.as_mut_ptr() as *mut u64,
|
||||
ark.as_ptr() as *const u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
unsafe {
|
||||
ffi::add_constants_and_apply_inv_sbox(
|
||||
state.as_mut_ptr() as *mut u64,
|
||||
ark.as_ptr() as *const u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "avx2")]
|
||||
mod x86_64_avx2;
|
||||
|
||||
#[cfg(target_feature = "avx2")]
|
||||
pub mod optimized {
|
||||
use super::x86_64_avx2::{apply_inv_sbox, apply_sbox};
|
||||
use crate::{
|
||||
hash::rescue::{add_constants, STATE_WIDTH},
|
||||
Felt,
|
||||
};
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
add_constants(state, ark);
|
||||
unsafe {
|
||||
apply_sbox(std::mem::transmute(state));
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
add_constants(state, ark);
|
||||
unsafe {
|
||||
apply_inv_sbox(std::mem::transmute(state));
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(target_feature = "avx2", target_feature = "sve")))]
|
||||
pub mod optimized {
|
||||
use crate::{hash::rescue::STATE_WIDTH, Felt};
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
_state: &mut [Felt; STATE_WIDTH],
|
||||
_ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
_state: &mut [Felt; STATE_WIDTH],
|
||||
_ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
325
src/hash/rescue/arch/x86_64_avx2.rs
Normal file
325
src/hash/rescue/arch/x86_64_avx2.rs
Normal file
@@ -0,0 +1,325 @@
|
||||
use core::arch::x86_64::*;
|
||||
|
||||
// The following AVX2 implementation has been copied from plonky2:
|
||||
// https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs
|
||||
|
||||
// Preliminary notes:
|
||||
// 1. AVX does not support addition with carry but 128-bit (2-word) addition can be easily
|
||||
// emulated. The method recognizes that for a + b overflowed iff (a + b) < a:
|
||||
// i. res_lo = a_lo + b_lo
|
||||
// ii. carry_mask = res_lo < a_lo
|
||||
// iii. res_hi = a_hi + b_hi - carry_mask
|
||||
// Notice that carry_mask is subtracted, not added. This is because AVX comparison instructions
|
||||
// return -1 (all bits 1) for true and 0 for false.
|
||||
//
|
||||
// 2. AVX does not have unsigned 64-bit comparisons. Those can be emulated with signed comparisons
|
||||
// by recognizing that a <u b iff a + (1 << 63) <s b + (1 << 63), where the addition wraps around
|
||||
// and the comparisons are unsigned and signed respectively. The shift function adds/subtracts
|
||||
// 1 << 63 to enable this trick.
|
||||
// Example: addition with carry.
|
||||
// i. a_lo_s = shift(a_lo)
|
||||
// ii. res_lo_s = a_lo_s + b_lo
|
||||
// iii. carry_mask = res_lo_s <s a_lo_s
|
||||
// iv. res_lo = shift(res_lo_s)
|
||||
// v. res_hi = a_hi + b_hi - carry_mask
|
||||
// The suffix _s denotes a value that has been shifted by 1 << 63. The result of addition is
|
||||
// shifted if exactly one of the operands is shifted, as is the case on line ii. Line iii.
|
||||
// performs a signed comparison res_lo_s <s a_lo_s on shifted values to emulate unsigned
|
||||
// comparison res_lo <u a_lo on unshifted values. Finally, line iv. reverses the shift so the
|
||||
// result can be returned.
|
||||
// When performing a chain of calculations, we can often save instructions by letting the shift
|
||||
// propagate through and only undoing it when necessary. For example, to compute the addition of
|
||||
// three two-word (128-bit) numbers we can do:
|
||||
// i. a_lo_s = shift(a_lo)
|
||||
// ii. tmp_lo_s = a_lo_s + b_lo
|
||||
// iii. tmp_carry_mask = tmp_lo_s <s a_lo_s
|
||||
// iv. tmp_hi = a_hi + b_hi - tmp_carry_mask
|
||||
// v. res_lo_s = tmp_lo_s + c_lo
|
||||
// vi. res_carry_mask = res_lo_s <s tmp_lo_s
|
||||
// vii. res_lo = shift(res_lo_s)
|
||||
// viii. res_hi = tmp_hi + c_hi - res_carry_mask
|
||||
// Notice that the above 3-value addition still only requires two calls to shift, just like our
|
||||
// 2-value addition.
|
||||
|
||||
#[inline(always)]
|
||||
pub fn branch_hint() {
|
||||
// NOTE: These are the currently supported assembly architectures. See the
|
||||
// [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
|
||||
// the most up-to-date list.
|
||||
#[cfg(any(
|
||||
target_arch = "aarch64",
|
||||
target_arch = "arm",
|
||||
target_arch = "riscv32",
|
||||
target_arch = "riscv64",
|
||||
target_arch = "x86",
|
||||
target_arch = "x86_64",
|
||||
))]
|
||||
unsafe {
|
||||
core::arch::asm!("", options(nomem, nostack, preserves_flags));
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! map3 {
|
||||
($f:ident::<$l:literal>, $v:ident) => {
|
||||
($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2))
|
||||
};
|
||||
($f:ident::<$l:literal>, $v1:ident, $v2:ident) => {
|
||||
($f::<$l>($v1.0, $v2.0), $f::<$l>($v1.1, $v2.1), $f::<$l>($v1.2, $v2.2))
|
||||
};
|
||||
($f:ident, $v:ident) => {
|
||||
($f($v.0), $f($v.1), $f($v.2))
|
||||
};
|
||||
($f:ident, $v0:ident, $v1:ident) => {
|
||||
($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2))
|
||||
};
|
||||
($f:ident, rep $v0:ident, $v1:ident) => {
|
||||
($f($v0, $v1.0), $f($v0, $v1.1), $f($v0, $v1.2))
|
||||
};
|
||||
|
||||
($f:ident, $v0:ident, rep $v1:ident) => {
|
||||
($f($v0.0, $v1), $f($v0.1, $v1), $f($v0.2, $v1))
|
||||
};
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn square3(
|
||||
x: (__m256i, __m256i, __m256i),
|
||||
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
|
||||
let x_hi = {
|
||||
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
|
||||
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
|
||||
// This is safe and free.
|
||||
let x_ps = map3!(_mm256_castsi256_ps, x);
|
||||
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
|
||||
map3!(_mm256_castps_si256, x_hi_ps)
|
||||
};
|
||||
|
||||
// All pairwise multiplications.
|
||||
let mul_ll = map3!(_mm256_mul_epu32, x, x);
|
||||
let mul_lh = map3!(_mm256_mul_epu32, x, x_hi);
|
||||
let mul_hh = map3!(_mm256_mul_epu32, x_hi, x_hi);
|
||||
|
||||
// Bignum addition, but mul_lh is shifted by 33 bits (not 32).
|
||||
let mul_ll_hi = map3!(_mm256_srli_epi64::<33>, mul_ll);
|
||||
let t0 = map3!(_mm256_add_epi64, mul_lh, mul_ll_hi);
|
||||
let t0_hi = map3!(_mm256_srli_epi64::<31>, t0);
|
||||
let res_hi = map3!(_mm256_add_epi64, mul_hh, t0_hi);
|
||||
|
||||
// Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high
|
||||
// position).
|
||||
let mul_lh_lo = map3!(_mm256_slli_epi64::<33>, mul_lh);
|
||||
let res_lo = map3!(_mm256_add_epi64, mul_ll, mul_lh_lo);
|
||||
|
||||
(res_lo, res_hi)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mul3(
|
||||
x: (__m256i, __m256i, __m256i),
|
||||
y: (__m256i, __m256i, __m256i),
|
||||
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
|
||||
let epsilon = _mm256_set1_epi64x(0xffffffff);
|
||||
let x_hi = {
|
||||
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
|
||||
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
|
||||
// This is safe and free.
|
||||
let x_ps = map3!(_mm256_castsi256_ps, x);
|
||||
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
|
||||
map3!(_mm256_castps_si256, x_hi_ps)
|
||||
};
|
||||
let y_hi = {
|
||||
let y_ps = map3!(_mm256_castsi256_ps, y);
|
||||
let y_hi_ps = map3!(_mm256_movehdup_ps, y_ps);
|
||||
map3!(_mm256_castps_si256, y_hi_ps)
|
||||
};
|
||||
|
||||
// All four pairwise multiplications
|
||||
let mul_ll = map3!(_mm256_mul_epu32, x, y);
|
||||
let mul_lh = map3!(_mm256_mul_epu32, x, y_hi);
|
||||
let mul_hl = map3!(_mm256_mul_epu32, x_hi, y);
|
||||
let mul_hh = map3!(_mm256_mul_epu32, x_hi, y_hi);
|
||||
|
||||
// Bignum addition
|
||||
// Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow.
|
||||
let mul_ll_hi = map3!(_mm256_srli_epi64::<32>, mul_ll);
|
||||
let t0 = map3!(_mm256_add_epi64, mul_hl, mul_ll_hi);
|
||||
// Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow.
|
||||
// Also, extract high 32 bits of t0 and add to mul_hh.
|
||||
let t0_lo = map3!(_mm256_and_si256, t0, rep epsilon);
|
||||
let t0_hi = map3!(_mm256_srli_epi64::<32>, t0);
|
||||
let t1 = map3!(_mm256_add_epi64, mul_lh, t0_lo);
|
||||
let t2 = map3!(_mm256_add_epi64, mul_hh, t0_hi);
|
||||
// Lastly, extract the high 32 bits of t1 and add to t2.
|
||||
let t1_hi = map3!(_mm256_srli_epi64::<32>, t1);
|
||||
let res_hi = map3!(_mm256_add_epi64, t2, t1_hi);
|
||||
|
||||
// Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high
|
||||
// position).
|
||||
let t1_lo = {
|
||||
let t1_ps = map3!(_mm256_castsi256_ps, t1);
|
||||
let t1_lo_ps = map3!(_mm256_moveldup_ps, t1_ps);
|
||||
map3!(_mm256_castps_si256, t1_lo_ps)
|
||||
};
|
||||
let res_lo = map3!(_mm256_blend_epi32::<0xaa>, mul_ll, t1_lo);
|
||||
|
||||
(res_lo, res_hi)
|
||||
}
|
||||
|
||||
/// Addition, where the second operand is `0 <= y < 0xffffffff00000001`.
|
||||
#[inline(always)]
|
||||
unsafe fn add_small(
|
||||
x_s: (__m256i, __m256i, __m256i),
|
||||
y: (__m256i, __m256i, __m256i),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let res_wrapped_s = map3!(_mm256_add_epi64, x_s, y);
|
||||
let mask = map3!(_mm256_cmpgt_epi32, x_s, res_wrapped_s);
|
||||
let wrapback_amt = map3!(_mm256_srli_epi64::<32>, mask); // EPSILON if overflowed else 0.
|
||||
let res_s = map3!(_mm256_add_epi64, res_wrapped_s, wrapback_amt);
|
||||
res_s
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn maybe_adj_sub(res_wrapped_s: __m256i, mask: __m256i) -> __m256i {
|
||||
// The subtraction is very unlikely to overflow so we're best off branching.
|
||||
// The even u32s in `mask` are meaningless, so we want to ignore them. `_mm256_testz_pd`
|
||||
// branches depending on the sign bit of double-precision (64-bit) floats. Bit cast `mask` to
|
||||
// floating-point (this is free).
|
||||
let mask_pd = _mm256_castsi256_pd(mask);
|
||||
// `_mm256_testz_pd(mask_pd, mask_pd) == 1` iff all sign bits are 0, meaning that underflow
|
||||
// did not occur for any of the vector elements.
|
||||
if _mm256_testz_pd(mask_pd, mask_pd) == 1 {
|
||||
res_wrapped_s
|
||||
} else {
|
||||
branch_hint();
|
||||
// Highly unlikely: underflow did occur. Find adjustment per element and apply it.
|
||||
let adj_amount = _mm256_srli_epi64::<32>(mask); // EPSILON if underflow.
|
||||
_mm256_sub_epi64(res_wrapped_s, adj_amount)
|
||||
}
|
||||
}
|
||||
|
||||
/// Addition, where the second operand is much smaller than `0xffffffff00000001`.
|
||||
#[inline(always)]
|
||||
unsafe fn sub_tiny(
|
||||
x_s: (__m256i, __m256i, __m256i),
|
||||
y: (__m256i, __m256i, __m256i),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let res_wrapped_s = map3!(_mm256_sub_epi64, x_s, y);
|
||||
let mask = map3!(_mm256_cmpgt_epi32, res_wrapped_s, x_s);
|
||||
let res_s = map3!(maybe_adj_sub, res_wrapped_s, mask);
|
||||
res_s
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn reduce3(
|
||||
(lo0, hi0): ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let sign_bit = _mm256_set1_epi64x(i64::MIN);
|
||||
let epsilon = _mm256_set1_epi64x(0xffffffff);
|
||||
let lo0_s = map3!(_mm256_xor_si256, lo0, rep sign_bit);
|
||||
let hi_hi0 = map3!(_mm256_srli_epi64::<32>, hi0);
|
||||
let lo1_s = sub_tiny(lo0_s, hi_hi0);
|
||||
let t1 = map3!(_mm256_mul_epu32, hi0, rep epsilon);
|
||||
let lo2_s = add_small(lo1_s, t1);
|
||||
let lo2 = map3!(_mm256_xor_si256, lo2_s, rep sign_bit);
|
||||
lo2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mul_reduce(
|
||||
a: (__m256i, __m256i, __m256i),
|
||||
b: (__m256i, __m256i, __m256i),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
reduce3(mul3(a, b))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn square_reduce(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||
reduce3(square3(state))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn exp_acc(
|
||||
high: (__m256i, __m256i, __m256i),
|
||||
low: (__m256i, __m256i, __m256i),
|
||||
exp: usize,
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let mut result = high;
|
||||
for _ in 0..exp {
|
||||
result = square_reduce(result);
|
||||
}
|
||||
mul_reduce(result, low)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn do_apply_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||
let state2 = square_reduce(state);
|
||||
let state4_unreduced = square3(state2);
|
||||
let state3_unreduced = mul3(state2, state);
|
||||
let state4 = reduce3(state4_unreduced);
|
||||
let state3 = reduce3(state3_unreduced);
|
||||
let state7_unreduced = mul3(state3, state4);
|
||||
let state7 = reduce3(state7_unreduced);
|
||||
state7
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn do_apply_inv_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||
// compute base^10540996611094048183 using 72 multiplications per array element
|
||||
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
|
||||
|
||||
// compute base^10
|
||||
let t1 = square_reduce(state);
|
||||
|
||||
// compute base^100
|
||||
let t2 = square_reduce(t1);
|
||||
|
||||
// compute base^100100
|
||||
let t3 = exp_acc(t2, t2, 3);
|
||||
|
||||
// compute base^100100100100
|
||||
let t4 = exp_acc(t3, t3, 6);
|
||||
|
||||
// compute base^100100100100100100100100
|
||||
let t5 = exp_acc(t4, t4, 12);
|
||||
|
||||
// compute base^100100100100100100100100100100
|
||||
let t6 = exp_acc(t5, t3, 6);
|
||||
|
||||
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||
let t7 = exp_acc(t6, t6, 31);
|
||||
|
||||
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||
let a = square_reduce(square_reduce(mul_reduce(square_reduce(t7), t6)));
|
||||
let b = mul_reduce(t1, mul_reduce(t2, state));
|
||||
mul_reduce(a, b)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn avx2_load(state: &[u64; 12]) -> (__m256i, __m256i, __m256i) {
|
||||
(
|
||||
_mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()),
|
||||
_mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()),
|
||||
_mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn avx2_store(buf: &mut [u64; 12], state: (__m256i, __m256i, __m256i)) {
|
||||
_mm256_storeu_si256((&mut buf[0..4]).as_mut_ptr().cast::<__m256i>(), state.0);
|
||||
_mm256_storeu_si256((&mut buf[4..8]).as_mut_ptr().cast::<__m256i>(), state.1);
|
||||
_mm256_storeu_si256((&mut buf[8..12]).as_mut_ptr().cast::<__m256i>(), state.2);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub unsafe fn apply_sbox(buffer: &mut [u64; 12]) {
|
||||
let mut state = avx2_load(&buffer);
|
||||
state = do_apply_sbox(state);
|
||||
avx2_store(buffer, state);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub unsafe fn apply_inv_sbox(buffer: &mut [u64; 12]) {
|
||||
let mut state = avx2_load(&buffer);
|
||||
state = do_apply_inv_sbox(state);
|
||||
avx2_store(buffer, state);
|
||||
}
|
||||
@@ -11,7 +11,8 @@
|
||||
/// divisions by 2 and repeated modular reductions. This is because of our explicit choice of
|
||||
/// an MDS matrix that has small powers of 2 entries in frequency domain.
|
||||
/// The following implementation has benefited greatly from the discussions and insights of
|
||||
/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero.
|
||||
/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is base on Nabaglo's Plonky2
|
||||
/// implementation.
|
||||
|
||||
// Rescue MDS matrix in frequency domain.
|
||||
// More precisely, this is the output of the three 4-point (real) FFTs of the first column of
|
||||
@@ -26,7 +27,7 @@ const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1];
|
||||
|
||||
// We use split 3 x 4 FFT transform in order to transform our vectors into the frequency domain.
|
||||
#[inline(always)]
|
||||
pub(crate) const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
|
||||
pub const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
|
||||
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = state;
|
||||
|
||||
let (u0, u1, u2) = fft4_real([s0, s3, s6, s9]);
|
||||
@@ -156,9 +157,10 @@ const fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::{Felt, Rpo256, MDS, ZERO};
|
||||
use proptest::prelude::*;
|
||||
|
||||
use super::super::{apply_mds, Felt, MDS, ZERO};
|
||||
|
||||
const STATE_WIDTH: usize = 12;
|
||||
|
||||
#[inline(always)]
|
||||
@@ -185,7 +187,7 @@ mod tests {
|
||||
v2 = v1;
|
||||
|
||||
apply_mds_naive(&mut v1);
|
||||
Rpo256::apply_mds(&mut v2);
|
||||
apply_mds(&mut v2);
|
||||
|
||||
prop_assert_eq!(v1, v2);
|
||||
}
|
||||
214
src/hash/rescue/mds/mod.rs
Normal file
214
src/hash/rescue/mds/mod.rs
Normal file
@@ -0,0 +1,214 @@
|
||||
use super::{Felt, STATE_WIDTH, ZERO};
|
||||
|
||||
mod freq;
|
||||
pub use freq::mds_multiply_freq;
|
||||
|
||||
// MDS MULTIPLICATION
|
||||
// ================================================================================================
|
||||
|
||||
#[inline(always)]
|
||||
pub fn apply_mds(state: &mut [Felt; STATE_WIDTH]) {
|
||||
let mut result = [ZERO; STATE_WIDTH];
|
||||
|
||||
// Using the linearity of the operations we can split the state into a low||high decomposition
|
||||
// and operate on each with no overflow and then combine/reduce the result to a field element.
|
||||
// The no overflow is guaranteed by the fact that the MDS matrix is a small powers of two in
|
||||
// frequency domain.
|
||||
let mut state_l = [0u64; STATE_WIDTH];
|
||||
let mut state_h = [0u64; STATE_WIDTH];
|
||||
|
||||
for r in 0..STATE_WIDTH {
|
||||
let s = state[r].inner();
|
||||
state_h[r] = s >> 32;
|
||||
state_l[r] = (s as u32) as u64;
|
||||
}
|
||||
|
||||
let state_h = mds_multiply_freq(state_h);
|
||||
let state_l = mds_multiply_freq(state_l);
|
||||
|
||||
for r in 0..STATE_WIDTH {
|
||||
let s = state_l[r] as u128 + ((state_h[r] as u128) << 32);
|
||||
let s_hi = (s >> 64) as u64;
|
||||
let s_lo = s as u64;
|
||||
let z = (s_hi << 32) - s_hi;
|
||||
let (res, over) = s_lo.overflowing_add(z);
|
||||
|
||||
result[r] = Felt::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64));
|
||||
}
|
||||
*state = result;
|
||||
}
|
||||
|
||||
// MDS MATRIX
|
||||
// ================================================================================================
|
||||
|
||||
/// RPO MDS matrix
|
||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = [
|
||||
[
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
],
|
||||
[
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
],
|
||||
[
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
],
|
||||
[
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
],
|
||||
[
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
],
|
||||
[
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
],
|
||||
[
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
],
|
||||
[
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
],
|
||||
[
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
],
|
||||
[
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
],
|
||||
[
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
],
|
||||
[
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
],
|
||||
];
|
||||
349
src/hash/rescue/mod.rs
Normal file
349
src/hash/rescue/mod.rs
Normal file
@@ -0,0 +1,349 @@
|
||||
use core::ops::Range;
|
||||
|
||||
use super::{
|
||||
CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO,
|
||||
};
|
||||
|
||||
mod arch;
|
||||
pub use arch::optimized::{add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox};
|
||||
|
||||
mod mds;
|
||||
use mds::{apply_mds, MDS};
|
||||
|
||||
mod rpo;
|
||||
pub use rpo::{Rpo256, RpoDigest};
|
||||
|
||||
mod rpx;
|
||||
pub use rpx::{Rpx256, RpxDigest};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// The number of rounds is set to 7. For the RPO hash functions all rounds are uniform. For the
|
||||
/// RPX hash function, there are 3 different types of rounds.
|
||||
const NUM_ROUNDS: usize = 7;
|
||||
|
||||
/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
const STATE_WIDTH: usize = 12;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11.
|
||||
const RATE_RANGE: Range<usize> = 4..12;
|
||||
const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
|
||||
|
||||
const INPUT1_RANGE: Range<usize> = 4..8;
|
||||
const INPUT2_RANGE: Range<usize> = 8..12;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
const CAPACITY_RANGE: Range<usize> = 0..4;
|
||||
|
||||
/// The output of the hash function is a digest which consists of 4 field elements or 32 bytes.
|
||||
///
|
||||
/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the
|
||||
/// rate portion).
|
||||
const DIGEST_RANGE: Range<usize> = 4..8;
|
||||
const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start;
|
||||
|
||||
/// The number of bytes needed to encoded a digest
|
||||
const DIGEST_BYTES: usize = 32;
|
||||
|
||||
/// The number of byte chunks defining a field element when hashing a sequence of bytes
|
||||
const BINARY_CHUNK_SIZE: usize = 7;
|
||||
|
||||
/// S-Box and Inverse S-Box powers;
|
||||
///
|
||||
/// The constants are defined for tests only because the exponentiations in the code are unrolled
|
||||
/// for efficiency reasons.
|
||||
#[cfg(test)]
|
||||
const ALPHA: u64 = 7;
|
||||
#[cfg(test)]
|
||||
const INV_ALPHA: u64 = 10540996611094048183;
|
||||
|
||||
// SBOX FUNCTION
|
||||
// ================================================================================================
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||
state[0] = state[0].exp7();
|
||||
state[1] = state[1].exp7();
|
||||
state[2] = state[2].exp7();
|
||||
state[3] = state[3].exp7();
|
||||
state[4] = state[4].exp7();
|
||||
state[5] = state[5].exp7();
|
||||
state[6] = state[6].exp7();
|
||||
state[7] = state[7].exp7();
|
||||
state[8] = state[8].exp7();
|
||||
state[9] = state[9].exp7();
|
||||
state[10] = state[10].exp7();
|
||||
state[11] = state[11].exp7();
|
||||
}
|
||||
|
||||
// INVERSE SBOX FUNCTION
|
||||
// ================================================================================================
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||
// compute base^10540996611094048183 using 72 multiplications per array element
|
||||
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
|
||||
|
||||
// compute base^10
|
||||
let mut t1 = *state;
|
||||
t1.iter_mut().for_each(|t| *t = t.square());
|
||||
|
||||
// compute base^100
|
||||
let mut t2 = t1;
|
||||
t2.iter_mut().for_each(|t| *t = t.square());
|
||||
|
||||
// compute base^100100
|
||||
let t3 = exp_acc::<Felt, STATE_WIDTH, 3>(t2, t2);
|
||||
|
||||
// compute base^100100100100
|
||||
let t4 = exp_acc::<Felt, STATE_WIDTH, 6>(t3, t3);
|
||||
|
||||
// compute base^100100100100100100100100
|
||||
let t5 = exp_acc::<Felt, STATE_WIDTH, 12>(t4, t4);
|
||||
|
||||
// compute base^100100100100100100100100100100
|
||||
let t6 = exp_acc::<Felt, STATE_WIDTH, 6>(t5, t3);
|
||||
|
||||
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||
let t7 = exp_acc::<Felt, STATE_WIDTH, 31>(t6, t6);
|
||||
|
||||
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||
for (i, s) in state.iter_mut().enumerate() {
|
||||
let a = (t7[i].square() * t6[i]).square().square();
|
||||
let b = t1[i] * t2[i] * *s;
|
||||
*s = a * b;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn exp_acc<B: StarkField, const N: usize, const M: usize>(
|
||||
base: [B; N],
|
||||
tail: [B; N],
|
||||
) -> [B; N] {
|
||||
let mut result = base;
|
||||
for _ in 0..M {
|
||||
result.iter_mut().for_each(|r| *r = r.square());
|
||||
}
|
||||
result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t);
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) {
|
||||
state.iter_mut().zip(ark).for_each(|(s, &k)| *s += k);
|
||||
}
|
||||
|
||||
// ROUND CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Rescue round constants;
|
||||
/// computed as in [specifications](https://github.com/ASDiscreteMathematics/rpo)
|
||||
///
|
||||
/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the
|
||||
/// first half of RPO round, and ARK2 contains constants for the second half of RPO round.
|
||||
const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(5789762306288267392),
|
||||
Felt::new(6522564764413701783),
|
||||
Felt::new(17809893479458208203),
|
||||
Felt::new(107145243989736508),
|
||||
Felt::new(6388978042437517382),
|
||||
Felt::new(15844067734406016715),
|
||||
Felt::new(9975000513555218239),
|
||||
Felt::new(3344984123768313364),
|
||||
Felt::new(9959189626657347191),
|
||||
Felt::new(12960773468763563665),
|
||||
Felt::new(9602914297752488475),
|
||||
Felt::new(16657542370200465908),
|
||||
],
|
||||
[
|
||||
Felt::new(12987190162843096997),
|
||||
Felt::new(653957632802705281),
|
||||
Felt::new(4441654670647621225),
|
||||
Felt::new(4038207883745915761),
|
||||
Felt::new(5613464648874830118),
|
||||
Felt::new(13222989726778338773),
|
||||
Felt::new(3037761201230264149),
|
||||
Felt::new(16683759727265180203),
|
||||
Felt::new(8337364536491240715),
|
||||
Felt::new(3227397518293416448),
|
||||
Felt::new(8110510111539674682),
|
||||
Felt::new(2872078294163232137),
|
||||
],
|
||||
[
|
||||
Felt::new(18072785500942327487),
|
||||
Felt::new(6200974112677013481),
|
||||
Felt::new(17682092219085884187),
|
||||
Felt::new(10599526828986756440),
|
||||
Felt::new(975003873302957338),
|
||||
Felt::new(8264241093196931281),
|
||||
Felt::new(10065763900435475170),
|
||||
Felt::new(2181131744534710197),
|
||||
Felt::new(6317303992309418647),
|
||||
Felt::new(1401440938888741532),
|
||||
Felt::new(8884468225181997494),
|
||||
Felt::new(13066900325715521532),
|
||||
],
|
||||
[
|
||||
Felt::new(5674685213610121970),
|
||||
Felt::new(5759084860419474071),
|
||||
Felt::new(13943282657648897737),
|
||||
Felt::new(1352748651966375394),
|
||||
Felt::new(17110913224029905221),
|
||||
Felt::new(1003883795902368422),
|
||||
Felt::new(4141870621881018291),
|
||||
Felt::new(8121410972417424656),
|
||||
Felt::new(14300518605864919529),
|
||||
Felt::new(13712227150607670181),
|
||||
Felt::new(17021852944633065291),
|
||||
Felt::new(6252096473787587650),
|
||||
],
|
||||
[
|
||||
Felt::new(4887609836208846458),
|
||||
Felt::new(3027115137917284492),
|
||||
Felt::new(9595098600469470675),
|
||||
Felt::new(10528569829048484079),
|
||||
Felt::new(7864689113198939815),
|
||||
Felt::new(17533723827845969040),
|
||||
Felt::new(5781638039037710951),
|
||||
Felt::new(17024078752430719006),
|
||||
Felt::new(109659393484013511),
|
||||
Felt::new(7158933660534805869),
|
||||
Felt::new(2955076958026921730),
|
||||
Felt::new(7433723648458773977),
|
||||
],
|
||||
[
|
||||
Felt::new(16308865189192447297),
|
||||
Felt::new(11977192855656444890),
|
||||
Felt::new(12532242556065780287),
|
||||
Felt::new(14594890931430968898),
|
||||
Felt::new(7291784239689209784),
|
||||
Felt::new(5514718540551361949),
|
||||
Felt::new(10025733853830934803),
|
||||
Felt::new(7293794580341021693),
|
||||
Felt::new(6728552937464861756),
|
||||
Felt::new(6332385040983343262),
|
||||
Felt::new(13277683694236792804),
|
||||
Felt::new(2600778905124452676),
|
||||
],
|
||||
[
|
||||
Felt::new(7123075680859040534),
|
||||
Felt::new(1034205548717903090),
|
||||
Felt::new(7717824418247931797),
|
||||
Felt::new(3019070937878604058),
|
||||
Felt::new(11403792746066867460),
|
||||
Felt::new(10280580802233112374),
|
||||
Felt::new(337153209462421218),
|
||||
Felt::new(13333398568519923717),
|
||||
Felt::new(3596153696935337464),
|
||||
Felt::new(8104208463525993784),
|
||||
Felt::new(14345062289456085693),
|
||||
Felt::new(17036731477169661256),
|
||||
],
|
||||
];
|
||||
|
||||
const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(6077062762357204287),
|
||||
Felt::new(15277620170502011191),
|
||||
Felt::new(5358738125714196705),
|
||||
Felt::new(14233283787297595718),
|
||||
Felt::new(13792579614346651365),
|
||||
Felt::new(11614812331536767105),
|
||||
Felt::new(14871063686742261166),
|
||||
Felt::new(10148237148793043499),
|
||||
Felt::new(4457428952329675767),
|
||||
Felt::new(15590786458219172475),
|
||||
Felt::new(10063319113072092615),
|
||||
Felt::new(14200078843431360086),
|
||||
],
|
||||
[
|
||||
Felt::new(6202948458916099932),
|
||||
Felt::new(17690140365333231091),
|
||||
Felt::new(3595001575307484651),
|
||||
Felt::new(373995945117666487),
|
||||
Felt::new(1235734395091296013),
|
||||
Felt::new(14172757457833931602),
|
||||
Felt::new(707573103686350224),
|
||||
Felt::new(15453217512188187135),
|
||||
Felt::new(219777875004506018),
|
||||
Felt::new(17876696346199469008),
|
||||
Felt::new(17731621626449383378),
|
||||
Felt::new(2897136237748376248),
|
||||
],
|
||||
[
|
||||
Felt::new(8023374565629191455),
|
||||
Felt::new(15013690343205953430),
|
||||
Felt::new(4485500052507912973),
|
||||
Felt::new(12489737547229155153),
|
||||
Felt::new(9500452585969030576),
|
||||
Felt::new(2054001340201038870),
|
||||
Felt::new(12420704059284934186),
|
||||
Felt::new(355990932618543755),
|
||||
Felt::new(9071225051243523860),
|
||||
Felt::new(12766199826003448536),
|
||||
Felt::new(9045979173463556963),
|
||||
Felt::new(12934431667190679898),
|
||||
],
|
||||
[
|
||||
Felt::new(18389244934624494276),
|
||||
Felt::new(16731736864863925227),
|
||||
Felt::new(4440209734760478192),
|
||||
Felt::new(17208448209698888938),
|
||||
Felt::new(8739495587021565984),
|
||||
Felt::new(17000774922218161967),
|
||||
Felt::new(13533282547195532087),
|
||||
Felt::new(525402848358706231),
|
||||
Felt::new(16987541523062161972),
|
||||
Felt::new(5466806524462797102),
|
||||
Felt::new(14512769585918244983),
|
||||
Felt::new(10973956031244051118),
|
||||
],
|
||||
[
|
||||
Felt::new(6982293561042362913),
|
||||
Felt::new(14065426295947720331),
|
||||
Felt::new(16451845770444974180),
|
||||
Felt::new(7139138592091306727),
|
||||
Felt::new(9012006439959783127),
|
||||
Felt::new(14619614108529063361),
|
||||
Felt::new(1394813199588124371),
|
||||
Felt::new(4635111139507788575),
|
||||
Felt::new(16217473952264203365),
|
||||
Felt::new(10782018226466330683),
|
||||
Felt::new(6844229992533662050),
|
||||
Felt::new(7446486531695178711),
|
||||
],
|
||||
[
|
||||
Felt::new(3736792340494631448),
|
||||
Felt::new(577852220195055341),
|
||||
Felt::new(6689998335515779805),
|
||||
Felt::new(13886063479078013492),
|
||||
Felt::new(14358505101923202168),
|
||||
Felt::new(7744142531772274164),
|
||||
Felt::new(16135070735728404443),
|
||||
Felt::new(12290902521256031137),
|
||||
Felt::new(12059913662657709804),
|
||||
Felt::new(16456018495793751911),
|
||||
Felt::new(4571485474751953524),
|
||||
Felt::new(17200392109565783176),
|
||||
],
|
||||
[
|
||||
Felt::new(17130398059294018733),
|
||||
Felt::new(519782857322261988),
|
||||
Felt::new(9625384390925085478),
|
||||
Felt::new(1664893052631119222),
|
||||
Felt::new(7629576092524553570),
|
||||
Felt::new(3485239601103661425),
|
||||
Felt::new(9755891797164033838),
|
||||
Felt::new(15218148195153269027),
|
||||
Felt::new(16460604813734957368),
|
||||
Felt::new(9643968136937729763),
|
||||
Felt::new(3611348709641382851),
|
||||
Felt::new(18256379591337759196),
|
||||
],
|
||||
];
|
||||
@@ -1,13 +1,14 @@
|
||||
use super::{Digest, Felt, StarkField, DIGEST_SIZE, ZERO};
|
||||
use crate::utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
};
|
||||
use alloc::string::String;
|
||||
use core::{cmp::Ordering, fmt::Display, ops::Deref};
|
||||
use winter_utils::Randomizable;
|
||||
|
||||
/// The number of bytes needed to encoded a digest
|
||||
pub const DIGEST_BYTES: usize = 32;
|
||||
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
|
||||
use crate::{
|
||||
rand::Randomizable,
|
||||
utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
},
|
||||
};
|
||||
|
||||
// DIGEST TRAIT IMPLEMENTATIONS
|
||||
// ================================================================================================
|
||||
@@ -36,6 +37,11 @@ impl RpoDigest {
|
||||
{
|
||||
digests.flat_map(|d| d.0.iter())
|
||||
}
|
||||
|
||||
/// Returns hexadecimal representation of this digest prefixed with `0x`.
|
||||
pub fn to_hex(&self) -> String {
|
||||
bytes_to_hex_string(self.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl Digest for RpoDigest {
|
||||
@@ -161,7 +167,7 @@ impl From<RpoDigest> for [u8; DIGEST_BYTES] {
|
||||
impl From<RpoDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
bytes_to_hex_string(value.as_bytes())
|
||||
value.to_hex()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,9 +178,21 @@ impl From<&RpoDigest> for String {
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: TO DIGEST
|
||||
// CONVERSIONS: TO RPO DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum RpoDigestError {
|
||||
/// The provided u64 integer does not fit in the field's moduli.
|
||||
InvalidInteger,
|
||||
}
|
||||
|
||||
impl From<&[Felt; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: &[Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(*value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[Felt; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
@@ -200,6 +218,43 @@ impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8; DIGEST_BYTES]> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u64; DIGEST_SIZE]> for RpoDigest {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
|
||||
Ok(Self([
|
||||
value[0].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
|
||||
value[1].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
|
||||
value[2].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
|
||||
value[3].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
|
||||
]))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u64; DIGEST_SIZE]> for RpoDigest {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
@@ -253,15 +308,28 @@ impl Deserializable for RpoDigest {
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
impl IntoIterator for RpoDigest {
|
||||
type Item = Felt;
|
||||
type IntoIter = <[Felt; 4] as IntoIterator>::IntoIter;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.0.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES};
|
||||
use crate::utils::SliceReader;
|
||||
use alloc::string::String;
|
||||
use rand_utils::rand_value;
|
||||
|
||||
use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
|
||||
use crate::utils::SliceReader;
|
||||
|
||||
#[test]
|
||||
fn digest_serialization() {
|
||||
let e1 = Felt::new(rand_value());
|
||||
@@ -281,7 +349,6 @@ mod tests {
|
||||
assert_eq!(d1, d2);
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn digest_encoding() {
|
||||
let digest = RpoDigest([
|
||||
@@ -296,4 +363,54 @@ mod tests {
|
||||
|
||||
assert_eq!(digest, round_trip);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversions() {
|
||||
let digest = RpoDigest([
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = digest.into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = (&digest).into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpoDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpoDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
}
|
||||
}
|
||||
318
src/hash/rescue/rpo/mod.rs
Normal file
318
src/hash/rescue/rpo/mod.rs
Normal file
@@ -0,0 +1,318 @@
|
||||
use core::ops::Range;
|
||||
|
||||
use super::{
|
||||
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
|
||||
apply_mds, apply_sbox, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ARK1,
|
||||
ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, DIGEST_SIZE, INPUT1_RANGE,
|
||||
INPUT2_RANGE, MDS, NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO,
|
||||
};
|
||||
|
||||
mod digest;
|
||||
pub use digest::RpoDigest;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// HASHER IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
/// Implementation of the Rescue Prime Optimized hash function with 256-bit output.
|
||||
///
|
||||
/// The hash function is implemented according to the Rescue Prime Optimized
|
||||
/// [specifications](https://eprint.iacr.org/2022/1577)
|
||||
///
|
||||
/// The parameters used to instantiate the function are:
|
||||
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
|
||||
/// * State width: 12 field elements.
|
||||
/// * Capacity size: 4 field elements.
|
||||
/// * Number of founds: 7.
|
||||
/// * S-Box degree: 7.
|
||||
///
|
||||
/// The above parameters target a 128-bit security level. The digest consists of four field elements
|
||||
/// and it can be serialized into 32 bytes (256 bits).
|
||||
///
|
||||
/// ## Hash output consistency
|
||||
/// Functions [hash_elements()](Rpo256::hash_elements), [merge()](Rpo256::merge), and
|
||||
/// [merge_with_int()](Rpo256::merge_with_int) are internally consistent. That is, computing
|
||||
/// a hash for the same set of elements using these functions will always produce the same
|
||||
/// result. For example, merging two digests using [merge()](Rpo256::merge) will produce the
|
||||
/// same result as hashing 8 elements which make up these digests using
|
||||
/// [hash_elements()](Rpo256::hash_elements) function.
|
||||
///
|
||||
/// However, [hash()](Rpo256::hash) function is not consistent with functions mentioned above.
|
||||
/// For example, if we take two field elements, serialize them to bytes and hash them using
|
||||
/// [hash()](Rpo256::hash), the result will differ from the result obtained by hashing these
|
||||
/// elements directly using [hash_elements()](Rpo256::hash_elements) function. The reason for
|
||||
/// this difference is that [hash()](Rpo256::hash) function needs to be able to handle
|
||||
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
|
||||
/// deserialization procedure used by this function is different from the procedure used to
|
||||
/// deserialize valid field elements.
|
||||
///
|
||||
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
||||
/// to deserialize them into field elements and then hash them using
|
||||
/// [hash_elements()](Rpo256::hash_elements) function rather then hashing the serialized bytes
|
||||
/// using [hash()](Rpo256::hash) function.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Rpo256();
|
||||
|
||||
impl Hasher for Rpo256 {
|
||||
/// Rpo256 collision resistance is 128-bits.
|
||||
const COLLISION_RESISTANCE: u32 = 128;
|
||||
|
||||
type Digest = RpoDigest;
|
||||
|
||||
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||
// initialize the state with zeroes
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
// set the capacity (first element) to a flag on whether or not the input length is evenly
|
||||
// divided by the rate. this will prevent collisions between padded and non-padded inputs,
|
||||
// and will rule out the need to perform an extra permutation in case of evenly divided
|
||||
// inputs.
|
||||
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0;
|
||||
if !is_rate_multiple {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// initialize a buffer to receive the little-endian elements.
|
||||
let mut buf = [0_u8; 8];
|
||||
|
||||
// iterate the chunks of bytes, creating a field element from each chunk and copying it
|
||||
// into the state.
|
||||
//
|
||||
// every time the rate range is filled, a permutation is performed. if the final value of
|
||||
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
|
||||
// additional permutation must be performed.
|
||||
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
|
||||
// the last element of the iteration may or may not be a full chunk. if it's not, then
|
||||
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`.
|
||||
// this will avoid collisions.
|
||||
if chunk.len() == BINARY_CHUNK_SIZE {
|
||||
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||
} else {
|
||||
buf.fill(0);
|
||||
buf[..chunk.len()].copy_from_slice(chunk);
|
||||
buf[chunk.len()] = 1;
|
||||
}
|
||||
|
||||
// set the current rate element to the input. since we take at most 7 bytes, we are
|
||||
// guaranteed that the inputs data will fit into a single field element.
|
||||
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf));
|
||||
|
||||
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
||||
// counter to the beginning of the range.
|
||||
if i == RATE_WIDTH - 1 {
|
||||
Self::apply_permutation(&mut state);
|
||||
0
|
||||
} else {
|
||||
i + 1
|
||||
}
|
||||
});
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we
|
||||
// don't need to apply any extra padding because the first capacity element contains a
|
||||
// flag indicating whether the input is evenly divisible by the rate.
|
||||
if i != 0 {
|
||||
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the rate as hash result.
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = Self::Digest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
// initialize the state as follows:
|
||||
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element
|
||||
// and set the sixth rate element to 1.
|
||||
// - if the value doesn't fit into a single field element, split it into two field
|
||||
// elements, copy them into rate elements 5 and 6, and set the seventh rate element
|
||||
// to 1.
|
||||
// - set the first capacity element to 1
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||
if value < Felt::MODULUS {
|
||||
state[INPUT2_RANGE.start + 1] = ONE;
|
||||
} else {
|
||||
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
||||
state[INPUT2_RANGE.start + 2] = ONE;
|
||||
}
|
||||
|
||||
// common padding for both cases
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementHasher for Rpo256 {
|
||||
type BaseField = Felt;
|
||||
|
||||
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
|
||||
// convert the elements into a list of base field elements
|
||||
let elements = E::slice_as_base_elements(elements);
|
||||
|
||||
// initialize state to all zeros, except for the first element of the capacity part, which
|
||||
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
if elements.len() % RATE_WIDTH != 0 {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// absorb elements into the state one by one until the rate portion of the state is filled
|
||||
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
||||
// elements have been absorbed
|
||||
let mut i = 0;
|
||||
for &element in elements.iter() {
|
||||
state[RATE_RANGE.start + i] = element;
|
||||
i += 1;
|
||||
if i % RATE_WIDTH == 0 {
|
||||
Self::apply_permutation(&mut state);
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after
|
||||
// padding by appending a 1 followed by as many 0 as necessary to make the input length a
|
||||
// multiple of the RATE_WIDTH.
|
||||
if i > 0 {
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
i += 1;
|
||||
while i != RATE_WIDTH {
|
||||
state[RATE_RANGE.start + i] = ZERO;
|
||||
i += 1;
|
||||
}
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the state as hash result
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
// HASH FUNCTION IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Rpo256 {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// The number of rounds is set to 7 to target 128-bit security level.
|
||||
pub const NUM_ROUNDS: usize = NUM_ROUNDS;
|
||||
|
||||
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
pub const STATE_WIDTH: usize = STATE_WIDTH;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
|
||||
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
|
||||
|
||||
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
|
||||
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
|
||||
|
||||
/// MDS matrix used for computing the linear layer in a RPO round.
|
||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
|
||||
|
||||
/// Round constants added to the hasher state in the first half of the RPO round.
|
||||
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
|
||||
|
||||
/// Round constants added to the hasher state in the second half of the RPO round.
|
||||
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
|
||||
|
||||
// TRAIT PASS-THROUGH FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of the provided sequence of bytes.
|
||||
#[inline(always)]
|
||||
pub fn hash(bytes: &[u8]) -> RpoDigest {
|
||||
<Self as Hasher>::hash(bytes)
|
||||
}
|
||||
|
||||
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||
/// Merkle trees and verification of Merkle paths.
|
||||
#[inline(always)]
|
||||
pub fn merge(values: &[RpoDigest; 2]) -> RpoDigest {
|
||||
<Self as Hasher>::merge(values)
|
||||
}
|
||||
|
||||
/// Returns a hash of the provided field elements.
|
||||
#[inline(always)]
|
||||
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpoDigest {
|
||||
<Self as ElementHasher>::hash_elements(elements)
|
||||
}
|
||||
|
||||
// DOMAIN IDENTIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of two digests and a domain identifier.
|
||||
pub fn merge_in_domain(values: &[RpoDigest; 2], domain: Felt) -> RpoDigest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = RpoDigest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// set the second capacity element to the domain value. The first capacity element is used
|
||||
// for padding purposes.
|
||||
state[CAPACITY_RANGE.start + 1] = domain;
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
// RESCUE PERMUTATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Applies RPO permutation to the provided state.
|
||||
#[inline(always)]
|
||||
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
|
||||
for i in 0..NUM_ROUNDS {
|
||||
Self::apply_round(state, i);
|
||||
}
|
||||
}
|
||||
|
||||
/// RPO round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
// apply first half of RPO round
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_sbox(state, &ARK1[round]) {
|
||||
add_constants(state, &ARK1[round]);
|
||||
apply_sbox(state);
|
||||
}
|
||||
|
||||
// apply second half of RPO round
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
|
||||
add_constants(state, &ARK2[round]);
|
||||
apply_inv_sbox(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,21 +1,12 @@
|
||||
use super::{
|
||||
Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, ALPHA, INV_ALPHA, ONE, STATE_WIDTH,
|
||||
ZERO,
|
||||
};
|
||||
use crate::{
|
||||
utils::collections::{BTreeSet, Vec},
|
||||
Word,
|
||||
};
|
||||
use core::convert::TryInto;
|
||||
use proptest::prelude::*;
|
||||
use rand_utils::rand_value;
|
||||
|
||||
#[test]
|
||||
fn test_alphas() {
|
||||
let e: Felt = Felt::new(rand_value());
|
||||
let e_exp = e.exp(ALPHA);
|
||||
assert_eq!(e, e_exp.exp(INV_ALPHA));
|
||||
}
|
||||
use super::{
|
||||
super::{apply_inv_sbox, apply_sbox, ALPHA, INV_ALPHA},
|
||||
Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, ONE, STATE_WIDTH, ZERO,
|
||||
};
|
||||
use crate::Word;
|
||||
use alloc::{collections::BTreeSet, vec::Vec};
|
||||
|
||||
#[test]
|
||||
fn test_sbox() {
|
||||
@@ -25,7 +16,7 @@ fn test_sbox() {
|
||||
expected.iter_mut().for_each(|v| *v = v.exp(ALPHA));
|
||||
|
||||
let mut actual = state;
|
||||
Rpo256::apply_sbox(&mut actual);
|
||||
apply_sbox(&mut actual);
|
||||
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
@@ -38,7 +29,7 @@ fn test_inv_sbox() {
|
||||
expected.iter_mut().for_each(|v| *v = v.exp(INV_ALPHA));
|
||||
|
||||
let mut actual = state;
|
||||
Rpo256::apply_inv_sbox(&mut actual);
|
||||
apply_inv_sbox(&mut actual);
|
||||
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
406
src/hash/rescue/rpx/digest.rs
Normal file
406
src/hash/rescue/rpx/digest.rs
Normal file
@@ -0,0 +1,406 @@
|
||||
use alloc::string::String;
|
||||
use core::{cmp::Ordering, fmt::Display, ops::Deref};
|
||||
|
||||
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
|
||||
use crate::{
|
||||
rand::Randomizable,
|
||||
utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
},
|
||||
};
|
||||
|
||||
// DIGEST TRAIT IMPLEMENTATIONS
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
|
||||
pub struct RpxDigest([Felt; DIGEST_SIZE]);
|
||||
|
||||
impl RpxDigest {
|
||||
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
|
||||
pub fn as_elements(&self) -> &[Felt] {
|
||||
self.as_ref()
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
|
||||
<Self as Digest>::as_bytes(self)
|
||||
}
|
||||
|
||||
pub fn digests_as_elements<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
|
||||
where
|
||||
I: Iterator<Item = &'a Self>,
|
||||
{
|
||||
digests.flat_map(|d| d.0.iter())
|
||||
}
|
||||
|
||||
/// Returns hexadecimal representation of this digest prefixed with `0x`.
|
||||
pub fn to_hex(&self) -> String {
|
||||
bytes_to_hex_string(self.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl Digest for RpxDigest {
|
||||
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
|
||||
let mut result = [0; DIGEST_BYTES];
|
||||
|
||||
result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes());
|
||||
result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes());
|
||||
result[16..24].copy_from_slice(&self.0[2].as_int().to_le_bytes());
|
||||
result[24..].copy_from_slice(&self.0[3].as_int().to_le_bytes());
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for RpxDigest {
|
||||
type Target = [Felt; DIGEST_SIZE];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for RpxDigest {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
// compare the inner u64 of both elements.
|
||||
//
|
||||
// it will iterate the elements and will return the first computation different than
|
||||
// `Equal`. Otherwise, the ordering is equal.
|
||||
//
|
||||
// the endianness is irrelevant here because since, this being a cryptographically secure
|
||||
// hash computation, the digest shouldn't have any ordered property of its input.
|
||||
//
|
||||
// finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a
|
||||
// montgomery reduction for every limb. that is safe because every inner element of the
|
||||
// digest is guaranteed to be in its canonical form (that is, `x in [0,p)`).
|
||||
self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold(
|
||||
Ordering::Equal,
|
||||
|ord, (a, b)| match ord {
|
||||
Ordering::Equal => a.cmp(&b),
|
||||
_ => ord,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for RpxDigest {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for RpxDigest {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
let encoded: String = self.into();
|
||||
write!(f, "{}", encoded)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Randomizable for RpxDigest {
|
||||
const VALUE_SIZE: usize = DIGEST_BYTES;
|
||||
|
||||
fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
|
||||
let bytes_array: Option<[u8; 32]> = bytes.try_into().ok();
|
||||
if let Some(bytes_array) = bytes_array {
|
||||
Self::try_from(bytes_array).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: FROM RPX DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
impl From<&RpxDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for [u64; DIGEST_SIZE] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
[
|
||||
value.0[0].as_int(),
|
||||
value.0[1].as_int(),
|
||||
value.0[2].as_int(),
|
||||
value.0[3].as_int(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for [u64; DIGEST_SIZE] {
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
[
|
||||
value.0[0].as_int(),
|
||||
value.0[1].as_int(),
|
||||
value.0[2].as_int(),
|
||||
value.0[3].as_int(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for [u8; DIGEST_BYTES] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
value.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for [u8; DIGEST_BYTES] {
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
value.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
value.to_hex()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: TO RPX DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum RpxDigestError {
|
||||
/// The provided u64 integer does not fit in the field's moduli.
|
||||
InvalidInteger,
|
||||
}
|
||||
|
||||
impl From<&[Felt; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: &[Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(*value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[Felt; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u8; DIGEST_BYTES]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: [u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
// Note: the input length is known, the conversion from slice to array must succeed so the
|
||||
// `unwrap`s below are safe
|
||||
let a = u64::from_le_bytes(value[0..8].try_into().unwrap());
|
||||
let b = u64::from_le_bytes(value[8..16].try_into().unwrap());
|
||||
let c = u64::from_le_bytes(value[16..24].try_into().unwrap());
|
||||
let d = u64::from_le_bytes(value[24..32].try_into().unwrap());
|
||||
|
||||
if [a, b, c, d].iter().any(|v| *v >= Felt::MODULUS) {
|
||||
return Err(HexParseError::OutOfRange);
|
||||
}
|
||||
|
||||
Ok(RpxDigest([Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8; DIGEST_BYTES]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u64; DIGEST_SIZE]> for RpxDigest {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
|
||||
Ok(Self([
|
||||
value[0].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
|
||||
value[1].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
|
||||
value[2].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
|
||||
value[3].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
|
||||
]))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u64; DIGEST_SIZE]> for RpxDigest {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
hex_to_bytes(value).and_then(|v| v.try_into())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<String> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: String) -> Result<Self, Self::Error> {
|
||||
value.as_str().try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&String> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: &String) -> Result<Self, Self::Error> {
|
||||
value.as_str().try_into()
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for RpxDigest {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.as_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RpxDigest {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE];
|
||||
for inner in inner.iter_mut() {
|
||||
let e = source.read_u64()?;
|
||||
if e >= Felt::MODULUS {
|
||||
return Err(DeserializationError::InvalidValue(String::from(
|
||||
"Value not in the appropriate range",
|
||||
)));
|
||||
}
|
||||
*inner = Felt::new(e);
|
||||
}
|
||||
|
||||
Ok(Self(inner))
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use alloc::string::String;
|
||||
use rand_utils::rand_value;
|
||||
|
||||
use super::{Deserializable, Felt, RpxDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
|
||||
use crate::utils::SliceReader;
|
||||
|
||||
#[test]
|
||||
fn digest_serialization() {
|
||||
let e1 = Felt::new(rand_value());
|
||||
let e2 = Felt::new(rand_value());
|
||||
let e3 = Felt::new(rand_value());
|
||||
let e4 = Felt::new(rand_value());
|
||||
|
||||
let d1 = RpxDigest([e1, e2, e3, e4]);
|
||||
|
||||
let mut bytes = vec![];
|
||||
d1.write_into(&mut bytes);
|
||||
assert_eq!(DIGEST_BYTES, bytes.len());
|
||||
|
||||
let mut reader = SliceReader::new(&bytes);
|
||||
let d2 = RpxDigest::read_from(&mut reader).unwrap();
|
||||
|
||||
assert_eq!(d1, d2);
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn digest_encoding() {
|
||||
let digest = RpxDigest([
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
let string: String = digest.into();
|
||||
let round_trip: RpxDigest = string.try_into().expect("decoding failed");
|
||||
|
||||
assert_eq!(digest, round_trip);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversions() {
|
||||
let digest = RpxDigest([
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = digest.into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = (&digest).into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpxDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpxDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
}
|
||||
}
|
||||
353
src/hash/rescue/rpx/mod.rs
Normal file
353
src/hash/rescue/rpx/mod.rs
Normal file
@@ -0,0 +1,353 @@
|
||||
use core::ops::Range;
|
||||
|
||||
use super::{
|
||||
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
|
||||
apply_mds, apply_sbox, CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher,
|
||||
StarkField, ARK1, ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE,
|
||||
DIGEST_SIZE, INPUT1_RANGE, INPUT2_RANGE, MDS, NUM_ROUNDS, RATE_RANGE, RATE_WIDTH, STATE_WIDTH,
|
||||
ZERO,
|
||||
};
|
||||
|
||||
mod digest;
|
||||
pub use digest::RpxDigest;
|
||||
|
||||
pub type CubicExtElement = CubeExtension<Felt>;
|
||||
|
||||
// HASHER IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
/// Implementation of the Rescue Prime eXtension hash function with 256-bit output.
|
||||
///
|
||||
/// The hash function is based on the XHash12 construction in [specifications](https://eprint.iacr.org/2023/1045)
|
||||
///
|
||||
/// The parameters used to instantiate the function are:
|
||||
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
|
||||
/// * State width: 12 field elements.
|
||||
/// * Capacity size: 4 field elements.
|
||||
/// * S-Box degree: 7.
|
||||
/// * Rounds: There are 3 different types of rounds:
|
||||
/// - (FB): `apply_mds` → `add_constants` → `apply_sbox` → `apply_mds` → `add_constants` → `apply_inv_sbox`.
|
||||
/// - (E): `add_constants` → `ext_sbox` (which is raising to power 7 in the degree 3 extension field).
|
||||
/// - (M): `apply_mds` → `add_constants`.
|
||||
/// * Permutation: (FB) (E) (FB) (E) (FB) (E) (M).
|
||||
///
|
||||
/// The above parameters target a 128-bit security level. The digest consists of four field elements
|
||||
/// and it can be serialized into 32 bytes (256 bits).
|
||||
///
|
||||
/// ## Hash output consistency
|
||||
/// Functions [hash_elements()](Rpx256::hash_elements), [merge()](Rpx256::merge), and
|
||||
/// [merge_with_int()](Rpx256::merge_with_int) are internally consistent. That is, computing
|
||||
/// a hash for the same set of elements using these functions will always produce the same
|
||||
/// result. For example, merging two digests using [merge()](Rpx256::merge) will produce the
|
||||
/// same result as hashing 8 elements which make up these digests using
|
||||
/// [hash_elements()](Rpx256::hash_elements) function.
|
||||
///
|
||||
/// However, [hash()](Rpx256::hash) function is not consistent with functions mentioned above.
|
||||
/// For example, if we take two field elements, serialize them to bytes and hash them using
|
||||
/// [hash()](Rpx256::hash), the result will differ from the result obtained by hashing these
|
||||
/// elements directly using [hash_elements()](Rpx256::hash_elements) function. The reason for
|
||||
/// this difference is that [hash()](Rpx256::hash) function needs to be able to handle
|
||||
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
|
||||
/// deserialization procedure used by this function is different from the procedure used to
|
||||
/// deserialize valid field elements.
|
||||
///
|
||||
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
||||
/// to deserialize them into field elements and then hash them using
|
||||
/// [hash_elements()](Rpx256::hash_elements) function rather then hashing the serialized bytes
|
||||
/// using [hash()](Rpx256::hash) function.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Rpx256();
|
||||
|
||||
impl Hasher for Rpx256 {
|
||||
/// Rpx256 collision resistance is 128-bits.
|
||||
const COLLISION_RESISTANCE: u32 = 128;
|
||||
|
||||
type Digest = RpxDigest;
|
||||
|
||||
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||
// initialize the state with zeroes
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
// determine the number of field elements needed to encode `bytes` when each field element
|
||||
// represents at most 7 bytes.
|
||||
let num_field_elem = bytes.len().div_ceil(BINARY_CHUNK_SIZE);
|
||||
|
||||
// set the first capacity element to `RATE_WIDTH + (num_field_elem % RATE_WIDTH)`. We do
|
||||
// this to achieve:
|
||||
// 1. Domain separating hashing of `[u8]` from hashing of `[Felt]`.
|
||||
// 2. Avoiding collisions at the `[Felt]` representation of the encoded bytes.
|
||||
state[CAPACITY_RANGE.start] =
|
||||
Felt::from((RATE_WIDTH + (num_field_elem % RATE_WIDTH)) as u8);
|
||||
|
||||
// initialize a buffer to receive the little-endian elements.
|
||||
let mut buf = [0_u8; 8];
|
||||
|
||||
// iterate the chunks of bytes, creating a field element from each chunk and copying it
|
||||
// into the state.
|
||||
//
|
||||
// every time the rate range is filled, a permutation is performed. if the final value of
|
||||
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
|
||||
// additional permutation must be performed.
|
||||
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
|
||||
// copy the chunk into the buffer
|
||||
if i != num_field_elem - 1 {
|
||||
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||
} else {
|
||||
// on the last iteration, we pad `buf` with a 1 followed by as many 0's as are
|
||||
// needed to fill it
|
||||
buf.fill(0);
|
||||
buf[..chunk.len()].copy_from_slice(chunk);
|
||||
buf[chunk.len()] = 1;
|
||||
}
|
||||
|
||||
// set the current rate element to the input. since we take at most 7 bytes, we are
|
||||
// guaranteed that the inputs data will fit into a single field element.
|
||||
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf));
|
||||
|
||||
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
||||
// counter to the beginning of the range.
|
||||
if i == RATE_WIDTH - 1 {
|
||||
Self::apply_permutation(&mut state);
|
||||
0
|
||||
} else {
|
||||
i + 1
|
||||
}
|
||||
});
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation. we
|
||||
// don't need to apply any extra padding because the first capacity element contains a
|
||||
// flag indicating the number of field elements constituting the last block when the latter
|
||||
// is not divisible by `RATE_WIDTH`.
|
||||
if i != 0 {
|
||||
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the rate as hash result.
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = Self::Digest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// apply the RPX permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
// initialize the state as follows:
|
||||
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element and
|
||||
// set the first capacity element to 5.
|
||||
// - if the value doesn't fit into a single field element, split it into two field
|
||||
// elements, copy them into rate elements 5 and 6 and set the first capacity element to 6.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||
if value < Felt::MODULUS {
|
||||
state[CAPACITY_RANGE.start] = Felt::from(5_u8);
|
||||
} else {
|
||||
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
||||
state[CAPACITY_RANGE.start] = Felt::from(6_u8);
|
||||
}
|
||||
|
||||
// apply the RPX permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementHasher for Rpx256 {
|
||||
type BaseField = Felt;
|
||||
|
||||
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
|
||||
// convert the elements into a list of base field elements
|
||||
let elements = E::slice_as_base_elements(elements);
|
||||
|
||||
// initialize state to all zeros, except for the first element of the capacity part, which
|
||||
// is set to `elements.len() % RATE_WIDTH`.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[CAPACITY_RANGE.start] = Self::BaseField::from((elements.len() % RATE_WIDTH) as u8);
|
||||
|
||||
// absorb elements into the state one by one until the rate portion of the state is filled
|
||||
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
||||
// elements have been absorbed
|
||||
let mut i = 0;
|
||||
for &element in elements.iter() {
|
||||
state[RATE_RANGE.start + i] = element;
|
||||
i += 1;
|
||||
if i % RATE_WIDTH == 0 {
|
||||
Self::apply_permutation(&mut state);
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation after
|
||||
// padding by as many 0 as necessary to make the input length a multiple of the RATE_WIDTH.
|
||||
if i > 0 {
|
||||
while i != RATE_WIDTH {
|
||||
state[RATE_RANGE.start + i] = ZERO;
|
||||
i += 1;
|
||||
}
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the state as hash result
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
// HASH FUNCTION IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Rpx256 {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
pub const STATE_WIDTH: usize = STATE_WIDTH;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
|
||||
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
|
||||
|
||||
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
|
||||
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
|
||||
|
||||
/// MDS matrix used for computing the linear layer in the (FB) and (E) rounds.
|
||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
|
||||
|
||||
/// Round constants added to the hasher state in the first half of the round.
|
||||
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
|
||||
|
||||
/// Round constants added to the hasher state in the second half of the round.
|
||||
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
|
||||
|
||||
// TRAIT PASS-THROUGH FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of the provided sequence of bytes.
|
||||
#[inline(always)]
|
||||
pub fn hash(bytes: &[u8]) -> RpxDigest {
|
||||
<Self as Hasher>::hash(bytes)
|
||||
}
|
||||
|
||||
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||
/// Merkle trees and verification of Merkle paths.
|
||||
#[inline(always)]
|
||||
pub fn merge(values: &[RpxDigest; 2]) -> RpxDigest {
|
||||
<Self as Hasher>::merge(values)
|
||||
}
|
||||
|
||||
/// Returns a hash of the provided field elements.
|
||||
#[inline(always)]
|
||||
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpxDigest {
|
||||
<Self as ElementHasher>::hash_elements(elements)
|
||||
}
|
||||
|
||||
// DOMAIN IDENTIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of two digests and a domain identifier.
|
||||
pub fn merge_in_domain(values: &[RpxDigest; 2], domain: Felt) -> RpxDigest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = RpxDigest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// set the second capacity element to the domain value. The first capacity element is used
|
||||
// for padding purposes.
|
||||
state[CAPACITY_RANGE.start + 1] = domain;
|
||||
|
||||
// apply the RPX permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
// RPX PERMUTATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Applies RPX permutation to the provided state.
|
||||
#[inline(always)]
|
||||
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
|
||||
Self::apply_fb_round(state, 0);
|
||||
Self::apply_ext_round(state, 1);
|
||||
Self::apply_fb_round(state, 2);
|
||||
Self::apply_ext_round(state, 3);
|
||||
Self::apply_fb_round(state, 4);
|
||||
Self::apply_ext_round(state, 5);
|
||||
Self::apply_final_round(state, 6);
|
||||
}
|
||||
|
||||
// RPX PERMUTATION ROUND FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// (FB) round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_fb_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_sbox(state, &ARK1[round]) {
|
||||
add_constants(state, &ARK1[round]);
|
||||
apply_sbox(state);
|
||||
}
|
||||
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
|
||||
add_constants(state, &ARK2[round]);
|
||||
apply_inv_sbox(state);
|
||||
}
|
||||
}
|
||||
|
||||
/// (E) round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_ext_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
// add constants
|
||||
add_constants(state, &ARK1[round]);
|
||||
|
||||
// decompose the state into 4 elements in the cubic extension field and apply the power 7
|
||||
// map to each of the elements
|
||||
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = *state;
|
||||
let ext0 = Self::exp7(CubicExtElement::new(s0, s1, s2));
|
||||
let ext1 = Self::exp7(CubicExtElement::new(s3, s4, s5));
|
||||
let ext2 = Self::exp7(CubicExtElement::new(s6, s7, s8));
|
||||
let ext3 = Self::exp7(CubicExtElement::new(s9, s10, s11));
|
||||
|
||||
// decompose the state back into 12 base field elements
|
||||
let arr_ext = [ext0, ext1, ext2, ext3];
|
||||
*state = CubicExtElement::slice_as_base_elements(&arr_ext)
|
||||
.try_into()
|
||||
.expect("shouldn't fail");
|
||||
}
|
||||
|
||||
/// (M) round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_final_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
apply_mds(state);
|
||||
add_constants(state, &ARK1[round]);
|
||||
}
|
||||
|
||||
/// Computes an exponentiation to the power 7 in cubic extension field.
|
||||
#[inline(always)]
|
||||
pub fn exp7(x: CubeExtension<Felt>) -> CubeExtension<Felt> {
|
||||
let x2 = x.square();
|
||||
let x4 = x2.square();
|
||||
|
||||
let x3 = x2 * x;
|
||||
x3 * x4
|
||||
}
|
||||
}
|
||||
10
src/hash/rescue/tests.rs
Normal file
10
src/hash/rescue/tests.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use rand_utils::rand_value;
|
||||
|
||||
use super::{Felt, FieldElement, ALPHA, INV_ALPHA};
|
||||
|
||||
#[test]
|
||||
fn test_alphas() {
|
||||
let e: Felt = Felt::new(rand_value());
|
||||
let e_exp = e.exp(ALPHA);
|
||||
assert_eq!(e, e_exp.exp(INV_ALPHA));
|
||||
}
|
||||
@@ -1,905 +0,0 @@
|
||||
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO};
|
||||
use core::{convert::TryInto, ops::Range};
|
||||
|
||||
mod digest;
|
||||
pub use digest::RpoDigest;
|
||||
|
||||
mod mds_freq;
|
||||
use mds_freq::mds_multiply_freq;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
#[link(name = "rpo_sve", kind = "static")]
|
||||
extern "C" {
|
||||
fn add_constants_and_apply_sbox(
|
||||
state: *mut std::ffi::c_ulong,
|
||||
constants: *const std::ffi::c_ulong,
|
||||
) -> bool;
|
||||
fn add_constants_and_apply_inv_sbox(
|
||||
state: *mut std::ffi::c_ulong,
|
||||
constants: *const std::ffi::c_ulong,
|
||||
) -> bool;
|
||||
}
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
const STATE_WIDTH: usize = 12;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11.
|
||||
const RATE_RANGE: Range<usize> = 4..12;
|
||||
const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
|
||||
|
||||
const INPUT1_RANGE: Range<usize> = 4..8;
|
||||
const INPUT2_RANGE: Range<usize> = 8..12;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
const CAPACITY_RANGE: Range<usize> = 0..4;
|
||||
|
||||
/// The output of the hash function is a digest which consists of 4 field elements or 32 bytes.
|
||||
///
|
||||
/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the
|
||||
/// rate portion).
|
||||
const DIGEST_RANGE: Range<usize> = 4..8;
|
||||
const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start;
|
||||
|
||||
/// The number of rounds is set to 7 to target 128-bit security level
|
||||
const NUM_ROUNDS: usize = 7;
|
||||
|
||||
/// The number of byte chunks defining a field element when hashing a sequence of bytes
|
||||
const BINARY_CHUNK_SIZE: usize = 7;
|
||||
|
||||
/// S-Box and Inverse S-Box powers;
|
||||
///
|
||||
/// The constants are defined for tests only because the exponentiations in the code are unrolled
|
||||
/// for efficiency reasons.
|
||||
#[cfg(test)]
|
||||
const ALPHA: u64 = 7;
|
||||
#[cfg(test)]
|
||||
const INV_ALPHA: u64 = 10540996611094048183;
|
||||
|
||||
// HASHER IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
/// Implementation of the Rescue Prime Optimized hash function with 256-bit output.
|
||||
///
|
||||
/// The hash function is implemented according to the Rescue Prime Optimized
|
||||
/// [specifications](https://eprint.iacr.org/2022/1577)
|
||||
///
|
||||
/// The parameters used to instantiate the function are:
|
||||
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
|
||||
/// * State width: 12 field elements.
|
||||
/// * Capacity size: 4 field elements.
|
||||
/// * Number of founds: 7.
|
||||
/// * S-Box degree: 7.
|
||||
///
|
||||
/// The above parameters target 128-bit security level. The digest consists of four field elements
|
||||
/// and it can be serialized into 32 bytes (256 bits).
|
||||
///
|
||||
/// ## Hash output consistency
|
||||
/// Functions [hash_elements()](Rpo256::hash_elements), [merge()](Rpo256::merge), and
|
||||
/// [merge_with_int()](Rpo256::merge_with_int) are internally consistent. That is, computing
|
||||
/// a hash for the same set of elements using these functions will always produce the same
|
||||
/// result. For example, merging two digests using [merge()](Rpo256::merge) will produce the
|
||||
/// same result as hashing 8 elements which make up these digests using
|
||||
/// [hash_elements()](Rpo256::hash_elements) function.
|
||||
///
|
||||
/// However, [hash()](Rpo256::hash) function is not consistent with functions mentioned above.
|
||||
/// For example, if we take two field elements, serialize them to bytes and hash them using
|
||||
/// [hash()](Rpo256::hash), the result will differ from the result obtained by hashing these
|
||||
/// elements directly using [hash_elements()](Rpo256::hash_elements) function. The reason for
|
||||
/// this difference is that [hash()](Rpo256::hash) function needs to be able to handle
|
||||
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
|
||||
/// deserialization procedure used by this function is different from the procedure used to
|
||||
/// deserialize valid field elements.
|
||||
///
|
||||
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
||||
/// to deserialize them into field elements and then hash them using
|
||||
/// [hash_elements()](Rpo256::hash_elements) function rather then hashing the serialized bytes
|
||||
/// using [hash()](Rpo256::hash) function.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Rpo256();
|
||||
|
||||
impl Hasher for Rpo256 {
|
||||
/// Rpo256 collision resistance is the same as the security level, that is 128-bits.
|
||||
///
|
||||
/// #### Collision resistance
|
||||
///
|
||||
/// However, our setup of the capacity registers might drop it to 126.
|
||||
///
|
||||
/// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69)
|
||||
const COLLISION_RESISTANCE: u32 = 128;
|
||||
|
||||
type Digest = RpoDigest;
|
||||
|
||||
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||
// initialize the state with zeroes
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
// set the capacity (first element) to a flag on whether or not the input length is evenly
|
||||
// divided by the rate. this will prevent collisions between padded and non-padded inputs,
|
||||
// and will rule out the need to perform an extra permutation in case of evenly divided
|
||||
// inputs.
|
||||
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0;
|
||||
if !is_rate_multiple {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// initialize a buffer to receive the little-endian elements.
|
||||
let mut buf = [0_u8; 8];
|
||||
|
||||
// iterate the chunks of bytes, creating a field element from each chunk and copying it
|
||||
// into the state.
|
||||
//
|
||||
// every time the rate range is filled, a permutation is performed. if the final value of
|
||||
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
|
||||
// additional permutation must be performed.
|
||||
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
|
||||
// the last element of the iteration may or may not be a full chunk. if it's not, then
|
||||
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`.
|
||||
// this will avoid collisions.
|
||||
if chunk.len() == BINARY_CHUNK_SIZE {
|
||||
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||
} else {
|
||||
buf.fill(0);
|
||||
buf[..chunk.len()].copy_from_slice(chunk);
|
||||
buf[chunk.len()] = 1;
|
||||
}
|
||||
|
||||
// set the current rate element to the input. since we take at most 7 bytes, we are
|
||||
// guaranteed that the inputs data will fit into a single field element.
|
||||
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf));
|
||||
|
||||
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
||||
// counter to the beginning of the range.
|
||||
if i == RATE_WIDTH - 1 {
|
||||
Self::apply_permutation(&mut state);
|
||||
0
|
||||
} else {
|
||||
i + 1
|
||||
}
|
||||
});
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we
|
||||
// don't need to apply any extra padding because the first capacity element contains a
|
||||
// flag indicating whether the input is evenly divisible by the rate.
|
||||
if i != 0 {
|
||||
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the rate as hash result.
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = Self::Digest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
// initialize the state as follows:
|
||||
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element
|
||||
// and set the sixth rate element to 1.
|
||||
// - if the value doesn't fit into a single field element, split it into two field
|
||||
// elements, copy them into rate elements 5 and 6, and set the seventh rate element
|
||||
// to 1.
|
||||
// - set the first capacity element to 1
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||
if value < Felt::MODULUS {
|
||||
state[INPUT2_RANGE.start + 1] = ONE;
|
||||
} else {
|
||||
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
||||
state[INPUT2_RANGE.start + 2] = ONE;
|
||||
}
|
||||
|
||||
// common padding for both cases
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementHasher for Rpo256 {
|
||||
type BaseField = Felt;
|
||||
|
||||
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
|
||||
// convert the elements into a list of base field elements
|
||||
let elements = E::slice_as_base_elements(elements);
|
||||
|
||||
// initialize state to all zeros, except for the first element of the capacity part, which
|
||||
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
if elements.len() % RATE_WIDTH != 0 {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
|
||||
// absorb elements into the state one by one until the rate portion of the state is filled
|
||||
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
||||
// elements have been absorbed
|
||||
let mut i = 0;
|
||||
for &element in elements.iter() {
|
||||
state[RATE_RANGE.start + i] = element;
|
||||
i += 1;
|
||||
if i % RATE_WIDTH == 0 {
|
||||
Self::apply_permutation(&mut state);
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after
|
||||
// padding by appending a 1 followed by as many 0 as necessary to make the input length a
|
||||
// multiple of the RATE_WIDTH.
|
||||
if i > 0 {
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
i += 1;
|
||||
while i != RATE_WIDTH {
|
||||
state[RATE_RANGE.start + i] = ZERO;
|
||||
i += 1;
|
||||
}
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the state as hash result
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
// HASH FUNCTION IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Rpo256 {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// The number of rounds is set to 7 to target 128-bit security level.
|
||||
pub const NUM_ROUNDS: usize = NUM_ROUNDS;
|
||||
|
||||
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
pub const STATE_WIDTH: usize = STATE_WIDTH;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
|
||||
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
|
||||
|
||||
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
|
||||
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
|
||||
|
||||
/// MDS matrix used for computing the linear layer in a RPO round.
|
||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
|
||||
|
||||
/// Round constants added to the hasher state in the first half of the RPO round.
|
||||
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
|
||||
|
||||
/// Round constants added to the hasher state in the second half of the RPO round.
|
||||
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
|
||||
|
||||
// TRAIT PASS-THROUGH FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of the provided sequence of bytes.
|
||||
#[inline(always)]
|
||||
pub fn hash(bytes: &[u8]) -> RpoDigest {
|
||||
<Self as Hasher>::hash(bytes)
|
||||
}
|
||||
|
||||
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||
/// Merkle trees and verification of Merkle paths.
|
||||
#[inline(always)]
|
||||
pub fn merge(values: &[RpoDigest; 2]) -> RpoDigest {
|
||||
<Self as Hasher>::merge(values)
|
||||
}
|
||||
|
||||
/// Returns a hash of the provided field elements.
|
||||
#[inline(always)]
|
||||
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpoDigest {
|
||||
<Self as ElementHasher>::hash_elements(elements)
|
||||
}
|
||||
|
||||
// DOMAIN IDENTIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of two digests and a domain identifier.
|
||||
pub fn merge_in_domain(values: &[RpoDigest; 2], domain: Felt) -> RpoDigest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = RpoDigest::digests_as_elements(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// set the second capacity element to the domain value. The first capacity element is used
|
||||
// for padding purposes.
|
||||
state[CAPACITY_RANGE.start + 1] = domain;
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
// RESCUE PERMUTATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Applies RPO permutation to the provided state.
|
||||
#[inline(always)]
|
||||
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
|
||||
for i in 0..NUM_ROUNDS {
|
||||
Self::apply_round(state, i);
|
||||
}
|
||||
}
|
||||
|
||||
/// RPO round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
// apply first half of RPO round
|
||||
Self::apply_mds(state);
|
||||
if !Self::optimized_add_constants_and_apply_sbox(state, &ARK1[round]) {
|
||||
Self::add_constants(state, &ARK1[round]);
|
||||
Self::apply_sbox(state);
|
||||
}
|
||||
|
||||
// apply second half of RPO round
|
||||
Self::apply_mds(state);
|
||||
if !Self::optimized_add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
|
||||
Self::add_constants(state, &ARK2[round]);
|
||||
Self::apply_inv_sbox(state);
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
#[inline(always)]
|
||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
fn optimized_add_constants_and_apply_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
unsafe {
|
||||
add_constants_and_apply_sbox(state.as_mut_ptr() as *mut u64, ark.as_ptr() as *const u64)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[cfg(not(all(target_feature = "sve", feature = "sve")))]
|
||||
fn optimized_add_constants_and_apply_sbox(
|
||||
_state: &mut [Felt; STATE_WIDTH],
|
||||
_ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[cfg(all(target_feature = "sve", feature = "sve"))]
|
||||
fn optimized_add_constants_and_apply_inv_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
unsafe {
|
||||
add_constants_and_apply_inv_sbox(
|
||||
state.as_mut_ptr() as *mut u64,
|
||||
ark.as_ptr() as *const u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[cfg(not(all(target_feature = "sve", feature = "sve")))]
|
||||
fn optimized_add_constants_and_apply_inv_sbox(
|
||||
_state: &mut [Felt; STATE_WIDTH],
|
||||
_ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_mds(state: &mut [Felt; STATE_WIDTH]) {
|
||||
let mut result = [ZERO; STATE_WIDTH];
|
||||
|
||||
// Using the linearity of the operations we can split the state into a low||high decomposition
|
||||
// and operate on each with no overflow and then combine/reduce the result to a field element.
|
||||
// The no overflow is guaranteed by the fact that the MDS matrix is a small powers of two in
|
||||
// frequency domain.
|
||||
let mut state_l = [0u64; STATE_WIDTH];
|
||||
let mut state_h = [0u64; STATE_WIDTH];
|
||||
|
||||
for r in 0..STATE_WIDTH {
|
||||
let s = state[r].inner();
|
||||
state_h[r] = s >> 32;
|
||||
state_l[r] = (s as u32) as u64;
|
||||
}
|
||||
|
||||
let state_h = mds_multiply_freq(state_h);
|
||||
let state_l = mds_multiply_freq(state_l);
|
||||
|
||||
for r in 0..STATE_WIDTH {
|
||||
let s = state_l[r] as u128 + ((state_h[r] as u128) << 32);
|
||||
let s_hi = (s >> 64) as u64;
|
||||
let s_lo = s as u64;
|
||||
let z = (s_hi << 32) - s_hi;
|
||||
let (res, over) = s_lo.overflowing_add(z);
|
||||
|
||||
result[r] = Felt::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64));
|
||||
}
|
||||
*state = result;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) {
|
||||
state.iter_mut().zip(ark).for_each(|(s, &k)| *s += k);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||
state[0] = state[0].exp7();
|
||||
state[1] = state[1].exp7();
|
||||
state[2] = state[2].exp7();
|
||||
state[3] = state[3].exp7();
|
||||
state[4] = state[4].exp7();
|
||||
state[5] = state[5].exp7();
|
||||
state[6] = state[6].exp7();
|
||||
state[7] = state[7].exp7();
|
||||
state[8] = state[8].exp7();
|
||||
state[9] = state[9].exp7();
|
||||
state[10] = state[10].exp7();
|
||||
state[11] = state[11].exp7();
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||
// compute base^10540996611094048183 using 72 multiplications per array element
|
||||
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
|
||||
|
||||
// compute base^10
|
||||
let mut t1 = *state;
|
||||
t1.iter_mut().for_each(|t| *t = t.square());
|
||||
|
||||
// compute base^100
|
||||
let mut t2 = t1;
|
||||
t2.iter_mut().for_each(|t| *t = t.square());
|
||||
|
||||
// compute base^100100
|
||||
let t3 = Self::exp_acc::<Felt, STATE_WIDTH, 3>(t2, t2);
|
||||
|
||||
// compute base^100100100100
|
||||
let t4 = Self::exp_acc::<Felt, STATE_WIDTH, 6>(t3, t3);
|
||||
|
||||
// compute base^100100100100100100100100
|
||||
let t5 = Self::exp_acc::<Felt, STATE_WIDTH, 12>(t4, t4);
|
||||
|
||||
// compute base^100100100100100100100100100100
|
||||
let t6 = Self::exp_acc::<Felt, STATE_WIDTH, 6>(t5, t3);
|
||||
|
||||
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||
let t7 = Self::exp_acc::<Felt, STATE_WIDTH, 31>(t6, t6);
|
||||
|
||||
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||
for (i, s) in state.iter_mut().enumerate() {
|
||||
let a = (t7[i].square() * t6[i]).square().square();
|
||||
let b = t1[i] * t2[i] * *s;
|
||||
*s = a * b;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn exp_acc<B: StarkField, const N: usize, const M: usize>(
|
||||
base: [B; N],
|
||||
tail: [B; N],
|
||||
) -> [B; N] {
|
||||
let mut result = base;
|
||||
for _ in 0..M {
|
||||
result.iter_mut().for_each(|r| *r = r.square());
|
||||
}
|
||||
result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t);
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
// MDS
|
||||
// ================================================================================================
|
||||
/// RPO MDS matrix
|
||||
const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = [
|
||||
[
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
],
|
||||
[
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
],
|
||||
[
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
],
|
||||
[
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
],
|
||||
[
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
],
|
||||
[
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
],
|
||||
[
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
],
|
||||
[
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
],
|
||||
[
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
],
|
||||
[
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
],
|
||||
[
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
],
|
||||
[
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
],
|
||||
];
|
||||
|
||||
// ROUND CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Rescue round constants;
|
||||
/// computed as in [specifications](https://github.com/ASDiscreteMathematics/rpo)
|
||||
///
|
||||
/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the
|
||||
/// first half of RPO round, and ARK2 contains constants for the second half of RPO round.
|
||||
const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(5789762306288267392),
|
||||
Felt::new(6522564764413701783),
|
||||
Felt::new(17809893479458208203),
|
||||
Felt::new(107145243989736508),
|
||||
Felt::new(6388978042437517382),
|
||||
Felt::new(15844067734406016715),
|
||||
Felt::new(9975000513555218239),
|
||||
Felt::new(3344984123768313364),
|
||||
Felt::new(9959189626657347191),
|
||||
Felt::new(12960773468763563665),
|
||||
Felt::new(9602914297752488475),
|
||||
Felt::new(16657542370200465908),
|
||||
],
|
||||
[
|
||||
Felt::new(12987190162843096997),
|
||||
Felt::new(653957632802705281),
|
||||
Felt::new(4441654670647621225),
|
||||
Felt::new(4038207883745915761),
|
||||
Felt::new(5613464648874830118),
|
||||
Felt::new(13222989726778338773),
|
||||
Felt::new(3037761201230264149),
|
||||
Felt::new(16683759727265180203),
|
||||
Felt::new(8337364536491240715),
|
||||
Felt::new(3227397518293416448),
|
||||
Felt::new(8110510111539674682),
|
||||
Felt::new(2872078294163232137),
|
||||
],
|
||||
[
|
||||
Felt::new(18072785500942327487),
|
||||
Felt::new(6200974112677013481),
|
||||
Felt::new(17682092219085884187),
|
||||
Felt::new(10599526828986756440),
|
||||
Felt::new(975003873302957338),
|
||||
Felt::new(8264241093196931281),
|
||||
Felt::new(10065763900435475170),
|
||||
Felt::new(2181131744534710197),
|
||||
Felt::new(6317303992309418647),
|
||||
Felt::new(1401440938888741532),
|
||||
Felt::new(8884468225181997494),
|
||||
Felt::new(13066900325715521532),
|
||||
],
|
||||
[
|
||||
Felt::new(5674685213610121970),
|
||||
Felt::new(5759084860419474071),
|
||||
Felt::new(13943282657648897737),
|
||||
Felt::new(1352748651966375394),
|
||||
Felt::new(17110913224029905221),
|
||||
Felt::new(1003883795902368422),
|
||||
Felt::new(4141870621881018291),
|
||||
Felt::new(8121410972417424656),
|
||||
Felt::new(14300518605864919529),
|
||||
Felt::new(13712227150607670181),
|
||||
Felt::new(17021852944633065291),
|
||||
Felt::new(6252096473787587650),
|
||||
],
|
||||
[
|
||||
Felt::new(4887609836208846458),
|
||||
Felt::new(3027115137917284492),
|
||||
Felt::new(9595098600469470675),
|
||||
Felt::new(10528569829048484079),
|
||||
Felt::new(7864689113198939815),
|
||||
Felt::new(17533723827845969040),
|
||||
Felt::new(5781638039037710951),
|
||||
Felt::new(17024078752430719006),
|
||||
Felt::new(109659393484013511),
|
||||
Felt::new(7158933660534805869),
|
||||
Felt::new(2955076958026921730),
|
||||
Felt::new(7433723648458773977),
|
||||
],
|
||||
[
|
||||
Felt::new(16308865189192447297),
|
||||
Felt::new(11977192855656444890),
|
||||
Felt::new(12532242556065780287),
|
||||
Felt::new(14594890931430968898),
|
||||
Felt::new(7291784239689209784),
|
||||
Felt::new(5514718540551361949),
|
||||
Felt::new(10025733853830934803),
|
||||
Felt::new(7293794580341021693),
|
||||
Felt::new(6728552937464861756),
|
||||
Felt::new(6332385040983343262),
|
||||
Felt::new(13277683694236792804),
|
||||
Felt::new(2600778905124452676),
|
||||
],
|
||||
[
|
||||
Felt::new(7123075680859040534),
|
||||
Felt::new(1034205548717903090),
|
||||
Felt::new(7717824418247931797),
|
||||
Felt::new(3019070937878604058),
|
||||
Felt::new(11403792746066867460),
|
||||
Felt::new(10280580802233112374),
|
||||
Felt::new(337153209462421218),
|
||||
Felt::new(13333398568519923717),
|
||||
Felt::new(3596153696935337464),
|
||||
Felt::new(8104208463525993784),
|
||||
Felt::new(14345062289456085693),
|
||||
Felt::new(17036731477169661256),
|
||||
],
|
||||
];
|
||||
|
||||
const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(6077062762357204287),
|
||||
Felt::new(15277620170502011191),
|
||||
Felt::new(5358738125714196705),
|
||||
Felt::new(14233283787297595718),
|
||||
Felt::new(13792579614346651365),
|
||||
Felt::new(11614812331536767105),
|
||||
Felt::new(14871063686742261166),
|
||||
Felt::new(10148237148793043499),
|
||||
Felt::new(4457428952329675767),
|
||||
Felt::new(15590786458219172475),
|
||||
Felt::new(10063319113072092615),
|
||||
Felt::new(14200078843431360086),
|
||||
],
|
||||
[
|
||||
Felt::new(6202948458916099932),
|
||||
Felt::new(17690140365333231091),
|
||||
Felt::new(3595001575307484651),
|
||||
Felt::new(373995945117666487),
|
||||
Felt::new(1235734395091296013),
|
||||
Felt::new(14172757457833931602),
|
||||
Felt::new(707573103686350224),
|
||||
Felt::new(15453217512188187135),
|
||||
Felt::new(219777875004506018),
|
||||
Felt::new(17876696346199469008),
|
||||
Felt::new(17731621626449383378),
|
||||
Felt::new(2897136237748376248),
|
||||
],
|
||||
[
|
||||
Felt::new(8023374565629191455),
|
||||
Felt::new(15013690343205953430),
|
||||
Felt::new(4485500052507912973),
|
||||
Felt::new(12489737547229155153),
|
||||
Felt::new(9500452585969030576),
|
||||
Felt::new(2054001340201038870),
|
||||
Felt::new(12420704059284934186),
|
||||
Felt::new(355990932618543755),
|
||||
Felt::new(9071225051243523860),
|
||||
Felt::new(12766199826003448536),
|
||||
Felt::new(9045979173463556963),
|
||||
Felt::new(12934431667190679898),
|
||||
],
|
||||
[
|
||||
Felt::new(18389244934624494276),
|
||||
Felt::new(16731736864863925227),
|
||||
Felt::new(4440209734760478192),
|
||||
Felt::new(17208448209698888938),
|
||||
Felt::new(8739495587021565984),
|
||||
Felt::new(17000774922218161967),
|
||||
Felt::new(13533282547195532087),
|
||||
Felt::new(525402848358706231),
|
||||
Felt::new(16987541523062161972),
|
||||
Felt::new(5466806524462797102),
|
||||
Felt::new(14512769585918244983),
|
||||
Felt::new(10973956031244051118),
|
||||
],
|
||||
[
|
||||
Felt::new(6982293561042362913),
|
||||
Felt::new(14065426295947720331),
|
||||
Felt::new(16451845770444974180),
|
||||
Felt::new(7139138592091306727),
|
||||
Felt::new(9012006439959783127),
|
||||
Felt::new(14619614108529063361),
|
||||
Felt::new(1394813199588124371),
|
||||
Felt::new(4635111139507788575),
|
||||
Felt::new(16217473952264203365),
|
||||
Felt::new(10782018226466330683),
|
||||
Felt::new(6844229992533662050),
|
||||
Felt::new(7446486531695178711),
|
||||
],
|
||||
[
|
||||
Felt::new(3736792340494631448),
|
||||
Felt::new(577852220195055341),
|
||||
Felt::new(6689998335515779805),
|
||||
Felt::new(13886063479078013492),
|
||||
Felt::new(14358505101923202168),
|
||||
Felt::new(7744142531772274164),
|
||||
Felt::new(16135070735728404443),
|
||||
Felt::new(12290902521256031137),
|
||||
Felt::new(12059913662657709804),
|
||||
Felt::new(16456018495793751911),
|
||||
Felt::new(4571485474751953524),
|
||||
Felt::new(17200392109565783176),
|
||||
],
|
||||
[
|
||||
Felt::new(17130398059294018733),
|
||||
Felt::new(519782857322261988),
|
||||
Felt::new(9625384390925085478),
|
||||
Felt::new(1664893052631119222),
|
||||
Felt::new(7629576092524553570),
|
||||
Felt::new(3485239601103661425),
|
||||
Felt::new(9755891797164033838),
|
||||
Felt::new(15218148195153269027),
|
||||
Felt::new(16460604813734957368),
|
||||
Felt::new(9643968136937729763),
|
||||
Felt::new(3611348709641382851),
|
||||
Felt::new(18256379591337759196),
|
||||
],
|
||||
];
|
||||
13
src/lib.rs
13
src/lib.rs
@@ -1,9 +1,11 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![no_std]
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[cfg_attr(test, macro_use)]
|
||||
#[macro_use]
|
||||
extern crate alloc;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
extern crate std;
|
||||
|
||||
pub mod dsa;
|
||||
pub mod hash;
|
||||
pub mod merkle;
|
||||
@@ -13,7 +15,10 @@ pub mod utils;
|
||||
// RE-EXPORTS
|
||||
// ================================================================================================
|
||||
|
||||
pub use winter_math::{fields::f64::BaseElement as Felt, FieldElement, StarkField};
|
||||
pub use winter_math::{
|
||||
fields::{f64::BaseElement as Felt, CubeExtension, QuadExtension},
|
||||
FieldElement, StarkField,
|
||||
};
|
||||
|
||||
// TYPE ALIASES
|
||||
// ================================================================================================
|
||||
|
||||
68
src/main.rs
68
src/main.rs
@@ -1,20 +1,15 @@
|
||||
use clap::Parser;
|
||||
use miden_crypto::{
|
||||
hash::rpo::RpoDigest,
|
||||
merkle::MerkleError,
|
||||
Felt, Word, ONE,
|
||||
{hash::rpo::Rpo256, merkle::TieredSmt},
|
||||
};
|
||||
use rand_utils::rand_value;
|
||||
use std::time::Instant;
|
||||
|
||||
use clap::Parser;
|
||||
use miden_crypto::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
merkle::{MerkleError, Smt},
|
||||
Felt, Word, ONE,
|
||||
};
|
||||
use rand_utils::rand_value;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(
|
||||
name = "Benchmark",
|
||||
about = "Tiered SMT benchmark",
|
||||
version,
|
||||
rename_all = "kebab-case"
|
||||
)]
|
||||
#[clap(name = "Benchmark", about = "SMT benchmark", version, rename_all = "kebab-case")]
|
||||
pub struct BenchmarkCmd {
|
||||
/// Size of the tree
|
||||
#[clap(short = 's', long = "size")]
|
||||
@@ -22,11 +17,11 @@ pub struct BenchmarkCmd {
|
||||
}
|
||||
|
||||
fn main() {
|
||||
benchmark_tsmt();
|
||||
benchmark_smt();
|
||||
}
|
||||
|
||||
/// Run a benchmark for the Tiered SMT.
|
||||
pub fn benchmark_tsmt() {
|
||||
/// Run a benchmark for [`Smt`].
|
||||
pub fn benchmark_smt() {
|
||||
let args = BenchmarkCmd::parse();
|
||||
let tree_size = args.size;
|
||||
|
||||
@@ -43,38 +38,25 @@ pub fn benchmark_tsmt() {
|
||||
proof_generation(&mut tree, tree_size).unwrap();
|
||||
}
|
||||
|
||||
/// Runs the construction benchmark for the Tiered SMT, returning the constructed tree.
|
||||
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result<TieredSmt, MerkleError> {
|
||||
/// Runs the construction benchmark for [`Smt`], returning the constructed tree.
|
||||
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result<Smt, MerkleError> {
|
||||
println!("Running a construction benchmark:");
|
||||
let now = Instant::now();
|
||||
let tree = TieredSmt::with_entries(entries)?;
|
||||
let tree = Smt::with_entries(entries)?;
|
||||
let elapsed = now.elapsed();
|
||||
println!(
|
||||
"Constructed a TSMT with {} key-value pairs in {:.3} seconds",
|
||||
"Constructed a SMT with {} key-value pairs in {:.3} seconds",
|
||||
size,
|
||||
elapsed.as_secs_f32(),
|
||||
);
|
||||
|
||||
// Count how many nodes end up at each tier
|
||||
let mut nodes_num_16_32_48 = (0, 0, 0);
|
||||
|
||||
tree.upper_leaf_nodes().for_each(|(index, _)| match index.depth() {
|
||||
16 => nodes_num_16_32_48.0 += 1,
|
||||
32 => nodes_num_16_32_48.1 += 1,
|
||||
48 => nodes_num_16_32_48.2 += 1,
|
||||
_ => unreachable!(),
|
||||
});
|
||||
|
||||
println!("Number of nodes on depth 16: {}", nodes_num_16_32_48.0);
|
||||
println!("Number of nodes on depth 32: {}", nodes_num_16_32_48.1);
|
||||
println!("Number of nodes on depth 48: {}", nodes_num_16_32_48.2);
|
||||
println!("Number of nodes on depth 64: {}\n", tree.bottom_leaves().count());
|
||||
println!("Number of leaf nodes: {}\n", tree.leaves().count());
|
||||
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Runs the insertion benchmark for the Tiered SMT.
|
||||
pub fn insertion(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
|
||||
/// Runs the insertion benchmark for the [`Smt`].
|
||||
pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
|
||||
println!("Running an insertion benchmark:");
|
||||
|
||||
let mut insertion_times = Vec::new();
|
||||
@@ -90,9 +72,9 @@ pub fn insertion(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
|
||||
}
|
||||
|
||||
println!(
|
||||
"An average insertion time measured by 20 inserts into a TSMT with {} key-value pairs is {:.3} milliseconds\n",
|
||||
"An average insertion time measured by 20 inserts into a SMT with {} key-value pairs is {:.3} milliseconds\n",
|
||||
size,
|
||||
// calculate the average by dividing by 20 and convert to milliseconds by multiplying by
|
||||
// calculate the average by dividing by 20 and convert to milliseconds by multiplying by
|
||||
// 1000. As a result, we can only multiply by 50
|
||||
insertion_times.iter().sum::<f32>() * 50f32,
|
||||
);
|
||||
@@ -100,8 +82,8 @@ pub fn insertion(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Runs the proof generation benchmark for the Tiered SMT.
|
||||
pub fn proof_generation(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
|
||||
/// Runs the proof generation benchmark for the [`Smt`].
|
||||
pub fn proof_generation(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
|
||||
println!("Running a proof generation benchmark:");
|
||||
|
||||
let mut insertion_times = Vec::new();
|
||||
@@ -112,13 +94,13 @@ pub fn proof_generation(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleErr
|
||||
tree.insert(test_key, test_value);
|
||||
|
||||
let now = Instant::now();
|
||||
let _proof = tree.prove(test_key);
|
||||
let _proof = tree.open(&test_key);
|
||||
let elapsed = now.elapsed();
|
||||
insertion_times.push(elapsed.as_secs_f32());
|
||||
}
|
||||
|
||||
println!(
|
||||
"An average proving time measured by 20 value proofs in a TSMT with {} key-value pairs in {:.3} microseconds",
|
||||
"An average proving time measured by 20 value proofs in a SMT with {} key-value pairs in {:.3} microseconds",
|
||||
size,
|
||||
// calculate the average by dividing by 20 and convert to microseconds by multiplying by
|
||||
// 1000000. As a result, we can only multiply by 50000
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
use super::{
|
||||
BTreeMap, KvMap, MerkleError, MerkleStore, NodeIndex, RpoDigest, StoreNode, Vec, Word,
|
||||
};
|
||||
use crate::utils::collections::Diff;
|
||||
|
||||
#[cfg(test)]
|
||||
use super::{super::ONE, Felt, SimpleSmt, EMPTY_WORD, ZERO};
|
||||
|
||||
// MERKLE STORE DELTA
|
||||
// ================================================================================================
|
||||
|
||||
/// [MerkleStoreDelta] stores a vector of ([RpoDigest], [MerkleTreeDelta]) tuples where the
|
||||
/// [RpoDigest] represents the root of the Merkle tree and [MerkleTreeDelta] represents the
|
||||
/// differences between the initial and final Merkle tree states.
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>);
|
||||
|
||||
// MERKLE TREE DELTA
|
||||
// ================================================================================================
|
||||
|
||||
/// [MerkleDelta] stores the differences between the initial and final Merkle tree states.
|
||||
///
|
||||
/// The differences are represented as follows:
|
||||
/// - depth: the depth of the merkle tree.
|
||||
/// - cleared_slots: indexes of slots where values were set to [ZERO; 4].
|
||||
/// - updated_slots: index-value pairs of slots where values were set to non [ZERO; 4] values.
|
||||
#[cfg(not(test))]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MerkleTreeDelta {
|
||||
depth: u8,
|
||||
cleared_slots: Vec<u64>,
|
||||
updated_slots: Vec<(u64, Word)>,
|
||||
}
|
||||
|
||||
impl MerkleTreeDelta {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
pub fn new(depth: u8) -> Self {
|
||||
Self {
|
||||
depth,
|
||||
cleared_slots: Vec::new(),
|
||||
updated_slots: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns the depth of the Merkle tree the [MerkleDelta] is associated with.
|
||||
pub fn depth(&self) -> u8 {
|
||||
self.depth
|
||||
}
|
||||
|
||||
/// Returns the indexes of slots where values were set to [ZERO; 4].
|
||||
pub fn cleared_slots(&self) -> &[u64] {
|
||||
&self.cleared_slots
|
||||
}
|
||||
|
||||
/// Returns the index-value pairs of slots where values were set to non [ZERO; 4] values.
|
||||
pub fn updated_slots(&self) -> &[(u64, Word)] {
|
||||
&self.updated_slots
|
||||
}
|
||||
|
||||
// MODIFIERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Adds a slot index to the list of cleared slots.
|
||||
pub fn add_cleared_slot(&mut self, index: u64) {
|
||||
self.cleared_slots.push(index);
|
||||
}
|
||||
|
||||
/// Adds a slot index and a value to the list of updated slots.
|
||||
pub fn add_updated_slot(&mut self, index: u64, value: Word) {
|
||||
self.updated_slots.push((index, value));
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts a [MerkleDelta] object by comparing the leaves of two Merkle trees specifies by
|
||||
/// their roots and depth.
|
||||
pub fn merkle_tree_delta<T: KvMap<RpoDigest, StoreNode>>(
|
||||
tree_root_1: RpoDigest,
|
||||
tree_root_2: RpoDigest,
|
||||
depth: u8,
|
||||
merkle_store: &MerkleStore<T>,
|
||||
) -> Result<MerkleTreeDelta, MerkleError> {
|
||||
if tree_root_1 == tree_root_2 {
|
||||
return Ok(MerkleTreeDelta::new(depth));
|
||||
}
|
||||
|
||||
let tree_1_leaves: BTreeMap<NodeIndex, RpoDigest> =
|
||||
merkle_store.non_empty_leaves(tree_root_1, depth).collect();
|
||||
let tree_2_leaves: BTreeMap<NodeIndex, RpoDigest> =
|
||||
merkle_store.non_empty_leaves(tree_root_2, depth).collect();
|
||||
let diff = tree_1_leaves.diff(&tree_2_leaves);
|
||||
|
||||
// TODO: Refactor this diff implementation to prevent allocation of both BTree and Vec.
|
||||
Ok(MerkleTreeDelta {
|
||||
depth,
|
||||
cleared_slots: diff.removed.into_iter().map(|index| index.value()).collect(),
|
||||
updated_slots: diff
|
||||
.updated
|
||||
.into_iter()
|
||||
.map(|(index, leaf)| (index.value(), *leaf))
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
// INTERNALS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
#[cfg(test)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MerkleTreeDelta {
|
||||
pub depth: u8,
|
||||
pub cleared_slots: Vec<u64>,
|
||||
pub updated_slots: Vec<(u64, Word)>,
|
||||
}
|
||||
|
||||
// MERKLE DELTA
|
||||
// ================================================================================================
|
||||
#[test]
|
||||
fn test_compute_merkle_delta() {
|
||||
let entries = vec![
|
||||
(10, [ZERO, ONE, Felt::new(2), Felt::new(3)]),
|
||||
(15, [Felt::new(4), Felt::new(5), Felt::new(6), Felt::new(7)]),
|
||||
(20, [Felt::new(8), Felt::new(9), Felt::new(10), Felt::new(11)]),
|
||||
(31, [Felt::new(12), Felt::new(13), Felt::new(14), Felt::new(15)]),
|
||||
];
|
||||
let simple_smt = SimpleSmt::with_leaves(30, entries.clone()).unwrap();
|
||||
let mut store: MerkleStore = (&simple_smt).into();
|
||||
let root = simple_smt.root();
|
||||
|
||||
// add a new node
|
||||
let new_value = [Felt::new(16), Felt::new(17), Felt::new(18), Felt::new(19)];
|
||||
let new_index = NodeIndex::new(simple_smt.depth(), 32).unwrap();
|
||||
let root = store.set_node(root, new_index, new_value.into()).unwrap().root;
|
||||
|
||||
// update an existing node
|
||||
let update_value = [Felt::new(20), Felt::new(21), Felt::new(22), Felt::new(23)];
|
||||
let update_idx = NodeIndex::new(simple_smt.depth(), entries[0].0).unwrap();
|
||||
let root = store.set_node(root, update_idx, update_value.into()).unwrap().root;
|
||||
|
||||
// remove a node
|
||||
let remove_idx = NodeIndex::new(simple_smt.depth(), entries[1].0).unwrap();
|
||||
let root = store.set_node(root, remove_idx, EMPTY_WORD.into()).unwrap().root;
|
||||
|
||||
let merkle_delta =
|
||||
merkle_tree_delta(simple_smt.root(), root, simple_smt.depth(), &store).unwrap();
|
||||
let expected_merkle_delta = MerkleTreeDelta {
|
||||
depth: simple_smt.depth(),
|
||||
cleared_slots: vec![remove_idx.value()],
|
||||
updated_slots: vec![(update_idx.value(), update_value), (new_index.value(), new_value)],
|
||||
};
|
||||
|
||||
assert_eq!(merkle_delta, expected_merkle_delta);
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
use super::{Felt, RpoDigest, EMPTY_WORD};
|
||||
use core::slice;
|
||||
|
||||
use super::{Felt, RpoDigest, EMPTY_WORD};
|
||||
|
||||
// EMPTY NODES SUBTREES
|
||||
// ================================================================================================
|
||||
|
||||
@@ -10,12 +11,19 @@ pub struct EmptySubtreeRoots;
|
||||
impl EmptySubtreeRoots {
|
||||
/// Returns a static slice with roots of empty subtrees of a Merkle tree starting at the
|
||||
/// specified depth.
|
||||
pub const fn empty_hashes(depth: u8) -> &'static [RpoDigest] {
|
||||
let ptr = &EMPTY_SUBTREES[255 - depth as usize] as *const RpoDigest;
|
||||
pub const fn empty_hashes(tree_depth: u8) -> &'static [RpoDigest] {
|
||||
let ptr = &EMPTY_SUBTREES[255 - tree_depth as usize] as *const RpoDigest;
|
||||
// Safety: this is a static/constant array, so it will never be outlived. If we attempt to
|
||||
// use regular slices, this wouldn't be a `const` function, meaning we won't be able to use
|
||||
// the returned value for static/constant definitions.
|
||||
unsafe { slice::from_raw_parts(ptr, depth as usize + 1) }
|
||||
unsafe { slice::from_raw_parts(ptr, tree_depth as usize + 1) }
|
||||
}
|
||||
|
||||
/// Returns the node's digest for a sub-tree with all its leaves set to the empty word.
|
||||
pub const fn entry(tree_depth: u8, node_depth: u8) -> &'static RpoDigest {
|
||||
assert!(node_depth <= tree_depth);
|
||||
let pos = 255 - tree_depth + node_depth;
|
||||
&EMPTY_SUBTREES[pos as usize]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1583,3 +1591,16 @@ fn all_depths_opens_to_zero() {
|
||||
.for_each(|(x, computed)| assert_eq!(x, computed));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_entry() {
|
||||
// check the leaf is always the empty work
|
||||
for depth in 0..255 {
|
||||
assert_eq!(EmptySubtreeRoots::entry(depth, depth), &RpoDigest::new(EMPTY_WORD));
|
||||
}
|
||||
|
||||
// check the root matches the first element of empty_hashes
|
||||
for depth in 0..255 {
|
||||
assert_eq!(EmptySubtreeRoots::entry(depth, 0), &EmptySubtreeRoots::empty_hashes(depth)[0]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
use crate::{
|
||||
merkle::{MerklePath, NodeIndex, RpoDigest},
|
||||
utils::collections::Vec,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use core::fmt;
|
||||
|
||||
use super::{smt::SmtLeafError, MerklePath, NodeIndex, RpoDigest};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum MerkleError {
|
||||
ConflictingRoots(Vec<RpoDigest>),
|
||||
@@ -13,12 +12,14 @@ pub enum MerkleError {
|
||||
DuplicateValuesForKey(RpoDigest),
|
||||
InvalidIndex { depth: u8, value: u64 },
|
||||
InvalidDepth { expected: u8, provided: u8 },
|
||||
InvalidSubtreeDepth { subtree_depth: u8, tree_depth: u8 },
|
||||
InvalidPath(MerklePath),
|
||||
InvalidNumEntries(usize, usize),
|
||||
InvalidNumEntries(usize),
|
||||
NodeNotInSet(NodeIndex),
|
||||
NodeNotInStore(RpoDigest, NodeIndex),
|
||||
NumLeavesNotPowerOfTwo(usize),
|
||||
RootNotInStore(RpoDigest),
|
||||
SmtLeaf(SmtLeafError),
|
||||
}
|
||||
|
||||
impl fmt::Display for MerkleError {
|
||||
@@ -30,25 +31,35 @@ impl fmt::Display for MerkleError {
|
||||
DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"),
|
||||
DuplicateValuesForIndex(key) => write!(f, "multiple values provided for key {key}"),
|
||||
DuplicateValuesForKey(key) => write!(f, "multiple values provided for key {key}"),
|
||||
InvalidIndex{ depth, value} => write!(
|
||||
f,
|
||||
"the index value {value} is not valid for the depth {depth}"
|
||||
),
|
||||
InvalidDepth { expected, provided } => write!(
|
||||
f,
|
||||
"the provided depth {provided} is not valid for {expected}"
|
||||
),
|
||||
InvalidIndex { depth, value } => {
|
||||
write!(f, "the index value {value} is not valid for the depth {depth}")
|
||||
}
|
||||
InvalidDepth { expected, provided } => {
|
||||
write!(f, "the provided depth {provided} is not valid for {expected}")
|
||||
}
|
||||
InvalidSubtreeDepth { subtree_depth, tree_depth } => {
|
||||
write!(f, "tried inserting a subtree of depth {subtree_depth} into a tree of depth {tree_depth}")
|
||||
}
|
||||
InvalidPath(_path) => write!(f, "the provided path is not valid"),
|
||||
InvalidNumEntries(max, provided) => write!(f, "the provided number of entries is {provided}, but the maximum for the given depth is {max}"),
|
||||
InvalidNumEntries(max) => write!(f, "number of entries exceeded the maximum: {max}"),
|
||||
NodeNotInSet(index) => write!(f, "the node with index ({index}) is not in the set"),
|
||||
NodeNotInStore(hash, index) => write!(f, "the node {hash:?} with index ({index}) is not in the store"),
|
||||
NodeNotInStore(hash, index) => {
|
||||
write!(f, "the node {hash:?} with index ({index}) is not in the store")
|
||||
}
|
||||
NumLeavesNotPowerOfTwo(leaves) => {
|
||||
write!(f, "the leaves count {leaves} is not a power of 2")
|
||||
}
|
||||
RootNotInStore(root) => write!(f, "the root {:?} is not in the store", root),
|
||||
SmtLeaf(smt_leaf_error) => write!(f, "smt leaf error: {smt_leaf_error}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for MerkleError {}
|
||||
|
||||
impl From<SmtLeafError> for MerkleError {
|
||||
fn from(value: SmtLeafError) -> Self {
|
||||
Self::SmtLeaf(value)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use super::{Felt, MerkleError, RpoDigest, StarkField};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
use core::fmt::Display;
|
||||
|
||||
use super::{Felt, MerkleError, RpoDigest};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
|
||||
// NODE INDEX
|
||||
// ================================================================================================
|
||||
|
||||
@@ -181,19 +182,27 @@ impl Deserializable for NodeIndex {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use proptest::prelude::*;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_node_index_value_too_high() {
|
||||
assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
|
||||
match NodeIndex::new(0, 1) {
|
||||
Err(MerkleError::InvalidIndex { depth, value }) => {
|
||||
assert_eq!(depth, 0);
|
||||
assert_eq!(value, 1);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
let err = NodeIndex::new(0, 1).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 0, value: 1 });
|
||||
|
||||
assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
|
||||
let err = NodeIndex::new(1, 2).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 1, value: 2 });
|
||||
|
||||
assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
|
||||
let err = NodeIndex::new(2, 4).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 2, value: 4 });
|
||||
|
||||
assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
|
||||
let err = NodeIndex::new(3, 8).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 3, value: 8 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
use super::{InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word};
|
||||
use crate::utils::{string::String, uninit_vector, word_to_hex};
|
||||
use alloc::{string::String, vec::Vec};
|
||||
use core::{fmt, ops::Deref, slice};
|
||||
|
||||
use winter_math::log2;
|
||||
|
||||
use super::{InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Word};
|
||||
use crate::utils::{uninit_vector, word_to_hex};
|
||||
|
||||
// MERKLE TREE
|
||||
// ================================================================================================
|
||||
|
||||
@@ -283,13 +286,15 @@ pub fn path_to_text(path: &MerklePath) -> Result<String, fmt::Error> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use core::mem::size_of;
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
use super::*;
|
||||
use crate::{
|
||||
merkle::{digests_to_words, int_to_leaf, int_to_node, InnerNodeInfo},
|
||||
Felt, Word, WORD_SIZE,
|
||||
merkle::{digests_to_words, int_to_leaf, int_to_node},
|
||||
Felt, WORD_SIZE,
|
||||
};
|
||||
use core::mem::size_of;
|
||||
use proptest::prelude::*;
|
||||
|
||||
const LEAVES4: [RpoDigest; WORD_SIZE] =
|
||||
[int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
use super::super::{RpoDigest, Vec};
|
||||
use super::super::RpoDigest;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// Container for the update data of a [PartialMmr]
|
||||
/// Container for the update data of a [super::PartialMmr]
|
||||
#[derive(Debug)]
|
||||
pub struct MmrDelta {
|
||||
/// The new version of the [Mmr]
|
||||
/// The new version of the [super::Mmr]
|
||||
pub forest: usize,
|
||||
|
||||
/// Update data.
|
||||
///
|
||||
/// The data is packed as follows:
|
||||
/// 1. All the elements needed to perform authentication path updates. These are the right
|
||||
/// siblings required to perform tree merges on the [PartialMmr].
|
||||
/// siblings required to perform tree merges on the [super::PartialMmr].
|
||||
/// 2. The new peaks.
|
||||
pub data: Vec<RpoDigest>,
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use crate::merkle::MerkleError;
|
||||
use core::fmt::{Display, Formatter};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use std::error::Error;
|
||||
|
||||
use crate::merkle::MerkleError;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub enum MmrError {
|
||||
InvalidPosition(usize),
|
||||
|
||||
@@ -9,12 +9,14 @@
|
||||
//! least number of leaves. The structure preserves the invariant that each tree has different
|
||||
//! depths, i.e. as part of adding adding a new element to the forest the trees with same depth are
|
||||
//! merged, creating a new tree with depth d+1, this process is continued until the property is
|
||||
//! restabilished.
|
||||
//! reestablished.
|
||||
use super::{
|
||||
super::{InnerNodeInfo, MerklePath, RpoDigest, Vec},
|
||||
super::{InnerNodeInfo, MerklePath},
|
||||
bit::TrueBitPositionIterator,
|
||||
leaf_to_corresponding_tree, nodes_in_forest, MmrDelta, MmrError, MmrPeaks, MmrProof, Rpo256,
|
||||
RpoDigest,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
// MMR
|
||||
// ===============================================================================================
|
||||
@@ -76,13 +78,13 @@ impl Mmr {
|
||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||
/// has position 0, the second position 1, and so on.
|
||||
pub fn open(&self, pos: usize) -> Result<MmrProof, MmrError> {
|
||||
pub fn open(&self, pos: usize, target_forest: usize) -> Result<MmrProof, MmrError> {
|
||||
// find the target tree responsible for the MMR position
|
||||
let tree_bit =
|
||||
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?;
|
||||
leaf_to_corresponding_tree(pos, target_forest).ok_or(MmrError::InvalidPosition(pos))?;
|
||||
|
||||
// isolate the trees before the target
|
||||
let forest_before = self.forest & high_bitmask(tree_bit + 1);
|
||||
let forest_before = target_forest & high_bitmask(tree_bit + 1);
|
||||
let index_offset = nodes_in_forest(forest_before);
|
||||
|
||||
// update the value position from global to the target tree
|
||||
@@ -92,7 +94,7 @@ impl Mmr {
|
||||
let (_, path) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);
|
||||
|
||||
Ok(MmrProof {
|
||||
forest: self.forest,
|
||||
forest: target_forest,
|
||||
position: pos,
|
||||
merkle_path: MerklePath::new(path),
|
||||
})
|
||||
@@ -143,9 +145,13 @@ impl Mmr {
|
||||
self.forest += 1;
|
||||
}
|
||||
|
||||
/// Returns an accumulator representing the current state of the MMR.
|
||||
pub fn accumulator(&self) -> MmrPeaks {
|
||||
let peaks: Vec<RpoDigest> = TrueBitPositionIterator::new(self.forest)
|
||||
/// Returns an peaks of the MMR for the version specified by `forest`.
|
||||
pub fn peaks(&self, forest: usize) -> Result<MmrPeaks, MmrError> {
|
||||
if forest > self.forest {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
}
|
||||
|
||||
let peaks: Vec<RpoDigest> = TrueBitPositionIterator::new(forest)
|
||||
.rev()
|
||||
.map(|bit| nodes_in_forest(1 << bit))
|
||||
.scan(0, |offset, el| {
|
||||
@@ -156,39 +162,41 @@ impl Mmr {
|
||||
.collect();
|
||||
|
||||
// Safety: the invariant is maintained by the [Mmr]
|
||||
MmrPeaks::new(self.forest, peaks).unwrap()
|
||||
let peaks = MmrPeaks::new(forest, peaks).unwrap();
|
||||
|
||||
Ok(peaks)
|
||||
}
|
||||
|
||||
/// Compute the required update to `original_forest`.
|
||||
///
|
||||
/// The result is a packed sequence of the authentication elements required to update the trees
|
||||
/// that have been merged together, followed by the new peaks of the [Mmr].
|
||||
pub fn get_delta(&self, original_forest: usize) -> Result<MmrDelta, MmrError> {
|
||||
if original_forest > self.forest {
|
||||
pub fn get_delta(&self, from_forest: usize, to_forest: usize) -> Result<MmrDelta, MmrError> {
|
||||
if to_forest > self.forest || from_forest > to_forest {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
}
|
||||
|
||||
if original_forest == self.forest {
|
||||
return Ok(MmrDelta { forest: self.forest, data: Vec::new() });
|
||||
if from_forest == to_forest {
|
||||
return Ok(MmrDelta { forest: to_forest, data: Vec::new() });
|
||||
}
|
||||
|
||||
let mut result = Vec::new();
|
||||
|
||||
// Find the largest tree in this [Mmr] which is new to `original_forest`.
|
||||
let candidate_trees = self.forest ^ original_forest;
|
||||
// Find the largest tree in this [Mmr] which is new to `from_forest`.
|
||||
let candidate_trees = to_forest ^ from_forest;
|
||||
let mut new_high = 1 << candidate_trees.ilog2();
|
||||
|
||||
// Collect authentication nodes used for tree merges
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// Find the trees from `original_forest` that have been merged into `new_high`.
|
||||
let mut merges = original_forest & (new_high - 1);
|
||||
// Find the trees from `from_forest` that have been merged into `new_high`.
|
||||
let mut merges = from_forest & (new_high - 1);
|
||||
|
||||
// Find the peaks that are common to `original_forest` and this [Mmr]
|
||||
let common_trees = original_forest ^ merges;
|
||||
// Find the peaks that are common to `from_forest` and this [Mmr]
|
||||
let common_trees = from_forest ^ merges;
|
||||
|
||||
if merges != 0 {
|
||||
// Skip the smallest trees unknown to `original_forest`.
|
||||
// Skip the smallest trees unknown to `from_forest`.
|
||||
let mut target = 1 << merges.trailing_zeros();
|
||||
|
||||
// Collect siblings required to computed the merged tree's peak
|
||||
@@ -213,15 +221,15 @@ impl Mmr {
|
||||
}
|
||||
} else {
|
||||
// The new high tree may not be the result of any merges, if it is smaller than all the
|
||||
// trees of `original_forest`.
|
||||
// trees of `from_forest`.
|
||||
new_high = 0;
|
||||
}
|
||||
|
||||
// Collect the new [Mmr] peaks
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
let mut new_peaks = self.forest ^ common_trees ^ new_high;
|
||||
let old_peaks = self.forest ^ new_peaks;
|
||||
let mut new_peaks = to_forest ^ common_trees ^ new_high;
|
||||
let old_peaks = to_forest ^ new_peaks;
|
||||
let mut offset = nodes_in_forest(old_peaks);
|
||||
while new_peaks != 0 {
|
||||
let target = 1 << new_peaks.ilog2();
|
||||
@@ -230,7 +238,7 @@ impl Mmr {
|
||||
new_peaks ^= target;
|
||||
}
|
||||
|
||||
Ok(MmrDelta { forest: self.forest, data: result })
|
||||
Ok(MmrDelta { forest: to_forest, data: result })
|
||||
}
|
||||
|
||||
/// An iterator over inner nodes in the MMR. The order of iteration is unspecified.
|
||||
@@ -273,7 +281,7 @@ impl Mmr {
|
||||
// Update the depth of the tree to correspond to a subtree
|
||||
forest_target >>= 1;
|
||||
|
||||
// compute the indeces of the right and left subtrees based on the post-order
|
||||
// compute the indices of the right and left subtrees based on the post-order
|
||||
let right_offset = index - 1;
|
||||
let left_offset = right_offset - nodes_in_forest(forest_target);
|
||||
|
||||
|
||||
@@ -6,6 +6,9 @@
|
||||
//! leaves count.
|
||||
use core::num::NonZeroUsize;
|
||||
|
||||
// IN-ORDER INDEX
|
||||
// ================================================================================================
|
||||
|
||||
/// Index of nodes in a perfectly balanced binary tree based on an in-order tree walk.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct InOrderIndex {
|
||||
@@ -13,15 +16,17 @@ pub struct InOrderIndex {
|
||||
}
|
||||
|
||||
impl InOrderIndex {
|
||||
/// Constructor for a new [InOrderIndex].
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [InOrderIndex] instantiated from the provided value.
|
||||
pub fn new(idx: NonZeroUsize) -> InOrderIndex {
|
||||
InOrderIndex { idx: idx.get() }
|
||||
}
|
||||
|
||||
/// Constructs an index from a leaf position.
|
||||
///
|
||||
/// Panics:
|
||||
/// Return a new [InOrderIndex] instantiated from the specified leaf position.
|
||||
///
|
||||
/// # Panics:
|
||||
/// If `leaf` is higher than or equal to `usize::MAX / 2`.
|
||||
pub fn from_leaf_pos(leaf: usize) -> InOrderIndex {
|
||||
// Convert the position from 0-indexed to 1-indexed, since the bit manipulation in this
|
||||
@@ -30,6 +35,9 @@ impl InOrderIndex {
|
||||
InOrderIndex { idx: pos * 2 - 1 }
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// True if the index is pointing at a leaf.
|
||||
///
|
||||
/// Every odd number represents a leaf.
|
||||
@@ -37,6 +45,11 @@ impl InOrderIndex {
|
||||
self.idx & 1 == 1
|
||||
}
|
||||
|
||||
/// Returns true if this note is a left child of its parent.
|
||||
pub fn is_left_child(&self) -> bool {
|
||||
self.parent().left_child() == *self
|
||||
}
|
||||
|
||||
/// Returns the level of the index.
|
||||
///
|
||||
/// Starts at level zero for leaves and increases by one for each parent.
|
||||
@@ -46,8 +59,7 @@ impl InOrderIndex {
|
||||
|
||||
/// Returns the index of the left child.
|
||||
///
|
||||
/// Panics:
|
||||
///
|
||||
/// # Panics:
|
||||
/// If the index corresponds to a leaf.
|
||||
pub fn left_child(&self) -> InOrderIndex {
|
||||
// The left child is itself a parent, with an index that splits its left/right subtrees. To
|
||||
@@ -59,8 +71,7 @@ impl InOrderIndex {
|
||||
|
||||
/// Returns the index of the right child.
|
||||
///
|
||||
/// Panics:
|
||||
///
|
||||
/// # Panics:
|
||||
/// If the index corresponds to a leaf.
|
||||
pub fn right_child(&self) -> InOrderIndex {
|
||||
// To compute the index of the parent of the right subtree it is sufficient to add the size
|
||||
@@ -94,13 +105,31 @@ impl InOrderIndex {
|
||||
parent.right_child()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the inner value of this [InOrderIndex].
|
||||
pub fn inner(&self) -> u64 {
|
||||
self.idx as u64
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS FROM IN-ORDER INDEX
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl From<InOrderIndex> for u64 {
|
||||
fn from(index: InOrderIndex) -> Self {
|
||||
index.idx as u64
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::InOrderIndex;
|
||||
use proptest::prelude::*;
|
||||
|
||||
use super::InOrderIndex;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn proptest_inorder_index_random(count in 1..1000usize) {
|
||||
|
||||
@@ -10,7 +10,7 @@ mod proof;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use super::{Felt, Rpo256, Word};
|
||||
use super::{Felt, Rpo256, RpoDigest, Word};
|
||||
|
||||
// REEXPORTS
|
||||
// ================================================================================================
|
||||
@@ -40,10 +40,10 @@ const fn leaf_to_corresponding_tree(pos: usize, forest: usize) -> Option<u32> {
|
||||
// - each bit in the forest is a unique tree and the bit position its power-of-two size
|
||||
// - each tree owns a consecutive range of positions equal to its size from left-to-right
|
||||
// - this means the first tree owns from `0` up to the `2^k_0` first positions, where `k_0`
|
||||
// is the highest true bit position, the second tree from `2^k_0 + 1` up to `2^k_1` where
|
||||
// `k_1` is the second highest bit, so on.
|
||||
// is the highest true bit position, the second tree from `2^k_0 + 1` up to `2^k_1` where
|
||||
// `k_1` is the second highest bit, so on.
|
||||
// - this means the highest bits work as a category marker, and the position is owned by
|
||||
// the first tree which doesn't share a high bit with the position
|
||||
// the first tree which doesn't share a high bit with the position
|
||||
let before = forest & pos;
|
||||
let after = forest ^ before;
|
||||
let tree = after.ilog2();
|
||||
|
||||
@@ -1,58 +1,64 @@
|
||||
use crate::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
merkle::{
|
||||
mmr::{leaf_to_corresponding_tree, nodes_in_forest},
|
||||
InOrderIndex, MerklePath, MmrError, MmrPeaks,
|
||||
},
|
||||
utils::collections::{BTreeMap, Vec},
|
||||
use super::{MmrDelta, MmrProof, Rpo256, RpoDigest};
|
||||
use crate::merkle::{
|
||||
mmr::{leaf_to_corresponding_tree, nodes_in_forest},
|
||||
InOrderIndex, InnerNodeInfo, MerklePath, MmrError, MmrPeaks,
|
||||
};
|
||||
use alloc::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
vec::Vec,
|
||||
};
|
||||
|
||||
use super::{MmrDelta, MmrProof};
|
||||
// TYPE ALIASES
|
||||
// ================================================================================================
|
||||
|
||||
/// Partially materialized [Mmr], used to efficiently store and update the authentication paths for
|
||||
/// a subset of the elements in a full [Mmr].
|
||||
type NodeMap = BTreeMap<InOrderIndex, RpoDigest>;
|
||||
|
||||
// PARTIAL MERKLE MOUNTAIN RANGE
|
||||
// ================================================================================================
|
||||
/// Partially materialized Merkle Mountain Range (MMR), used to efficiently store and update the
|
||||
/// authentication paths for a subset of the elements in a full MMR.
|
||||
///
|
||||
/// This structure store only the authentication path for a value, the value itself is stored
|
||||
/// separately.
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PartialMmr {
|
||||
/// The version of the [Mmr].
|
||||
/// The version of the MMR.
|
||||
///
|
||||
/// This value serves the following purposes:
|
||||
///
|
||||
/// - The forest is a counter for the total number of elements in the [Mmr].
|
||||
/// - Since the [Mmr] is an append-only structure, every change to it causes a change to the
|
||||
/// - The forest is a counter for the total number of elements in the MMR.
|
||||
/// - Since the MMR is an append-only structure, every change to it causes a change to the
|
||||
/// `forest`, so this value has a dual purpose as a version tag.
|
||||
/// - The bits in the forest also corresponds to the count and size of every perfect binary
|
||||
/// tree that composes the [Mmr] structure, which server to compute indexes and perform
|
||||
/// tree that composes the MMR structure, which server to compute indexes and perform
|
||||
/// validation.
|
||||
pub(crate) forest: usize,
|
||||
|
||||
/// The [Mmr] peaks.
|
||||
/// The MMR peaks.
|
||||
///
|
||||
/// The peaks are used for two reasons:
|
||||
///
|
||||
/// 1. It authenticates the addition of an element to the [PartialMmr], ensuring only valid
|
||||
/// elements are tracked.
|
||||
/// 2. During a [Mmr] update peaks can be merged by hashing the left and right hand sides. The
|
||||
/// 2. During a MMR update peaks can be merged by hashing the left and right hand sides. The
|
||||
/// peaks are used as the left hand.
|
||||
///
|
||||
/// All the peaks of every tree in the [Mmr] forest. The peaks are always ordered by number of
|
||||
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
|
||||
/// leaves, starting from the peak with most children, to the one with least.
|
||||
pub(crate) peaks: Vec<RpoDigest>,
|
||||
|
||||
/// Authentication nodes used to construct merkle paths for a subset of the [Mmr]'s leaves.
|
||||
/// Authentication nodes used to construct merkle paths for a subset of the MMR's leaves.
|
||||
///
|
||||
/// This does not include the [Mmr]'s peaks nor the tracked nodes, only the elements required
|
||||
/// to construct their authentication paths. This property is used to detect when elements can
|
||||
/// be safely removed from, because they are no longer required to authenticate any element in
|
||||
/// the [PartialMmr].
|
||||
/// This does not include the MMR's peaks nor the tracked nodes, only the elements required to
|
||||
/// construct their authentication paths. This property is used to detect when elements can be
|
||||
/// safely removed, because they are no longer required to authenticate any element in the
|
||||
/// [PartialMmr].
|
||||
///
|
||||
/// The elements in the [Mmr] are referenced using a in-order tree index. This indexing scheme
|
||||
/// The elements in the MMR are referenced using a in-order tree index. This indexing scheme
|
||||
/// permits for easy computation of the relative nodes (left/right children, sibling, parent),
|
||||
/// which is useful for traversal. The indexing is also stable, meaning that merges to the
|
||||
/// trees in the [Mmr] can be represented without rewrites of the indexes.
|
||||
pub(crate) nodes: BTreeMap<InOrderIndex, RpoDigest>,
|
||||
/// trees in the MMR can be represented without rewrites of the indexes.
|
||||
pub(crate) nodes: NodeMap,
|
||||
|
||||
/// Flag indicating if the odd element should be tracked.
|
||||
///
|
||||
@@ -65,38 +71,75 @@ impl PartialMmr {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Constructs a [PartialMmr] from the given [MmrPeaks].
|
||||
pub fn from_peaks(accumulator: MmrPeaks) -> Self {
|
||||
let forest = accumulator.num_leaves();
|
||||
let peaks = accumulator.peaks().to_vec();
|
||||
/// Returns a new [PartialMmr] instantiated from the specified peaks.
|
||||
pub fn from_peaks(peaks: MmrPeaks) -> Self {
|
||||
let forest = peaks.num_leaves();
|
||||
let peaks = peaks.into();
|
||||
let nodes = BTreeMap::new();
|
||||
let track_latest = false;
|
||||
|
||||
Self { forest, peaks, nodes, track_latest }
|
||||
}
|
||||
|
||||
// ACCESSORS
|
||||
/// Returns a new [PartialMmr] instantiated from the specified components.
|
||||
///
|
||||
/// This constructor does not check the consistency between peaks and nodes. If the specified
|
||||
/// peaks are nodes are inconsistent, the returned partial MMR may exhibit undefined behavior.
|
||||
pub fn from_parts(peaks: MmrPeaks, nodes: NodeMap, track_latest: bool) -> Self {
|
||||
let forest = peaks.num_leaves();
|
||||
let peaks = peaks.into();
|
||||
|
||||
Self { forest, peaks, nodes, track_latest }
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
// Gets the current `forest`.
|
||||
//
|
||||
// This value corresponds to the version of the [PartialMmr] and the number of leaves in it.
|
||||
/// Returns the current `forest` of this [PartialMmr].
|
||||
///
|
||||
/// This value corresponds to the version of the [PartialMmr] and the number of leaves in the
|
||||
/// underlying MMR.
|
||||
pub fn forest(&self) -> usize {
|
||||
self.forest
|
||||
}
|
||||
|
||||
// Returns a reference to the current peaks in the [PartialMmr]
|
||||
pub fn peaks(&self) -> &[RpoDigest] {
|
||||
&self.peaks
|
||||
/// Returns the number of leaves in the underlying MMR for this [PartialMmr].
|
||||
pub fn num_leaves(&self) -> usize {
|
||||
self.forest
|
||||
}
|
||||
|
||||
/// Given a leaf position, returns the Merkle path to its corresponding peak. If the position
|
||||
/// is greater-or-equal than the tree size an error is returned. If the requested value is not
|
||||
/// tracked returns `None`.
|
||||
/// Returns the peaks of the MMR for this [PartialMmr].
|
||||
pub fn peaks(&self) -> MmrPeaks {
|
||||
// expect() is OK here because the constructor ensures that MMR peaks can be constructed
|
||||
// correctly
|
||||
MmrPeaks::new(self.forest, self.peaks.clone()).expect("invalid MMR peaks")
|
||||
}
|
||||
|
||||
/// Returns true if this partial MMR tracks an authentication path for the leaf at the
|
||||
/// specified position.
|
||||
pub fn is_tracked(&self, pos: usize) -> bool {
|
||||
if pos >= self.forest {
|
||||
return false;
|
||||
} else if pos == self.forest - 1 && self.forest & 1 != 0 {
|
||||
// if the number of leaves in the MMR is odd and the position is for the last leaf
|
||||
// whether the leaf is tracked is defined by the `track_latest` flag
|
||||
return self.track_latest;
|
||||
}
|
||||
|
||||
let leaf_index = InOrderIndex::from_leaf_pos(pos);
|
||||
self.is_tracked_node(&leaf_index)
|
||||
}
|
||||
|
||||
/// Given a leaf position, returns the Merkle path to its corresponding peak, or None if this
|
||||
/// partial MMR does not track an authentication paths for the specified leaf.
|
||||
///
|
||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||
/// has position 0, the second position 1, and so on.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified position is greater-or-equal than the number of leaves
|
||||
/// in the underlying MMR.
|
||||
pub fn open(&self, pos: usize) -> Result<Option<MmrProof>, MmrError> {
|
||||
let tree_bit =
|
||||
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?;
|
||||
@@ -125,23 +168,127 @@ impl PartialMmr {
|
||||
}
|
||||
}
|
||||
|
||||
// MODIFIERS
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Add the authentication path represented by [MerklePath] if it is valid.
|
||||
/// Returns an iterator nodes of all authentication paths of this [PartialMmr].
|
||||
pub fn nodes(&self) -> impl Iterator<Item = (&InOrderIndex, &RpoDigest)> {
|
||||
self.nodes.iter()
|
||||
}
|
||||
|
||||
/// Returns an iterator over inner nodes of this [PartialMmr] for the specified leaves.
|
||||
///
|
||||
/// The `index` refers to the global position of the leaf in the [Mmr], these are 0-indexed
|
||||
/// values assigned in a strictly monotonic fashion as elements are inserted into the [Mmr],
|
||||
/// this value corresponds to the values used in the [Mmr] structure.
|
||||
/// The order of iteration is not defined. If a leaf is not presented in this partial MMR it
|
||||
/// is silently ignored.
|
||||
pub fn inner_nodes<'a, I: Iterator<Item = (usize, RpoDigest)> + 'a>(
|
||||
&'a self,
|
||||
mut leaves: I,
|
||||
) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
let stack = if let Some((pos, leaf)) = leaves.next() {
|
||||
let idx = InOrderIndex::from_leaf_pos(pos);
|
||||
vec![(idx, leaf)]
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
InnerNodeIterator {
|
||||
nodes: &self.nodes,
|
||||
leaves,
|
||||
stack,
|
||||
seen_nodes: BTreeSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Adds a new peak and optionally track it. Returns a vector of the authentication nodes
|
||||
/// inserted into this [PartialMmr] as a result of this operation.
|
||||
///
|
||||
/// The `node` corresponds to the value at `index`, and `path` is the authentication path for
|
||||
/// that element up to its corresponding Mmr peak. The `node` is only used to compute the root
|
||||
/// When `track` is `true` the new leaf is tracked.
|
||||
pub fn add(&mut self, leaf: RpoDigest, track: bool) -> Vec<(InOrderIndex, RpoDigest)> {
|
||||
self.forest += 1;
|
||||
let merges = self.forest.trailing_zeros() as usize;
|
||||
let mut new_nodes = Vec::with_capacity(merges);
|
||||
|
||||
let peak = if merges == 0 {
|
||||
self.track_latest = track;
|
||||
leaf
|
||||
} else {
|
||||
let mut track_right = track;
|
||||
let mut track_left = self.track_latest;
|
||||
|
||||
let mut right = leaf;
|
||||
let mut right_idx = forest_to_rightmost_index(self.forest);
|
||||
|
||||
for _ in 0..merges {
|
||||
let left = self.peaks.pop().expect("Missing peak");
|
||||
let left_idx = right_idx.sibling();
|
||||
|
||||
if track_right {
|
||||
let old = self.nodes.insert(left_idx, left);
|
||||
new_nodes.push((left_idx, left));
|
||||
|
||||
debug_assert!(
|
||||
old.is_none(),
|
||||
"Idx {:?} already contained an element {:?}",
|
||||
left_idx,
|
||||
old
|
||||
);
|
||||
};
|
||||
if track_left {
|
||||
let old = self.nodes.insert(right_idx, right);
|
||||
new_nodes.push((right_idx, right));
|
||||
|
||||
debug_assert!(
|
||||
old.is_none(),
|
||||
"Idx {:?} already contained an element {:?}",
|
||||
right_idx,
|
||||
old
|
||||
);
|
||||
};
|
||||
|
||||
// Update state for the next iteration.
|
||||
// --------------------------------------------------------------------------------
|
||||
|
||||
// This layer is merged, go up one layer.
|
||||
right_idx = right_idx.parent();
|
||||
|
||||
// Merge the current layer. The result is either the right element of the next
|
||||
// merge, or a new peak.
|
||||
right = Rpo256::merge(&[left, right]);
|
||||
|
||||
// This iteration merged the left and right nodes, the new value is always used as
|
||||
// the next iteration's right node. Therefore the tracking flags of this iteration
|
||||
// have to be merged into the right side only.
|
||||
track_right = track_right || track_left;
|
||||
|
||||
// On the next iteration, a peak will be merged. If any of its children are tracked,
|
||||
// then we have to track the left side
|
||||
track_left = self.is_tracked_node(&right_idx.sibling());
|
||||
}
|
||||
right
|
||||
};
|
||||
|
||||
self.peaks.push(peak);
|
||||
|
||||
new_nodes
|
||||
}
|
||||
|
||||
/// Adds the authentication path represented by [MerklePath] if it is valid.
|
||||
///
|
||||
/// The `leaf_pos` refers to the global position of the leaf in the MMR, these are 0-indexed
|
||||
/// values assigned in a strictly monotonic fashion as elements are inserted into the MMR,
|
||||
/// this value corresponds to the values used in the MMR structure.
|
||||
///
|
||||
/// The `leaf` corresponds to the value at `leaf_pos`, and `path` is the authentication path for
|
||||
/// that element up to its corresponding Mmr peak. The `leaf` is only used to compute the root
|
||||
/// from the authentication path to valid the data, only the authentication data is saved in
|
||||
/// the structure. If the value is required it should be stored out-of-band.
|
||||
pub fn add(
|
||||
pub fn track(
|
||||
&mut self,
|
||||
index: usize,
|
||||
node: RpoDigest,
|
||||
leaf_pos: usize,
|
||||
leaf: RpoDigest,
|
||||
path: &MerklePath,
|
||||
) -> Result<(), MmrError> {
|
||||
// Checks there is a tree with same depth as the authentication path, if not the path is
|
||||
@@ -151,42 +298,42 @@ impl PartialMmr {
|
||||
return Err(MmrError::UnknownPeak);
|
||||
};
|
||||
|
||||
if index + 1 == self.forest
|
||||
if leaf_pos + 1 == self.forest
|
||||
&& path.depth() == 0
|
||||
&& self.peaks.last().map_or(false, |v| *v == node)
|
||||
&& self.peaks.last().map_or(false, |v| *v == leaf)
|
||||
{
|
||||
self.track_latest = true;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// ignore the trees smaller than the target (these elements are position after the current
|
||||
// target and don't affect the target index)
|
||||
// target and don't affect the target leaf_pos)
|
||||
let target_forest = self.forest ^ (self.forest & (tree - 1));
|
||||
let peak_pos = (target_forest.count_ones() - 1) as usize;
|
||||
|
||||
// translate from mmr index to merkle path
|
||||
let path_idx = index - (target_forest ^ tree);
|
||||
// translate from mmr leaf_pos to merkle path
|
||||
let path_idx = leaf_pos - (target_forest ^ tree);
|
||||
|
||||
// Compute the root of the authentication path, and check it matches the current version of
|
||||
// the PartialMmr.
|
||||
let computed = path.compute_root(path_idx as u64, node).map_err(MmrError::MerkleError)?;
|
||||
let computed = path.compute_root(path_idx as u64, leaf).map_err(MmrError::MerkleError)?;
|
||||
if self.peaks[peak_pos] != computed {
|
||||
return Err(MmrError::InvalidPeak);
|
||||
}
|
||||
|
||||
let mut idx = InOrderIndex::from_leaf_pos(index);
|
||||
for node in path.nodes() {
|
||||
self.nodes.insert(idx.sibling(), *node);
|
||||
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
|
||||
for leaf in path.nodes() {
|
||||
self.nodes.insert(idx.sibling(), *leaf);
|
||||
idx = idx.parent();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove a leaf of the [PartialMmr] and the unused nodes from the authentication path.
|
||||
/// Removes a leaf of the [PartialMmr] and the unused nodes from the authentication path.
|
||||
///
|
||||
/// Note: `leaf_pos` corresponds to the position the [Mmr] and not on an individual tree.
|
||||
pub fn remove(&mut self, leaf_pos: usize) {
|
||||
/// Note: `leaf_pos` corresponds to the position in the MMR and not on an individual tree.
|
||||
pub fn untrack(&mut self, leaf_pos: usize) {
|
||||
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
|
||||
|
||||
self.nodes.remove(&idx.sibling());
|
||||
@@ -202,18 +349,21 @@ impl PartialMmr {
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies updates to the [PartialMmr].
|
||||
pub fn apply(&mut self, delta: MmrDelta) -> Result<(), MmrError> {
|
||||
/// Applies updates to this [PartialMmr] and returns a vector of new authentication nodes
|
||||
/// inserted into the partial MMR.
|
||||
pub fn apply(&mut self, delta: MmrDelta) -> Result<Vec<(InOrderIndex, RpoDigest)>, MmrError> {
|
||||
if delta.forest < self.forest {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
}
|
||||
|
||||
let mut inserted_nodes = Vec::new();
|
||||
|
||||
if delta.forest == self.forest {
|
||||
if !delta.data.is_empty() {
|
||||
return Err(MmrError::InvalidUpdate);
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
return Ok(inserted_nodes);
|
||||
}
|
||||
|
||||
// find the tree merges
|
||||
@@ -268,16 +418,21 @@ impl PartialMmr {
|
||||
// check if either the left or right subtrees have saved for authentication paths.
|
||||
// If so, turn tracking on to update those paths.
|
||||
if target != 1 && !track {
|
||||
let left_child = peak_idx.left_child();
|
||||
let right_child = peak_idx.right_child();
|
||||
track = self.nodes.contains_key(&left_child)
|
||||
| self.nodes.contains_key(&right_child);
|
||||
track = self.is_tracked_node(&peak_idx);
|
||||
}
|
||||
|
||||
// update data only contains the nodes from the right subtrees, left nodes are
|
||||
// either previously known peaks or computed values
|
||||
let (left, right) = if target & merges != 0 {
|
||||
let peak = self.peaks[peak_count];
|
||||
let sibling_idx = peak_idx.sibling();
|
||||
|
||||
// if the sibling peak is tracked, add this peaks to the set of
|
||||
// authentication nodes
|
||||
if self.is_tracked_node(&sibling_idx) {
|
||||
self.nodes.insert(peak_idx, new);
|
||||
inserted_nodes.push((peak_idx, new));
|
||||
}
|
||||
peak_count += 1;
|
||||
(peak, new)
|
||||
} else {
|
||||
@@ -287,7 +442,14 @@ impl PartialMmr {
|
||||
};
|
||||
|
||||
if track {
|
||||
self.nodes.insert(peak_idx.sibling(), right);
|
||||
let sibling_idx = peak_idx.sibling();
|
||||
if peak_idx.is_left_child() {
|
||||
self.nodes.insert(sibling_idx, right);
|
||||
inserted_nodes.push((sibling_idx, right));
|
||||
} else {
|
||||
self.nodes.insert(sibling_idx, left);
|
||||
inserted_nodes.push((sibling_idx, left));
|
||||
}
|
||||
}
|
||||
|
||||
peak_idx = peak_idx.parent();
|
||||
@@ -313,7 +475,22 @@ impl PartialMmr {
|
||||
|
||||
debug_assert!(self.peaks.len() == (self.forest.count_ones() as usize));
|
||||
|
||||
Ok(())
|
||||
Ok(inserted_nodes)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if this [PartialMmr] tracks authentication path for the node at the specified
|
||||
/// index.
|
||||
fn is_tracked_node(&self, node_index: &InOrderIndex) -> bool {
|
||||
if node_index.is_leaf() {
|
||||
self.nodes.contains_key(&node_index.sibling())
|
||||
} else {
|
||||
let left_child = node_index.left_child();
|
||||
let right_child = node_index.right_child();
|
||||
self.nodes.contains_key(&left_child) | self.nodes.contains_key(&right_child)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -348,12 +525,59 @@ impl From<&PartialMmr> for MmrPeaks {
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
|
||||
/// An iterator over every inner node of the [PartialMmr].
|
||||
pub struct InnerNodeIterator<'a, I: Iterator<Item = (usize, RpoDigest)>> {
|
||||
nodes: &'a NodeMap,
|
||||
leaves: I,
|
||||
stack: Vec<(InOrderIndex, RpoDigest)>,
|
||||
seen_nodes: BTreeSet<InOrderIndex>,
|
||||
}
|
||||
|
||||
impl<'a, I: Iterator<Item = (usize, RpoDigest)>> Iterator for InnerNodeIterator<'a, I> {
|
||||
type Item = InnerNodeInfo;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while let Some((idx, node)) = self.stack.pop() {
|
||||
let parent_idx = idx.parent();
|
||||
let new_node = self.seen_nodes.insert(parent_idx);
|
||||
|
||||
// if we haven't seen this node's parent before, and the node has a sibling, return
|
||||
// the inner node defined by the parent of this node, and move up the branch
|
||||
if new_node {
|
||||
if let Some(sibling) = self.nodes.get(&idx.sibling()) {
|
||||
let (left, right) = if parent_idx.left_child() == idx {
|
||||
(node, *sibling)
|
||||
} else {
|
||||
(*sibling, node)
|
||||
};
|
||||
let parent = Rpo256::merge(&[left, right]);
|
||||
let inner_node = InnerNodeInfo { value: parent, left, right };
|
||||
|
||||
self.stack.push((parent_idx, parent));
|
||||
return Some(inner_node);
|
||||
}
|
||||
}
|
||||
|
||||
// the previous leaf has been processed, try to process the next leaf
|
||||
if let Some((pos, leaf)) = self.leaves.next() {
|
||||
let idx = InOrderIndex::from_leaf_pos(pos);
|
||||
self.stack.push((idx, leaf));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// UTILS
|
||||
// ================================================================================================
|
||||
|
||||
/// Given the description of a `forest`, returns the index of the root element of the smallest tree
|
||||
/// in it.
|
||||
pub fn forest_to_root_index(forest: usize) -> InOrderIndex {
|
||||
fn forest_to_root_index(forest: usize) -> InOrderIndex {
|
||||
// Count total size of all trees in the forest.
|
||||
let nodes = nodes_in_forest(forest);
|
||||
|
||||
@@ -370,10 +594,41 @@ pub fn forest_to_root_index(forest: usize) -> InOrderIndex {
|
||||
InOrderIndex::new(idx.try_into().unwrap())
|
||||
}
|
||||
|
||||
/// Given the description of a `forest`, returns the index of the right most element.
|
||||
fn forest_to_rightmost_index(forest: usize) -> InOrderIndex {
|
||||
// Count total size of all trees in the forest.
|
||||
let nodes = nodes_in_forest(forest);
|
||||
|
||||
// Add the count for the parent nodes that separate each tree. These are allocated but
|
||||
// currently empty, and correspond to the nodes that will be used once the trees are merged.
|
||||
let open_trees = (forest.count_ones() - 1) as usize;
|
||||
|
||||
let idx = nodes + open_trees;
|
||||
|
||||
InOrderIndex::new(idx.try_into().unwrap())
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::forest_to_root_index;
|
||||
use crate::merkle::InOrderIndex;
|
||||
mod tests {
|
||||
use super::{
|
||||
forest_to_rightmost_index, forest_to_root_index, InOrderIndex, MmrPeaks, PartialMmr,
|
||||
RpoDigest,
|
||||
};
|
||||
use crate::merkle::{int_to_node, MerkleStore, Mmr, NodeIndex};
|
||||
use alloc::{collections::BTreeSet, vec::Vec};
|
||||
|
||||
const LEAVES: [RpoDigest; 7] = [
|
||||
int_to_node(0),
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
int_to_node(3),
|
||||
int_to_node(4),
|
||||
int_to_node(5),
|
||||
int_to_node(6),
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn test_forest_to_root_index() {
|
||||
@@ -400,4 +655,256 @@ mod test {
|
||||
assert_eq!(forest_to_root_index(0b1100), idx(20));
|
||||
assert_eq!(forest_to_root_index(0b1110), idx(26));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forest_to_rightmost_index() {
|
||||
fn idx(pos: usize) -> InOrderIndex {
|
||||
InOrderIndex::new(pos.try_into().unwrap())
|
||||
}
|
||||
|
||||
for forest in 1..256 {
|
||||
assert!(forest_to_rightmost_index(forest).inner() % 2 == 1, "Leaves are always odd");
|
||||
}
|
||||
|
||||
assert_eq!(forest_to_rightmost_index(0b0001), idx(1));
|
||||
assert_eq!(forest_to_rightmost_index(0b0010), idx(3));
|
||||
assert_eq!(forest_to_rightmost_index(0b0011), idx(5));
|
||||
assert_eq!(forest_to_rightmost_index(0b0100), idx(7));
|
||||
assert_eq!(forest_to_rightmost_index(0b0101), idx(9));
|
||||
assert_eq!(forest_to_rightmost_index(0b0110), idx(11));
|
||||
assert_eq!(forest_to_rightmost_index(0b0111), idx(13));
|
||||
assert_eq!(forest_to_rightmost_index(0b1000), idx(15));
|
||||
assert_eq!(forest_to_rightmost_index(0b1001), idx(17));
|
||||
assert_eq!(forest_to_rightmost_index(0b1010), idx(19));
|
||||
assert_eq!(forest_to_rightmost_index(0b1011), idx(21));
|
||||
assert_eq!(forest_to_rightmost_index(0b1100), idx(23));
|
||||
assert_eq!(forest_to_rightmost_index(0b1101), idx(25));
|
||||
assert_eq!(forest_to_rightmost_index(0b1110), idx(27));
|
||||
assert_eq!(forest_to_rightmost_index(0b1111), idx(29));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_apply_delta() {
|
||||
// build an MMR with 10 nodes (2 peaks) and a partial MMR based on it
|
||||
let mut mmr = Mmr::default();
|
||||
(0..10).for_each(|i| mmr.add(int_to_node(i)));
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
|
||||
// add authentication path for position 1 and 8
|
||||
{
|
||||
let node = mmr.get(1).unwrap();
|
||||
let proof = mmr.open(1, mmr.forest()).unwrap();
|
||||
partial_mmr.track(1, node, &proof.merkle_path).unwrap();
|
||||
}
|
||||
|
||||
{
|
||||
let node = mmr.get(8).unwrap();
|
||||
let proof = mmr.open(8, mmr.forest()).unwrap();
|
||||
partial_mmr.track(8, node, &proof.merkle_path).unwrap();
|
||||
}
|
||||
|
||||
// add 2 more nodes into the MMR and validate apply_delta()
|
||||
(10..12).for_each(|i| mmr.add(int_to_node(i)));
|
||||
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||
|
||||
// add 1 more node to the MMR, validate apply_delta() and start tracking the node
|
||||
mmr.add(int_to_node(12));
|
||||
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||
{
|
||||
let node = mmr.get(12).unwrap();
|
||||
let proof = mmr.open(12, mmr.forest()).unwrap();
|
||||
partial_mmr.track(12, node, &proof.merkle_path).unwrap();
|
||||
assert!(partial_mmr.track_latest);
|
||||
}
|
||||
|
||||
// by this point we are tracking authentication paths for positions: 1, 8, and 12
|
||||
|
||||
// add 3 more nodes to the MMR (collapses to 1 peak) and validate apply_delta()
|
||||
(13..16).for_each(|i| mmr.add(int_to_node(i)));
|
||||
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||
}
|
||||
|
||||
fn validate_apply_delta(mmr: &Mmr, partial: &mut PartialMmr) {
|
||||
let tracked_leaves = partial
|
||||
.nodes
|
||||
.iter()
|
||||
.filter_map(|(index, _)| if index.is_leaf() { Some(index.sibling()) } else { None })
|
||||
.collect::<Vec<_>>();
|
||||
let nodes_before = partial.nodes.clone();
|
||||
|
||||
// compute and apply delta
|
||||
let delta = mmr.get_delta(partial.forest(), mmr.forest()).unwrap();
|
||||
let nodes_delta = partial.apply(delta).unwrap();
|
||||
|
||||
// new peaks were computed correctly
|
||||
assert_eq!(mmr.peaks(mmr.forest()).unwrap(), partial.peaks());
|
||||
|
||||
let mut expected_nodes = nodes_before;
|
||||
for (key, value) in nodes_delta {
|
||||
// nodes should not be duplicated
|
||||
assert!(expected_nodes.insert(key, value).is_none());
|
||||
}
|
||||
|
||||
// new nodes should be a combination of original nodes and delta
|
||||
assert_eq!(expected_nodes, partial.nodes);
|
||||
|
||||
// make sure tracked leaves open to the same proofs as in the underlying MMR
|
||||
for index in tracked_leaves {
|
||||
let index_value: u64 = index.into();
|
||||
let pos = index_value / 2;
|
||||
let proof1 = partial.open(pos as usize).unwrap().unwrap();
|
||||
let proof2 = mmr.open(pos as usize, mmr.forest()).unwrap();
|
||||
assert_eq!(proof1, proof2);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_inner_nodes_iterator() {
|
||||
// build the MMR
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let first_peak = mmr.peaks(mmr.forest).unwrap().peaks()[0];
|
||||
|
||||
// -- test single tree ----------------------------
|
||||
|
||||
// get path and node for position 1
|
||||
let node1 = mmr.get(1).unwrap();
|
||||
let proof1 = mmr.open(1, mmr.forest()).unwrap();
|
||||
|
||||
// create partial MMR and add authentication path to node at position 1
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
|
||||
|
||||
// empty iterator should have no nodes
|
||||
assert_eq!(partial_mmr.inner_nodes([].iter().cloned()).next(), None);
|
||||
|
||||
// build Merkle store from authentication paths in partial MMR
|
||||
let mut store: MerkleStore = MerkleStore::new();
|
||||
store.extend(partial_mmr.inner_nodes([(1, node1)].iter().cloned()));
|
||||
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||
|
||||
assert_eq!(path1, proof1.merkle_path);
|
||||
|
||||
// -- test no duplicates --------------------------
|
||||
|
||||
// build the partial MMR
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
|
||||
let node0 = mmr.get(0).unwrap();
|
||||
let proof0 = mmr.open(0, mmr.forest()).unwrap();
|
||||
|
||||
let node2 = mmr.get(2).unwrap();
|
||||
let proof2 = mmr.open(2, mmr.forest()).unwrap();
|
||||
|
||||
partial_mmr.track(0, node0, &proof0.merkle_path).unwrap();
|
||||
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
|
||||
partial_mmr.track(2, node2, &proof2.merkle_path).unwrap();
|
||||
|
||||
// make sure there are no duplicates
|
||||
let leaves = [(0, node0), (1, node1), (2, node2)];
|
||||
let mut nodes = BTreeSet::new();
|
||||
for node in partial_mmr.inner_nodes(leaves.iter().cloned()) {
|
||||
assert!(nodes.insert(node.value));
|
||||
}
|
||||
|
||||
// and also that the store is still be built correctly
|
||||
store.extend(partial_mmr.inner_nodes(leaves.iter().cloned()));
|
||||
|
||||
let index0 = NodeIndex::new(2, 0).unwrap();
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
let index2 = NodeIndex::new(2, 2).unwrap();
|
||||
|
||||
let path0 = store.get_path(first_peak, index0).unwrap().path;
|
||||
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||
let path2 = store.get_path(first_peak, index2).unwrap().path;
|
||||
|
||||
assert_eq!(path0, proof0.merkle_path);
|
||||
assert_eq!(path1, proof1.merkle_path);
|
||||
assert_eq!(path2, proof2.merkle_path);
|
||||
|
||||
// -- test multiple trees -------------------------
|
||||
|
||||
// build the partial MMR
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
|
||||
let node5 = mmr.get(5).unwrap();
|
||||
let proof5 = mmr.open(5, mmr.forest()).unwrap();
|
||||
|
||||
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
|
||||
partial_mmr.track(5, node5, &proof5.merkle_path).unwrap();
|
||||
|
||||
// build Merkle store from authentication paths in partial MMR
|
||||
let mut store: MerkleStore = MerkleStore::new();
|
||||
store.extend(partial_mmr.inner_nodes([(1, node1), (5, node5)].iter().cloned()));
|
||||
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
let index5 = NodeIndex::new(1, 1).unwrap();
|
||||
|
||||
let second_peak = mmr.peaks(mmr.forest).unwrap().peaks()[1];
|
||||
|
||||
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||
let path5 = store.get_path(second_peak, index5).unwrap().path;
|
||||
|
||||
assert_eq!(path1, proof1.merkle_path);
|
||||
assert_eq!(path5, proof5.merkle_path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_add_without_track() {
|
||||
let mut mmr = Mmr::default();
|
||||
let empty_peaks = MmrPeaks::new(0, vec![]).unwrap();
|
||||
let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
|
||||
|
||||
for el in (0..256).map(int_to_node) {
|
||||
mmr.add(el);
|
||||
partial_mmr.add(el, false);
|
||||
|
||||
let mmr_peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(mmr_peaks, partial_mmr.peaks());
|
||||
assert_eq!(mmr.forest(), partial_mmr.forest());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_add_with_track() {
|
||||
let mut mmr = Mmr::default();
|
||||
let empty_peaks = MmrPeaks::new(0, vec![]).unwrap();
|
||||
let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
|
||||
|
||||
for i in 0..256 {
|
||||
let el = int_to_node(i);
|
||||
mmr.add(el);
|
||||
partial_mmr.add(el, true);
|
||||
|
||||
let mmr_peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(mmr_peaks, partial_mmr.peaks());
|
||||
assert_eq!(mmr.forest(), partial_mmr.forest());
|
||||
|
||||
for pos in 0..i {
|
||||
let mmr_proof = mmr.open(pos as usize, mmr.forest()).unwrap();
|
||||
let partialmmr_proof = partial_mmr.open(pos as usize).unwrap().unwrap();
|
||||
assert_eq!(mmr_proof, partialmmr_proof);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_add_existing_track() {
|
||||
let mut mmr = Mmr::from((0..7).map(int_to_node));
|
||||
|
||||
// derive a partial Mmr from it which tracks authentication path to leaf 5
|
||||
let mut partial_mmr = PartialMmr::from_peaks(mmr.peaks(mmr.forest()).unwrap());
|
||||
let path_to_5 = mmr.open(5, mmr.forest()).unwrap().merkle_path;
|
||||
let leaf_at_5 = mmr.get(5).unwrap();
|
||||
partial_mmr.track(5, leaf_at_5, &path_to_5).unwrap();
|
||||
|
||||
// add a new leaf to both Mmr and partial Mmr
|
||||
let leaf_at_7 = int_to_node(7);
|
||||
mmr.add(leaf_at_7);
|
||||
partial_mmr.add(leaf_at_7, false);
|
||||
|
||||
// the openings should be the same
|
||||
assert_eq!(mmr.open(5, mmr.forest()).unwrap(), partial_mmr.open(5).unwrap().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
use super::{
|
||||
super::{RpoDigest, Vec, ZERO},
|
||||
Felt, MmrError, MmrProof, Rpo256, Word,
|
||||
};
|
||||
use super::{super::ZERO, Felt, MmrError, MmrProof, Rpo256, RpoDigest, Word};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
// MMR PEAKS
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MmrPeaks {
|
||||
/// The number of leaves is used to differentiate accumulators that have the same number of
|
||||
/// peaks. This happens because the number of peaks goes up-and-down as the structure is used
|
||||
/// causing existing trees to be merged and new ones to be created. As an example, every time
|
||||
/// the [Mmr] has a power-of-two number of leaves there is a single peak.
|
||||
/// The number of leaves is used to differentiate MMRs that have the same number of peaks. This
|
||||
/// happens because the number of peaks goes up-and-down as the structure is used causing
|
||||
/// existing trees to be merged and new ones to be created. As an example, every time the MMR
|
||||
/// has a power-of-two number of leaves there is a single peak.
|
||||
///
|
||||
/// Every tree in the [Mmr] forest has a distinct power-of-two size, this means only the right
|
||||
/// most tree can have an odd number of elements (e.g. `1`). Additionally this means that the bits in
|
||||
/// `num_leaves` conveniently encode the size of each individual tree.
|
||||
/// Every tree in the MMR forest has a distinct power-of-two size, this means only the right-
|
||||
/// most tree can have an odd number of elements (e.g. `1`). Additionally this means that the
|
||||
/// bits in `num_leaves` conveniently encode the size of each individual tree.
|
||||
///
|
||||
/// Examples:
|
||||
///
|
||||
@@ -25,7 +26,7 @@ pub struct MmrPeaks {
|
||||
/// leftmost tree has `2**3=8` elements, and the right most has `2**2=4` elements.
|
||||
num_leaves: usize,
|
||||
|
||||
/// All the peaks of every tree in the [Mmr] forest. The peaks are always ordered by number of
|
||||
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
|
||||
/// leaves, starting from the peak with most children, to the one with least.
|
||||
///
|
||||
/// Invariant: The length of `peaks` must be equal to the number of true bits in `num_leaves`.
|
||||
@@ -33,6 +34,14 @@ pub struct MmrPeaks {
|
||||
}
|
||||
|
||||
impl MmrPeaks {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns new [MmrPeaks] instantiated from the provided vector of peaks and the number of
|
||||
/// leaves in the underlying MMR.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the number of leaves and the number of peaks are inconsistent.
|
||||
pub fn new(num_leaves: usize, peaks: Vec<RpoDigest>) -> Result<Self, MmrError> {
|
||||
if num_leaves.count_ones() as usize != peaks.len() {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
@@ -44,23 +53,34 @@ impl MmrPeaks {
|
||||
// ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a count of the [Mmr]'s leaves.
|
||||
/// Returns a count of leaves in the underlying MMR.
|
||||
pub fn num_leaves(&self) -> usize {
|
||||
self.num_leaves
|
||||
}
|
||||
|
||||
/// Returns the current peaks of the [Mmr].
|
||||
/// Returns the number of peaks of the underlying MMR.
|
||||
pub fn num_peaks(&self) -> usize {
|
||||
self.peaks.len()
|
||||
}
|
||||
|
||||
/// Returns the list of peaks of the underlying MMR.
|
||||
pub fn peaks(&self) -> &[RpoDigest] {
|
||||
&self.peaks
|
||||
}
|
||||
|
||||
/// Converts this [MmrPeaks] into its components: number of leaves and a vector of peaks of
|
||||
/// the underlying MMR.
|
||||
pub fn into_parts(self) -> (usize, Vec<RpoDigest>) {
|
||||
(self.num_leaves, self.peaks)
|
||||
}
|
||||
|
||||
/// Hashes the peaks.
|
||||
///
|
||||
/// The procedure will:
|
||||
/// - Flatten and pad the peaks to a vector of Felts.
|
||||
/// - Hash the vector of Felts.
|
||||
pub fn hash_peaks(&self) -> Word {
|
||||
Rpo256::hash_elements(&self.flatten_and_pad_peaks()).into()
|
||||
pub fn hash_peaks(&self) -> RpoDigest {
|
||||
Rpo256::hash_elements(&self.flatten_and_pad_peaks())
|
||||
}
|
||||
|
||||
pub fn verify(&self, value: RpoDigest, opening: MmrProof) -> bool {
|
||||
@@ -110,3 +130,9 @@ impl MmrPeaks {
|
||||
elements
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MmrPeaks> for Vec<RpoDigest> {
|
||||
fn from(peaks: MmrPeaks) -> Self {
|
||||
peaks.peaks
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
use super::super::MerklePath;
|
||||
use super::{full::high_bitmask, leaf_to_corresponding_tree};
|
||||
|
||||
// MMR PROOF
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MmrProof {
|
||||
@@ -26,9 +29,78 @@ impl MmrProof {
|
||||
self.position - forest_before
|
||||
}
|
||||
|
||||
/// Returns index of the MMR peak against which the Merkle path in this proof can be verified.
|
||||
pub fn peak_index(&self) -> usize {
|
||||
let root = leaf_to_corresponding_tree(self.position, self.forest)
|
||||
.expect("position must be part of the forest");
|
||||
(self.forest.count_ones() - root - 1) as usize
|
||||
let smaller_peak_mask = 2_usize.pow(root) as usize - 1;
|
||||
let num_smaller_peaks = (self.forest & smaller_peak_mask).count_ones();
|
||||
(self.forest.count_ones() - num_smaller_peaks - 1) as usize
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{MerklePath, MmrProof};
|
||||
|
||||
#[test]
|
||||
fn test_peak_index() {
|
||||
// --- single peak forest ---------------------------------------------
|
||||
let forest = 11;
|
||||
|
||||
// the first 4 leaves belong to peak 0
|
||||
for position in 0..8 {
|
||||
let proof = make_dummy_proof(forest, position);
|
||||
assert_eq!(proof.peak_index(), 0);
|
||||
}
|
||||
|
||||
// --- forest with non-consecutive peaks ------------------------------
|
||||
let forest = 11;
|
||||
|
||||
// the first 8 leaves belong to peak 0
|
||||
for position in 0..8 {
|
||||
let proof = make_dummy_proof(forest, position);
|
||||
assert_eq!(proof.peak_index(), 0);
|
||||
}
|
||||
|
||||
// the next 2 leaves belong to peak 1
|
||||
for position in 8..10 {
|
||||
let proof = make_dummy_proof(forest, position);
|
||||
assert_eq!(proof.peak_index(), 1);
|
||||
}
|
||||
|
||||
// the last leaf is the peak 2
|
||||
let proof = make_dummy_proof(forest, 10);
|
||||
assert_eq!(proof.peak_index(), 2);
|
||||
|
||||
// --- forest with consecutive peaks ----------------------------------
|
||||
let forest = 7;
|
||||
|
||||
// the first 4 leaves belong to peak 0
|
||||
for position in 0..4 {
|
||||
let proof = make_dummy_proof(forest, position);
|
||||
assert_eq!(proof.peak_index(), 0);
|
||||
}
|
||||
|
||||
// the next 2 leaves belong to peak 1
|
||||
for position in 4..6 {
|
||||
let proof = make_dummy_proof(forest, position);
|
||||
assert_eq!(proof.peak_index(), 1);
|
||||
}
|
||||
|
||||
// the last leaf is the peak 2
|
||||
let proof = make_dummy_proof(forest, 6);
|
||||
assert_eq!(proof.peak_index(), 2);
|
||||
}
|
||||
|
||||
fn make_dummy_proof(forest: usize, position: usize) -> MmrProof {
|
||||
MmrProof {
|
||||
forest,
|
||||
position,
|
||||
merkle_path: MerklePath::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
use super::{
|
||||
super::{InnerNodeInfo, Vec},
|
||||
super::{InnerNodeInfo, Rpo256, RpoDigest},
|
||||
bit::TrueBitPositionIterator,
|
||||
full::high_bitmask,
|
||||
leaf_to_corresponding_tree, nodes_in_forest, Mmr, MmrPeaks, PartialMmr, Rpo256,
|
||||
leaf_to_corresponding_tree, nodes_in_forest, Mmr, MmrPeaks, PartialMmr,
|
||||
};
|
||||
use crate::{
|
||||
hash::rpo::RpoDigest,
|
||||
merkle::{int_to_node, InOrderIndex, MerklePath, MerkleTree, MmrProof, NodeIndex},
|
||||
Felt, Word,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
#[test]
|
||||
fn test_position_equal_or_higher_than_leafs_is_never_contained() {
|
||||
@@ -115,13 +115,14 @@ const LEAVES: [RpoDigest; 7] = [
|
||||
|
||||
#[test]
|
||||
fn test_mmr_simple() {
|
||||
let mut postorder = Vec::new();
|
||||
postorder.push(LEAVES[0]);
|
||||
postorder.push(LEAVES[1]);
|
||||
postorder.push(merge(LEAVES[0], LEAVES[1]));
|
||||
postorder.push(LEAVES[2]);
|
||||
postorder.push(LEAVES[3]);
|
||||
postorder.push(merge(LEAVES[2], LEAVES[3]));
|
||||
let mut postorder = vec![
|
||||
LEAVES[0],
|
||||
LEAVES[1],
|
||||
merge(LEAVES[0], LEAVES[1]),
|
||||
LEAVES[2],
|
||||
LEAVES[3],
|
||||
merge(LEAVES[2], LEAVES[3]),
|
||||
];
|
||||
postorder.push(merge(postorder[2], postorder[5]));
|
||||
postorder.push(LEAVES[4]);
|
||||
postorder.push(LEAVES[5]);
|
||||
@@ -137,7 +138,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 1);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 1);
|
||||
assert_eq!(acc.peaks(), &[postorder[0]]);
|
||||
|
||||
@@ -146,7 +147,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 3);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 2);
|
||||
assert_eq!(acc.peaks(), &[postorder[2]]);
|
||||
|
||||
@@ -155,7 +156,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 4);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 3);
|
||||
assert_eq!(acc.peaks(), &[postorder[2], postorder[3]]);
|
||||
|
||||
@@ -164,7 +165,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 7);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 4);
|
||||
assert_eq!(acc.peaks(), &[postorder[6]]);
|
||||
|
||||
@@ -173,7 +174,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 8);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 5);
|
||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[7]]);
|
||||
|
||||
@@ -182,7 +183,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 10);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 6);
|
||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[9]]);
|
||||
|
||||
@@ -191,7 +192,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 11);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(acc.num_leaves(), 7);
|
||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[9], postorder[10]]);
|
||||
}
|
||||
@@ -203,96 +204,139 @@ fn test_mmr_open() {
|
||||
let h23 = merge(LEAVES[2], LEAVES[3]);
|
||||
|
||||
// node at pos 7 is the root
|
||||
assert!(mmr.open(7).is_err(), "Element 7 is not in the tree, result should be None");
|
||||
assert!(
|
||||
mmr.open(7, mmr.forest()).is_err(),
|
||||
"Element 7 is not in the tree, result should be None"
|
||||
);
|
||||
|
||||
// node at pos 6 is the root
|
||||
let empty: MerklePath = MerklePath::new(vec![]);
|
||||
let opening = mmr
|
||||
.open(6)
|
||||
.open(6, mmr.forest())
|
||||
.expect("Element 6 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, empty);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 6);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[6], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[6], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
// nodes 4,5 are depth 1
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[4]]);
|
||||
let opening = mmr
|
||||
.open(5)
|
||||
.open(5, mmr.forest())
|
||||
.expect("Element 5 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 5);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[5], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[5], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[5]]);
|
||||
let opening = mmr
|
||||
.open(4)
|
||||
.open(4, mmr.forest())
|
||||
.expect("Element 4 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 4);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[4], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[4], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
// nodes 0,1,2,3 are detph 2
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[2], h01]);
|
||||
let opening = mmr
|
||||
.open(3)
|
||||
.open(3, mmr.forest())
|
||||
.expect("Element 3 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 3);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[3], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[3], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[3], h01]);
|
||||
let opening = mmr
|
||||
.open(2)
|
||||
.open(2, mmr.forest())
|
||||
.expect("Element 2 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 2);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[2], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[2], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[0], h23]);
|
||||
let opening = mmr
|
||||
.open(1)
|
||||
.open(1, mmr.forest())
|
||||
.expect("Element 1 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 1);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[1], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[1], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[1], h23]);
|
||||
let opening = mmr
|
||||
.open(0)
|
||||
.open(0, mmr.forest())
|
||||
.expect("Element 0 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 0);
|
||||
assert!(
|
||||
mmr.accumulator().verify(LEAVES[0], opening),
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[0], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_open_older_version() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
fn is_even(v: &usize) -> bool {
|
||||
v & 1 == 0
|
||||
}
|
||||
|
||||
// merkle path of a node is empty if there are no elements to pair with it
|
||||
for pos in (0..mmr.forest()).filter(is_even) {
|
||||
let forest = pos + 1;
|
||||
let proof = mmr.open(pos, forest).unwrap();
|
||||
assert_eq!(proof.forest, forest);
|
||||
assert_eq!(proof.merkle_path.nodes(), []);
|
||||
assert_eq!(proof.position, pos);
|
||||
}
|
||||
|
||||
// openings match that of a merkle tree
|
||||
let mtree: MerkleTree = LEAVES[..4].try_into().unwrap();
|
||||
for forest in 4..=LEAVES.len() {
|
||||
for pos in 0..4 {
|
||||
let idx = NodeIndex::new(2, pos).unwrap();
|
||||
let path = mtree.get_path(idx).unwrap();
|
||||
let proof = mmr.open(pos as usize, forest).unwrap();
|
||||
assert_eq!(path, proof.merkle_path);
|
||||
}
|
||||
}
|
||||
let mtree: MerkleTree = LEAVES[4..6].try_into().unwrap();
|
||||
for forest in 6..=LEAVES.len() {
|
||||
for pos in 0..2 {
|
||||
let idx = NodeIndex::new(1, pos).unwrap();
|
||||
let path = mtree.get_path(idx).unwrap();
|
||||
// account for the bigger tree with 4 elements
|
||||
let mmr_pos = (pos + 4) as usize;
|
||||
let proof = mmr.open(mmr_pos, forest).unwrap();
|
||||
assert_eq!(path, proof.merkle_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tests the openings of a simple Mmr with a single tree of depth 8.
|
||||
#[test]
|
||||
fn test_mmr_open_eight() {
|
||||
@@ -313,49 +357,49 @@ fn test_mmr_open_eight() {
|
||||
let root = mtree.root();
|
||||
|
||||
let position = 0;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 1;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 2;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 3;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 4;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 5;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 6;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 7;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
@@ -371,47 +415,47 @@ fn test_mmr_open_seven() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
let position = 0;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[0]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 1;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[1]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 2;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(2, LEAVES[2]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 3;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(3, LEAVES[3]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 4;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 0u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[4]).unwrap(), mtree2.root());
|
||||
|
||||
let position = 5;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 1u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[5]).unwrap(), mtree2.root());
|
||||
|
||||
let position = 6;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let merkle_path: MerklePath = [].as_ref().into();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[6]).unwrap(), LEAVES[6]);
|
||||
@@ -435,7 +479,7 @@ fn test_mmr_invariants() {
|
||||
let mut mmr = Mmr::new();
|
||||
for v in 1..=1028 {
|
||||
mmr.add(int_to_node(v));
|
||||
let accumulator = mmr.accumulator();
|
||||
let accumulator = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(v as usize, mmr.forest(), "MMR leaf count must increase by one on every add");
|
||||
assert_eq!(
|
||||
v as usize,
|
||||
@@ -516,10 +560,50 @@ fn test_mmr_inner_nodes() {
|
||||
assert_eq!(postorder, nodes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_peaks() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
let forest = 0b0001;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[0]]);
|
||||
|
||||
let forest = 0b0010;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[2]]);
|
||||
|
||||
let forest = 0b0011;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[2], mmr.nodes[3]]);
|
||||
|
||||
let forest = 0b0100;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6]]);
|
||||
|
||||
let forest = 0b0101;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[7]]);
|
||||
|
||||
let forest = 0b0110;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9]]);
|
||||
|
||||
let forest = 0b0111;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9], mmr.nodes[10]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_hash_peaks() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let peaks = mmr.accumulator();
|
||||
let peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||
|
||||
let first_peak = Rpo256::merge(&[
|
||||
Rpo256::merge(&[LEAVES[0], LEAVES[1]]),
|
||||
@@ -531,10 +615,7 @@ fn test_mmr_hash_peaks() {
|
||||
// minimum length is 16
|
||||
let mut expected_peaks = [first_peak, second_peak, third_peak].to_vec();
|
||||
expected_peaks.resize(16, RpoDigest::default());
|
||||
assert_eq!(
|
||||
peaks.hash_peaks(),
|
||||
*Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||
);
|
||||
assert_eq!(peaks.hash_peaks(), Rpo256::hash_elements(&digests_to_elements(&expected_peaks)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -552,7 +633,7 @@ fn test_mmr_peaks_hash_less_than_16() {
|
||||
expected_peaks.resize(16, RpoDigest::default());
|
||||
assert_eq!(
|
||||
accumulator.hash_peaks(),
|
||||
*Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||
Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -569,47 +650,47 @@ fn test_mmr_peaks_hash_odd() {
|
||||
expected_peaks.resize(18, RpoDigest::default());
|
||||
assert_eq!(
|
||||
accumulator.hash_peaks(),
|
||||
*Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||
Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_updates() {
|
||||
fn test_mmr_delta() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
|
||||
// original_forest can't have more elements
|
||||
assert!(
|
||||
mmr.get_delta(LEAVES.len() + 1).is_err(),
|
||||
mmr.get_delta(LEAVES.len() + 1, mmr.forest()).is_err(),
|
||||
"Can not provide updates for a newer Mmr"
|
||||
);
|
||||
|
||||
// if the number of elements is the same there is no change
|
||||
assert!(
|
||||
mmr.get_delta(LEAVES.len()).unwrap().data.is_empty(),
|
||||
mmr.get_delta(LEAVES.len(), mmr.forest()).unwrap().data.is_empty(),
|
||||
"There are no updates for the same Mmr version"
|
||||
);
|
||||
|
||||
// missing the last element added, which is itself a tree peak
|
||||
assert_eq!(mmr.get_delta(6).unwrap().data, vec![acc.peaks()[2]], "one peak");
|
||||
assert_eq!(mmr.get_delta(6, mmr.forest()).unwrap().data, vec![acc.peaks()[2]], "one peak");
|
||||
|
||||
// missing the sibling to complete the tree of depth 2, and the last element
|
||||
assert_eq!(
|
||||
mmr.get_delta(5).unwrap().data,
|
||||
mmr.get_delta(5, mmr.forest()).unwrap().data,
|
||||
vec![LEAVES[5], acc.peaks()[2]],
|
||||
"one sibling, one peak"
|
||||
);
|
||||
|
||||
// missing the whole last two trees, only send the peaks
|
||||
assert_eq!(
|
||||
mmr.get_delta(4).unwrap().data,
|
||||
mmr.get_delta(4, mmr.forest()).unwrap().data,
|
||||
vec![acc.peaks()[1], acc.peaks()[2]],
|
||||
"two peaks"
|
||||
);
|
||||
|
||||
// missing the sibling to complete the first tree, and the two last trees
|
||||
assert_eq!(
|
||||
mmr.get_delta(3).unwrap().data,
|
||||
mmr.get_delta(3, mmr.forest()).unwrap().data,
|
||||
vec![LEAVES[3], acc.peaks()[1], acc.peaks()[2]],
|
||||
"one sibling, two peaks"
|
||||
);
|
||||
@@ -617,37 +698,79 @@ fn test_mmr_updates() {
|
||||
// missing half of the first tree, only send the computed element (not the leaves), and the new
|
||||
// peaks
|
||||
assert_eq!(
|
||||
mmr.get_delta(2).unwrap().data,
|
||||
mmr.get_delta(2, mmr.forest()).unwrap().data,
|
||||
vec![mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
|
||||
"one sibling, two peaks"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
mmr.get_delta(1).unwrap().data,
|
||||
mmr.get_delta(1, mmr.forest()).unwrap().data,
|
||||
vec![LEAVES[1], mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
|
||||
"one sibling, two peaks"
|
||||
);
|
||||
|
||||
assert_eq!(&mmr.get_delta(0).unwrap().data, acc.peaks(), "all peaks");
|
||||
assert_eq!(&mmr.get_delta(0, mmr.forest()).unwrap().data, acc.peaks(), "all peaks");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_delta_old_forest() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
// from_forest must be smaller-or-equal to to_forest
|
||||
for version in 1..=mmr.forest() {
|
||||
assert!(mmr.get_delta(version + 1, version).is_err());
|
||||
}
|
||||
|
||||
// when from_forest and to_forest are equal, there are no updates
|
||||
for version in 1..=mmr.forest() {
|
||||
let delta = mmr.get_delta(version, version).unwrap();
|
||||
assert!(delta.data.is_empty());
|
||||
assert_eq!(delta.forest, version);
|
||||
}
|
||||
|
||||
// test update which merges the odd peak to the right
|
||||
for count in 0..(mmr.forest() / 2) {
|
||||
// *2 because every iteration tests a pair
|
||||
// +1 because the Mmr is 1-indexed
|
||||
let from_forest = (count * 2) + 1;
|
||||
let to_forest = (count * 2) + 2;
|
||||
let delta = mmr.get_delta(from_forest, to_forest).unwrap();
|
||||
|
||||
// *2 because every iteration tests a pair
|
||||
// +1 because sibling is the odd element
|
||||
let sibling = (count * 2) + 1;
|
||||
assert_eq!(delta.data, [LEAVES[sibling]]);
|
||||
assert_eq!(delta.forest, to_forest);
|
||||
}
|
||||
|
||||
let version = 4;
|
||||
let delta = mmr.get_delta(1, version).unwrap();
|
||||
assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5]]);
|
||||
assert_eq!(delta.forest, version);
|
||||
|
||||
let version = 5;
|
||||
let delta = mmr.get_delta(1, version).unwrap();
|
||||
assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5], mmr.nodes[7]]);
|
||||
assert_eq!(delta.forest, version);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_simple() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let acc = mmr.accumulator();
|
||||
let mut partial: PartialMmr = acc.clone().into();
|
||||
let peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||
let mut partial: PartialMmr = peaks.clone().into();
|
||||
|
||||
// check initial state of the partial mmr
|
||||
assert_eq!(partial.peaks(), acc.peaks());
|
||||
assert_eq!(partial.forest(), acc.num_leaves());
|
||||
assert_eq!(partial.peaks(), peaks);
|
||||
assert_eq!(partial.forest(), peaks.num_leaves());
|
||||
assert_eq!(partial.forest(), LEAVES.len());
|
||||
assert_eq!(partial.peaks().len(), 3);
|
||||
assert_eq!(partial.peaks().num_peaks(), 3);
|
||||
assert_eq!(partial.nodes.len(), 0);
|
||||
|
||||
// check state after adding tracking one element
|
||||
let proof1 = mmr.open(0).unwrap();
|
||||
let proof1 = mmr.open(0, mmr.forest()).unwrap();
|
||||
let el1 = mmr.get(proof1.position).unwrap();
|
||||
partial.add(proof1.position, el1, &proof1.merkle_path).unwrap();
|
||||
partial.track(proof1.position, el1, &proof1.merkle_path).unwrap();
|
||||
|
||||
// check the number of nodes increased by the number of nodes in the proof
|
||||
assert_eq!(partial.nodes.len(), proof1.merkle_path.len());
|
||||
@@ -657,9 +780,9 @@ fn test_partial_mmr_simple() {
|
||||
let idx = idx.parent();
|
||||
assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[1]);
|
||||
|
||||
let proof2 = mmr.open(1).unwrap();
|
||||
let proof2 = mmr.open(1, mmr.forest()).unwrap();
|
||||
let el2 = mmr.get(proof2.position).unwrap();
|
||||
partial.add(proof2.position, el2, &proof2.merkle_path).unwrap();
|
||||
partial.track(proof2.position, el2, &proof2.merkle_path).unwrap();
|
||||
|
||||
// check the number of nodes increased by a single element (the one that is not shared)
|
||||
assert_eq!(partial.nodes.len(), 3);
|
||||
@@ -675,22 +798,22 @@ fn test_partial_mmr_update_single() {
|
||||
let mut full = Mmr::new();
|
||||
let zero = int_to_node(0);
|
||||
full.add(zero);
|
||||
let mut partial: PartialMmr = full.accumulator().into();
|
||||
let mut partial: PartialMmr = full.peaks(full.forest()).unwrap().into();
|
||||
|
||||
let proof = full.open(0).unwrap();
|
||||
partial.add(proof.position, zero, &proof.merkle_path).unwrap();
|
||||
let proof = full.open(0, full.forest()).unwrap();
|
||||
partial.track(proof.position, zero, &proof.merkle_path).unwrap();
|
||||
|
||||
for i in 1..100 {
|
||||
let node = int_to_node(i);
|
||||
full.add(node);
|
||||
let delta = full.get_delta(partial.forest()).unwrap();
|
||||
let delta = full.get_delta(partial.forest(), full.forest()).unwrap();
|
||||
partial.apply(delta).unwrap();
|
||||
|
||||
assert_eq!(partial.forest(), full.forest());
|
||||
assert_eq!(partial.peaks(), full.accumulator().peaks());
|
||||
assert_eq!(partial.peaks(), full.peaks(full.forest()).unwrap());
|
||||
|
||||
let proof1 = full.open(i as usize).unwrap();
|
||||
partial.add(proof1.position, node, &proof1.merkle_path).unwrap();
|
||||
let proof1 = full.open(i as usize, full.forest()).unwrap();
|
||||
partial.track(proof1.position, node, &proof1.merkle_path).unwrap();
|
||||
let proof2 = partial.open(proof1.position).unwrap().unwrap();
|
||||
assert_eq!(proof1.merkle_path, proof2.merkle_path);
|
||||
}
|
||||
@@ -699,25 +822,26 @@ fn test_partial_mmr_update_single() {
|
||||
#[test]
|
||||
fn test_mmr_add_invalid_odd_leaf() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let acc = mmr.accumulator();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
let mut partial: PartialMmr = acc.clone().into();
|
||||
|
||||
let empty = MerklePath::new(Vec::new());
|
||||
|
||||
// None of the other leaves should work
|
||||
for node in LEAVES.iter().cloned().rev().skip(1) {
|
||||
let result = partial.add(LEAVES.len() - 1, node, &empty);
|
||||
let result = partial.track(LEAVES.len() - 1, node, &empty);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
let result = partial.add(LEAVES.len() - 1, LEAVES[6], &empty);
|
||||
let result = partial.track(LEAVES.len() - 1, LEAVES[6], &empty);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
mod property_tests {
|
||||
use super::leaf_to_corresponding_tree;
|
||||
use proptest::prelude::*;
|
||||
|
||||
use super::leaf_to_corresponding_tree;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_last_position_is_always_contained_in_the_last_tree(leaves in any::<usize>().prop_filter("cant have an empty tree", |v| *v != 0)) {
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
|
||||
use super::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, TryApplyDiff, Vec},
|
||||
Felt, StarkField, Word, EMPTY_WORD, ZERO,
|
||||
Felt, Word, EMPTY_WORD, ZERO,
|
||||
};
|
||||
|
||||
// REEXPORTS
|
||||
@@ -12,9 +11,6 @@ use super::{
|
||||
mod empty_roots;
|
||||
pub use empty_roots::EmptySubtreeRoots;
|
||||
|
||||
mod delta;
|
||||
pub use delta::{merkle_tree_delta, MerkleStoreDelta, MerkleTreeDelta};
|
||||
|
||||
mod index;
|
||||
pub use index::NodeIndex;
|
||||
|
||||
@@ -24,14 +20,14 @@ pub use merkle_tree::{path_to_text, tree_to_text, MerkleTree};
|
||||
mod path;
|
||||
pub use path::{MerklePath, RootPath, ValuePath};
|
||||
|
||||
mod simple_smt;
|
||||
pub use simple_smt::SimpleSmt;
|
||||
|
||||
mod tiered_smt;
|
||||
pub use tiered_smt::{TieredSmt, TieredSmtProof, TieredSmtProofError};
|
||||
mod smt;
|
||||
pub use smt::{
|
||||
LeafIndex, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH,
|
||||
SMT_MAX_DEPTH, SMT_MIN_DEPTH,
|
||||
};
|
||||
|
||||
mod mmr;
|
||||
pub use mmr::{InOrderIndex, Mmr, MmrError, MmrPeaks, MmrProof, PartialMmr};
|
||||
pub use mmr::{InOrderIndex, Mmr, MmrDelta, MmrError, MmrPeaks, MmrProof, PartialMmr};
|
||||
|
||||
mod store;
|
||||
pub use store::{DefaultMerkleStore, MerkleStore, RecordingMerkleStore, StoreNode};
|
||||
@@ -59,6 +55,6 @@ const fn int_to_leaf(value: u64) -> Word {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn digests_to_words(digests: &[RpoDigest]) -> Vec<Word> {
|
||||
fn digests_to_words(digests: &[RpoDigest]) -> alloc::vec::Vec<Word> {
|
||||
digests.iter().map(|d| d.into()).collect()
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::hash::rpo::RpoDigest;
|
||||
use super::RpoDigest;
|
||||
|
||||
/// Representation of a node with two children used for iterating over containers.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
use super::{
|
||||
BTreeMap, BTreeSet, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest,
|
||||
ValuePath, Vec, Word, EMPTY_WORD,
|
||||
};
|
||||
use crate::utils::{
|
||||
format, string::String, vec, word_to_hex, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, Serializable,
|
||||
use alloc::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
string::String,
|
||||
vec::Vec,
|
||||
};
|
||||
use core::fmt;
|
||||
|
||||
use super::{
|
||||
InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, ValuePath, Word,
|
||||
EMPTY_WORD,
|
||||
};
|
||||
use crate::utils::{
|
||||
word_to_hex, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
@@ -109,9 +114,9 @@ impl PartialMerkleTree {
|
||||
|
||||
// check if the number of leaves can be accommodated by the tree's depth; we use a min
|
||||
// depth of 63 because we consider passing in a vector of size 2^64 infeasible.
|
||||
let max = (1_u64 << 63) as usize;
|
||||
let max = 2usize.pow(63);
|
||||
if layers.len() > max {
|
||||
return Err(MerkleError::InvalidNumEntries(max, layers.len()));
|
||||
return Err(MerkleError::InvalidNumEntries(max));
|
||||
}
|
||||
|
||||
// Get maximum depth
|
||||
@@ -179,7 +184,7 @@ impl PartialMerkleTree {
|
||||
/// # Errors
|
||||
/// Returns an error if the specified NodeIndex is not contained in the nodes map.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
self.nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index)).map(|hash| *hash)
|
||||
self.nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index)).copied()
|
||||
}
|
||||
|
||||
/// Returns true if provided index contains in the leaves set, false otherwise.
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use super::{
|
||||
super::{
|
||||
digests_to_words, int_to_node, BTreeMap, DefaultMerkleStore as MerkleStore, MerkleTree,
|
||||
NodeIndex, PartialMerkleTree,
|
||||
digests_to_words, int_to_node, DefaultMerkleStore as MerkleStore, MerkleTree, NodeIndex,
|
||||
PartialMerkleTree,
|
||||
},
|
||||
Deserializable, InnerNodeInfo, RpoDigest, Serializable, ValuePath, Vec,
|
||||
Deserializable, InnerNodeInfo, RpoDigest, Serializable, ValuePath,
|
||||
};
|
||||
use alloc::{collections::BTreeMap, vec::Vec};
|
||||
|
||||
// TEST DATA
|
||||
// ================================================================================================
|
||||
@@ -209,7 +210,7 @@ fn get_paths() {
|
||||
// Which have leaf nodes 20, 22, 23, 32 and 33. Hence overall we will have 5 paths -- one path
|
||||
// for each leaf.
|
||||
|
||||
let leaves = vec![NODE20, NODE22, NODE23, NODE32, NODE33];
|
||||
let leaves = [NODE20, NODE22, NODE23, NODE32, NODE33];
|
||||
let expected_paths: Vec<(NodeIndex, ValuePath)> = leaves
|
||||
.iter()
|
||||
.map(|&leaf| {
|
||||
@@ -257,7 +258,7 @@ fn leaves() {
|
||||
let value32 = mt.get_node(NODE32).unwrap();
|
||||
let value33 = mt.get_node(NODE33).unwrap();
|
||||
|
||||
let leaves = vec![(NODE11, value11), (NODE20, value20), (NODE32, value32), (NODE33, value33)];
|
||||
let leaves = [(NODE11, value11), (NODE20, value20), (NODE32, value32), (NODE33, value33)];
|
||||
|
||||
let expected_leaves = leaves.iter().copied();
|
||||
assert!(expected_leaves.eq(pmt.leaves()));
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
use super::{vec, InnerNodeInfo, MerkleError, NodeIndex, Rpo256, RpoDigest, Vec};
|
||||
use alloc::vec::Vec;
|
||||
use core::ops::{Deref, DerefMut};
|
||||
|
||||
use super::{InnerNodeInfo, MerkleError, NodeIndex, Rpo256, RpoDigest};
|
||||
use crate::{
|
||||
utils::{ByteReader, Deserializable, DeserializationError, Serializable},
|
||||
Word,
|
||||
};
|
||||
|
||||
// MERKLE PATH
|
||||
// ================================================================================================
|
||||
|
||||
@@ -17,6 +23,7 @@ impl MerklePath {
|
||||
|
||||
/// Creates a new Merkle path from a list of nodes.
|
||||
pub fn new(nodes: Vec<RpoDigest>) -> Self {
|
||||
assert!(nodes.len() <= u8::MAX.into(), "MerklePath may have at most 256 items");
|
||||
Self { nodes }
|
||||
}
|
||||
|
||||
@@ -122,7 +129,7 @@ impl FromIterator<RpoDigest> for MerklePath {
|
||||
|
||||
impl IntoIterator for MerklePath {
|
||||
type Item = RpoDigest;
|
||||
type IntoIter = vec::IntoIter<RpoDigest>;
|
||||
type IntoIter = alloc::vec::IntoIter<RpoDigest>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.nodes.into_iter()
|
||||
@@ -161,7 +168,7 @@ impl<'a> Iterator for InnerNodeIterator<'a> {
|
||||
// MERKLE PATH CONTAINERS
|
||||
// ================================================================================================
|
||||
|
||||
/// A container for a [Word] value and its [MerklePath] opening.
|
||||
/// A container for a [crate::Word] value and its [MerklePath] opening.
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
||||
pub struct ValuePath {
|
||||
/// The node value opening for `path`.
|
||||
@@ -172,12 +179,18 @@ pub struct ValuePath {
|
||||
|
||||
impl ValuePath {
|
||||
/// Returns a new [ValuePath] instantiated from the specified value and path.
|
||||
pub fn new(value: RpoDigest, path: Vec<RpoDigest>) -> Self {
|
||||
Self { value, path: MerklePath::new(path) }
|
||||
pub fn new(value: RpoDigest, path: MerklePath) -> Self {
|
||||
Self { value, path }
|
||||
}
|
||||
}
|
||||
|
||||
/// A container for a [MerklePath] and its [Word] root.
|
||||
impl From<(MerklePath, Word)> for ValuePath {
|
||||
fn from((path, value): (MerklePath, Word)) -> Self {
|
||||
ValuePath::new(value.into(), path)
|
||||
}
|
||||
}
|
||||
|
||||
/// A container for a [MerklePath] and its [crate::Word] root.
|
||||
///
|
||||
/// This structure does not provide any guarantees regarding the correctness of the path to the
|
||||
/// root. For more information, check [MerklePath::verify].
|
||||
@@ -189,6 +202,55 @@ pub struct RootPath {
|
||||
pub path: MerklePath,
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for MerklePath {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
assert!(self.nodes.len() <= u8::MAX.into(), "Length enforced in the constructor");
|
||||
target.write_u8(self.nodes.len() as u8);
|
||||
target.write_many(&self.nodes);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for MerklePath {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let count = source.read_u8()?.into();
|
||||
let nodes = source.read_many::<RpoDigest>(count)?;
|
||||
Ok(Self { nodes })
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for ValuePath {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
self.value.write_into(target);
|
||||
self.path.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for ValuePath {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let value = RpoDigest::read_from(source)?;
|
||||
let path = MerklePath::read_from(source)?;
|
||||
Ok(Self { value, path })
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for RootPath {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
self.root.write_into(target);
|
||||
self.path.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RootPath {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let root = RpoDigest::read_from(source)?;
|
||||
let path = MerklePath::read_from(source)?;
|
||||
Ok(Self { root, path })
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
|
||||
@@ -1,302 +0,0 @@
|
||||
use super::{
|
||||
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, MerkleTreeDelta,
|
||||
NodeIndex, Rpo256, RpoDigest, StoreNode, TryApplyDiff, Vec, Word,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// SPARSE MERKLE TREE
|
||||
// ================================================================================================
|
||||
|
||||
/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
|
||||
///
|
||||
/// The root of the tree is recomputed on each new leaf update.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct SimpleSmt {
|
||||
depth: u8,
|
||||
root: RpoDigest,
|
||||
leaves: BTreeMap<u64, Word>,
|
||||
branches: BTreeMap<NodeIndex, BranchNode>,
|
||||
empty_hashes: Vec<RpoDigest>,
|
||||
}
|
||||
|
||||
impl SimpleSmt {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Minimum supported depth.
|
||||
pub const MIN_DEPTH: u8 = 1;
|
||||
|
||||
/// Maximum supported depth.
|
||||
pub const MAX_DEPTH: u8 = 64;
|
||||
|
||||
/// Value of an empty leaf.
|
||||
pub const EMPTY_VALUE: Word = super::EMPTY_WORD;
|
||||
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [SimpleSmt] instantiated with the specified depth.
|
||||
///
|
||||
/// All leaves in the returned tree are set to [ZERO; 4].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the depth is 0 or is greater than 64.
|
||||
pub fn new(depth: u8) -> Result<Self, MerkleError> {
|
||||
// validate the range of the depth.
|
||||
if depth < Self::MIN_DEPTH {
|
||||
return Err(MerkleError::DepthTooSmall(depth));
|
||||
} else if Self::MAX_DEPTH < depth {
|
||||
return Err(MerkleError::DepthTooBig(depth as u64));
|
||||
}
|
||||
|
||||
let empty_hashes = EmptySubtreeRoots::empty_hashes(depth).to_vec();
|
||||
let root = empty_hashes[0];
|
||||
|
||||
Ok(Self {
|
||||
root,
|
||||
depth,
|
||||
empty_hashes,
|
||||
leaves: BTreeMap::new(),
|
||||
branches: BTreeMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a new [SimpleSmt] instantiated with the specified depth and with leaves
|
||||
/// set as specified by the provided entries.
|
||||
///
|
||||
/// All leaves omitted from the entries list are set to [ZERO; 4].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - If the depth is 0 or is greater than 64.
|
||||
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
|
||||
/// - The provided entries contain multiple values for the same key.
|
||||
pub fn with_leaves<R, I>(depth: u8, entries: R) -> Result<Self, MerkleError>
|
||||
where
|
||||
R: IntoIterator<IntoIter = I>,
|
||||
I: Iterator<Item = (u64, Word)> + ExactSizeIterator,
|
||||
{
|
||||
// create an empty tree
|
||||
let mut tree = Self::new(depth)?;
|
||||
|
||||
// check if the number of leaves can be accommodated by the tree's depth; we use a min
|
||||
// depth of 63 because we consider passing in a vector of size 2^64 infeasible.
|
||||
let entries = entries.into_iter();
|
||||
let max = 1 << tree.depth.min(63);
|
||||
if entries.len() > max {
|
||||
return Err(MerkleError::InvalidNumEntries(max, entries.len()));
|
||||
}
|
||||
|
||||
// append leaves to the tree returning an error if a duplicate entry for the same key
|
||||
// is found
|
||||
let mut empty_entries = BTreeSet::new();
|
||||
for (key, value) in entries {
|
||||
let old_value = tree.update_leaf(key, value)?;
|
||||
if old_value != Self::EMPTY_VALUE || empty_entries.contains(&key) {
|
||||
return Err(MerkleError::DuplicateValuesForIndex(key));
|
||||
}
|
||||
// if we've processed an empty entry, add the key to the set of empty entry keys, and
|
||||
// if this key was already in the set, return an error
|
||||
if value == Self::EMPTY_VALUE && !empty_entries.insert(key) {
|
||||
return Err(MerkleError::DuplicateValuesForIndex(key));
|
||||
}
|
||||
}
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the root of this Merkle tree.
|
||||
pub const fn root(&self) -> RpoDigest {
|
||||
self.root
|
||||
}
|
||||
|
||||
/// Returns the depth of this Merkle tree.
|
||||
pub const fn depth(&self) -> u8 {
|
||||
self.depth
|
||||
}
|
||||
|
||||
/// Returns a node at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified index has depth set to 0 or the depth is greater than
|
||||
/// the depth of this Merkle tree.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
if index.is_root() {
|
||||
Err(MerkleError::DepthTooSmall(index.depth()))
|
||||
} else if index.depth() > self.depth() {
|
||||
Err(MerkleError::DepthTooBig(index.depth() as u64))
|
||||
} else if index.depth() == self.depth() {
|
||||
// the lookup in empty_hashes could fail only if empty_hashes were not built correctly
|
||||
// by the constructor as we check the depth of the lookup above.
|
||||
Ok(RpoDigest::from(
|
||||
self.get_leaf_node(index.value())
|
||||
.unwrap_or_else(|| *self.empty_hashes[index.depth() as usize]),
|
||||
))
|
||||
} else {
|
||||
Ok(self.get_branch_node(&index).parent())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a value of the leaf at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
|
||||
pub fn get_leaf(&self, index: u64) -> Result<Word, MerkleError> {
|
||||
let index = NodeIndex::new(self.depth, index)?;
|
||||
Ok(self.get_node(index)?.into())
|
||||
}
|
||||
|
||||
/// Returns a Merkle path from the node at the specified index to the root.
|
||||
///
|
||||
/// The node itself is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified index has depth set to 0 or the depth is greater than
|
||||
/// the depth of this Merkle tree.
|
||||
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
} else if index.depth() > self.depth() {
|
||||
return Err(MerkleError::DepthTooBig(index.depth() as u64));
|
||||
}
|
||||
|
||||
let mut path = Vec::with_capacity(index.depth() as usize);
|
||||
for _ in 0..index.depth() {
|
||||
let is_right = index.is_value_odd();
|
||||
index.move_up();
|
||||
let BranchNode { left, right } = self.get_branch_node(&index);
|
||||
let value = if is_right { left } else { right };
|
||||
path.push(value);
|
||||
}
|
||||
Ok(MerklePath::new(path))
|
||||
}
|
||||
|
||||
/// Return a Merkle path from the leaf at the specified index to the root.
|
||||
///
|
||||
/// The leaf itself is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
|
||||
pub fn get_leaf_path(&self, index: u64) -> Result<MerklePath, MerkleError> {
|
||||
let index = NodeIndex::new(self.depth(), index)?;
|
||||
self.get_path(index)
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over the leaves of this [SimpleSmt].
|
||||
pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
|
||||
self.leaves.iter().map(|(i, w)| (*i, w))
|
||||
}
|
||||
|
||||
/// Returns an iterator over the inner nodes of this Merkle tree.
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.branches.values().map(|e| InnerNodeInfo {
|
||||
value: e.parent(),
|
||||
left: e.left,
|
||||
right: e.right,
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Updates value of the leaf at the specified index returning the old leaf value.
|
||||
///
|
||||
/// This also recomputes all hashes between the leaf and the root, updating the root itself.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
|
||||
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<Word, MerkleError> {
|
||||
let old_value = self.insert_leaf_node(index, value).unwrap_or(Self::EMPTY_VALUE);
|
||||
|
||||
// if the old value and new value are the same, there is nothing to update
|
||||
if value == old_value {
|
||||
return Ok(value);
|
||||
}
|
||||
|
||||
let mut index = NodeIndex::new(self.depth(), index)?;
|
||||
let mut value = RpoDigest::from(value);
|
||||
for _ in 0..index.depth() {
|
||||
let is_right = index.is_value_odd();
|
||||
index.move_up();
|
||||
let BranchNode { left, right } = self.get_branch_node(&index);
|
||||
let (left, right) = if is_right { (left, value) } else { (value, right) };
|
||||
self.insert_branch_node(index, left, right);
|
||||
value = Rpo256::merge(&[left, right]);
|
||||
}
|
||||
self.root = value;
|
||||
Ok(old_value)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn get_leaf_node(&self, key: u64) -> Option<Word> {
|
||||
self.leaves.get(&key).copied()
|
||||
}
|
||||
|
||||
fn insert_leaf_node(&mut self, key: u64, node: Word) -> Option<Word> {
|
||||
self.leaves.insert(key, node)
|
||||
}
|
||||
|
||||
fn get_branch_node(&self, index: &NodeIndex) -> BranchNode {
|
||||
self.branches.get(index).cloned().unwrap_or_else(|| {
|
||||
let node = self.empty_hashes[index.depth() as usize + 1];
|
||||
BranchNode { left: node, right: node }
|
||||
})
|
||||
}
|
||||
|
||||
fn insert_branch_node(&mut self, index: NodeIndex, left: RpoDigest, right: RpoDigest) {
|
||||
let branch = BranchNode { left, right };
|
||||
self.branches.insert(index, branch);
|
||||
}
|
||||
}
|
||||
|
||||
// BRANCH NODE
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
struct BranchNode {
|
||||
left: RpoDigest,
|
||||
right: RpoDigest,
|
||||
}
|
||||
|
||||
impl BranchNode {
|
||||
fn parent(&self) -> RpoDigest {
|
||||
Rpo256::merge(&[self.left, self.right])
|
||||
}
|
||||
}
|
||||
|
||||
// TRY APPLY DIFF
|
||||
// ================================================================================================
|
||||
impl TryApplyDiff<RpoDigest, StoreNode> for SimpleSmt {
|
||||
type Error = MerkleError;
|
||||
type DiffType = MerkleTreeDelta;
|
||||
|
||||
fn try_apply(&mut self, diff: MerkleTreeDelta) -> Result<(), MerkleError> {
|
||||
if diff.depth() != self.depth() {
|
||||
return Err(MerkleError::InvalidDepth {
|
||||
expected: self.depth(),
|
||||
provided: diff.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
for slot in diff.cleared_slots() {
|
||||
self.update_leaf(*slot, Self::EMPTY_VALUE)?;
|
||||
}
|
||||
|
||||
for (slot, value) in diff.updated_slots() {
|
||||
self.update_leaf(*slot, *value)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,251 +0,0 @@
|
||||
use super::{
|
||||
super::{InnerNodeInfo, MerkleError, MerkleTree, RpoDigest, SimpleSmt, EMPTY_WORD},
|
||||
NodeIndex, Rpo256, Vec,
|
||||
};
|
||||
use crate::{
|
||||
merkle::{digests_to_words, int_to_leaf, int_to_node},
|
||||
Word,
|
||||
};
|
||||
|
||||
// TEST DATA
|
||||
// ================================================================================================
|
||||
|
||||
const KEYS4: [u64; 4] = [0, 1, 2, 3];
|
||||
const KEYS8: [u64; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
|
||||
|
||||
const VALUES4: [RpoDigest; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
|
||||
|
||||
const VALUES8: [RpoDigest; 8] = [
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
int_to_node(3),
|
||||
int_to_node(4),
|
||||
int_to_node(5),
|
||||
int_to_node(6),
|
||||
int_to_node(7),
|
||||
int_to_node(8),
|
||||
];
|
||||
|
||||
const ZERO_VALUES8: [Word; 8] = [int_to_leaf(0); 8];
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn build_empty_tree() {
|
||||
// tree of depth 3
|
||||
let smt = SimpleSmt::new(3).unwrap();
|
||||
let mt = MerkleTree::new(ZERO_VALUES8.to_vec()).unwrap();
|
||||
assert_eq!(mt.root(), smt.root());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_sparse_tree() {
|
||||
let mut smt = SimpleSmt::new(3).unwrap();
|
||||
let mut values = ZERO_VALUES8.to_vec();
|
||||
|
||||
// insert single value
|
||||
let key = 6;
|
||||
let new_node = int_to_leaf(7);
|
||||
values[key as usize] = new_node;
|
||||
let old_value = smt.update_leaf(key, new_node).expect("Failed to update leaf");
|
||||
let mt2 = MerkleTree::new(values.clone()).unwrap();
|
||||
assert_eq!(mt2.root(), smt.root());
|
||||
assert_eq!(
|
||||
mt2.get_path(NodeIndex::make(3, 6)).unwrap(),
|
||||
smt.get_path(NodeIndex::make(3, 6)).unwrap()
|
||||
);
|
||||
assert_eq!(old_value, EMPTY_WORD);
|
||||
|
||||
// insert second value at distinct leaf branch
|
||||
let key = 2;
|
||||
let new_node = int_to_leaf(3);
|
||||
values[key as usize] = new_node;
|
||||
let old_value = smt.update_leaf(key, new_node).expect("Failed to update leaf");
|
||||
let mt3 = MerkleTree::new(values).unwrap();
|
||||
assert_eq!(mt3.root(), smt.root());
|
||||
assert_eq!(
|
||||
mt3.get_path(NodeIndex::make(3, 2)).unwrap(),
|
||||
smt.get_path(NodeIndex::make(3, 2)).unwrap()
|
||||
);
|
||||
assert_eq!(old_value, EMPTY_WORD);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_depth2_tree() {
|
||||
let tree =
|
||||
SimpleSmt::with_leaves(2, KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()))
|
||||
.unwrap();
|
||||
|
||||
// check internal structure
|
||||
let (root, node2, node3) = compute_internal_nodes();
|
||||
assert_eq!(root, tree.root());
|
||||
assert_eq!(node2, tree.get_node(NodeIndex::make(1, 0)).unwrap());
|
||||
assert_eq!(node3, tree.get_node(NodeIndex::make(1, 1)).unwrap());
|
||||
|
||||
// check get_node()
|
||||
assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
|
||||
assert_eq!(VALUES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
|
||||
assert_eq!(VALUES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
|
||||
assert_eq!(VALUES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
|
||||
|
||||
// check get_path(): depth 2
|
||||
assert_eq!(vec![VALUES4[1], node3], *tree.get_path(NodeIndex::make(2, 0)).unwrap());
|
||||
assert_eq!(vec![VALUES4[0], node3], *tree.get_path(NodeIndex::make(2, 1)).unwrap());
|
||||
assert_eq!(vec![VALUES4[3], node2], *tree.get_path(NodeIndex::make(2, 2)).unwrap());
|
||||
assert_eq!(vec![VALUES4[2], node2], *tree.get_path(NodeIndex::make(2, 3)).unwrap());
|
||||
|
||||
// check get_path(): depth 1
|
||||
assert_eq!(vec![node3], *tree.get_path(NodeIndex::make(1, 0)).unwrap());
|
||||
assert_eq!(vec![node2], *tree.get_path(NodeIndex::make(1, 1)).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inner_node_iterator() -> Result<(), MerkleError> {
|
||||
let tree =
|
||||
SimpleSmt::with_leaves(2, KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()))
|
||||
.unwrap();
|
||||
|
||||
// check depth 2
|
||||
assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
|
||||
assert_eq!(VALUES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
|
||||
assert_eq!(VALUES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
|
||||
assert_eq!(VALUES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
|
||||
|
||||
// get parent nodes
|
||||
let root = tree.root();
|
||||
let l1n0 = tree.get_node(NodeIndex::make(1, 0))?;
|
||||
let l1n1 = tree.get_node(NodeIndex::make(1, 1))?;
|
||||
let l2n0 = tree.get_node(NodeIndex::make(2, 0))?;
|
||||
let l2n1 = tree.get_node(NodeIndex::make(2, 1))?;
|
||||
let l2n2 = tree.get_node(NodeIndex::make(2, 2))?;
|
||||
let l2n3 = tree.get_node(NodeIndex::make(2, 3))?;
|
||||
|
||||
let nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
|
||||
let expected = vec![
|
||||
InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
|
||||
InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
|
||||
InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
|
||||
];
|
||||
assert_eq!(nodes, expected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_leaf() {
|
||||
let mut tree =
|
||||
SimpleSmt::with_leaves(3, KEYS8.into_iter().zip(digests_to_words(&VALUES8).into_iter()))
|
||||
.unwrap();
|
||||
|
||||
// update one value
|
||||
let key = 3;
|
||||
let new_node = int_to_leaf(9);
|
||||
let mut expected_values = digests_to_words(&VALUES8);
|
||||
expected_values[key] = new_node;
|
||||
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
|
||||
|
||||
let old_leaf = tree.update_leaf(key as u64, new_node).unwrap();
|
||||
assert_eq!(expected_tree.root(), tree.root);
|
||||
assert_eq!(old_leaf, *VALUES8[key]);
|
||||
|
||||
// update another value
|
||||
let key = 6;
|
||||
let new_node = int_to_leaf(10);
|
||||
expected_values[key] = new_node;
|
||||
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
|
||||
|
||||
let old_leaf = tree.update_leaf(key as u64, new_node).unwrap();
|
||||
assert_eq!(expected_tree.root(), tree.root);
|
||||
assert_eq!(old_leaf, *VALUES8[key]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn small_tree_opening_is_consistent() {
|
||||
// ____k____
|
||||
// / \
|
||||
// _i_ _j_
|
||||
// / \ / \
|
||||
// e f g h
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
let e = Rpo256::merge(&[a.into(), b.into()]);
|
||||
let f = Rpo256::merge(&[z.into(), z.into()]);
|
||||
let g = Rpo256::merge(&[c.into(), z.into()]);
|
||||
let h = Rpo256::merge(&[z.into(), d.into()]);
|
||||
|
||||
let i = Rpo256::merge(&[e, f]);
|
||||
let j = Rpo256::merge(&[g, h]);
|
||||
|
||||
let k = Rpo256::merge(&[i, j]);
|
||||
|
||||
let depth = 3;
|
||||
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
|
||||
let tree = SimpleSmt::with_leaves(depth, entries).unwrap();
|
||||
|
||||
assert_eq!(tree.root(), k);
|
||||
|
||||
let cases: Vec<(u8, u64, Vec<RpoDigest>)> = vec![
|
||||
(3, 0, vec![b.into(), f, j]),
|
||||
(3, 1, vec![a.into(), f, j]),
|
||||
(3, 4, vec![z.into(), h, i]),
|
||||
(3, 7, vec![z.into(), g, i]),
|
||||
(2, 0, vec![f, j]),
|
||||
(2, 1, vec![e, j]),
|
||||
(2, 2, vec![h, i]),
|
||||
(2, 3, vec![g, i]),
|
||||
(1, 0, vec![j]),
|
||||
(1, 1, vec![i]),
|
||||
];
|
||||
|
||||
for (depth, key, path) in cases {
|
||||
let opening = tree.get_path(NodeIndex::make(depth, key)).unwrap();
|
||||
|
||||
assert_eq!(path, *opening);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fail_on_duplicates() {
|
||||
let entries = [(1_u64, int_to_leaf(1)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(3))];
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert!(smt.is_err());
|
||||
|
||||
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(0))];
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert!(smt.is_err());
|
||||
|
||||
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(1))];
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert!(smt.is_err());
|
||||
|
||||
let entries = [(1_u64, int_to_leaf(1)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(0))];
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert!(smt.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn with_no_duplicates_empty_node() {
|
||||
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2))];
|
||||
let smt = SimpleSmt::with_leaves(64, entries);
|
||||
assert!(smt.is_ok());
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn compute_internal_nodes() -> (RpoDigest, RpoDigest, RpoDigest) {
|
||||
let node2 = Rpo256::merge(&[VALUES4[0], VALUES4[1]]);
|
||||
let node3 = Rpo256::merge(&[VALUES4[2], VALUES4[3]]);
|
||||
let root = Rpo256::merge(&[node2, node3]);
|
||||
|
||||
(root, node2, node3)
|
||||
}
|
||||
86
src/merkle/smt/full/error.rs
Normal file
86
src/merkle/smt/full/error.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use alloc::vec::Vec;
|
||||
use core::fmt;
|
||||
|
||||
use crate::{
|
||||
hash::rpo::RpoDigest,
|
||||
merkle::{LeafIndex, SMT_DEPTH},
|
||||
Word,
|
||||
};
|
||||
|
||||
// SMT LEAF ERROR
|
||||
// =================================================================================================
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum SmtLeafError {
|
||||
InconsistentKeys {
|
||||
entries: Vec<(RpoDigest, Word)>,
|
||||
key_1: RpoDigest,
|
||||
key_2: RpoDigest,
|
||||
},
|
||||
InvalidNumEntriesForMultiple(usize),
|
||||
SingleKeyInconsistentWithLeafIndex {
|
||||
key: RpoDigest,
|
||||
leaf_index: LeafIndex<SMT_DEPTH>,
|
||||
},
|
||||
MultipleKeysInconsistentWithLeafIndex {
|
||||
leaf_index_from_keys: LeafIndex<SMT_DEPTH>,
|
||||
leaf_index_supplied: LeafIndex<SMT_DEPTH>,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for SmtLeafError {}
|
||||
|
||||
impl fmt::Display for SmtLeafError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use SmtLeafError::*;
|
||||
match self {
|
||||
InvalidNumEntriesForMultiple(num_entries) => {
|
||||
write!(f, "Multiple leaf requires 2 or more entries. Got: {num_entries}")
|
||||
}
|
||||
InconsistentKeys { entries, key_1, key_2 } => {
|
||||
write!(f, "Multiple leaf requires all keys to map to the same leaf index. Offending keys: {key_1} and {key_2}. Entries: {entries:?}.")
|
||||
}
|
||||
SingleKeyInconsistentWithLeafIndex { key, leaf_index } => {
|
||||
write!(
|
||||
f,
|
||||
"Single key in leaf inconsistent with leaf index. Key: {key}, leaf index: {}",
|
||||
leaf_index.value()
|
||||
)
|
||||
}
|
||||
MultipleKeysInconsistentWithLeafIndex {
|
||||
leaf_index_from_keys,
|
||||
leaf_index_supplied,
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"Keys in entries map to leaf index {}, but leaf index {} was supplied",
|
||||
leaf_index_from_keys.value(),
|
||||
leaf_index_supplied.value()
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SMT PROOF ERROR
|
||||
// =================================================================================================
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum SmtProofError {
|
||||
InvalidPathLength(usize),
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for SmtProofError {}
|
||||
|
||||
impl fmt::Display for SmtProofError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use SmtProofError::*;
|
||||
match self {
|
||||
InvalidPathLength(path_length) => {
|
||||
write!(f, "Invalid Merkle path length. Expected {SMT_DEPTH}, got {path_length}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
371
src/merkle/smt/full/leaf.rs
Normal file
371
src/merkle/smt/full/leaf.rs
Normal file
@@ -0,0 +1,371 @@
|
||||
use alloc::{string::ToString, vec::Vec};
|
||||
use core::cmp::Ordering;
|
||||
|
||||
use super::{Felt, LeafIndex, Rpo256, RpoDigest, SmtLeafError, Word, EMPTY_WORD, SMT_DEPTH};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub enum SmtLeaf {
|
||||
Empty(LeafIndex<SMT_DEPTH>),
|
||||
Single((RpoDigest, Word)),
|
||||
Multiple(Vec<(RpoDigest, Word)>),
|
||||
}
|
||||
|
||||
impl SmtLeaf {
|
||||
// CONSTRUCTORS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new leaf with the specified entries
|
||||
///
|
||||
/// # Errors
|
||||
/// - Returns an error if 2 keys in `entries` map to a different leaf index
|
||||
/// - Returns an error if 1 or more keys in `entries` map to a leaf index
|
||||
/// different from `leaf_index`
|
||||
pub fn new(
|
||||
entries: Vec<(RpoDigest, Word)>,
|
||||
leaf_index: LeafIndex<SMT_DEPTH>,
|
||||
) -> Result<Self, SmtLeafError> {
|
||||
match entries.len() {
|
||||
0 => Ok(Self::new_empty(leaf_index)),
|
||||
1 => {
|
||||
let (key, value) = entries[0];
|
||||
|
||||
if LeafIndex::<SMT_DEPTH>::from(key) != leaf_index {
|
||||
return Err(SmtLeafError::SingleKeyInconsistentWithLeafIndex {
|
||||
key,
|
||||
leaf_index,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Self::new_single(key, value))
|
||||
}
|
||||
_ => {
|
||||
let leaf = Self::new_multiple(entries)?;
|
||||
|
||||
// `new_multiple()` checked that all keys map to the same leaf index. We still need
|
||||
// to ensure that that leaf index is `leaf_index`.
|
||||
if leaf.index() != leaf_index {
|
||||
Err(SmtLeafError::MultipleKeysInconsistentWithLeafIndex {
|
||||
leaf_index_from_keys: leaf.index(),
|
||||
leaf_index_supplied: leaf_index,
|
||||
})
|
||||
} else {
|
||||
Ok(leaf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new empty leaf with the specified leaf index
|
||||
pub fn new_empty(leaf_index: LeafIndex<SMT_DEPTH>) -> Self {
|
||||
Self::Empty(leaf_index)
|
||||
}
|
||||
|
||||
/// Returns a new single leaf with the specified entry. The leaf index is derived from the
|
||||
/// entry's key.
|
||||
pub fn new_single(key: RpoDigest, value: Word) -> Self {
|
||||
Self::Single((key, value))
|
||||
}
|
||||
|
||||
/// Returns a new single leaf with the specified entry. The leaf index is derived from the
|
||||
/// entries' keys.
|
||||
///
|
||||
/// # Errors
|
||||
/// - Returns an error if 2 keys in `entries` map to a different leaf index
|
||||
pub fn new_multiple(entries: Vec<(RpoDigest, Word)>) -> Result<Self, SmtLeafError> {
|
||||
if entries.len() < 2 {
|
||||
return Err(SmtLeafError::InvalidNumEntriesForMultiple(entries.len()));
|
||||
}
|
||||
|
||||
// Check that all keys map to the same leaf index
|
||||
{
|
||||
let mut keys = entries.iter().map(|(key, _)| key);
|
||||
|
||||
let first_key = *keys.next().expect("ensured at least 2 entries");
|
||||
let first_leaf_index: LeafIndex<SMT_DEPTH> = first_key.into();
|
||||
|
||||
for &next_key in keys {
|
||||
let next_leaf_index: LeafIndex<SMT_DEPTH> = next_key.into();
|
||||
|
||||
if next_leaf_index != first_leaf_index {
|
||||
return Err(SmtLeafError::InconsistentKeys {
|
||||
entries,
|
||||
key_1: first_key,
|
||||
key_2: next_key,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self::Multiple(entries))
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if the leaf is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
matches!(self, Self::Empty(_))
|
||||
}
|
||||
|
||||
/// Returns the leaf's index in the [`super::Smt`]
|
||||
pub fn index(&self) -> LeafIndex<SMT_DEPTH> {
|
||||
match self {
|
||||
SmtLeaf::Empty(leaf_index) => *leaf_index,
|
||||
SmtLeaf::Single((key, _)) => key.into(),
|
||||
SmtLeaf::Multiple(entries) => {
|
||||
// Note: All keys are guaranteed to have the same leaf index
|
||||
let (first_key, _) = entries[0];
|
||||
first_key.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of entries stored in the leaf
|
||||
pub fn num_entries(&self) -> u64 {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => 0,
|
||||
SmtLeaf::Single(_) => 1,
|
||||
SmtLeaf::Multiple(entries) => {
|
||||
entries.len().try_into().expect("shouldn't have more than 2^64 entries")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the hash of the leaf
|
||||
pub fn hash(&self) -> RpoDigest {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => EMPTY_WORD.into(),
|
||||
SmtLeaf::Single((key, value)) => Rpo256::merge(&[*key, value.into()]),
|
||||
SmtLeaf::Multiple(kvs) => {
|
||||
let elements: Vec<Felt> = kvs.iter().copied().flat_map(kv_to_elements).collect();
|
||||
Rpo256::hash_elements(&elements)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the key-value pairs in the leaf
|
||||
pub fn entries(&self) -> Vec<&(RpoDigest, Word)> {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => Vec::new(),
|
||||
SmtLeaf::Single(kv_pair) => vec![kv_pair],
|
||||
SmtLeaf::Multiple(kv_pairs) => kv_pairs.iter().collect(),
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Converts a leaf to a list of field elements
|
||||
pub fn to_elements(&self) -> Vec<Felt> {
|
||||
self.clone().into_elements()
|
||||
}
|
||||
|
||||
/// Converts a leaf to a list of field elements
|
||||
pub fn into_elements(self) -> Vec<Felt> {
|
||||
self.into_entries().into_iter().flat_map(kv_to_elements).collect()
|
||||
}
|
||||
|
||||
/// Converts a leaf the key-value pairs in the leaf
|
||||
pub fn into_entries(self) -> Vec<(RpoDigest, Word)> {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => Vec::new(),
|
||||
SmtLeaf::Single(kv_pair) => vec![kv_pair],
|
||||
SmtLeaf::Multiple(kv_pairs) => kv_pairs,
|
||||
}
|
||||
}
|
||||
|
||||
// HELPERS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the value associated with `key` in the leaf, or `None` if `key` maps to another leaf.
|
||||
pub(super) fn get_value(&self, key: &RpoDigest) -> Option<Word> {
|
||||
// Ensure that `key` maps to this leaf
|
||||
if self.index() != key.into() {
|
||||
return None;
|
||||
}
|
||||
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => Some(EMPTY_WORD),
|
||||
SmtLeaf::Single((key_in_leaf, value_in_leaf)) => {
|
||||
if key == key_in_leaf {
|
||||
Some(*value_in_leaf)
|
||||
} else {
|
||||
Some(EMPTY_WORD)
|
||||
}
|
||||
}
|
||||
SmtLeaf::Multiple(kv_pairs) => {
|
||||
for (key_in_leaf, value_in_leaf) in kv_pairs {
|
||||
if key == key_in_leaf {
|
||||
return Some(*value_in_leaf);
|
||||
}
|
||||
}
|
||||
|
||||
Some(EMPTY_WORD)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Inserts key-value pair into the leaf; returns the previous value associated with `key`, if
|
||||
/// any.
|
||||
///
|
||||
/// The caller needs to ensure that `key` has the same leaf index as all other keys in the leaf
|
||||
pub(super) fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => {
|
||||
*self = SmtLeaf::new_single(key, value);
|
||||
None
|
||||
}
|
||||
SmtLeaf::Single(kv_pair) => {
|
||||
if kv_pair.0 == key {
|
||||
// the key is already in this leaf. Update the value and return the previous
|
||||
// value
|
||||
let old_value = kv_pair.1;
|
||||
kv_pair.1 = value;
|
||||
Some(old_value)
|
||||
} else {
|
||||
// Another entry is present in this leaf. Transform the entry into a list
|
||||
// entry, and make sure the key-value pairs are sorted by key
|
||||
let mut pairs = vec![*kv_pair, (key, value)];
|
||||
pairs.sort_by(|(key_1, _), (key_2, _)| cmp_keys(*key_1, *key_2));
|
||||
|
||||
*self = SmtLeaf::Multiple(pairs);
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
SmtLeaf::Multiple(kv_pairs) => {
|
||||
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
|
||||
Ok(pos) => {
|
||||
let old_value = kv_pairs[pos].1;
|
||||
kv_pairs[pos].1 = value;
|
||||
|
||||
Some(old_value)
|
||||
}
|
||||
Err(pos) => {
|
||||
kv_pairs.insert(pos, (key, value));
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes key-value pair from the leaf stored at key; returns the previous value associated
|
||||
/// with `key`, if any. Also returns an `is_empty` flag, indicating whether the leaf became
|
||||
/// empty, and must be removed from the data structure it is contained in.
|
||||
pub(super) fn remove(&mut self, key: RpoDigest) -> (Option<Word>, bool) {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => (None, false),
|
||||
SmtLeaf::Single((key_at_leaf, value_at_leaf)) => {
|
||||
if *key_at_leaf == key {
|
||||
// our key was indeed stored in the leaf, so we return the value that was stored
|
||||
// in it, and indicate that the leaf should be removed
|
||||
let old_value = *value_at_leaf;
|
||||
|
||||
// Note: this is not strictly needed, since the caller is expected to drop this
|
||||
// `SmtLeaf` object.
|
||||
*self = SmtLeaf::new_empty(key.into());
|
||||
|
||||
(Some(old_value), true)
|
||||
} else {
|
||||
// another key is stored at leaf; nothing to update
|
||||
(None, false)
|
||||
}
|
||||
}
|
||||
SmtLeaf::Multiple(kv_pairs) => {
|
||||
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
|
||||
Ok(pos) => {
|
||||
let old_value = kv_pairs[pos].1;
|
||||
|
||||
kv_pairs.remove(pos);
|
||||
debug_assert!(!kv_pairs.is_empty());
|
||||
|
||||
if kv_pairs.len() == 1 {
|
||||
// convert the leaf into `Single`
|
||||
*self = SmtLeaf::Single(kv_pairs[0]);
|
||||
}
|
||||
|
||||
(Some(old_value), false)
|
||||
}
|
||||
Err(_) => {
|
||||
// other keys are stored at leaf; nothing to update
|
||||
(None, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for SmtLeaf {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
// Write: num entries
|
||||
self.num_entries().write_into(target);
|
||||
|
||||
// Write: leaf index
|
||||
let leaf_index: u64 = self.index().value();
|
||||
leaf_index.write_into(target);
|
||||
|
||||
// Write: entries
|
||||
for (key, value) in self.entries() {
|
||||
key.write_into(target);
|
||||
value.write_into(target);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SmtLeaf {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
// Read: num entries
|
||||
let num_entries = source.read_u64()?;
|
||||
|
||||
// Read: leaf index
|
||||
let leaf_index: LeafIndex<SMT_DEPTH> = {
|
||||
let value = source.read_u64()?;
|
||||
LeafIndex::new_max_depth(value)
|
||||
};
|
||||
|
||||
// Read: entries
|
||||
let mut entries: Vec<(RpoDigest, Word)> = Vec::new();
|
||||
for _ in 0..num_entries {
|
||||
let key: RpoDigest = source.read()?;
|
||||
let value: Word = source.read()?;
|
||||
|
||||
entries.push((key, value));
|
||||
}
|
||||
|
||||
Self::new(entries, leaf_index)
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Converts a key-value tuple to an iterator of `Felt`s
|
||||
fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt> {
|
||||
let key_elements = key.into_iter();
|
||||
let value_elements = value.into_iter();
|
||||
|
||||
key_elements.chain(value_elements)
|
||||
}
|
||||
|
||||
/// Compares two keys, compared element-by-element using their integer representations starting with
|
||||
/// the most significant element.
|
||||
fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering {
|
||||
for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() {
|
||||
let v1 = v1.as_int();
|
||||
let v2 = v2.as_int();
|
||||
if v1 != v2 {
|
||||
return v1.cmp(&v2);
|
||||
}
|
||||
}
|
||||
|
||||
Ordering::Equal
|
||||
}
|
||||
296
src/merkle/smt/full/mod.rs
Normal file
296
src/merkle/smt/full/mod.rs
Normal file
@@ -0,0 +1,296 @@
|
||||
use super::{
|
||||
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath,
|
||||
NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
|
||||
};
|
||||
use alloc::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
mod error;
|
||||
pub use error::{SmtLeafError, SmtProofError};
|
||||
|
||||
mod leaf;
|
||||
pub use leaf::SmtLeaf;
|
||||
|
||||
mod proof;
|
||||
pub use proof::SmtProof;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
pub const SMT_DEPTH: u8 = 64;
|
||||
|
||||
// SMT
|
||||
// ================================================================================================
|
||||
|
||||
/// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented
|
||||
/// by 4 field elements.
|
||||
///
|
||||
/// All leaves sit at depth 64. The most significant element of the key is used to identify the leaf to
|
||||
/// which the key maps.
|
||||
///
|
||||
/// A leaf is either empty, or holds one or more key-value pairs. An empty leaf hashes to the empty
|
||||
/// word. Otherwise, a leaf hashes to the hash of its key-value pairs, ordered by key first, value
|
||||
/// second.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct Smt {
|
||||
root: RpoDigest,
|
||||
leaves: BTreeMap<u64, SmtLeaf>,
|
||||
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
|
||||
}
|
||||
|
||||
impl Smt {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// The default value used to compute the hash of empty leaves
|
||||
pub const EMPTY_VALUE: Word = <Self as SparseMerkleTree<SMT_DEPTH>>::EMPTY_VALUE;
|
||||
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [Smt].
|
||||
///
|
||||
/// All leaves in the returned tree are set to [Self::EMPTY_VALUE].
|
||||
pub fn new() -> Self {
|
||||
let root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
|
||||
|
||||
Self {
|
||||
root,
|
||||
leaves: BTreeMap::new(),
|
||||
inner_nodes: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new [Smt] instantiated with leaves set as specified by the provided entries.
|
||||
///
|
||||
/// All leaves omitted from the entries list are set to [Self::EMPTY_VALUE].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the provided entries contain multiple values for the same key.
|
||||
pub fn with_entries(
|
||||
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
// create an empty tree
|
||||
let mut tree = Self::new();
|
||||
|
||||
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
|
||||
// entries with the empty value need additional tracking.
|
||||
let mut key_set_to_zero = BTreeSet::new();
|
||||
|
||||
for (key, value) in entries {
|
||||
let old_value = tree.insert(key, value);
|
||||
|
||||
if old_value != EMPTY_WORD || key_set_to_zero.contains(&key) {
|
||||
return Err(MerkleError::DuplicateValuesForIndex(
|
||||
LeafIndex::<SMT_DEPTH>::from(key).value(),
|
||||
));
|
||||
}
|
||||
|
||||
if value == EMPTY_WORD {
|
||||
key_set_to_zero.insert(key);
|
||||
};
|
||||
}
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the depth of the tree
|
||||
pub const fn depth(&self) -> u8 {
|
||||
SMT_DEPTH
|
||||
}
|
||||
|
||||
/// Returns the root of the tree
|
||||
pub fn root(&self) -> RpoDigest {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::root(self)
|
||||
}
|
||||
|
||||
/// Returns the leaf to which `key` maps
|
||||
pub fn get_leaf(&self, key: &RpoDigest) -> SmtLeaf {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::get_leaf(self, key)
|
||||
}
|
||||
|
||||
/// Returns the value associated with `key`
|
||||
pub fn get_value(&self, key: &RpoDigest) -> Word {
|
||||
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
|
||||
|
||||
match self.leaves.get(&leaf_pos) {
|
||||
Some(leaf) => leaf.get_value(key).unwrap_or_default(),
|
||||
None => EMPTY_WORD,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
|
||||
/// path to the leaf, as well as the leaf itself.
|
||||
pub fn open(&self, key: &RpoDigest) -> SmtProof {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::open(self, key)
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over the leaves of this [Smt].
|
||||
pub fn leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
|
||||
self.leaves
|
||||
.iter()
|
||||
.map(|(leaf_index, leaf)| (LeafIndex::new_max_depth(*leaf_index), leaf))
|
||||
}
|
||||
|
||||
/// Returns an iterator over the key-value pairs of this [Smt].
|
||||
pub fn entries(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
self.leaves().flat_map(|(_, leaf)| leaf.entries())
|
||||
}
|
||||
|
||||
/// Returns an iterator over the inner nodes of this [Smt].
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.inner_nodes.values().map(|e| InnerNodeInfo {
|
||||
value: e.hash(),
|
||||
left: e.left,
|
||||
right: e.right,
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts a value at the specified key, returning the previous value associated with that key.
|
||||
/// Recall that by definition, any key that hasn't been updated is associated with
|
||||
/// [`Self::EMPTY_VALUE`].
|
||||
///
|
||||
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
|
||||
/// updating the root itself.
|
||||
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Word {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::insert(self, key, value)
|
||||
}
|
||||
|
||||
// HELPERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts `value` at leaf index pointed to by `key`. `value` is guaranteed to not be the empty
|
||||
/// value, such that this is indeed an insertion.
|
||||
fn perform_insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
||||
debug_assert_ne!(value, Self::EMPTY_VALUE);
|
||||
|
||||
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
|
||||
|
||||
match self.leaves.get_mut(&leaf_index.value()) {
|
||||
Some(leaf) => leaf.insert(key, value),
|
||||
None => {
|
||||
self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes key-value pair at leaf index pointed to by `key` if it exists.
|
||||
fn perform_remove(&mut self, key: RpoDigest) -> Option<Word> {
|
||||
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
|
||||
|
||||
if let Some(leaf) = self.leaves.get_mut(&leaf_index.value()) {
|
||||
let (old_value, is_empty) = leaf.remove(key);
|
||||
if is_empty {
|
||||
self.leaves.remove(&leaf_index.value());
|
||||
}
|
||||
old_value
|
||||
} else {
|
||||
// there's nothing stored at the leaf; nothing to update
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SparseMerkleTree<SMT_DEPTH> for Smt {
|
||||
type Key = RpoDigest;
|
||||
type Value = Word;
|
||||
type Leaf = SmtLeaf;
|
||||
type Opening = SmtProof;
|
||||
|
||||
const EMPTY_VALUE: Self::Value = EMPTY_WORD;
|
||||
|
||||
fn root(&self) -> RpoDigest {
|
||||
self.root
|
||||
}
|
||||
|
||||
fn set_root(&mut self, root: RpoDigest) {
|
||||
self.root = root;
|
||||
}
|
||||
|
||||
fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
|
||||
self.inner_nodes.get(&index).cloned().unwrap_or_else(|| {
|
||||
let node = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth() + 1);
|
||||
|
||||
InnerNode { left: *node, right: *node }
|
||||
})
|
||||
}
|
||||
|
||||
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
|
||||
self.inner_nodes.insert(index, inner_node);
|
||||
}
|
||||
|
||||
fn remove_inner_node(&mut self, index: NodeIndex) {
|
||||
let _ = self.inner_nodes.remove(&index);
|
||||
}
|
||||
|
||||
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> {
|
||||
// inserting an `EMPTY_VALUE` is equivalent to removing any value associated with `key`
|
||||
if value != Self::EMPTY_VALUE {
|
||||
self.perform_insert(key, value)
|
||||
} else {
|
||||
self.perform_remove(key)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_leaf(&self, key: &RpoDigest) -> Self::Leaf {
|
||||
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
|
||||
|
||||
match self.leaves.get(&leaf_pos) {
|
||||
Some(leaf) => leaf.clone(),
|
||||
None => SmtLeaf::new_empty(key.into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest {
|
||||
leaf.hash()
|
||||
}
|
||||
|
||||
fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex<SMT_DEPTH> {
|
||||
let most_significant_felt = key[3];
|
||||
LeafIndex::new_max_depth(most_significant_felt.as_int())
|
||||
}
|
||||
|
||||
fn path_and_leaf_to_opening(path: MerklePath, leaf: SmtLeaf) -> SmtProof {
|
||||
SmtProof::new_unchecked(path, leaf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Smt {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS
|
||||
// ================================================================================================
|
||||
|
||||
impl From<Word> for LeafIndex<SMT_DEPTH> {
|
||||
fn from(value: Word) -> Self {
|
||||
// We use the most significant `Felt` of a `Word` as the leaf index.
|
||||
Self::new_max_depth(value[3].as_int())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for LeafIndex<SMT_DEPTH> {
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
Word::from(value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for LeafIndex<SMT_DEPTH> {
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
Word::from(value).into()
|
||||
}
|
||||
}
|
||||
114
src/merkle/smt/full/proof.rs
Normal file
114
src/merkle/smt/full/proof.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use super::{MerklePath, RpoDigest, SmtLeaf, SmtProofError, Word, SMT_DEPTH};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
use alloc::string::ToString;
|
||||
|
||||
/// A proof which can be used to assert membership (or non-membership) of key-value pairs in a
|
||||
/// [`super::Smt`].
|
||||
///
|
||||
/// The proof consists of a Merkle path and leaf which describes the node located at the base of the
|
||||
/// path.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct SmtProof {
|
||||
path: MerklePath,
|
||||
leaf: SmtLeaf,
|
||||
}
|
||||
|
||||
impl SmtProof {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new instance of [`SmtProof`] instantiated from the specified path and leaf.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the path length is not [`SMT_DEPTH`].
|
||||
pub fn new(path: MerklePath, leaf: SmtLeaf) -> Result<Self, SmtProofError> {
|
||||
let depth: usize = SMT_DEPTH.into();
|
||||
if path.len() != depth {
|
||||
return Err(SmtProofError::InvalidPathLength(path.len()));
|
||||
}
|
||||
|
||||
Ok(Self { path, leaf })
|
||||
}
|
||||
|
||||
/// Returns a new instance of [`SmtProof`] instantiated from the specified path and leaf.
|
||||
///
|
||||
/// The length of the path is not checked. Reserved for internal use.
|
||||
pub(super) fn new_unchecked(path: MerklePath, leaf: SmtLeaf) -> Self {
|
||||
Self { path, leaf }
|
||||
}
|
||||
|
||||
// PROOF VERIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if a [`super::Smt`] with the specified root contains the provided
|
||||
/// key-value pair.
|
||||
///
|
||||
/// Note: this method cannot be used to assert non-membership. That is, if false is returned,
|
||||
/// it does not mean that the provided key-value pair is not in the tree.
|
||||
pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool {
|
||||
let maybe_value_in_leaf = self.leaf.get_value(key);
|
||||
|
||||
match maybe_value_in_leaf {
|
||||
Some(value_in_leaf) => {
|
||||
// The value must match for the proof to be valid
|
||||
if value_in_leaf != *value {
|
||||
return false;
|
||||
}
|
||||
|
||||
// make sure the Merkle path resolves to the correct root
|
||||
self.compute_root() == *root
|
||||
}
|
||||
// If the key maps to a different leaf, the proof cannot verify membership of `value`
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the value associated with the specific key according to this proof, or None if
|
||||
/// this proof does not contain a value for the specified key.
|
||||
///
|
||||
/// A key-value pair generated by using this method should pass the `verify_membership()` check.
|
||||
pub fn get(&self, key: &RpoDigest) -> Option<Word> {
|
||||
self.leaf.get_value(key)
|
||||
}
|
||||
|
||||
/// Computes the root of a [`super::Smt`] to which this proof resolves.
|
||||
pub fn compute_root(&self) -> RpoDigest {
|
||||
self.path
|
||||
.compute_root(self.leaf.index().value(), self.leaf.hash())
|
||||
.expect("failed to compute Merkle path root")
|
||||
}
|
||||
|
||||
/// Returns the proof's Merkle path.
|
||||
pub fn path(&self) -> &MerklePath {
|
||||
&self.path
|
||||
}
|
||||
|
||||
/// Returns the leaf associated with the proof.
|
||||
pub fn leaf(&self) -> &SmtLeaf {
|
||||
&self.leaf
|
||||
}
|
||||
|
||||
/// Consume the proof and returns its parts.
|
||||
pub fn into_parts(self) -> (MerklePath, SmtLeaf) {
|
||||
(self.path, self.leaf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for SmtProof {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
self.path.write_into(target);
|
||||
self.leaf.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SmtProof {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let path = MerklePath::read_from(source)?;
|
||||
let leaf = SmtLeaf::read_from(source)?;
|
||||
|
||||
Self::new(path, leaf).map_err(|err| DeserializationError::InvalidValue(err.to_string()))
|
||||
}
|
||||
}
|
||||
407
src/merkle/smt/full/tests.rs
Normal file
407
src/merkle/smt/full/tests.rs
Normal file
@@ -0,0 +1,407 @@
|
||||
use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
|
||||
use crate::{
|
||||
merkle::{EmptySubtreeRoots, MerkleStore},
|
||||
utils::{Deserializable, Serializable},
|
||||
Word, ONE, WORD_SIZE,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
// SMT
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// This test checks that inserting twice at the same key functions as expected. The test covers
|
||||
/// only the case where the key is alone in its leaf
|
||||
#[test]
|
||||
fn test_smt_insert_at_same_key() {
|
||||
let mut smt = Smt::default();
|
||||
let mut store: MerkleStore = MerkleStore::default();
|
||||
|
||||
assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
|
||||
|
||||
let key_1: RpoDigest = {
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
|
||||
};
|
||||
let key_1_index: NodeIndex = LeafIndex::<SMT_DEPTH>::from(key_1).into();
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [ONE + ONE; WORD_SIZE];
|
||||
|
||||
// Insert value 1 and ensure root is as expected
|
||||
{
|
||||
let leaf_node = build_empty_or_single_leaf_node(key_1, value_1);
|
||||
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
|
||||
|
||||
let old_value_1 = smt.insert(key_1, value_1);
|
||||
assert_eq!(old_value_1, EMPTY_WORD);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
}
|
||||
|
||||
// Insert value 2 and ensure root is as expected
|
||||
{
|
||||
let leaf_node = build_empty_or_single_leaf_node(key_1, value_2);
|
||||
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
|
||||
|
||||
let old_value_2 = smt.insert(key_1, value_2);
|
||||
assert_eq!(old_value_2, value_1);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
}
|
||||
}
|
||||
|
||||
/// This test checks that inserting twice at the same key functions as expected. The test covers
|
||||
/// only the case where the leaf type is `SmtLeaf::Multiple`
|
||||
#[test]
|
||||
fn test_smt_insert_at_same_key_2() {
|
||||
// The most significant u64 used for both keys (to ensure they map to the same leaf)
|
||||
let key_msb: u64 = 42;
|
||||
|
||||
let key_already_present: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(key_msb)]);
|
||||
let key_already_present_index: NodeIndex =
|
||||
LeafIndex::<SMT_DEPTH>::from(key_already_present).into();
|
||||
let value_already_present = [ONE + ONE + ONE; WORD_SIZE];
|
||||
|
||||
let mut smt =
|
||||
Smt::with_entries(core::iter::once((key_already_present, value_already_present))).unwrap();
|
||||
let mut store: MerkleStore = {
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
let leaf_node = build_empty_or_single_leaf_node(key_already_present, value_already_present);
|
||||
store
|
||||
.set_node(*EmptySubtreeRoots::entry(SMT_DEPTH, 0), key_already_present_index, leaf_node)
|
||||
.unwrap();
|
||||
store
|
||||
};
|
||||
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(key_msb)]);
|
||||
let key_1_index: NodeIndex = LeafIndex::<SMT_DEPTH>::from(key_1).into();
|
||||
|
||||
assert_eq!(key_1_index, key_already_present_index);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [ONE + ONE; WORD_SIZE];
|
||||
|
||||
// Insert value 1 and ensure root is as expected
|
||||
{
|
||||
// Note: key_1 comes first because it is smaller
|
||||
let leaf_node = build_multiple_leaf_node(&[
|
||||
(key_1, value_1),
|
||||
(key_already_present, value_already_present),
|
||||
]);
|
||||
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
|
||||
|
||||
let old_value_1 = smt.insert(key_1, value_1);
|
||||
assert_eq!(old_value_1, EMPTY_WORD);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
}
|
||||
|
||||
// Insert value 2 and ensure root is as expected
|
||||
{
|
||||
let leaf_node = build_multiple_leaf_node(&[
|
||||
(key_1, value_2),
|
||||
(key_already_present, value_already_present),
|
||||
]);
|
||||
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
|
||||
|
||||
let old_value_2 = smt.insert(key_1, value_2);
|
||||
assert_eq!(old_value_2, value_1);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
}
|
||||
}
|
||||
|
||||
/// This test ensures that the root of the tree is as expected when we add/remove 3 items at 3
|
||||
/// different keys. This also tests that the merkle paths produced are as expected.
|
||||
#[test]
|
||||
fn test_smt_insert_and_remove_multiple_values() {
|
||||
fn insert_values_and_assert_path(
|
||||
smt: &mut Smt,
|
||||
store: &mut MerkleStore,
|
||||
key_values: &[(RpoDigest, Word)],
|
||||
) {
|
||||
for &(key, value) in key_values {
|
||||
let key_index: NodeIndex = LeafIndex::<SMT_DEPTH>::from(key).into();
|
||||
|
||||
let leaf_node = build_empty_or_single_leaf_node(key, value);
|
||||
let tree_root = store.set_node(smt.root(), key_index, leaf_node).unwrap().root;
|
||||
|
||||
let _ = smt.insert(key, value);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
let expected_path = store.get_path(tree_root, key_index).unwrap();
|
||||
assert_eq!(smt.open(&key).into_parts().0, expected_path.path);
|
||||
}
|
||||
}
|
||||
let mut smt = Smt::default();
|
||||
let mut store: MerkleStore = MerkleStore::default();
|
||||
|
||||
assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
|
||||
|
||||
let key_1: RpoDigest = {
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
|
||||
};
|
||||
|
||||
let key_2: RpoDigest = {
|
||||
let raw = 0b_11111111_11111111_11111111_11111111_11111111_11111111_11111111_11111111_u64;
|
||||
|
||||
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
|
||||
};
|
||||
|
||||
let key_3: RpoDigest = {
|
||||
let raw = 0b_00000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000_u64;
|
||||
|
||||
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
|
||||
};
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [ONE + ONE; WORD_SIZE];
|
||||
let value_3 = [ONE + ONE + ONE; WORD_SIZE];
|
||||
|
||||
// Insert values in the tree
|
||||
let key_values = [(key_1, value_1), (key_2, value_2), (key_3, value_3)];
|
||||
insert_values_and_assert_path(&mut smt, &mut store, &key_values);
|
||||
|
||||
// Remove values from the tree
|
||||
let key_empty_values = [(key_1, EMPTY_WORD), (key_2, EMPTY_WORD), (key_3, EMPTY_WORD)];
|
||||
insert_values_and_assert_path(&mut smt, &mut store, &key_empty_values);
|
||||
|
||||
let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
|
||||
assert_eq!(smt.root(), empty_root);
|
||||
|
||||
// an empty tree should have no leaves or inner nodes
|
||||
assert!(smt.leaves.is_empty());
|
||||
assert!(smt.inner_nodes.is_empty());
|
||||
}
|
||||
|
||||
/// This tests that inserting the empty value does indeed remove the key-value contained at the
|
||||
/// leaf. We insert & remove 3 values at the same leaf to ensure that all cases are covered (empty,
|
||||
/// single, multiple).
|
||||
#[test]
|
||||
fn test_smt_removal() {
|
||||
let mut smt = Smt::default();
|
||||
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
|
||||
let key_3: RpoDigest =
|
||||
RpoDigest::from([3_u32.into(), 3_u32.into(), 3_u32.into(), Felt::new(raw)]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
|
||||
|
||||
// insert key-value 1
|
||||
{
|
||||
let old_value_1 = smt.insert(key_1, value_1);
|
||||
assert_eq!(old_value_1, EMPTY_WORD);
|
||||
|
||||
assert_eq!(smt.get_leaf(&key_1), SmtLeaf::Single((key_1, value_1)));
|
||||
}
|
||||
|
||||
// insert key-value 2
|
||||
{
|
||||
let old_value_2 = smt.insert(key_2, value_2);
|
||||
assert_eq!(old_value_2, EMPTY_WORD);
|
||||
|
||||
assert_eq!(
|
||||
smt.get_leaf(&key_2),
|
||||
SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)])
|
||||
);
|
||||
}
|
||||
|
||||
// insert key-value 3
|
||||
{
|
||||
let old_value_3 = smt.insert(key_3, value_3);
|
||||
assert_eq!(old_value_3, EMPTY_WORD);
|
||||
|
||||
assert_eq!(
|
||||
smt.get_leaf(&key_3),
|
||||
SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2), (key_3, value_3)])
|
||||
);
|
||||
}
|
||||
|
||||
// remove key 3
|
||||
{
|
||||
let old_value_3 = smt.insert(key_3, EMPTY_WORD);
|
||||
assert_eq!(old_value_3, value_3);
|
||||
|
||||
assert_eq!(
|
||||
smt.get_leaf(&key_3),
|
||||
SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)])
|
||||
);
|
||||
}
|
||||
|
||||
// remove key 2
|
||||
{
|
||||
let old_value_2 = smt.insert(key_2, EMPTY_WORD);
|
||||
assert_eq!(old_value_2, value_2);
|
||||
|
||||
assert_eq!(smt.get_leaf(&key_2), SmtLeaf::Single((key_1, value_1)));
|
||||
}
|
||||
|
||||
// remove key 1
|
||||
{
|
||||
let old_value_1 = smt.insert(key_1, EMPTY_WORD);
|
||||
assert_eq!(old_value_1, value_1);
|
||||
|
||||
assert_eq!(smt.get_leaf(&key_1), SmtLeaf::new_empty(key_1.into()));
|
||||
}
|
||||
}
|
||||
|
||||
/// Tests that 2 key-value pairs stored in the same leaf have the same path
|
||||
#[test]
|
||||
fn test_smt_path_to_keys_in_same_leaf_are_equal() {
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
|
||||
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
|
||||
|
||||
assert_eq!(smt.open(&key_1), smt.open(&key_2));
|
||||
}
|
||||
|
||||
/// Tests that an empty leaf hashes to the empty word
|
||||
#[test]
|
||||
fn test_empty_leaf_hash() {
|
||||
let smt = Smt::default();
|
||||
|
||||
let leaf = smt.get_leaf(&RpoDigest::default());
|
||||
assert_eq!(leaf.hash(), EMPTY_WORD.into());
|
||||
}
|
||||
|
||||
/// Tests that `get_value()` works as expected
|
||||
#[test]
|
||||
fn test_smt_get_value() {
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), 2_u32.into()]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
|
||||
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
|
||||
|
||||
let returned_value_1 = smt.get_value(&key_1);
|
||||
let returned_value_2 = smt.get_value(&key_2);
|
||||
|
||||
assert_eq!(value_1, returned_value_1);
|
||||
assert_eq!(value_2, returned_value_2);
|
||||
|
||||
// Check that a key with no inserted value returns the empty word
|
||||
let key_no_value =
|
||||
RpoDigest::from([42_u32.into(), 42_u32.into(), 42_u32.into(), 42_u32.into()]);
|
||||
|
||||
assert_eq!(EMPTY_WORD, smt.get_value(&key_no_value));
|
||||
}
|
||||
|
||||
/// Tests that `entries()` works as expected
|
||||
#[test]
|
||||
fn test_smt_entries() {
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), 2_u32.into()]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
|
||||
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
|
||||
|
||||
let mut entries = smt.entries();
|
||||
|
||||
// Note: for simplicity, we assume the order `(k1,v1), (k2,v2)`. If a new implementation
|
||||
// switches the order, it is OK to modify the order here as well.
|
||||
assert_eq!(&(key_1, value_1), entries.next().unwrap());
|
||||
assert_eq!(&(key_2, value_2), entries.next().unwrap());
|
||||
assert!(entries.next().is_none());
|
||||
}
|
||||
|
||||
// SMT LEAF
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_empty_smt_leaf_serialization() {
|
||||
let empty_leaf = SmtLeaf::new_empty(LeafIndex::new_max_depth(42));
|
||||
|
||||
let mut serialized = empty_leaf.to_bytes();
|
||||
// extend buffer with random bytes
|
||||
serialized.extend([1, 2, 3, 4, 5]);
|
||||
let deserialized = SmtLeaf::read_from_bytes(&serialized).unwrap();
|
||||
|
||||
assert_eq!(empty_leaf, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_smt_leaf_serialization() {
|
||||
let single_leaf = SmtLeaf::new_single(
|
||||
RpoDigest::from([10_u32.into(), 11_u32.into(), 12_u32.into(), 13_u32.into()]),
|
||||
[1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()],
|
||||
);
|
||||
|
||||
let mut serialized = single_leaf.to_bytes();
|
||||
// extend buffer with random bytes
|
||||
serialized.extend([1, 2, 3, 4, 5]);
|
||||
let deserialized = SmtLeaf::read_from_bytes(&serialized).unwrap();
|
||||
|
||||
assert_eq!(single_leaf, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_smt_leaf_serialization_success() {
|
||||
let multiple_leaf = SmtLeaf::new_multiple(vec![
|
||||
(
|
||||
RpoDigest::from([10_u32.into(), 11_u32.into(), 12_u32.into(), 13_u32.into()]),
|
||||
[1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()],
|
||||
),
|
||||
(
|
||||
RpoDigest::from([100_u32.into(), 101_u32.into(), 102_u32.into(), 13_u32.into()]),
|
||||
[11_u32.into(), 12_u32.into(), 13_u32.into(), 14_u32.into()],
|
||||
),
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let mut serialized = multiple_leaf.to_bytes();
|
||||
// extend buffer with random bytes
|
||||
serialized.extend([1, 2, 3, 4, 5]);
|
||||
let deserialized = SmtLeaf::read_from_bytes(&serialized).unwrap();
|
||||
|
||||
assert_eq!(multiple_leaf, deserialized);
|
||||
}
|
||||
|
||||
// HELPERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn build_empty_or_single_leaf_node(key: RpoDigest, value: Word) -> RpoDigest {
|
||||
if value == EMPTY_WORD {
|
||||
SmtLeaf::new_empty(key.into()).hash()
|
||||
} else {
|
||||
SmtLeaf::Single((key, value)).hash()
|
||||
}
|
||||
}
|
||||
|
||||
fn build_multiple_leaf_node(kv_pairs: &[(RpoDigest, Word)]) -> RpoDigest {
|
||||
let elements: Vec<Felt> = kv_pairs
|
||||
.iter()
|
||||
.flat_map(|(key, value)| {
|
||||
let key_elements = key.into_iter();
|
||||
let value_elements = (*value).into_iter();
|
||||
|
||||
key_elements.chain(value_elements)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Rpo256::hash_elements(&elements)
|
||||
}
|
||||
245
src/merkle/smt/mod.rs
Normal file
245
src/merkle/smt/mod.rs
Normal file
@@ -0,0 +1,245 @@
|
||||
use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
|
||||
use crate::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
Felt, Word, EMPTY_WORD,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
mod full;
|
||||
pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
|
||||
|
||||
mod simple;
|
||||
pub use simple::SimpleSmt;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Minimum supported depth.
|
||||
pub const SMT_MIN_DEPTH: u8 = 1;
|
||||
|
||||
/// Maximum supported depth.
|
||||
pub const SMT_MAX_DEPTH: u8 = 64;
|
||||
|
||||
// SPARSE MERKLE TREE
|
||||
// ================================================================================================
|
||||
|
||||
/// An abstract description of a sparse Merkle tree.
|
||||
///
|
||||
/// A sparse Merkle tree is a key-value map which also supports proving that a given value is indeed
|
||||
/// stored at a given key in the tree. It is viewed as always being fully populated. If a leaf's
|
||||
/// value was not explicitly set, then its value is the default value. Typically, the vast majority
|
||||
/// of leaves will store the default value (hence it is "sparse"), and therefore the internal
|
||||
/// representation of the tree will only keep track of the leaves that have a different value from
|
||||
/// the default.
|
||||
///
|
||||
/// All leaves sit at the same depth. The deeper the tree, the more leaves it has; but also the
|
||||
/// longer its proofs are - of exactly `log(depth)` size. A tree cannot have depth 0, since such a
|
||||
/// tree is just a single value, and is probably a programming mistake.
|
||||
///
|
||||
/// Every key maps to one leaf. If there are as many keys as there are leaves, then
|
||||
/// [Self::Leaf] should be the same type as [Self::Value], as is the case with
|
||||
/// [crate::merkle::SimpleSmt]. However, if there are more keys than leaves, then [`Self::Leaf`]
|
||||
/// must accomodate all keys that map to the same leaf.
|
||||
///
|
||||
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
|
||||
pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
|
||||
/// The type for a key
|
||||
type Key: Clone;
|
||||
/// The type for a value
|
||||
type Value: Clone + PartialEq;
|
||||
/// The type for a leaf
|
||||
type Leaf;
|
||||
/// The type for an opening (i.e. a "proof") of a leaf
|
||||
type Opening;
|
||||
|
||||
/// The default value used to compute the hash of empty leaves
|
||||
const EMPTY_VALUE: Self::Value;
|
||||
|
||||
// PROVIDED METHODS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
|
||||
/// path to the leaf, as well as the leaf itself.
|
||||
fn open(&self, key: &Self::Key) -> Self::Opening {
|
||||
let leaf = self.get_leaf(key);
|
||||
|
||||
let mut index: NodeIndex = {
|
||||
let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(key);
|
||||
leaf_index.into()
|
||||
};
|
||||
|
||||
let merkle_path = {
|
||||
let mut path = Vec::with_capacity(index.depth() as usize);
|
||||
for _ in 0..index.depth() {
|
||||
let is_right = index.is_value_odd();
|
||||
index.move_up();
|
||||
let InnerNode { left, right } = self.get_inner_node(index);
|
||||
let value = if is_right { left } else { right };
|
||||
path.push(value);
|
||||
}
|
||||
|
||||
MerklePath::new(path)
|
||||
};
|
||||
|
||||
Self::path_and_leaf_to_opening(merkle_path, leaf)
|
||||
}
|
||||
|
||||
/// Inserts a value at the specified key, returning the previous value associated with that key.
|
||||
/// Recall that by definition, any key that hasn't been updated is associated with
|
||||
/// [`Self::EMPTY_VALUE`].
|
||||
///
|
||||
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
|
||||
/// updating the root itself.
|
||||
fn insert(&mut self, key: Self::Key, value: Self::Value) -> Self::Value {
|
||||
let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE);
|
||||
|
||||
// if the old value and new value are the same, there is nothing to update
|
||||
if value == old_value {
|
||||
return value;
|
||||
}
|
||||
|
||||
let leaf = self.get_leaf(&key);
|
||||
let node_index = {
|
||||
let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(&key);
|
||||
leaf_index.into()
|
||||
};
|
||||
|
||||
self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf));
|
||||
|
||||
old_value
|
||||
}
|
||||
|
||||
/// Recomputes the branch nodes (including the root) from `index` all the way to the root.
|
||||
/// `node_hash_at_index` is the hash of the node stored at index.
|
||||
fn recompute_nodes_from_index_to_root(
|
||||
&mut self,
|
||||
mut index: NodeIndex,
|
||||
node_hash_at_index: RpoDigest,
|
||||
) {
|
||||
let mut node_hash = node_hash_at_index;
|
||||
for node_depth in (0..index.depth()).rev() {
|
||||
let is_right = index.is_value_odd();
|
||||
index.move_up();
|
||||
let InnerNode { left, right } = self.get_inner_node(index);
|
||||
let (left, right) = if is_right {
|
||||
(left, node_hash)
|
||||
} else {
|
||||
(node_hash, right)
|
||||
};
|
||||
node_hash = Rpo256::merge(&[left, right]);
|
||||
|
||||
if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
|
||||
// If a subtree is empty, when can remove the inner node, since it's equal to the
|
||||
// default value
|
||||
self.remove_inner_node(index)
|
||||
} else {
|
||||
self.insert_inner_node(index, InnerNode { left, right });
|
||||
}
|
||||
}
|
||||
self.set_root(node_hash);
|
||||
}
|
||||
|
||||
// REQUIRED METHODS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// The root of the tree
|
||||
fn root(&self) -> RpoDigest;
|
||||
|
||||
/// Sets the root of the tree
|
||||
fn set_root(&mut self, root: RpoDigest);
|
||||
|
||||
/// Retrieves an inner node at the given index
|
||||
fn get_inner_node(&self, index: NodeIndex) -> InnerNode;
|
||||
|
||||
/// Inserts an inner node at the given index
|
||||
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode);
|
||||
|
||||
/// Removes an inner node at the given index
|
||||
fn remove_inner_node(&mut self, index: NodeIndex);
|
||||
|
||||
/// Inserts a leaf node, and returns the value at the key if already exists
|
||||
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>;
|
||||
|
||||
/// Returns the leaf at the specified index.
|
||||
fn get_leaf(&self, key: &Self::Key) -> Self::Leaf;
|
||||
|
||||
/// Returns the hash of a leaf
|
||||
fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest;
|
||||
|
||||
/// Maps a key to a leaf index
|
||||
fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
|
||||
|
||||
/// Maps a (MerklePath, Self::Leaf) to an opening.
|
||||
///
|
||||
/// The length `path` is guaranteed to be equal to `DEPTH`
|
||||
fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening;
|
||||
}
|
||||
|
||||
// INNER NODE
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub(crate) struct InnerNode {
|
||||
pub left: RpoDigest,
|
||||
pub right: RpoDigest,
|
||||
}
|
||||
|
||||
impl InnerNode {
|
||||
pub fn hash(&self) -> RpoDigest {
|
||||
Rpo256::merge(&[self.left, self.right])
|
||||
}
|
||||
}
|
||||
|
||||
// LEAF INDEX
|
||||
// ================================================================================================
|
||||
|
||||
/// The index of a leaf, at a depth known at compile-time.
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct LeafIndex<const DEPTH: u8> {
|
||||
index: NodeIndex,
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> LeafIndex<DEPTH> {
|
||||
pub fn new(value: u64) -> Result<Self, MerkleError> {
|
||||
if DEPTH < SMT_MIN_DEPTH {
|
||||
return Err(MerkleError::DepthTooSmall(DEPTH));
|
||||
}
|
||||
|
||||
Ok(LeafIndex { index: NodeIndex::new(DEPTH, value)? })
|
||||
}
|
||||
|
||||
pub fn value(&self) -> u64 {
|
||||
self.index.value()
|
||||
}
|
||||
}
|
||||
|
||||
impl LeafIndex<SMT_MAX_DEPTH> {
|
||||
pub const fn new_max_depth(value: u64) -> Self {
|
||||
LeafIndex {
|
||||
index: NodeIndex::new_unchecked(SMT_MAX_DEPTH, value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> From<LeafIndex<DEPTH>> for NodeIndex {
|
||||
fn from(value: LeafIndex<DEPTH>) -> Self {
|
||||
value.index
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
|
||||
type Error = MerkleError;
|
||||
|
||||
fn try_from(node_index: NodeIndex) -> Result<Self, Self::Error> {
|
||||
if node_index.depth() != DEPTH {
|
||||
return Err(MerkleError::InvalidDepth {
|
||||
expected: DEPTH,
|
||||
provided: node_index.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
Self::new(node_index.value())
|
||||
}
|
||||
}
|
||||
305
src/merkle/smt/simple/mod.rs
Normal file
305
src/merkle/smt/simple/mod.rs
Normal file
@@ -0,0 +1,305 @@
|
||||
use super::{
|
||||
super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError,
|
||||
MerklePath, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, SMT_MAX_DEPTH,
|
||||
SMT_MIN_DEPTH,
|
||||
};
|
||||
use alloc::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// SPARSE MERKLE TREE
|
||||
// ================================================================================================
|
||||
|
||||
/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
|
||||
///
|
||||
/// The root of the tree is recomputed on each new leaf update.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct SimpleSmt<const DEPTH: u8> {
|
||||
root: RpoDigest,
|
||||
leaves: BTreeMap<u64, Word>,
|
||||
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> SimpleSmt<DEPTH> {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// The default value used to compute the hash of empty leaves
|
||||
pub const EMPTY_VALUE: Word = <Self as SparseMerkleTree<DEPTH>>::EMPTY_VALUE;
|
||||
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [SimpleSmt].
|
||||
///
|
||||
/// All leaves in the returned tree are set to [ZERO; 4].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if DEPTH is 0 or is greater than 64.
|
||||
pub fn new() -> Result<Self, MerkleError> {
|
||||
// validate the range of the depth.
|
||||
if DEPTH < SMT_MIN_DEPTH {
|
||||
return Err(MerkleError::DepthTooSmall(DEPTH));
|
||||
} else if SMT_MAX_DEPTH < DEPTH {
|
||||
return Err(MerkleError::DepthTooBig(DEPTH as u64));
|
||||
}
|
||||
|
||||
let root = *EmptySubtreeRoots::entry(DEPTH, 0);
|
||||
|
||||
Ok(Self {
|
||||
root,
|
||||
leaves: BTreeMap::new(),
|
||||
inner_nodes: BTreeMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a new [SimpleSmt] instantiated with leaves set as specified by the provided entries.
|
||||
///
|
||||
/// All leaves omitted from the entries list are set to [ZERO; 4].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - If the depth is 0 or is greater than 64.
|
||||
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
|
||||
/// - The provided entries contain multiple values for the same key.
|
||||
pub fn with_leaves(
|
||||
entries: impl IntoIterator<Item = (u64, Word)>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
// create an empty tree
|
||||
let mut tree = Self::new()?;
|
||||
|
||||
// compute the max number of entries. We use an upper bound of depth 63 because we consider
|
||||
// passing in a vector of size 2^64 infeasible.
|
||||
let max_num_entries = 2_usize.pow(DEPTH.min(63).into());
|
||||
|
||||
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
|
||||
// entries with the empty value need additional tracking.
|
||||
let mut key_set_to_zero = BTreeSet::new();
|
||||
|
||||
for (idx, (key, value)) in entries.into_iter().enumerate() {
|
||||
if idx >= max_num_entries {
|
||||
return Err(MerkleError::InvalidNumEntries(max_num_entries));
|
||||
}
|
||||
|
||||
let old_value = tree.insert(LeafIndex::<DEPTH>::new(key)?, value);
|
||||
|
||||
if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) {
|
||||
return Err(MerkleError::DuplicateValuesForIndex(key));
|
||||
}
|
||||
|
||||
if value == Self::EMPTY_VALUE {
|
||||
key_set_to_zero.insert(key);
|
||||
};
|
||||
}
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
|
||||
/// starting at index 0.
|
||||
pub fn with_contiguous_leaves(
|
||||
entries: impl IntoIterator<Item = Word>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
Self::with_leaves(
|
||||
entries
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, word)| (idx.try_into().expect("tree max depth is 2^8"), word)),
|
||||
)
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the depth of the tree
|
||||
pub const fn depth(&self) -> u8 {
|
||||
DEPTH
|
||||
}
|
||||
|
||||
/// Returns the root of the tree
|
||||
pub fn root(&self) -> RpoDigest {
|
||||
<Self as SparseMerkleTree<DEPTH>>::root(self)
|
||||
}
|
||||
|
||||
/// Returns the leaf at the specified index.
|
||||
pub fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
|
||||
<Self as SparseMerkleTree<DEPTH>>::get_leaf(self, key)
|
||||
}
|
||||
|
||||
/// Returns a node at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified index has depth set to 0 or the depth is greater than
|
||||
/// the depth of this Merkle tree.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
if index.is_root() {
|
||||
Err(MerkleError::DepthTooSmall(index.depth()))
|
||||
} else if index.depth() > DEPTH {
|
||||
Err(MerkleError::DepthTooBig(index.depth() as u64))
|
||||
} else if index.depth() == DEPTH {
|
||||
let leaf = self.get_leaf(&LeafIndex::<DEPTH>::try_from(index)?);
|
||||
|
||||
Ok(leaf.into())
|
||||
} else {
|
||||
Ok(self.get_inner_node(index).hash())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
|
||||
/// path to the leaf, as well as the leaf itself.
|
||||
pub fn open(&self, key: &LeafIndex<DEPTH>) -> ValuePath {
|
||||
<Self as SparseMerkleTree<DEPTH>>::open(self, key)
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over the leaves of this [SimpleSmt].
|
||||
pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
|
||||
self.leaves.iter().map(|(i, w)| (*i, w))
|
||||
}
|
||||
|
||||
/// Returns an iterator over the inner nodes of this [SimpleSmt].
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.inner_nodes.values().map(|e| InnerNodeInfo {
|
||||
value: e.hash(),
|
||||
left: e.left,
|
||||
right: e.right,
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts a value at the specified key, returning the previous value associated with that key.
|
||||
/// Recall that by definition, any key that hasn't been updated is associated with
|
||||
/// [`EMPTY_WORD`].
|
||||
///
|
||||
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
|
||||
/// updating the root itself.
|
||||
pub fn insert(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Word {
|
||||
<Self as SparseMerkleTree<DEPTH>>::insert(self, key, value)
|
||||
}
|
||||
|
||||
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
|
||||
/// computed as `DEPTH - SUBTREE_DEPTH`.
|
||||
///
|
||||
/// Returns the new root.
|
||||
pub fn set_subtree<const SUBTREE_DEPTH: u8>(
|
||||
&mut self,
|
||||
subtree_insertion_index: u64,
|
||||
subtree: SimpleSmt<SUBTREE_DEPTH>,
|
||||
) -> Result<RpoDigest, MerkleError> {
|
||||
if SUBTREE_DEPTH > DEPTH {
|
||||
return Err(MerkleError::InvalidSubtreeDepth {
|
||||
subtree_depth: SUBTREE_DEPTH,
|
||||
tree_depth: DEPTH,
|
||||
});
|
||||
}
|
||||
|
||||
// Verify that `subtree_insertion_index` is valid.
|
||||
let subtree_root_insertion_depth = DEPTH - SUBTREE_DEPTH;
|
||||
let subtree_root_index =
|
||||
NodeIndex::new(subtree_root_insertion_depth, subtree_insertion_index)?;
|
||||
|
||||
// add leaves
|
||||
// --------------
|
||||
|
||||
// The subtree's leaf indices live in their own context - i.e. a subtree of depth `d`. If we
|
||||
// insert the subtree at `subtree_insertion_index = 0`, then the subtree leaf indices are
|
||||
// valid as they are. However, consider what happens when we insert at
|
||||
// `subtree_insertion_index = 1`. The first leaf of our subtree now will have index `2^d`;
|
||||
// you can see it as there's a full subtree sitting on its left. In general, for
|
||||
// `subtree_insertion_index = i`, there are `i` subtrees sitting before the subtree we want
|
||||
// to insert, so we need to adjust all its leaves by `i * 2^d`.
|
||||
let leaf_index_shift: u64 = subtree_insertion_index * 2_u64.pow(SUBTREE_DEPTH.into());
|
||||
for (subtree_leaf_idx, leaf_value) in subtree.leaves() {
|
||||
let new_leaf_idx = leaf_index_shift + subtree_leaf_idx;
|
||||
debug_assert!(new_leaf_idx < 2_u64.pow(DEPTH.into()));
|
||||
|
||||
self.leaves.insert(new_leaf_idx, *leaf_value);
|
||||
}
|
||||
|
||||
// add subtree's branch nodes (which includes the root)
|
||||
// --------------
|
||||
for (branch_idx, branch_node) in subtree.inner_nodes {
|
||||
let new_branch_idx = {
|
||||
let new_depth = subtree_root_insertion_depth + branch_idx.depth();
|
||||
let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
|
||||
+ branch_idx.value();
|
||||
|
||||
NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
|
||||
};
|
||||
|
||||
self.inner_nodes.insert(new_branch_idx, branch_node);
|
||||
}
|
||||
|
||||
// recompute nodes starting from subtree root
|
||||
// --------------
|
||||
self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
|
||||
|
||||
Ok(self.root)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
|
||||
type Key = LeafIndex<DEPTH>;
|
||||
type Value = Word;
|
||||
type Leaf = Word;
|
||||
type Opening = ValuePath;
|
||||
|
||||
const EMPTY_VALUE: Self::Value = EMPTY_WORD;
|
||||
|
||||
fn root(&self) -> RpoDigest {
|
||||
self.root
|
||||
}
|
||||
|
||||
fn set_root(&mut self, root: RpoDigest) {
|
||||
self.root = root;
|
||||
}
|
||||
|
||||
fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
|
||||
self.inner_nodes.get(&index).cloned().unwrap_or_else(|| {
|
||||
let node = EmptySubtreeRoots::entry(DEPTH, index.depth() + 1);
|
||||
|
||||
InnerNode { left: *node, right: *node }
|
||||
})
|
||||
}
|
||||
|
||||
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
|
||||
self.inner_nodes.insert(index, inner_node);
|
||||
}
|
||||
|
||||
fn remove_inner_node(&mut self, index: NodeIndex) {
|
||||
let _ = self.inner_nodes.remove(&index);
|
||||
}
|
||||
|
||||
fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> {
|
||||
self.leaves.insert(key.value(), value)
|
||||
}
|
||||
|
||||
fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
|
||||
// the lookup in empty_hashes could fail only if empty_hashes were not built correctly
|
||||
// by the constructor as we check the depth of the lookup above.
|
||||
let leaf_pos = key.value();
|
||||
|
||||
match self.leaves.get(&leaf_pos) {
|
||||
Some(word) => *word,
|
||||
None => Word::from(*EmptySubtreeRoots::entry(DEPTH, DEPTH)),
|
||||
}
|
||||
}
|
||||
|
||||
fn hash_leaf(leaf: &Word) -> RpoDigest {
|
||||
// `SimpleSmt` takes the leaf value itself as the hash
|
||||
leaf.into()
|
||||
}
|
||||
|
||||
fn key_to_leaf_index(key: &LeafIndex<DEPTH>) -> LeafIndex<DEPTH> {
|
||||
*key
|
||||
}
|
||||
|
||||
fn path_and_leaf_to_opening(path: MerklePath, leaf: Word) -> ValuePath {
|
||||
(path, leaf).into()
|
||||
}
|
||||
}
|
||||
437
src/merkle/smt/simple/tests.rs
Normal file
437
src/merkle/smt/simple/tests.rs
Normal file
@@ -0,0 +1,437 @@
|
||||
use super::{
|
||||
super::{MerkleError, RpoDigest, SimpleSmt},
|
||||
NodeIndex,
|
||||
};
|
||||
use crate::{
|
||||
hash::rpo::Rpo256,
|
||||
merkle::{
|
||||
digests_to_words, int_to_leaf, int_to_node, smt::SparseMerkleTree, EmptySubtreeRoots,
|
||||
InnerNodeInfo, LeafIndex, MerkleTree,
|
||||
},
|
||||
Word, EMPTY_WORD,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
// TEST DATA
|
||||
// ================================================================================================
|
||||
|
||||
const KEYS4: [u64; 4] = [0, 1, 2, 3];
|
||||
const KEYS8: [u64; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
|
||||
|
||||
const VALUES4: [RpoDigest; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
|
||||
|
||||
const VALUES8: [RpoDigest; 8] = [
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
int_to_node(3),
|
||||
int_to_node(4),
|
||||
int_to_node(5),
|
||||
int_to_node(6),
|
||||
int_to_node(7),
|
||||
int_to_node(8),
|
||||
];
|
||||
|
||||
const ZERO_VALUES8: [Word; 8] = [int_to_leaf(0); 8];
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn build_empty_tree() {
|
||||
// tree of depth 3
|
||||
let smt = SimpleSmt::<3>::new().unwrap();
|
||||
let mt = MerkleTree::new(ZERO_VALUES8).unwrap();
|
||||
assert_eq!(mt.root(), smt.root());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_sparse_tree() {
|
||||
const DEPTH: u8 = 3;
|
||||
let mut smt = SimpleSmt::<DEPTH>::new().unwrap();
|
||||
let mut values = ZERO_VALUES8.to_vec();
|
||||
|
||||
// insert single value
|
||||
let key = 6;
|
||||
let new_node = int_to_leaf(7);
|
||||
values[key as usize] = new_node;
|
||||
let old_value = smt.insert(LeafIndex::<DEPTH>::new(key).unwrap(), new_node);
|
||||
let mt2 = MerkleTree::new(values.clone()).unwrap();
|
||||
assert_eq!(mt2.root(), smt.root());
|
||||
assert_eq!(
|
||||
mt2.get_path(NodeIndex::make(3, 6)).unwrap(),
|
||||
smt.open(&LeafIndex::<3>::new(6).unwrap()).path
|
||||
);
|
||||
assert_eq!(old_value, EMPTY_WORD);
|
||||
|
||||
// insert second value at distinct leaf branch
|
||||
let key = 2;
|
||||
let new_node = int_to_leaf(3);
|
||||
values[key as usize] = new_node;
|
||||
let old_value = smt.insert(LeafIndex::<DEPTH>::new(key).unwrap(), new_node);
|
||||
let mt3 = MerkleTree::new(values).unwrap();
|
||||
assert_eq!(mt3.root(), smt.root());
|
||||
assert_eq!(
|
||||
mt3.get_path(NodeIndex::make(3, 2)).unwrap(),
|
||||
smt.open(&LeafIndex::<3>::new(2).unwrap()).path
|
||||
);
|
||||
assert_eq!(old_value, EMPTY_WORD);
|
||||
}
|
||||
|
||||
/// Tests that [`SimpleSmt::with_contiguous_leaves`] works as expected
|
||||
#[test]
|
||||
fn build_contiguous_tree() {
|
||||
let tree_with_leaves =
|
||||
SimpleSmt::<2>::with_leaves([0, 1, 2, 3].into_iter().zip(digests_to_words(&VALUES4)))
|
||||
.unwrap();
|
||||
|
||||
let tree_with_contiguous_leaves =
|
||||
SimpleSmt::<2>::with_contiguous_leaves(digests_to_words(&VALUES4)).unwrap();
|
||||
|
||||
assert_eq!(tree_with_leaves, tree_with_contiguous_leaves);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_depth2_tree() {
|
||||
let tree =
|
||||
SimpleSmt::<2>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4))).unwrap();
|
||||
|
||||
// check internal structure
|
||||
let (root, node2, node3) = compute_internal_nodes();
|
||||
assert_eq!(root, tree.root());
|
||||
assert_eq!(node2, tree.get_node(NodeIndex::make(1, 0)).unwrap());
|
||||
assert_eq!(node3, tree.get_node(NodeIndex::make(1, 1)).unwrap());
|
||||
|
||||
// check get_node()
|
||||
assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
|
||||
assert_eq!(VALUES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
|
||||
assert_eq!(VALUES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
|
||||
assert_eq!(VALUES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
|
||||
|
||||
// check get_path(): depth 2
|
||||
assert_eq!(vec![VALUES4[1], node3], *tree.open(&LeafIndex::<2>::new(0).unwrap()).path);
|
||||
assert_eq!(vec![VALUES4[0], node3], *tree.open(&LeafIndex::<2>::new(1).unwrap()).path);
|
||||
assert_eq!(vec![VALUES4[3], node2], *tree.open(&LeafIndex::<2>::new(2).unwrap()).path);
|
||||
assert_eq!(vec![VALUES4[2], node2], *tree.open(&LeafIndex::<2>::new(3).unwrap()).path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inner_node_iterator() -> Result<(), MerkleError> {
|
||||
let tree =
|
||||
SimpleSmt::<2>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4))).unwrap();
|
||||
|
||||
// check depth 2
|
||||
assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
|
||||
assert_eq!(VALUES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
|
||||
assert_eq!(VALUES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
|
||||
assert_eq!(VALUES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
|
||||
|
||||
// get parent nodes
|
||||
let root = tree.root();
|
||||
let l1n0 = tree.get_node(NodeIndex::make(1, 0))?;
|
||||
let l1n1 = tree.get_node(NodeIndex::make(1, 1))?;
|
||||
let l2n0 = tree.get_node(NodeIndex::make(2, 0))?;
|
||||
let l2n1 = tree.get_node(NodeIndex::make(2, 1))?;
|
||||
let l2n2 = tree.get_node(NodeIndex::make(2, 2))?;
|
||||
let l2n3 = tree.get_node(NodeIndex::make(2, 3))?;
|
||||
|
||||
let nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
|
||||
let expected = vec![
|
||||
InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
|
||||
InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
|
||||
InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
|
||||
];
|
||||
assert_eq!(nodes, expected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_leaf() {
|
||||
const DEPTH: u8 = 3;
|
||||
let mut tree =
|
||||
SimpleSmt::<DEPTH>::with_leaves(KEYS8.into_iter().zip(digests_to_words(&VALUES8))).unwrap();
|
||||
|
||||
// update one value
|
||||
let key = 3;
|
||||
let new_node = int_to_leaf(9);
|
||||
let mut expected_values = digests_to_words(&VALUES8);
|
||||
expected_values[key] = new_node;
|
||||
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
|
||||
|
||||
let old_leaf = tree.insert(LeafIndex::<DEPTH>::new(key as u64).unwrap(), new_node);
|
||||
assert_eq!(expected_tree.root(), tree.root);
|
||||
assert_eq!(old_leaf, *VALUES8[key]);
|
||||
|
||||
// update another value
|
||||
let key = 6;
|
||||
let new_node = int_to_leaf(10);
|
||||
expected_values[key] = new_node;
|
||||
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
|
||||
|
||||
let old_leaf = tree.insert(LeafIndex::<DEPTH>::new(key as u64).unwrap(), new_node);
|
||||
assert_eq!(expected_tree.root(), tree.root);
|
||||
assert_eq!(old_leaf, *VALUES8[key]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn small_tree_opening_is_consistent() {
|
||||
// ____k____
|
||||
// / \
|
||||
// _i_ _j_
|
||||
// / \ / \
|
||||
// e f g h
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
let e = Rpo256::merge(&[a.into(), b.into()]);
|
||||
let f = Rpo256::merge(&[z.into(), z.into()]);
|
||||
let g = Rpo256::merge(&[c.into(), z.into()]);
|
||||
let h = Rpo256::merge(&[z.into(), d.into()]);
|
||||
|
||||
let i = Rpo256::merge(&[e, f]);
|
||||
let j = Rpo256::merge(&[g, h]);
|
||||
|
||||
let k = Rpo256::merge(&[i, j]);
|
||||
|
||||
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
|
||||
let tree = SimpleSmt::<3>::with_leaves(entries).unwrap();
|
||||
|
||||
assert_eq!(tree.root(), k);
|
||||
|
||||
let cases: Vec<(u64, Vec<RpoDigest>)> = vec![
|
||||
(0, vec![b.into(), f, j]),
|
||||
(1, vec![a.into(), f, j]),
|
||||
(4, vec![z.into(), h, i]),
|
||||
(7, vec![z.into(), g, i]),
|
||||
];
|
||||
|
||||
for (key, path) in cases {
|
||||
let opening = tree.open(&LeafIndex::<3>::new(key).unwrap());
|
||||
|
||||
assert_eq!(path, *opening.path);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplesmt_fail_on_duplicates() {
|
||||
let values = [
|
||||
// same key, same value
|
||||
(int_to_leaf(1), int_to_leaf(1)),
|
||||
// same key, different values
|
||||
(int_to_leaf(1), int_to_leaf(2)),
|
||||
// same key, set to zero
|
||||
(EMPTY_WORD, int_to_leaf(1)),
|
||||
// same key, re-set to zero
|
||||
(int_to_leaf(1), EMPTY_WORD),
|
||||
// same key, set to zero twice
|
||||
(EMPTY_WORD, EMPTY_WORD),
|
||||
];
|
||||
|
||||
for (first, second) in values.iter() {
|
||||
// consecutive
|
||||
let entries = [(1, *first), (1, *second)];
|
||||
let smt = SimpleSmt::<64>::with_leaves(entries);
|
||||
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||
|
||||
// not consecutive
|
||||
let entries = [(1, *first), (5, int_to_leaf(5)), (1, *second)];
|
||||
let smt = SimpleSmt::<64>::with_leaves(entries);
|
||||
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn with_no_duplicates_empty_node() {
|
||||
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2))];
|
||||
let smt = SimpleSmt::<64>::with_leaves(entries);
|
||||
assert!(smt.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplesmt_with_leaves_nonexisting_leaf() {
|
||||
// TESTING WITH EMPTY WORD
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
// Depth 1 has 2 leaf. Position is 0-indexed, position 2 doesn't exist.
|
||||
let leaves = [(2, EMPTY_WORD)];
|
||||
let result = SimpleSmt::<1>::with_leaves(leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let leaves = [(4, EMPTY_WORD)];
|
||||
let result = SimpleSmt::<2>::with_leaves(leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||
let leaves = [(8, EMPTY_WORD)];
|
||||
let result = SimpleSmt::<3>::with_leaves(leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// TESTING WITH A VALUE
|
||||
// --------------------------------------------------------------------------------------------
|
||||
let value = int_to_node(1);
|
||||
|
||||
// Depth 1 has 2 leaves. Position is 0-indexed, position 2 doesn't exist.
|
||||
let leaves = [(2, *value)];
|
||||
let result = SimpleSmt::<1>::with_leaves(leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let leaves = [(4, *value)];
|
||||
let result = SimpleSmt::<2>::with_leaves(leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||
let leaves = [(8, *value)];
|
||||
let result = SimpleSmt::<3>::with_leaves(leaves);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplesmt_set_subtree() {
|
||||
// Final Tree:
|
||||
//
|
||||
// ____k____
|
||||
// / \
|
||||
// _i_ _j_
|
||||
// / \ / \
|
||||
// e f g h
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
let e = Rpo256::merge(&[a.into(), b.into()]);
|
||||
let f = Rpo256::merge(&[z.into(), z.into()]);
|
||||
let g = Rpo256::merge(&[c.into(), z.into()]);
|
||||
let h = Rpo256::merge(&[z.into(), d.into()]);
|
||||
|
||||
let i = Rpo256::merge(&[e, f]);
|
||||
let j = Rpo256::merge(&[g, h]);
|
||||
|
||||
let k = Rpo256::merge(&[i, j]);
|
||||
|
||||
// subtree:
|
||||
// g
|
||||
// / \
|
||||
// c 0
|
||||
let subtree = {
|
||||
let entries = vec![(0, c)];
|
||||
SimpleSmt::<1>::with_leaves(entries).unwrap()
|
||||
};
|
||||
|
||||
// insert subtree
|
||||
const TREE_DEPTH: u8 = 3;
|
||||
let tree = {
|
||||
let entries = vec![(0, a), (1, b), (7, d)];
|
||||
let mut tree = SimpleSmt::<TREE_DEPTH>::with_leaves(entries).unwrap();
|
||||
|
||||
tree.set_subtree(2, subtree).unwrap();
|
||||
|
||||
tree
|
||||
};
|
||||
|
||||
assert_eq!(tree.root(), k);
|
||||
assert_eq!(tree.get_leaf(&LeafIndex::<TREE_DEPTH>::new(4).unwrap()), c);
|
||||
assert_eq!(tree.get_inner_node(NodeIndex::new_unchecked(2, 2)).hash(), g);
|
||||
}
|
||||
|
||||
/// Ensures that an invalid input node index into `set_subtree()` incurs no mutation of the tree
|
||||
#[test]
|
||||
fn test_simplesmt_set_subtree_unchanged_for_wrong_index() {
|
||||
// Final Tree:
|
||||
//
|
||||
// ____k____
|
||||
// / \
|
||||
// _i_ _j_
|
||||
// / \ / \
|
||||
// e f g h
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
// subtree:
|
||||
// g
|
||||
// / \
|
||||
// c 0
|
||||
let subtree = {
|
||||
let entries = vec![(0, c)];
|
||||
SimpleSmt::<1>::with_leaves(entries).unwrap()
|
||||
};
|
||||
|
||||
let mut tree = {
|
||||
let entries = vec![(0, a), (1, b), (7, d)];
|
||||
SimpleSmt::<3>::with_leaves(entries).unwrap()
|
||||
};
|
||||
let tree_root_before_insertion = tree.root();
|
||||
|
||||
// insert subtree
|
||||
assert!(tree.set_subtree(500, subtree).is_err());
|
||||
|
||||
assert_eq!(tree.root(), tree_root_before_insertion);
|
||||
}
|
||||
|
||||
/// We insert an empty subtree that has the same depth as the original tree
|
||||
#[test]
|
||||
fn test_simplesmt_set_subtree_entire_tree() {
|
||||
// Initial Tree:
|
||||
//
|
||||
// ____k____
|
||||
// / \
|
||||
// _i_ _j_
|
||||
// / \ / \
|
||||
// e f g h
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
// subtree: E3
|
||||
const DEPTH: u8 = 3;
|
||||
let subtree = { SimpleSmt::<DEPTH>::with_leaves(Vec::new()).unwrap() };
|
||||
assert_eq!(subtree.root(), *EmptySubtreeRoots::entry(DEPTH, 0));
|
||||
|
||||
// insert subtree
|
||||
let mut tree = {
|
||||
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
|
||||
SimpleSmt::<3>::with_leaves(entries).unwrap()
|
||||
};
|
||||
|
||||
tree.set_subtree(0, subtree).unwrap();
|
||||
|
||||
assert_eq!(tree.root(), *EmptySubtreeRoots::entry(DEPTH, 0));
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn compute_internal_nodes() -> (RpoDigest, RpoDigest, RpoDigest) {
|
||||
let node2 = Rpo256::merge(&[VALUES4[0], VALUES4[1]]);
|
||||
let node3 = Rpo256::merge(&[VALUES4[2], VALUES4[3]]);
|
||||
let root = Rpo256::merge(&[node2, node3]);
|
||||
|
||||
(root, node2, node3)
|
||||
}
|
||||
@@ -1,11 +1,15 @@
|
||||
use super::{
|
||||
mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, KvMap, MerkleError, MerklePath,
|
||||
MerkleStoreDelta, MerkleTree, NodeIndex, PartialMerkleTree, RecordingMap, RootPath, Rpo256,
|
||||
RpoDigest, SimpleSmt, TieredSmt, TryApplyDiff, ValuePath, Vec, EMPTY_WORD,
|
||||
};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
use alloc::{collections::BTreeMap, vec::Vec};
|
||||
use core::borrow::Borrow;
|
||||
|
||||
use super::{
|
||||
mmr::Mmr, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, MerkleTree, NodeIndex,
|
||||
PartialMerkleTree, RootPath, Rpo256, RpoDigest, SimpleSmt, Smt, ValuePath,
|
||||
};
|
||||
use crate::utils::{
|
||||
collections::{KvMap, RecordingMap},
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
@@ -173,7 +177,7 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
|
||||
// the path is computed from root to leaf, so it must be reversed
|
||||
path.reverse();
|
||||
|
||||
Ok(ValuePath::new(hash, path))
|
||||
Ok(ValuePath::new(hash, MerklePath::new(path)))
|
||||
}
|
||||
|
||||
// LEAF TRAVERSAL
|
||||
@@ -361,9 +365,6 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
|
||||
|
||||
// if the node is not in the store assume it is a leaf
|
||||
} else {
|
||||
// assert that if we have a leaf that is not at the max depth then it must be
|
||||
// at the depth of one of the tiers of an TSMT.
|
||||
debug_assert!(TieredSmt::TIER_DEPTHS[..3].contains(&index.depth()));
|
||||
return Some((index, node_hash));
|
||||
}
|
||||
}
|
||||
@@ -490,8 +491,15 @@ impl<T: KvMap<RpoDigest, StoreNode>> From<&MerkleTree> for MerkleStore<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&SimpleSmt> for MerkleStore<T> {
|
||||
fn from(value: &SimpleSmt) -> Self {
|
||||
impl<T: KvMap<RpoDigest, StoreNode>, const DEPTH: u8> From<&SimpleSmt<DEPTH>> for MerkleStore<T> {
|
||||
fn from(value: &SimpleSmt<DEPTH>) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&Smt> for MerkleStore<T> {
|
||||
fn from(value: &Smt) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
@@ -504,13 +512,6 @@ impl<T: KvMap<RpoDigest, StoreNode>> From<&Mmr> for MerkleStore<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&TieredSmt> for MerkleStore<T> {
|
||||
fn from(value: &TieredSmt) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&PartialMerkleTree> for MerkleStore<T> {
|
||||
fn from(value: &PartialMerkleTree) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
@@ -550,39 +551,6 @@ impl<T: KvMap<RpoDigest, StoreNode>> Extend<InnerNodeInfo> for MerkleStore<T> {
|
||||
}
|
||||
}
|
||||
|
||||
// DiffT & ApplyDiffT TRAIT IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> TryApplyDiff<RpoDigest, StoreNode> for MerkleStore<T> {
|
||||
type Error = MerkleError;
|
||||
type DiffType = MerkleStoreDelta;
|
||||
|
||||
fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), MerkleError> {
|
||||
for (root, delta) in diff.0 {
|
||||
let mut root = root;
|
||||
for cleared_slot in delta.cleared_slots() {
|
||||
root = self
|
||||
.set_node(
|
||||
root,
|
||||
NodeIndex::new(delta.depth(), *cleared_slot)?,
|
||||
EMPTY_WORD.into(),
|
||||
)?
|
||||
.root;
|
||||
}
|
||||
for (updated_slot, updated_value) in delta.updated_slots() {
|
||||
root = self
|
||||
.set_node(
|
||||
root,
|
||||
NodeIndex::new(delta.depth(), *updated_slot)?,
|
||||
(*updated_value).into(),
|
||||
)?
|
||||
.root;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
use seq_macro::seq;
|
||||
|
||||
use super::{
|
||||
DefaultMerkleStore as MerkleStore, EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex,
|
||||
PartialMerkleTree, RecordingMerkleStore, RpoDigest,
|
||||
PartialMerkleTree, RecordingMerkleStore, Rpo256, RpoDigest,
|
||||
};
|
||||
use crate::{
|
||||
hash::rpo::Rpo256,
|
||||
merkle::{digests_to_words, int_to_leaf, int_to_node, MerkleTree, SimpleSmt},
|
||||
merkle::{
|
||||
digests_to_words, int_to_leaf, int_to_node, LeafIndex, MerkleTree, SimpleSmt, SMT_MAX_DEPTH,
|
||||
},
|
||||
Felt, Word, ONE, WORD_SIZE, ZERO,
|
||||
};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use super::{Deserializable, Serializable};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use std::error::Error;
|
||||
use {
|
||||
super::{Deserializable, Serializable},
|
||||
alloc::boxed::Box,
|
||||
std::error::Error,
|
||||
};
|
||||
|
||||
// TEST DATA
|
||||
// ================================================================================================
|
||||
@@ -104,7 +108,7 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
|
||||
"node 3 must be the same for both MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
// STORE MERKLE PATH MATCHS ==============================================================
|
||||
// STORE MERKLE PATH MATCHES ==============================================================
|
||||
// assert the merkle path returned by the store is the same as the one in the tree
|
||||
let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 0)).unwrap();
|
||||
assert_eq!(
|
||||
@@ -174,12 +178,12 @@ fn test_leaf_paths_for_empty_trees() -> Result<(), MerkleError> {
|
||||
// Starts at 1 because leafs are not included in the store.
|
||||
// Ends at 64 because it is not possible to represent an index of a depth greater than 64,
|
||||
// because a u64 is used to index the leaf.
|
||||
for depth in 1..64 {
|
||||
let smt = SimpleSmt::new(depth)?;
|
||||
seq!(DEPTH in 1_u8..64_u8 {
|
||||
let smt = SimpleSmt::<DEPTH>::new()?;
|
||||
|
||||
let index = NodeIndex::make(depth, 0);
|
||||
let index = NodeIndex::make(DEPTH, 0);
|
||||
let store_path = store.get_path(smt.root(), index)?;
|
||||
let smt_path = smt.get_path(index)?;
|
||||
let smt_path = smt.open(&LeafIndex::<DEPTH>::new(0)?).path;
|
||||
assert_eq!(
|
||||
store_path.value,
|
||||
RpoDigest::default(),
|
||||
@@ -190,11 +194,12 @@ fn test_leaf_paths_for_empty_trees() -> Result<(), MerkleError> {
|
||||
"the returned merkle path does not match the computed values"
|
||||
);
|
||||
assert_eq!(
|
||||
store_path.path.compute_root(depth.into(), RpoDigest::default()).unwrap(),
|
||||
store_path.path.compute_root(DEPTH.into(), RpoDigest::default()).unwrap(),
|
||||
smt.root(),
|
||||
"computed root from the path must match the empty tree root"
|
||||
);
|
||||
}
|
||||
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -211,7 +216,7 @@ fn test_get_invalid_node() {
|
||||
fn test_add_sparse_merkle_tree_one_level() -> Result<(), MerkleError> {
|
||||
let keys2: [u64; 2] = [0, 1];
|
||||
let leaves2: [Word; 2] = [int_to_leaf(1), int_to_leaf(2)];
|
||||
let smt = SimpleSmt::with_leaves(1, keys2.into_iter().zip(leaves2.into_iter())).unwrap();
|
||||
let smt = SimpleSmt::<1>::with_leaves(keys2.into_iter().zip(leaves2)).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
|
||||
let idx = NodeIndex::make(1, 0);
|
||||
@@ -227,38 +232,36 @@ fn test_add_sparse_merkle_tree_one_level() -> Result<(), MerkleError> {
|
||||
|
||||
#[test]
|
||||
fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
|
||||
let smt = SimpleSmt::with_leaves(
|
||||
SimpleSmt::MAX_DEPTH,
|
||||
KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()),
|
||||
)
|
||||
.unwrap();
|
||||
let smt =
|
||||
SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4)))
|
||||
.unwrap();
|
||||
|
||||
let store = MerkleStore::from(&smt);
|
||||
|
||||
// STORE LEAVES ARE CORRECT ==============================================================
|
||||
// checks the leaves in the store corresponds to the expected values
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 0)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)),
|
||||
Ok(VALUES4[0]),
|
||||
"node 0 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 1)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)),
|
||||
Ok(VALUES4[1]),
|
||||
"node 1 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 2)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)),
|
||||
Ok(VALUES4[2]),
|
||||
"node 2 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 3)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)),
|
||||
Ok(VALUES4[3]),
|
||||
"node 3 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 4)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)),
|
||||
Ok(RpoDigest::default()),
|
||||
"unmodified node 4 must be ZERO"
|
||||
);
|
||||
@@ -266,86 +269,86 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
|
||||
// STORE LEAVES MATCH TREE ===============================================================
|
||||
// sanity check the values returned by the store and the tree
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(smt.depth(), 0)),
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 0)),
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 0)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)),
|
||||
"node 0 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(smt.depth(), 1)),
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 1)),
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 1)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)),
|
||||
"node 1 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(smt.depth(), 2)),
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 2)),
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 2)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)),
|
||||
"node 2 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(smt.depth(), 3)),
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 3)),
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 3)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)),
|
||||
"node 3 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(smt.depth(), 4)),
|
||||
store.get_node(smt.root(), NodeIndex::make(smt.depth(), 4)),
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 4)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)),
|
||||
"node 4 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
// STORE MERKLE PATH MATCHS ==============================================================
|
||||
// STORE MERKLE PATH MATCHES ==============================================================
|
||||
// assert the merkle path returned by the store is the same as the one in the tree
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 0)).unwrap();
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[0], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_path(NodeIndex::make(smt.depth(), 0)),
|
||||
Ok(result.path),
|
||||
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(0).unwrap()).path,
|
||||
result.path,
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 1)).unwrap();
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[1], result.value,
|
||||
"Value for merkle path at index 1 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_path(NodeIndex::make(smt.depth(), 1)),
|
||||
Ok(result.path),
|
||||
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(1).unwrap()).path,
|
||||
result.path,
|
||||
"merkle path for index 1 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 2)).unwrap();
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[2], result.value,
|
||||
"Value for merkle path at index 2 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_path(NodeIndex::make(smt.depth(), 2)),
|
||||
Ok(result.path),
|
||||
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(2).unwrap()).path,
|
||||
result.path,
|
||||
"merkle path for index 2 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 3)).unwrap();
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[3], result.value,
|
||||
"Value for merkle path at index 3 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_path(NodeIndex::make(smt.depth(), 3)),
|
||||
Ok(result.path),
|
||||
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(3).unwrap()).path,
|
||||
result.path,
|
||||
"merkle path for index 3 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(smt.depth(), 4)).unwrap();
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap();
|
||||
assert_eq!(
|
||||
RpoDigest::default(),
|
||||
result.value,
|
||||
"Value for merkle path at index 4 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_path(NodeIndex::make(smt.depth(), 4)),
|
||||
Ok(result.path),
|
||||
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(4).unwrap()).path,
|
||||
result.path,
|
||||
"merkle path for index 4 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
@@ -426,7 +429,7 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> {
|
||||
"node 3 must be the same for both PartialMerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
// STORE MERKLE PATH MATCHS ==============================================================
|
||||
// STORE MERKLE PATH MATCHES ==============================================================
|
||||
// assert the merkle path returned by the store is the same as the one in the pmt
|
||||
let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)).unwrap();
|
||||
assert_eq!(
|
||||
@@ -553,19 +556,15 @@ fn test_constructors() -> Result<(), MerkleError> {
|
||||
assert_eq!(mtree.get_path(index)?, value_path.path);
|
||||
}
|
||||
|
||||
let depth = 32;
|
||||
let smt = SimpleSmt::with_leaves(
|
||||
depth,
|
||||
KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()),
|
||||
)
|
||||
.unwrap();
|
||||
const DEPTH: u8 = 32;
|
||||
let smt =
|
||||
SimpleSmt::<DEPTH>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4))).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let depth = smt.depth();
|
||||
|
||||
for key in KEYS4 {
|
||||
let index = NodeIndex::make(depth, key);
|
||||
let index = NodeIndex::make(DEPTH, key);
|
||||
let value_path = store.get_path(smt.root(), index)?;
|
||||
assert_eq!(smt.get_path(index)?, value_path.path);
|
||||
assert_eq!(smt.open(&LeafIndex::<DEPTH>::new(key).unwrap()).path, value_path.path);
|
||||
}
|
||||
|
||||
let d = 2;
|
||||
@@ -653,7 +652,7 @@ fn get_leaf_depth_works_depth_64() {
|
||||
let index = NodeIndex::new(64, k).unwrap();
|
||||
|
||||
// assert the leaf doesn't exist before the insert. the returned depth should always
|
||||
// increment with the paths count of the set, as they are insersecting one another up to
|
||||
// increment with the paths count of the set, as they are intersecting one another up to
|
||||
// the first bits of the used key.
|
||||
assert_eq!(d, store.get_leaf_depth(root, 64, k).unwrap());
|
||||
|
||||
@@ -884,8 +883,9 @@ fn test_serialization() -> Result<(), Box<dyn Error>> {
|
||||
fn test_recorder() {
|
||||
// instantiate recorder from MerkleTree and SimpleSmt
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4)).unwrap();
|
||||
let smtree = SimpleSmt::with_leaves(
|
||||
64,
|
||||
|
||||
const TREE_DEPTH: u8 = 64;
|
||||
let smtree = SimpleSmt::<TREE_DEPTH>::with_leaves(
|
||||
KEYS8.into_iter().zip(VALUES8.into_iter().map(|x| x.into()).rev()),
|
||||
)
|
||||
.unwrap();
|
||||
@@ -898,13 +898,13 @@ fn test_recorder() {
|
||||
let node = recorder.get_node(mtree.root(), index_0).unwrap();
|
||||
assert_eq!(node, mtree.get_node(index_0).unwrap());
|
||||
|
||||
let index_1 = NodeIndex::new(smtree.depth(), 1).unwrap();
|
||||
let index_1 = NodeIndex::new(TREE_DEPTH, 1).unwrap();
|
||||
let node = recorder.get_node(smtree.root(), index_1).unwrap();
|
||||
assert_eq!(node, smtree.get_node(index_1).unwrap());
|
||||
|
||||
// insert a value and assert that when we request it next time it is accurate
|
||||
let new_value = [ZERO, ZERO, ONE, ONE].into();
|
||||
let index_2 = NodeIndex::new(smtree.depth(), 2).unwrap();
|
||||
let index_2 = NodeIndex::new(TREE_DEPTH, 2).unwrap();
|
||||
let root = recorder.set_node(smtree.root(), index_2, new_value).unwrap().root;
|
||||
assert_eq!(recorder.get_node(root, index_2).unwrap(), new_value);
|
||||
|
||||
@@ -921,10 +921,13 @@ fn test_recorder() {
|
||||
assert_eq!(node, smtree.get_node(index_1).unwrap());
|
||||
|
||||
let node = merkle_store.get_node(smtree.root(), index_2).unwrap();
|
||||
assert_eq!(node, smtree.get_leaf(index_2.value()).unwrap().into());
|
||||
assert_eq!(
|
||||
node,
|
||||
smtree.get_leaf(&LeafIndex::<TREE_DEPTH>::try_from(index_2).unwrap()).into()
|
||||
);
|
||||
|
||||
// assert that is doesnt contain nodes that were not recorded
|
||||
let not_recorded_index = NodeIndex::new(smtree.depth(), 4).unwrap();
|
||||
let not_recorded_index = NodeIndex::new(TREE_DEPTH, 4).unwrap();
|
||||
assert!(merkle_store.get_node(smtree.root(), not_recorded_index).is_err());
|
||||
assert!(smtree.get_node(not_recorded_index).is_ok());
|
||||
}
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
use core::fmt::Display;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum TieredSmtProofError {
|
||||
EntriesEmpty,
|
||||
EmptyValueNotAllowed,
|
||||
MismatchedPrefixes(u64, u64),
|
||||
MultipleEntriesOutsideLastTier,
|
||||
NotATierPath(u8),
|
||||
PathTooLong,
|
||||
}
|
||||
|
||||
impl Display for TieredSmtProofError {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
match self {
|
||||
TieredSmtProofError::EntriesEmpty => {
|
||||
write!(f, "Missing entries for tiered sparse merkle tree proof")
|
||||
}
|
||||
TieredSmtProofError::EmptyValueNotAllowed => {
|
||||
write!(
|
||||
f,
|
||||
"The empty value [0, 0, 0, 0] is not allowed inside a tiered sparse merkle tree"
|
||||
)
|
||||
}
|
||||
TieredSmtProofError::MismatchedPrefixes(first, second) => {
|
||||
write!(f, "Not all leaves have the same prefix. First {first} second {second}")
|
||||
}
|
||||
TieredSmtProofError::MultipleEntriesOutsideLastTier => {
|
||||
write!(f, "Multiple entries are only allowed for the last tier (depth 64)")
|
||||
}
|
||||
TieredSmtProofError::NotATierPath(got) => {
|
||||
write!(
|
||||
f,
|
||||
"Path length does not correspond to a tier. Got {got} Expected one of 16, 32, 48, 64"
|
||||
)
|
||||
}
|
||||
TieredSmtProofError::PathTooLong => {
|
||||
write!(
|
||||
f,
|
||||
"Path longer than maximum depth of 64 for tiered sparse merkle tree proof"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for TieredSmtProofError {}
|
||||
@@ -1,509 +0,0 @@
|
||||
use super::{
|
||||
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex,
|
||||
Rpo256, RpoDigest, StarkField, Vec, Word,
|
||||
};
|
||||
use crate::utils::vec;
|
||||
use core::{cmp, ops::Deref};
|
||||
|
||||
mod nodes;
|
||||
use nodes::NodeStore;
|
||||
|
||||
mod values;
|
||||
use values::ValueStore;
|
||||
|
||||
mod proof;
|
||||
pub use proof::TieredSmtProof;
|
||||
|
||||
mod error;
|
||||
pub use error::TieredSmtProofError;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// TIERED SPARSE MERKLE TREE
|
||||
// ================================================================================================
|
||||
|
||||
/// Tiered (compacted) Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and
|
||||
/// values are represented by 4 field elements.
|
||||
///
|
||||
/// Leaves in the tree can exist only on specific depths called "tiers". These depths are: 16, 32,
|
||||
/// 48, and 64. Initially, when a tree is empty, it is equivalent to an empty Sparse Merkle tree
|
||||
/// of depth 64 (i.e., leaves at depth 64 are set to [ZERO; 4]). As non-empty values are inserted
|
||||
/// into the tree they are added to the first available tier.
|
||||
///
|
||||
/// For example, when the first key-value pair is inserted, it will be stored in a node at depth
|
||||
/// 16 such that the 16 most significant bits of the key determine the position of the node at
|
||||
/// depth 16. If another value with a key sharing the same 16-bit prefix is inserted, both values
|
||||
/// move into the next tier (depth 32). This process is repeated until values end up at the bottom
|
||||
/// tier (depth 64). If multiple values have keys with a common 64-bit prefix, such key-value pairs
|
||||
/// are stored in a sorted list at the bottom tier.
|
||||
///
|
||||
/// To differentiate between internal and leaf nodes, node values are computed as follows:
|
||||
/// - Internal nodes: hash(left_child, right_child).
|
||||
/// - Leaf node at depths 16, 32, or 64: hash(key, value, domain=depth).
|
||||
/// - Leaf node at depth 64: hash([key_0, value_0, ..., key_n, value_n], domain=64).
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct TieredSmt {
|
||||
root: RpoDigest,
|
||||
nodes: NodeStore,
|
||||
values: ValueStore,
|
||||
}
|
||||
|
||||
impl TieredSmt {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// The number of levels between tiers.
|
||||
pub const TIER_SIZE: u8 = 16;
|
||||
|
||||
/// Depths at which leaves can exist in a tiered SMT.
|
||||
pub const TIER_DEPTHS: [u8; 4] = [16, 32, 48, 64];
|
||||
|
||||
/// Maximum node depth. This is also the bottom tier of the tree.
|
||||
pub const MAX_DEPTH: u8 = 64;
|
||||
|
||||
/// Value of an empty leaf.
|
||||
pub const EMPTY_VALUE: Word = super::EMPTY_WORD;
|
||||
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [TieredSmt] instantiated with the specified key-value pairs.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the provided entries contain multiple values for the same key.
|
||||
pub fn with_entries<R, I>(entries: R) -> Result<Self, MerkleError>
|
||||
where
|
||||
R: IntoIterator<IntoIter = I>,
|
||||
I: Iterator<Item = (RpoDigest, Word)> + ExactSizeIterator,
|
||||
{
|
||||
// create an empty tree
|
||||
let mut tree = Self::default();
|
||||
|
||||
// append leaves to the tree returning an error if a duplicate entry for the same key
|
||||
// is found
|
||||
let mut empty_entries = BTreeSet::new();
|
||||
for (key, value) in entries {
|
||||
let old_value = tree.insert(key, value);
|
||||
if old_value != Self::EMPTY_VALUE || empty_entries.contains(&key) {
|
||||
return Err(MerkleError::DuplicateValuesForKey(key));
|
||||
}
|
||||
// if we've processed an empty entry, add the key to the set of empty entry keys, and
|
||||
// if this key was already in the set, return an error
|
||||
if value == Self::EMPTY_VALUE && !empty_entries.insert(key) {
|
||||
return Err(MerkleError::DuplicateValuesForKey(key));
|
||||
}
|
||||
}
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the root of this Merkle tree.
|
||||
pub const fn root(&self) -> RpoDigest {
|
||||
self.root
|
||||
}
|
||||
|
||||
/// Returns a node at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index depth is 0 or greater than 64.
|
||||
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
||||
/// when a leaf node with the same index prefix exists at a tier higher than the requested
|
||||
/// node.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
self.nodes.get_node(index)
|
||||
}
|
||||
|
||||
/// Returns a Merkle path from the node at the specified index to the root.
|
||||
///
|
||||
/// The node itself is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index depth is 0 or greater than 64.
|
||||
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
||||
/// when a leaf node with the same index prefix exists at a tier higher than the node to
|
||||
/// which the path is requested.
|
||||
pub fn get_path(&self, index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
self.nodes.get_path(index)
|
||||
}
|
||||
|
||||
/// Returns the value associated with the specified key.
|
||||
///
|
||||
/// If nothing was inserted into this tree for the specified key, [ZERO; 4] is returned.
|
||||
pub fn get_value(&self, key: RpoDigest) -> Word {
|
||||
match self.values.get(&key) {
|
||||
Some(value) => *value,
|
||||
None => Self::EMPTY_VALUE,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a proof for a key-value pair defined by the specified key.
|
||||
///
|
||||
/// The proof can be used to attest membership of this key-value pair in a Tiered Sparse Merkle
|
||||
/// Tree defined by the same root as this tree.
|
||||
pub fn prove(&self, key: RpoDigest) -> TieredSmtProof {
|
||||
let (path, index, leaf_exists) = self.nodes.get_proof(&key);
|
||||
|
||||
let entries = if index.depth() == Self::MAX_DEPTH {
|
||||
match self.values.get_all(index.value()) {
|
||||
Some(entries) => entries,
|
||||
None => vec![(key, Self::EMPTY_VALUE)],
|
||||
}
|
||||
} else if leaf_exists {
|
||||
let entry =
|
||||
self.values.get_first(index_to_prefix(&index)).expect("leaf entry not found");
|
||||
debug_assert_eq!(entry.0, key);
|
||||
vec![*entry]
|
||||
} else {
|
||||
vec![(key, Self::EMPTY_VALUE)]
|
||||
};
|
||||
|
||||
TieredSmtProof::new(path, entries).expect("Bug detected, TSMT produced invalid proof")
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts the provided value into the tree under the specified key and returns the value
|
||||
/// previously stored under this key.
|
||||
///
|
||||
/// If the value for the specified key was not previously set, [ZERO; 4] is returned.
|
||||
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Word {
|
||||
// if an empty value is being inserted, remove the leaf node to make it look as if the
|
||||
// value was never inserted
|
||||
if value == Self::EMPTY_VALUE {
|
||||
return self.remove_leaf_node(key);
|
||||
}
|
||||
|
||||
// insert the value into the value store, and if the key was already in the store, update
|
||||
// it with the new value
|
||||
if let Some(old_value) = self.values.insert(key, value) {
|
||||
if old_value != value {
|
||||
// if the new value is different from the old value, determine the location of
|
||||
// the leaf node for this key, build the node, and update the root
|
||||
let (index, leaf_exists) = self.nodes.get_leaf_index(&key);
|
||||
debug_assert!(leaf_exists);
|
||||
let node = self.build_leaf_node(index, key, value);
|
||||
self.root = self.nodes.update_leaf_node(index, node);
|
||||
}
|
||||
return old_value;
|
||||
};
|
||||
|
||||
// determine the location for the leaf node; this index could have 3 different meanings:
|
||||
// - it points to a root of an empty subtree or an empty node at depth 64; in this case,
|
||||
// we can replace the node with the value node immediately.
|
||||
// - it points to an existing leaf at the bottom tier (i.e., depth = 64); in this case,
|
||||
// we need to process update the bottom leaf.
|
||||
// - it points to an existing leaf node for a different key with the same prefix (same
|
||||
// key case was handled above); in this case, we need to move the leaf to a lower tier
|
||||
let (index, leaf_exists) = self.nodes.get_leaf_index(&key);
|
||||
|
||||
self.root = if leaf_exists && index.depth() == Self::MAX_DEPTH {
|
||||
// returned index points to a leaf at the bottom tier
|
||||
let node = self.build_leaf_node(index, key, value);
|
||||
self.nodes.update_leaf_node(index, node)
|
||||
} else if leaf_exists {
|
||||
// returned index points to a leaf for a different key with the same prefix
|
||||
|
||||
// get the key-value pair for the key with the same prefix; since the key-value
|
||||
// pair has already been inserted into the value store, we need to filter it out
|
||||
// when looking for the other key-value pair
|
||||
let (other_key, other_value) = self
|
||||
.values
|
||||
.get_first_filtered(index_to_prefix(&index), &key)
|
||||
.expect("other key-value pair not found");
|
||||
|
||||
// determine how far down the tree should we move the leaves
|
||||
let common_prefix_len = get_common_prefix_tier_depth(&key, other_key);
|
||||
let depth = cmp::min(common_prefix_len + Self::TIER_SIZE, Self::MAX_DEPTH);
|
||||
|
||||
// compute node locations for new and existing key-value paris
|
||||
let new_index = LeafNodeIndex::from_key(&key, depth);
|
||||
let other_index = LeafNodeIndex::from_key(other_key, depth);
|
||||
|
||||
// compute node values for the new and existing key-value pairs
|
||||
let new_node = self.build_leaf_node(new_index, key, value);
|
||||
let other_node = self.build_leaf_node(other_index, *other_key, *other_value);
|
||||
|
||||
// replace the leaf located at index with a subtree containing nodes for new and
|
||||
// existing key-value paris
|
||||
self.nodes.replace_leaf_with_subtree(
|
||||
index,
|
||||
[(new_index, new_node), (other_index, other_node)],
|
||||
)
|
||||
} else {
|
||||
// returned index points to an empty subtree or an empty leaf at the bottom tier
|
||||
let node = self.build_leaf_node(index, key, value);
|
||||
self.nodes.insert_leaf_node(index, node)
|
||||
};
|
||||
|
||||
Self::EMPTY_VALUE
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over all key-value pairs in this [TieredSmt].
|
||||
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
self.values.iter()
|
||||
}
|
||||
|
||||
/// Returns an iterator over all inner nodes of this [TieredSmt] (i.e., nodes not at depths 16
|
||||
/// 32, 48, or 64).
|
||||
///
|
||||
/// The iterator order is unspecified.
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.nodes.inner_nodes()
|
||||
}
|
||||
|
||||
/// Returns an iterator over upper leaves (i.e., depth = 16, 32, or 48) for this [TieredSmt]
|
||||
/// where each yielded item is a (node, key, value) tuple.
|
||||
///
|
||||
/// The iterator order is unspecified.
|
||||
pub fn upper_leaves(&self) -> impl Iterator<Item = (RpoDigest, RpoDigest, Word)> + '_ {
|
||||
self.nodes.upper_leaves().map(|(index, node)| {
|
||||
let key_prefix = index_to_prefix(index);
|
||||
let (key, value) = self.values.get_first(key_prefix).expect("upper leaf not found");
|
||||
debug_assert_eq!(*index, LeafNodeIndex::from_key(key, index.depth()).into());
|
||||
(*node, *key, *value)
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns an iterator over upper leaves (i.e., depth = 16, 32, or 48) for this [TieredSmt]
|
||||
/// where each yielded item is a (node_index, value) tuple.
|
||||
pub fn upper_leaf_nodes(&self) -> impl Iterator<Item = (&NodeIndex, &RpoDigest)> {
|
||||
self.nodes.upper_leaves()
|
||||
}
|
||||
|
||||
/// Returns an iterator over bottom leaves (i.e., depth = 64) of this [TieredSmt].
|
||||
///
|
||||
/// Each yielded item consists of the hash of the leaf and its contents, where contents is
|
||||
/// a vector containing key-value pairs of entries storied in this leaf.
|
||||
///
|
||||
/// The iterator order is unspecified.
|
||||
pub fn bottom_leaves(&self) -> impl Iterator<Item = (RpoDigest, Vec<(RpoDigest, Word)>)> + '_ {
|
||||
self.nodes.bottom_leaves().map(|(&prefix, node)| {
|
||||
let values = self.values.get_all(prefix).expect("bottom leaf not found");
|
||||
(*node, values)
|
||||
})
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Removes the node holding the key-value pair for the specified key from this tree, and
|
||||
/// returns the value associated with the specified key.
|
||||
///
|
||||
/// If no value was associated with the specified key, [ZERO; 4] is returned.
|
||||
fn remove_leaf_node(&mut self, key: RpoDigest) -> Word {
|
||||
// remove the key-value pair from the value store; if no value was associated with the
|
||||
// specified key, return.
|
||||
let old_value = match self.values.remove(&key) {
|
||||
Some(old_value) => old_value,
|
||||
None => return Self::EMPTY_VALUE,
|
||||
};
|
||||
|
||||
// determine the location of the leaf holding the key-value pair to be removed
|
||||
let (index, leaf_exists) = self.nodes.get_leaf_index(&key);
|
||||
debug_assert!(leaf_exists);
|
||||
|
||||
// if the leaf is at the bottom tier and after removing the key-value pair from it, the
|
||||
// leaf is still not empty, we either just update it, or move it up to a higher tier (if
|
||||
// the leaf doesn't have siblings at lower tiers)
|
||||
if index.depth() == Self::MAX_DEPTH {
|
||||
if let Some(entries) = self.values.get_all(index.value()) {
|
||||
// if there is only one key-value pair left at the bottom leaf, and it can be
|
||||
// moved up to a higher tier, truncate the branch and return
|
||||
if entries.len() == 1 {
|
||||
let new_depth = self.nodes.get_last_single_child_parent_depth(index.value());
|
||||
if new_depth != Self::MAX_DEPTH {
|
||||
let node = hash_upper_leaf(entries[0].0, entries[0].1, new_depth);
|
||||
self.root = self.nodes.truncate_branch(index.value(), new_depth, node);
|
||||
return old_value;
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise just recompute the leaf hash and update the leaf node
|
||||
let node = hash_bottom_leaf(&entries);
|
||||
self.root = self.nodes.update_leaf_node(index, node);
|
||||
return old_value;
|
||||
};
|
||||
}
|
||||
|
||||
// if the removed key-value pair has a lone sibling at the current tier with a root at
|
||||
// higher tier, we need to move the sibling to a higher tier
|
||||
if let Some((sib_key, sib_val, new_sib_index)) = self.values.get_lone_sibling(index) {
|
||||
// determine the current index of the sibling node
|
||||
let sib_index = LeafNodeIndex::from_key(sib_key, index.depth());
|
||||
debug_assert!(sib_index.depth() > new_sib_index.depth());
|
||||
|
||||
// compute node value for the new location of the sibling leaf and replace the subtree
|
||||
// with this leaf node
|
||||
let node = self.build_leaf_node(new_sib_index, *sib_key, *sib_val);
|
||||
let new_sib_depth = new_sib_index.depth();
|
||||
self.root = self.nodes.replace_subtree_with_leaf(index, sib_index, new_sib_depth, node);
|
||||
} else {
|
||||
// if the removed key-value pair did not have a sibling at the current tier with a
|
||||
// root at higher tiers, just clear the leaf node
|
||||
self.root = self.nodes.clear_leaf_node(index);
|
||||
}
|
||||
|
||||
old_value
|
||||
}
|
||||
|
||||
/// Builds and returns a leaf node value for the node located as the specified index.
|
||||
///
|
||||
/// This method assumes that the key-value pair for the node has already been inserted into
|
||||
/// the value store, however, for depths 16, 32, and 48, the node is computed directly from
|
||||
/// the passed-in values (for depth 64, the value store is queried to get all the key-value
|
||||
/// pairs located at the specified index).
|
||||
fn build_leaf_node(&self, index: LeafNodeIndex, key: RpoDigest, value: Word) -> RpoDigest {
|
||||
let depth = index.depth();
|
||||
|
||||
// insert the key into index-key map and compute the new value of the node
|
||||
if index.depth() == Self::MAX_DEPTH {
|
||||
// for the bottom tier, we add the key-value pair to the existing leaf, or create a
|
||||
// new leaf with this key-value pair
|
||||
let values = self.values.get_all(index.value()).unwrap();
|
||||
hash_bottom_leaf(&values)
|
||||
} else {
|
||||
debug_assert_eq!(self.values.get_first(index_to_prefix(&index)), Some(&(key, value)));
|
||||
hash_upper_leaf(key, value, depth)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TieredSmt {
|
||||
fn default() -> Self {
|
||||
let root = EmptySubtreeRoots::empty_hashes(Self::MAX_DEPTH)[0];
|
||||
Self {
|
||||
root,
|
||||
nodes: NodeStore::new(root),
|
||||
values: ValueStore::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LEAF NODE INDEX
|
||||
// ================================================================================================
|
||||
/// A wrapper around [NodeIndex] to provide type-safe references to nodes at depths 16, 32, 48, and
|
||||
/// 64.
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
|
||||
pub struct LeafNodeIndex(NodeIndex);
|
||||
|
||||
impl LeafNodeIndex {
|
||||
/// Returns a new [LeafNodeIndex] instantiated from the provided [NodeIndex].
|
||||
///
|
||||
/// In debug mode, panics if index depth is not 16, 32, 48, or 64.
|
||||
pub fn new(index: NodeIndex) -> Self {
|
||||
// check if the depth is 16, 32, 48, or 64; this works because for a valid depth,
|
||||
// depth - 16, can be 0, 16, 32, or 48 - i.e., the value is either 0 or any of the 4th
|
||||
// or 5th bits are set. We can test for this by computing a bitwise AND with a value
|
||||
// which has all but the 4th and 5th bits set (which is !48).
|
||||
debug_assert_eq!(((index.depth() - 16) & !48), 0, "invalid tier depth {}", index.depth());
|
||||
Self(index)
|
||||
}
|
||||
|
||||
/// Returns a new [LeafNodeIndex] instantiated from the specified key inserted at the specified
|
||||
/// depth.
|
||||
///
|
||||
/// The value for the key is computed by taking n most significant bits from the most significant
|
||||
/// element of the key, where n is the specified depth.
|
||||
pub fn from_key(key: &RpoDigest, depth: u8) -> Self {
|
||||
let mse = get_key_prefix(key);
|
||||
Self::new(NodeIndex::new_unchecked(depth, mse >> (TieredSmt::MAX_DEPTH - depth)))
|
||||
}
|
||||
|
||||
/// Returns a new [LeafNodeIndex] instantiated for testing purposes.
|
||||
#[cfg(test)]
|
||||
pub fn make(depth: u8, value: u64) -> Self {
|
||||
Self::new(NodeIndex::make(depth, value))
|
||||
}
|
||||
|
||||
/// Traverses towards the root until the specified depth is reached.
|
||||
///
|
||||
/// The new depth must be a valid tier depth - i.e., 16, 32, 48, or 64.
|
||||
pub fn move_up_to(&mut self, depth: u8) {
|
||||
debug_assert_eq!(((depth - 16) & !48), 0, "invalid tier depth: {depth}");
|
||||
self.0.move_up_to(depth);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for LeafNodeIndex {
|
||||
type Target = NodeIndex;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NodeIndex> for LeafNodeIndex {
|
||||
fn from(value: NodeIndex) -> Self {
|
||||
Self::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LeafNodeIndex> for NodeIndex {
|
||||
fn from(value: LeafNodeIndex) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Returns the value representing the 64 most significant bits of the specified key.
|
||||
fn get_key_prefix(key: &RpoDigest) -> u64 {
|
||||
Word::from(key)[3].as_int()
|
||||
}
|
||||
|
||||
/// Returns the index value shifted to be in the most significant bit positions of the returned
|
||||
/// u64 value.
|
||||
fn index_to_prefix(index: &NodeIndex) -> u64 {
|
||||
index.value() << (TieredSmt::MAX_DEPTH - index.depth())
|
||||
}
|
||||
|
||||
/// Returns tiered common prefix length between the most significant elements of the provided keys.
|
||||
///
|
||||
/// Specifically:
|
||||
/// - returns 64 if the most significant elements are equal.
|
||||
/// - returns 48 if the common prefix is between 48 and 63 bits.
|
||||
/// - returns 32 if the common prefix is between 32 and 47 bits.
|
||||
/// - returns 16 if the common prefix is between 16 and 31 bits.
|
||||
/// - returns 0 if the common prefix is fewer than 16 bits.
|
||||
fn get_common_prefix_tier_depth(key1: &RpoDigest, key2: &RpoDigest) -> u8 {
|
||||
let e1 = get_key_prefix(key1);
|
||||
let e2 = get_key_prefix(key2);
|
||||
let ex = (e1 ^ e2).leading_zeros() as u8;
|
||||
(ex / 16) * 16
|
||||
}
|
||||
|
||||
/// Computes node value for leaves at tiers 16, 32, or 48.
|
||||
///
|
||||
/// Node value is computed as: hash(key || value, domain = depth).
|
||||
pub fn hash_upper_leaf(key: RpoDigest, value: Word, depth: u8) -> RpoDigest {
|
||||
const NUM_UPPER_TIERS: usize = TieredSmt::TIER_DEPTHS.len() - 1;
|
||||
debug_assert!(TieredSmt::TIER_DEPTHS[..NUM_UPPER_TIERS].contains(&depth));
|
||||
Rpo256::merge_in_domain(&[key, value.into()], depth.into())
|
||||
}
|
||||
|
||||
/// Computes node value for leaves at the bottom tier (depth 64).
|
||||
///
|
||||
/// Node value is computed as: hash([key_0, value_0, ..., key_n, value_n], domain=64).
|
||||
///
|
||||
/// TODO: when hashing in domain is implemented for `hash_elements()`, combine this function with
|
||||
/// `hash_upper_leaf()` function.
|
||||
pub fn hash_bottom_leaf(values: &[(RpoDigest, Word)]) -> RpoDigest {
|
||||
let mut elements = Vec::with_capacity(values.len() * 8);
|
||||
for (key, val) in values.iter() {
|
||||
elements.extend_from_slice(key.as_elements());
|
||||
elements.extend_from_slice(val.as_slice());
|
||||
}
|
||||
// TODO: hash in domain
|
||||
Rpo256::hash_elements(&elements)
|
||||
}
|
||||
@@ -1,419 +0,0 @@
|
||||
use super::{
|
||||
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, LeafNodeIndex, MerkleError, MerklePath,
|
||||
NodeIndex, Rpo256, RpoDigest, Vec,
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// The number of levels between tiers.
|
||||
const TIER_SIZE: u8 = super::TieredSmt::TIER_SIZE;
|
||||
|
||||
/// Depths at which leaves can exist in a tiered SMT.
|
||||
const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
|
||||
|
||||
/// Maximum node depth. This is also the bottom tier of the tree.
|
||||
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
|
||||
|
||||
// NODE STORE
|
||||
// ================================================================================================
|
||||
|
||||
/// A store of nodes for a Tiered Sparse Merkle tree.
|
||||
///
|
||||
/// The store contains information about all nodes as well as information about which of the nodes
|
||||
/// represent leaf nodes in a Tiered Sparse Merkle tree. In the current implementation, [BTreeSet]s
|
||||
/// are used to determine the position of the leaves in the tree.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct NodeStore {
|
||||
nodes: BTreeMap<NodeIndex, RpoDigest>,
|
||||
upper_leaves: BTreeSet<NodeIndex>,
|
||||
bottom_leaves: BTreeSet<u64>,
|
||||
}
|
||||
|
||||
impl NodeStore {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns a new instance of [NodeStore] instantiated with the specified root node.
|
||||
///
|
||||
/// Root node is assumed to be a root of an empty sparse Merkle tree.
|
||||
pub fn new(root_node: RpoDigest) -> Self {
|
||||
let mut nodes = BTreeMap::default();
|
||||
nodes.insert(NodeIndex::root(), root_node);
|
||||
|
||||
Self {
|
||||
nodes,
|
||||
upper_leaves: BTreeSet::default(),
|
||||
bottom_leaves: BTreeSet::default(),
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a node at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index depth is 0 or greater than 64.
|
||||
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
||||
/// when a leaf node with the same index prefix exists at a tier higher than the requested
|
||||
/// node.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
self.validate_node_access(index)?;
|
||||
Ok(self.get_node_unchecked(&index))
|
||||
}
|
||||
|
||||
/// Returns a Merkle path from the node at the specified index to the root.
|
||||
///
|
||||
/// The node itself is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index depth is 0 or greater than 64.
|
||||
/// - The node with the specified index does not exists in the Merkle tree. This is possible
|
||||
/// when a leaf node with the same index prefix exists at a tier higher than the node to
|
||||
/// which the path is requested.
|
||||
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
self.validate_node_access(index)?;
|
||||
|
||||
let mut path = Vec::with_capacity(index.depth() as usize);
|
||||
for _ in 0..index.depth() {
|
||||
let node = self.get_node_unchecked(&index.sibling());
|
||||
path.push(node);
|
||||
index.move_up();
|
||||
}
|
||||
|
||||
Ok(path.into())
|
||||
}
|
||||
|
||||
/// Returns a Merkle path to the node specified by the key together with a flag indicating,
|
||||
/// whether this node is a leaf at depths 16, 32, or 48.
|
||||
pub fn get_proof(&self, key: &RpoDigest) -> (MerklePath, NodeIndex, bool) {
|
||||
let (index, leaf_exists) = self.get_leaf_index(key);
|
||||
let index: NodeIndex = index.into();
|
||||
let path = self.get_path(index).expect("failed to retrieve Merkle path for a node index");
|
||||
(path, index, leaf_exists)
|
||||
}
|
||||
|
||||
/// Returns an index at which a leaf node for the specified key should be inserted.
|
||||
///
|
||||
/// The second value in the returned tuple is set to true if the node at the returned index
|
||||
/// is already a leaf node.
|
||||
pub fn get_leaf_index(&self, key: &RpoDigest) -> (LeafNodeIndex, bool) {
|
||||
// traverse the tree from the root down checking nodes at tiers 16, 32, and 48. Return if
|
||||
// a node at any of the tiers is either a leaf or a root of an empty subtree.
|
||||
const NUM_UPPER_TIERS: usize = TIER_DEPTHS.len() - 1;
|
||||
for &tier_depth in TIER_DEPTHS[..NUM_UPPER_TIERS].iter() {
|
||||
let index = LeafNodeIndex::from_key(key, tier_depth);
|
||||
if self.upper_leaves.contains(&index) {
|
||||
return (index, true);
|
||||
} else if !self.nodes.contains_key(&index) {
|
||||
return (index, false);
|
||||
}
|
||||
}
|
||||
|
||||
// if we got here, that means all of the nodes checked so far are internal nodes, and
|
||||
// the new node would need to be inserted in the bottom tier.
|
||||
let index = LeafNodeIndex::from_key(key, MAX_DEPTH);
|
||||
(index, self.bottom_leaves.contains(&index.value()))
|
||||
}
|
||||
|
||||
/// Traverses the tree up from the bottom tier starting at the specified leaf index and
|
||||
/// returns the depth of the first node which hash more than one child. The returned depth
|
||||
/// is rounded up to the next tier.
|
||||
pub fn get_last_single_child_parent_depth(&self, leaf_index: u64) -> u8 {
|
||||
let mut index = NodeIndex::new_unchecked(MAX_DEPTH, leaf_index);
|
||||
|
||||
for _ in (TIER_DEPTHS[0]..MAX_DEPTH).rev() {
|
||||
let sibling_index = index.sibling();
|
||||
if self.nodes.contains_key(&sibling_index) {
|
||||
break;
|
||||
}
|
||||
index.move_up();
|
||||
}
|
||||
|
||||
let tier = (index.depth() - 1) / TIER_SIZE;
|
||||
TIER_DEPTHS[tier as usize]
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over all inner nodes of the Tiered Sparse Merkle tree (i.e., nodes not
|
||||
/// at depths 16 32, 48, or 64).
|
||||
///
|
||||
/// The iterator order is unspecified.
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.nodes.iter().filter_map(|(index, node)| {
|
||||
if self.is_internal_node(index) {
|
||||
Some(InnerNodeInfo {
|
||||
value: *node,
|
||||
left: self.get_node_unchecked(&index.left_child()),
|
||||
right: self.get_node_unchecked(&index.right_child()),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns an iterator over the upper leaves (i.e., leaves with depths 16, 32, 48) of the
|
||||
/// Tiered Sparse Merkle tree.
|
||||
pub fn upper_leaves(&self) -> impl Iterator<Item = (&NodeIndex, &RpoDigest)> {
|
||||
self.upper_leaves.iter().map(|index| (index, &self.nodes[index]))
|
||||
}
|
||||
|
||||
/// Returns an iterator over the bottom leaves (i.e., leaves with depth 64) of the Tiered
|
||||
/// Sparse Merkle tree.
|
||||
pub fn bottom_leaves(&self) -> impl Iterator<Item = (&u64, &RpoDigest)> {
|
||||
self.bottom_leaves.iter().map(|value| {
|
||||
let index = NodeIndex::new_unchecked(MAX_DEPTH, *value);
|
||||
(value, &self.nodes[&index])
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Replaces the leaf node at the specified index with a tree consisting of two leaves located
|
||||
/// at the specified indexes. Recomputes and returns the new root.
|
||||
pub fn replace_leaf_with_subtree(
|
||||
&mut self,
|
||||
leaf_index: LeafNodeIndex,
|
||||
subtree_leaves: [(LeafNodeIndex, RpoDigest); 2],
|
||||
) -> RpoDigest {
|
||||
debug_assert!(self.is_non_empty_leaf(&leaf_index));
|
||||
debug_assert!(!is_empty_root(&subtree_leaves[0].1));
|
||||
debug_assert!(!is_empty_root(&subtree_leaves[1].1));
|
||||
debug_assert_eq!(subtree_leaves[0].0.depth(), subtree_leaves[1].0.depth());
|
||||
debug_assert!(leaf_index.depth() < subtree_leaves[0].0.depth());
|
||||
|
||||
self.upper_leaves.remove(&leaf_index);
|
||||
|
||||
if subtree_leaves[0].0 == subtree_leaves[1].0 {
|
||||
// if the subtree is for a single node at depth 64, we only need to insert one node
|
||||
debug_assert_eq!(subtree_leaves[0].0.depth(), MAX_DEPTH);
|
||||
debug_assert_eq!(subtree_leaves[0].1, subtree_leaves[1].1);
|
||||
self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1)
|
||||
} else {
|
||||
self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1);
|
||||
self.insert_leaf_node(subtree_leaves[1].0, subtree_leaves[1].1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Replaces a subtree containing the retained and the removed leaf nodes, with a leaf node
|
||||
/// containing the retained leaf.
|
||||
///
|
||||
/// This has the effect of deleting the the node at the `removed_leaf` index from the tree,
|
||||
/// moving the node at the `retained_leaf` index up to the tier specified by `new_depth`.
|
||||
pub fn replace_subtree_with_leaf(
|
||||
&mut self,
|
||||
removed_leaf: LeafNodeIndex,
|
||||
retained_leaf: LeafNodeIndex,
|
||||
new_depth: u8,
|
||||
node: RpoDigest,
|
||||
) -> RpoDigest {
|
||||
debug_assert!(!is_empty_root(&node));
|
||||
debug_assert!(self.is_non_empty_leaf(&removed_leaf));
|
||||
debug_assert!(self.is_non_empty_leaf(&retained_leaf));
|
||||
debug_assert_eq!(removed_leaf.depth(), retained_leaf.depth());
|
||||
debug_assert!(removed_leaf.depth() > new_depth);
|
||||
|
||||
// remove the branches leading up to the tier to which the retained leaf is to be moved
|
||||
self.remove_branch(removed_leaf, new_depth);
|
||||
self.remove_branch(retained_leaf, new_depth);
|
||||
|
||||
// compute the index of the common root for retained and removed leaves
|
||||
let mut new_index = retained_leaf;
|
||||
new_index.move_up_to(new_depth);
|
||||
|
||||
// insert the node at the root index
|
||||
self.insert_leaf_node(new_index, node)
|
||||
}
|
||||
|
||||
/// Inserts the specified node at the specified index; recomputes and returns the new root
|
||||
/// of the Tiered Sparse Merkle tree.
|
||||
///
|
||||
/// This method assumes that the provided node is a non-empty value, and that there is no node
|
||||
/// at the specified index.
|
||||
pub fn insert_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest {
|
||||
debug_assert!(!is_empty_root(&node));
|
||||
debug_assert_eq!(self.nodes.get(&index), None);
|
||||
|
||||
// mark the node as the leaf
|
||||
if index.depth() == MAX_DEPTH {
|
||||
self.bottom_leaves.insert(index.value());
|
||||
} else {
|
||||
self.upper_leaves.insert(index.into());
|
||||
};
|
||||
|
||||
// insert the node and update the path from the node to the root
|
||||
let mut index: NodeIndex = index.into();
|
||||
for _ in 0..index.depth() {
|
||||
self.nodes.insert(index, node);
|
||||
let sibling = self.get_node_unchecked(&index.sibling());
|
||||
node = Rpo256::merge(&index.build_node(node, sibling));
|
||||
index.move_up();
|
||||
}
|
||||
|
||||
// update the root
|
||||
self.nodes.insert(NodeIndex::root(), node);
|
||||
node
|
||||
}
|
||||
|
||||
/// Updates the node at the specified index with the specified node value; recomputes and
|
||||
/// returns the new root of the Tiered Sparse Merkle tree.
|
||||
///
|
||||
/// This method can accept `node` as either an empty or a non-empty value.
|
||||
pub fn update_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest {
|
||||
debug_assert!(self.is_non_empty_leaf(&index));
|
||||
|
||||
// if the value we are updating the node to is a root of an empty tree, clear the leaf
|
||||
// flag for this node
|
||||
if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] {
|
||||
if index.depth() == MAX_DEPTH {
|
||||
self.bottom_leaves.remove(&index.value());
|
||||
} else {
|
||||
self.upper_leaves.remove(&index);
|
||||
}
|
||||
} else {
|
||||
debug_assert!(!is_empty_root(&node));
|
||||
}
|
||||
|
||||
// update the path from the node to the root
|
||||
let mut index: NodeIndex = index.into();
|
||||
for _ in 0..index.depth() {
|
||||
if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] {
|
||||
self.nodes.remove(&index);
|
||||
} else {
|
||||
self.nodes.insert(index, node);
|
||||
}
|
||||
|
||||
let sibling = self.get_node_unchecked(&index.sibling());
|
||||
node = Rpo256::merge(&index.build_node(node, sibling));
|
||||
index.move_up();
|
||||
}
|
||||
|
||||
// update the root
|
||||
self.nodes.insert(NodeIndex::root(), node);
|
||||
node
|
||||
}
|
||||
|
||||
/// Replaces the leaf node at the specified index with a root of an empty subtree; recomputes
|
||||
/// and returns the new root of the Tiered Sparse Merkle tree.
|
||||
pub fn clear_leaf_node(&mut self, index: LeafNodeIndex) -> RpoDigest {
|
||||
debug_assert!(self.is_non_empty_leaf(&index));
|
||||
let node = EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize];
|
||||
self.update_leaf_node(index, node)
|
||||
}
|
||||
|
||||
/// Truncates a branch starting with specified leaf at the bottom tier to new depth.
|
||||
///
|
||||
/// This involves removing the part of the branch below the new depth, and then inserting a new
|
||||
/// // node at the new depth.
|
||||
pub fn truncate_branch(
|
||||
&mut self,
|
||||
leaf_index: u64,
|
||||
new_depth: u8,
|
||||
node: RpoDigest,
|
||||
) -> RpoDigest {
|
||||
debug_assert!(self.bottom_leaves.contains(&leaf_index));
|
||||
|
||||
let mut leaf_index = LeafNodeIndex::new(NodeIndex::new_unchecked(MAX_DEPTH, leaf_index));
|
||||
self.remove_branch(leaf_index, new_depth);
|
||||
|
||||
leaf_index.move_up_to(new_depth);
|
||||
self.insert_leaf_node(leaf_index, node)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if the node at the specified index is a leaf node.
|
||||
fn is_non_empty_leaf(&self, index: &LeafNodeIndex) -> bool {
|
||||
if index.depth() == MAX_DEPTH {
|
||||
self.bottom_leaves.contains(&index.value())
|
||||
} else {
|
||||
self.upper_leaves.contains(index)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the node at the specified index is an internal node - i.e., there is
|
||||
/// no leaf at that node and the node does not belong to the bottom tier.
|
||||
fn is_internal_node(&self, index: &NodeIndex) -> bool {
|
||||
if index.depth() == MAX_DEPTH {
|
||||
false
|
||||
} else {
|
||||
!self.upper_leaves.contains(index)
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if the specified index is valid in the context of this Merkle tree.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified index depth is 0 or greater than 64.
|
||||
/// - The node for the specified index does not exists in the Merkle tree. This is possible
|
||||
/// when an ancestors of the specified index is a leaf node.
|
||||
fn validate_node_access(&self, index: NodeIndex) -> Result<(), MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
} else if index.depth() > MAX_DEPTH {
|
||||
return Err(MerkleError::DepthTooBig(index.depth() as u64));
|
||||
} else {
|
||||
// make sure that there are no leaf nodes in the ancestors of the index; since leaf
|
||||
// nodes can live at specific depth, we just need to check these depths.
|
||||
let tier = ((index.depth() - 1) / TIER_SIZE) as usize;
|
||||
let mut tier_index = index;
|
||||
for &depth in TIER_DEPTHS[..tier].iter().rev() {
|
||||
tier_index.move_up_to(depth);
|
||||
if self.upper_leaves.contains(&tier_index) {
|
||||
return Err(MerkleError::NodeNotInSet(index));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns a node at the specified index. If the node does not exist at this index, a root
|
||||
/// for an empty subtree at the index's depth is returned.
|
||||
///
|
||||
/// Unlike [NodeStore::get_node()] this does not perform any checks to verify that the
|
||||
/// returned node is valid in the context of this tree.
|
||||
fn get_node_unchecked(&self, index: &NodeIndex) -> RpoDigest {
|
||||
match self.nodes.get(index) {
|
||||
Some(node) => *node,
|
||||
None => EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize],
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a sequence of nodes starting at the specified index and traversing the tree up to
|
||||
/// the specified depth. The node at the `end_depth` is also removed, and the appropriate leaf
|
||||
/// flag is cleared.
|
||||
///
|
||||
/// This method does not update any other nodes and does not recompute the tree root.
|
||||
fn remove_branch(&mut self, index: LeafNodeIndex, end_depth: u8) {
|
||||
if index.depth() == MAX_DEPTH {
|
||||
self.bottom_leaves.remove(&index.value());
|
||||
} else {
|
||||
self.upper_leaves.remove(&index);
|
||||
}
|
||||
|
||||
let mut index: NodeIndex = index.into();
|
||||
assert!(index.depth() > end_depth);
|
||||
for _ in 0..(index.depth() - end_depth + 1) {
|
||||
self.nodes.remove(&index);
|
||||
index.move_up()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Returns true if the specified node is a root of an empty tree or an empty value ([ZERO; 4]).
|
||||
fn is_empty_root(node: &RpoDigest) -> bool {
|
||||
EmptySubtreeRoots::empty_hashes(MAX_DEPTH).contains(node)
|
||||
}
|
||||
@@ -1,162 +0,0 @@
|
||||
use super::{
|
||||
get_common_prefix_tier_depth, get_key_prefix, hash_bottom_leaf, hash_upper_leaf,
|
||||
EmptySubtreeRoots, LeafNodeIndex, MerklePath, RpoDigest, TieredSmtProofError, Vec, Word,
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Maximum node depth. This is also the bottom tier of the tree.
|
||||
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
|
||||
|
||||
/// Value of an empty leaf.
|
||||
pub const EMPTY_VALUE: Word = super::TieredSmt::EMPTY_VALUE;
|
||||
|
||||
/// Depths at which leaves can exist in a tiered SMT.
|
||||
pub const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
|
||||
|
||||
// TIERED SPARSE MERKLE TREE PROOF
|
||||
// ================================================================================================
|
||||
|
||||
/// A proof which can be used to assert membership (or non-membership) of key-value pairs in a
|
||||
/// Tiered Sparse Merkle tree.
|
||||
///
|
||||
/// The proof consists of a Merkle path and one or more key-value entries which describe the node
|
||||
/// located at the base of the path. If the node at the base of the path resolves to [ZERO; 4],
|
||||
/// the entries will contain a single item with value set to [ZERO; 4].
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
pub struct TieredSmtProof {
|
||||
path: MerklePath,
|
||||
entries: Vec<(RpoDigest, Word)>,
|
||||
}
|
||||
|
||||
impl TieredSmtProof {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns a new instance of [TieredSmtProof] instantiated from the specified path and entries.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if:
|
||||
/// - The length of the path is greater than 64.
|
||||
/// - Entries is an empty vector.
|
||||
/// - Entries contains more than 1 item, but the length of the path is not 64.
|
||||
/// - Entries contains more than 1 item, and one of the items has value set to [ZERO; 4].
|
||||
/// - Entries contains multiple items with keys which don't share the same 64-bit prefix.
|
||||
pub fn new<I>(path: MerklePath, entries: I) -> Result<Self, TieredSmtProofError>
|
||||
where
|
||||
I: IntoIterator<Item = (RpoDigest, Word)>,
|
||||
{
|
||||
let entries: Vec<(RpoDigest, Word)> = entries.into_iter().collect();
|
||||
|
||||
if !TIER_DEPTHS.into_iter().any(|e| e == path.depth()) {
|
||||
return Err(TieredSmtProofError::NotATierPath(path.depth()));
|
||||
}
|
||||
|
||||
if entries.is_empty() {
|
||||
return Err(TieredSmtProofError::EntriesEmpty);
|
||||
}
|
||||
|
||||
if entries.len() > 1 {
|
||||
if path.depth() != MAX_DEPTH {
|
||||
return Err(TieredSmtProofError::MultipleEntriesOutsideLastTier);
|
||||
}
|
||||
|
||||
let prefix = get_key_prefix(&entries[0].0);
|
||||
for entry in entries.iter().skip(1) {
|
||||
if entry.1 == EMPTY_VALUE {
|
||||
return Err(TieredSmtProofError::EmptyValueNotAllowed);
|
||||
}
|
||||
let current = get_key_prefix(&entry.0);
|
||||
if prefix != current {
|
||||
return Err(TieredSmtProofError::MismatchedPrefixes(prefix, current));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self { path, entries })
|
||||
}
|
||||
|
||||
// PROOF VERIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if a Tiered Sparse Merkle tree with the specified root contains the provided
|
||||
/// key-value pair.
|
||||
///
|
||||
/// Note: this method cannot be used to assert non-membership. That is, if false is returned,
|
||||
/// it does not mean that the provided key-value pair is not in the tree.
|
||||
pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool {
|
||||
if self.is_value_empty() {
|
||||
if value != &EMPTY_VALUE {
|
||||
return false;
|
||||
}
|
||||
// if the proof is for an empty value, we can verify it against any key which has a
|
||||
// common prefix with the key storied in entries, but the prefix must be greater than
|
||||
// the path length
|
||||
let common_prefix_tier = get_common_prefix_tier_depth(key, &self.entries[0].0);
|
||||
if common_prefix_tier < self.path.depth() {
|
||||
return false;
|
||||
}
|
||||
} else if !self.entries.contains(&(*key, *value)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// make sure the Merkle path resolves to the correct root
|
||||
root == &self.compute_root()
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the value associated with the specific key according to this proof, or None if
|
||||
/// this proof does not contain a value for the specified key.
|
||||
///
|
||||
/// A key-value pair generated by using this method should pass the `verify_membership()` check.
|
||||
pub fn get(&self, key: &RpoDigest) -> Option<Word> {
|
||||
if self.is_value_empty() {
|
||||
let common_prefix_tier = get_common_prefix_tier_depth(key, &self.entries[0].0);
|
||||
if common_prefix_tier < self.path.depth() {
|
||||
None
|
||||
} else {
|
||||
Some(EMPTY_VALUE)
|
||||
}
|
||||
} else {
|
||||
self.entries.iter().find(|(k, _)| k == key).map(|(_, value)| *value)
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the root of a Tiered Sparse Merkle tree to which this proof resolve.
|
||||
pub fn compute_root(&self) -> RpoDigest {
|
||||
let node = self.build_node();
|
||||
let index = LeafNodeIndex::from_key(&self.entries[0].0, self.path.depth());
|
||||
self.path
|
||||
.compute_root(index.value(), node)
|
||||
.expect("failed to compute Merkle path root")
|
||||
}
|
||||
|
||||
/// Consume the proof and returns its parts.
|
||||
pub fn into_parts(self) -> (MerklePath, Vec<(RpoDigest, Word)>) {
|
||||
(self.path, self.entries)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if the proof is for an empty value.
|
||||
fn is_value_empty(&self) -> bool {
|
||||
self.entries[0].1 == EMPTY_VALUE
|
||||
}
|
||||
|
||||
/// Converts the entries contained in this proof into a node value for node at the base of the
|
||||
/// path contained in this proof.
|
||||
fn build_node(&self) -> RpoDigest {
|
||||
let depth = self.path.depth();
|
||||
if self.is_value_empty() {
|
||||
EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[depth as usize]
|
||||
} else if depth == MAX_DEPTH {
|
||||
hash_bottom_leaf(&self.entries)
|
||||
} else {
|
||||
let (key, value) = self.entries[0];
|
||||
hash_upper_leaf(key, value, depth)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,936 +0,0 @@
|
||||
use super::{
|
||||
super::{super::ONE, super::WORD_SIZE, Felt, MerkleStore, EMPTY_WORD, ZERO},
|
||||
EmptySubtreeRoots, InnerNodeInfo, NodeIndex, Rpo256, RpoDigest, TieredSmt, Vec, Word,
|
||||
};
|
||||
|
||||
// INSERTION TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_insert_one() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let value = [ONE; WORD_SIZE];
|
||||
|
||||
// since the tree is empty, the first node will be inserted at depth 16 and the index will be
|
||||
// 16 most significant bits of the key
|
||||
let index = NodeIndex::make(16, raw >> 48);
|
||||
let leaf_node = build_leaf_node(key, value, 16);
|
||||
let tree_root = store.set_node(smt.root(), index, leaf_node).unwrap().root;
|
||||
|
||||
smt.insert(key, value);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
// make sure the value was inserted, and the node is at the expected index
|
||||
assert_eq!(smt.get_value(key), value);
|
||||
assert_eq!(smt.get_node(index).unwrap(), leaf_node);
|
||||
|
||||
// make sure the paths we get from the store and the tree match
|
||||
let expected_path = store.get_path(tree_root, index).unwrap();
|
||||
assert_eq!(smt.get_path(index).unwrap(), expected_path.path);
|
||||
|
||||
// make sure inner nodes match
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
assert_eq!(actual_nodes.len(), expected_nodes.len());
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
|
||||
// make sure leaves are returned correctly
|
||||
let mut leaves = smt.upper_leaves();
|
||||
assert_eq!(leaves.next(), Some((leaf_node, key, value)));
|
||||
assert_eq!(leaves.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_insert_two_16() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// --- insert the first value ---------------------------------------------
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let val_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key_a, val_a);
|
||||
|
||||
// --- insert the second value --------------------------------------------
|
||||
// the key for this value has the same 16-bit prefix as the key for the first value,
|
||||
// thus, on insertions, both values should be pushed to depth 32 tier
|
||||
let raw_b = 0b_10101010_10101010_10011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let val_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key_b, val_b);
|
||||
|
||||
// --- build Merkle store with equivalent data ----------------------------
|
||||
let mut tree_root = get_init_root();
|
||||
let index_a = NodeIndex::make(32, raw_a >> 32);
|
||||
let leaf_node_a = build_leaf_node(key_a, val_a, 32);
|
||||
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
|
||||
|
||||
let index_b = NodeIndex::make(32, raw_b >> 32);
|
||||
let leaf_node_b = build_leaf_node(key_b, val_b, 32);
|
||||
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
|
||||
|
||||
// --- verify that data is consistent between store and tree --------------
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key_a), val_a);
|
||||
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
|
||||
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
|
||||
|
||||
assert_eq!(smt.get_value(key_b), val_b);
|
||||
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
|
||||
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
|
||||
// make sure leaves are returned correctly
|
||||
let mut leaves = smt.upper_leaves();
|
||||
assert_eq!(leaves.next(), Some((leaf_node_a, key_a, val_a)));
|
||||
assert_eq!(leaves.next(), Some((leaf_node_b, key_b, val_b)));
|
||||
assert_eq!(leaves.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_insert_two_32() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// --- insert the first value ---------------------------------------------
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let val_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key_a, val_a);
|
||||
|
||||
// --- insert the second value --------------------------------------------
|
||||
// the key for this value has the same 32-bit prefix as the key for the first value,
|
||||
// thus, on insertions, both values should be pushed to depth 48 tier
|
||||
let raw_b = 0b_10101010_10101010_00011111_11111111_00010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let val_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key_b, val_b);
|
||||
|
||||
// --- build Merkle store with equivalent data ----------------------------
|
||||
let mut tree_root = get_init_root();
|
||||
let index_a = NodeIndex::make(48, raw_a >> 16);
|
||||
let leaf_node_a = build_leaf_node(key_a, val_a, 48);
|
||||
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
|
||||
|
||||
let index_b = NodeIndex::make(48, raw_b >> 16);
|
||||
let leaf_node_b = build_leaf_node(key_b, val_b, 48);
|
||||
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
|
||||
|
||||
// --- verify that data is consistent between store and tree --------------
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key_a), val_a);
|
||||
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
|
||||
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
|
||||
|
||||
assert_eq!(smt.get_value(key_b), val_b);
|
||||
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
|
||||
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_insert_three() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// --- insert the first value ---------------------------------------------
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let val_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key_a, val_a);
|
||||
|
||||
// --- insert the second value --------------------------------------------
|
||||
// the key for this value has the same 16-bit prefix as the key for the first value,
|
||||
// thus, on insertions, both values should be pushed to depth 32 tier
|
||||
let raw_b = 0b_10101010_10101010_10011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let val_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key_b, val_b);
|
||||
|
||||
// --- insert the third value ---------------------------------------------
|
||||
// the key for this value has the same 16-bit prefix as the keys for the first two,
|
||||
// values; thus, on insertions, it will be inserted into depth 32 tier, but will not
|
||||
// affect locations of the other two values
|
||||
let raw_c = 0b_10101010_10101010_11011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let val_c = [Felt::new(3); WORD_SIZE];
|
||||
smt.insert(key_c, val_c);
|
||||
|
||||
// --- build Merkle store with equivalent data ----------------------------
|
||||
let mut tree_root = get_init_root();
|
||||
let index_a = NodeIndex::make(32, raw_a >> 32);
|
||||
let leaf_node_a = build_leaf_node(key_a, val_a, 32);
|
||||
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
|
||||
|
||||
let index_b = NodeIndex::make(32, raw_b >> 32);
|
||||
let leaf_node_b = build_leaf_node(key_b, val_b, 32);
|
||||
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
|
||||
|
||||
let index_c = NodeIndex::make(32, raw_c >> 32);
|
||||
let leaf_node_c = build_leaf_node(key_c, val_c, 32);
|
||||
tree_root = store.set_node(tree_root, index_c, leaf_node_c).unwrap().root;
|
||||
|
||||
// --- verify that data is consistent between store and tree --------------
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key_a), val_a);
|
||||
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
|
||||
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
|
||||
|
||||
assert_eq!(smt.get_value(key_b), val_b);
|
||||
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
|
||||
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
|
||||
|
||||
assert_eq!(smt.get_value(key_c), val_c);
|
||||
assert_eq!(smt.get_node(index_c).unwrap(), leaf_node_c);
|
||||
let expected_path = store.get_path(tree_root, index_c).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_c).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
}
|
||||
|
||||
// UPDATE TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_update() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let value_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key, value_a);
|
||||
|
||||
// --- update the value ---------------------------------------------------
|
||||
let value_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key, value_b);
|
||||
|
||||
// --- verify consistency -------------------------------------------------
|
||||
let mut tree_root = get_init_root();
|
||||
let index = NodeIndex::make(16, raw >> 48);
|
||||
let leaf_node = build_leaf_node(key, value_b, 16);
|
||||
tree_root = store.set_node(tree_root, index, leaf_node).unwrap().root;
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key), value_b);
|
||||
assert_eq!(smt.get_node(index).unwrap(), leaf_node);
|
||||
let expected_path = store.get_path(tree_root, index).unwrap().path;
|
||||
assert_eq!(smt.get_path(index).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
}
|
||||
|
||||
// DELETION TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_16() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let smt0 = smt.clone();
|
||||
let raw_a = 0b_01010101_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert another value into the tree ---------------------------------
|
||||
let smt1 = smt.clone();
|
||||
let raw_b = 0b_01011111_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
|
||||
assert_eq!(smt, smt1);
|
||||
|
||||
// --- delete the first inserted value ------------------------------------
|
||||
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
|
||||
assert_eq!(smt, smt0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_32() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let smt0 = smt.clone();
|
||||
let raw_a = 0b_01010101_01101100_01111111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert another with the same 16-bit prefix into the tree -----------
|
||||
let smt1 = smt.clone();
|
||||
let raw_b = 0b_01010101_01101100_00111111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// --- insert the 3rd value with the same 16-bit prefix into the tree -----
|
||||
let smt2 = smt.clone();
|
||||
let raw_c = 0b_01010101_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
|
||||
assert_eq!(smt, smt2);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
|
||||
assert_eq!(smt, smt1);
|
||||
|
||||
// --- delete the first inserted value ------------------------------------
|
||||
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
|
||||
assert_eq!(smt, smt0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_48_same_32_bit_prefix() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// test the case when all values share the same 32-bit prefix
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let smt0 = smt.clone();
|
||||
let raw_a = 0b_01010101_01010101_11111111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert another with the same 32-bit prefix into the tree -----------
|
||||
let smt1 = smt.clone();
|
||||
let raw_b = 0b_01010101_01010101_11111111_11111111_11010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// --- insert the 3rd value with the same 32-bit prefix into the tree -----
|
||||
let smt2 = smt.clone();
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_11110110_10010011_11100000_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
|
||||
assert_eq!(smt, smt2);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
|
||||
assert_eq!(smt, smt1);
|
||||
|
||||
// --- delete the first inserted value ------------------------------------
|
||||
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
|
||||
assert_eq!(smt, smt0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_48_mixed_prefix() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// test the case when some values share a 32-bit prefix and others share a 16-bit prefix
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let smt0 = smt.clone();
|
||||
let raw_a = 0b_01010101_01010101_11111111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert another with the same 16-bit prefix into the tree -----------
|
||||
let smt1 = smt.clone();
|
||||
let raw_b = 0b_01010101_01010101_01111111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// --- insert a value with the same 32-bit prefix as the first value -----
|
||||
let smt2 = smt.clone();
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_11010110_10010011_11100000_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// --- insert another value with the same 32-bit prefix as the first value
|
||||
let smt3 = smt.clone();
|
||||
let raw_d = 0b_01010101_01010101_11111111_11111111_11110110_10010011_11100000_00000000_u64;
|
||||
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_d)]);
|
||||
let value_d = [ONE, ZERO, ZERO, ZERO];
|
||||
smt.insert(key_d, value_d);
|
||||
|
||||
// --- delete the inserted values one-by-one ------------------------------
|
||||
assert_eq!(smt.insert(key_d, EMPTY_WORD), value_d);
|
||||
assert_eq!(smt, smt3);
|
||||
|
||||
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
|
||||
assert_eq!(smt, smt2);
|
||||
|
||||
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
|
||||
assert_eq!(smt, smt1);
|
||||
|
||||
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
|
||||
assert_eq!(smt, smt0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_64() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// test the case when all values share the same 48-bit prefix
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let smt0 = smt.clone();
|
||||
let raw_a = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert a value with the same 48-bit prefix into the tree -----------
|
||||
let smt1 = smt.clone();
|
||||
let raw_b = 0b_01010101_01010101_11111111_11111111_10110101_10101010_10111100_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// --- insert a value with the same 32-bit prefix into the tree -----------
|
||||
let smt2 = smt.clone();
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_11111101_10101010_10111100_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
let smt3 = smt.clone();
|
||||
let raw_d = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
|
||||
let key_d = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_d)]);
|
||||
let value_d = [ONE, ZERO, ZERO, ZERO];
|
||||
smt.insert(key_d, value_d);
|
||||
|
||||
// --- delete the last inserted value -------------------------------------
|
||||
assert_eq!(smt.insert(key_d, EMPTY_WORD), value_d);
|
||||
assert_eq!(smt, smt3);
|
||||
|
||||
assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c);
|
||||
assert_eq!(smt, smt2);
|
||||
|
||||
assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b);
|
||||
assert_eq!(smt, smt1);
|
||||
|
||||
assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a);
|
||||
assert_eq!(smt, smt0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_delete_64_leaf_promotion() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// --- delete from bottom tier (no promotion to upper tiers) --------------
|
||||
|
||||
// insert a value into the tree
|
||||
let raw_a = 0b_01010101_01010101_11111111_11111111_10101010_10101010_11111111_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// insert another value with a key having the same 64-bit prefix
|
||||
let key_b = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// insert a value with a key which shared the same 48-bit prefix
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_10101010_10101010_00111111_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// delete entry A and compare to the tree which was built from B and C
|
||||
smt.insert(key_a, EMPTY_WORD);
|
||||
|
||||
let mut expected_smt = TieredSmt::default();
|
||||
expected_smt.insert(key_b, value_b);
|
||||
expected_smt.insert(key_c, value_c);
|
||||
assert_eq!(smt, expected_smt);
|
||||
|
||||
// entries B and C should stay at depth 64
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 64);
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 64);
|
||||
|
||||
// --- delete from bottom tier (promotion to depth 48) --------------------
|
||||
|
||||
let mut smt = TieredSmt::default();
|
||||
smt.insert(key_a, value_a);
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// insert a value with a key which shared the same 32-bit prefix
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_11101010_10101010_11111111_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// delete entry A and compare to the tree which was built from B and C
|
||||
smt.insert(key_a, EMPTY_WORD);
|
||||
|
||||
let mut expected_smt = TieredSmt::default();
|
||||
expected_smt.insert(key_b, value_b);
|
||||
expected_smt.insert(key_c, value_c);
|
||||
assert_eq!(smt, expected_smt);
|
||||
|
||||
// entry B moves to depth 48, entry C stays at depth 48
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 48);
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 48);
|
||||
|
||||
// --- delete from bottom tier (promotion to depth 32) --------------------
|
||||
|
||||
let mut smt = TieredSmt::default();
|
||||
smt.insert(key_a, value_a);
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// insert a value with a key which shared the same 16-bit prefix
|
||||
let raw_c = 0b_01010101_01010101_01111111_11111111_10101010_10101010_11111111_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// delete entry A and compare to the tree which was built from B and C
|
||||
smt.insert(key_a, EMPTY_WORD);
|
||||
|
||||
let mut expected_smt = TieredSmt::default();
|
||||
expected_smt.insert(key_b, value_b);
|
||||
expected_smt.insert(key_c, value_c);
|
||||
assert_eq!(smt, expected_smt);
|
||||
|
||||
// entry B moves to depth 32, entry C stays at depth 32
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 32);
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 32);
|
||||
|
||||
// --- delete from bottom tier (promotion to depth 16) --------------------
|
||||
|
||||
let mut smt = TieredSmt::default();
|
||||
smt.insert(key_a, value_a);
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
// insert a value with a key which shared prefix < 16 bits
|
||||
let raw_c = 0b_01010101_01010100_11111111_11111111_10101010_10101010_11111111_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// delete entry A and compare to the tree which was built from B and C
|
||||
smt.insert(key_a, EMPTY_WORD);
|
||||
|
||||
let mut expected_smt = TieredSmt::default();
|
||||
expected_smt.insert(key_b, value_b);
|
||||
expected_smt.insert(key_c, value_c);
|
||||
assert_eq!(smt, expected_smt);
|
||||
|
||||
// entry B moves to depth 16, entry C stays at depth 16
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 16);
|
||||
assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_order_sensitivity() {
|
||||
let raw = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000001_u64;
|
||||
let value = [ONE; WORD_SIZE];
|
||||
|
||||
let key_1 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let key_2 = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw)]);
|
||||
|
||||
let mut smt_1 = TieredSmt::default();
|
||||
|
||||
smt_1.insert(key_1, value);
|
||||
smt_1.insert(key_2, value);
|
||||
smt_1.insert(key_2, EMPTY_WORD);
|
||||
|
||||
let mut smt_2 = TieredSmt::default();
|
||||
smt_2.insert(key_1, value);
|
||||
|
||||
assert_eq!(smt_1.root(), smt_2.root());
|
||||
}
|
||||
|
||||
// BOTTOM TIER TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_bottom_tier() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// common prefix for the keys
|
||||
let prefix = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
// --- insert the first value ---------------------------------------------
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(prefix)]);
|
||||
let val_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key_a, val_a);
|
||||
|
||||
// --- insert the second value --------------------------------------------
|
||||
// this key has the same 64-bit prefix and thus both values should end up in the same
|
||||
// node at depth 64
|
||||
let key_b = RpoDigest::from([ZERO, ONE, ONE, Felt::new(prefix)]);
|
||||
let val_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key_b, val_b);
|
||||
|
||||
// --- build Merkle store with equivalent data ----------------------------
|
||||
let index = NodeIndex::make(64, prefix);
|
||||
// to build bottom leaf we sort by key starting with the least significant element, thus
|
||||
// key_b is smaller than key_a.
|
||||
let leaf_node = build_bottom_leaf_node(&[key_b, key_a], &[val_b, val_a]);
|
||||
let mut tree_root = get_init_root();
|
||||
tree_root = store.set_node(tree_root, index, leaf_node).unwrap().root;
|
||||
|
||||
// --- verify that data is consistent between store and tree --------------
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key_a), val_a);
|
||||
assert_eq!(smt.get_value(key_b), val_b);
|
||||
|
||||
assert_eq!(smt.get_node(index).unwrap(), leaf_node);
|
||||
let expected_path = store.get_path(tree_root, index).unwrap().path;
|
||||
assert_eq!(smt.get_path(index).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
|
||||
// make sure leaves are returned correctly
|
||||
let smt_clone = smt.clone();
|
||||
let mut leaves = smt_clone.bottom_leaves();
|
||||
assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a)])));
|
||||
assert_eq!(leaves.next(), None);
|
||||
|
||||
// --- update a leaf at the bottom tier -------------------------------------------------------
|
||||
|
||||
let val_a2 = [Felt::new(3); WORD_SIZE];
|
||||
assert_eq!(smt.insert(key_a, val_a2), val_a);
|
||||
|
||||
let leaf_node = build_bottom_leaf_node(&[key_b, key_a], &[val_b, val_a2]);
|
||||
store.set_node(tree_root, index, leaf_node).unwrap();
|
||||
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
|
||||
let mut leaves = smt.bottom_leaves();
|
||||
assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a2)])));
|
||||
assert_eq!(leaves.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tsmt_bottom_tier_two() {
|
||||
let mut smt = TieredSmt::default();
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
// --- insert the first value ---------------------------------------------
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let val_a = [ONE; WORD_SIZE];
|
||||
smt.insert(key_a, val_a);
|
||||
|
||||
// --- insert the second value --------------------------------------------
|
||||
// the key for this value has the same 48-bit prefix as the key for the first value,
|
||||
// thus, on insertions, both should end up in different nodes at depth 64
|
||||
let raw_b = 0b_10101010_10101010_00011111_11111111_10010110_10010011_01100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let val_b = [Felt::new(2); WORD_SIZE];
|
||||
smt.insert(key_b, val_b);
|
||||
|
||||
// --- build Merkle store with equivalent data ----------------------------
|
||||
let mut tree_root = get_init_root();
|
||||
let index_a = NodeIndex::make(64, raw_a);
|
||||
let leaf_node_a = build_bottom_leaf_node(&[key_a], &[val_a]);
|
||||
tree_root = store.set_node(tree_root, index_a, leaf_node_a).unwrap().root;
|
||||
|
||||
let index_b = NodeIndex::make(64, raw_b);
|
||||
let leaf_node_b = build_bottom_leaf_node(&[key_b], &[val_b]);
|
||||
tree_root = store.set_node(tree_root, index_b, leaf_node_b).unwrap().root;
|
||||
|
||||
// --- verify that data is consistent between store and tree --------------
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
assert_eq!(smt.get_value(key_a), val_a);
|
||||
assert_eq!(smt.get_node(index_a).unwrap(), leaf_node_a);
|
||||
let expected_path = store.get_path(tree_root, index_a).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_a).unwrap(), expected_path);
|
||||
|
||||
assert_eq!(smt.get_value(key_b), val_b);
|
||||
assert_eq!(smt.get_node(index_b).unwrap(), leaf_node_b);
|
||||
let expected_path = store.get_path(tree_root, index_b).unwrap().path;
|
||||
assert_eq!(smt.get_path(index_b).unwrap(), expected_path);
|
||||
|
||||
// make sure inner nodes match - the store contains more entries because it keeps track of
|
||||
// all prior state - so, we don't check that the number of inner nodes is the same in both
|
||||
let expected_nodes = get_non_empty_nodes(&store);
|
||||
let actual_nodes = smt.inner_nodes().collect::<Vec<_>>();
|
||||
actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node)));
|
||||
|
||||
// make sure leaves are returned correctly
|
||||
let mut leaves = smt.bottom_leaves();
|
||||
assert_eq!(leaves.next(), Some((leaf_node_b, vec![(key_b, val_b)])));
|
||||
assert_eq!(leaves.next(), Some((leaf_node_a, vec![(key_a, val_a)])));
|
||||
assert_eq!(leaves.next(), None);
|
||||
}
|
||||
|
||||
// GET PROOF TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_get_proof() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
// --- insert a value into the tree ---------------------------------------
|
||||
let raw_a = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
|
||||
let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE, ONE, ONE, ONE];
|
||||
smt.insert(key_a, value_a);
|
||||
|
||||
// --- insert a value with the same 48-bit prefix into the tree -----------
|
||||
let raw_b = 0b_01010101_01010101_11111111_11111111_10110101_10101010_10111100_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ONE, ONE, ZERO];
|
||||
smt.insert(key_b, value_b);
|
||||
|
||||
let smt_alt = smt.clone();
|
||||
|
||||
// --- insert a value with the same 32-bit prefix into the tree -----------
|
||||
let raw_c = 0b_01010101_01010101_11111111_11111111_11111101_10101010_10111100_00000000_u64;
|
||||
let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
smt.insert(key_c, value_c);
|
||||
|
||||
// --- insert a value with the same 64-bit prefix as A into the tree ------
|
||||
let raw_d = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64;
|
||||
let key_d = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_d)]);
|
||||
let value_d = [ONE, ZERO, ZERO, ZERO];
|
||||
smt.insert(key_d, value_d);
|
||||
|
||||
// at this point the tree looks as follows:
|
||||
// - A and D are located in the same node at depth 64.
|
||||
// - B is located at depth 64 and shares the same 48-bit prefix with A and D.
|
||||
// - C is located at depth 48 and shares the same 32-bit prefix with A, B, and D.
|
||||
|
||||
// --- generate proof for key A and test that it verifies correctly -------
|
||||
let proof = smt.prove(key_a);
|
||||
assert!(proof.verify_membership(&key_a, &value_a, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key_a, &value_b, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_a, &EMPTY_WORD, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_b, &value_a, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_a, &value_a, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key_a), Some(value_a));
|
||||
assert_eq!(proof.get(&key_b), None);
|
||||
|
||||
// since A and D are stored in the same node, we should be able to use the proof to verify
|
||||
// membership of D
|
||||
assert!(proof.verify_membership(&key_d, &value_d, &smt.root()));
|
||||
assert_eq!(proof.get(&key_d), Some(value_d));
|
||||
|
||||
// --- generate proof for key B and test that it verifies correctly -------
|
||||
let proof = smt.prove(key_b);
|
||||
assert!(proof.verify_membership(&key_b, &value_b, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key_b, &value_a, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_b, &EMPTY_WORD, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_a, &value_b, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_b, &value_b, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key_b), Some(value_b));
|
||||
assert_eq!(proof.get(&key_a), None);
|
||||
|
||||
// --- generate proof for key C and test that it verifies correctly -------
|
||||
let proof = smt.prove(key_c);
|
||||
assert!(proof.verify_membership(&key_c, &value_c, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key_c, &value_a, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_c, &EMPTY_WORD, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_a, &value_c, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_c, &value_c, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key_c), Some(value_c));
|
||||
assert_eq!(proof.get(&key_b), None);
|
||||
|
||||
// --- generate proof for key D and test that it verifies correctly -------
|
||||
let proof = smt.prove(key_d);
|
||||
assert!(proof.verify_membership(&key_d, &value_d, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key_d, &value_b, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_d, &EMPTY_WORD, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_b, &value_d, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key_d, &value_d, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key_d), Some(value_d));
|
||||
assert_eq!(proof.get(&key_b), None);
|
||||
|
||||
// since A and D are stored in the same node, we should be able to use the proof to verify
|
||||
// membership of A
|
||||
assert!(proof.verify_membership(&key_a, &value_a, &smt.root()));
|
||||
assert_eq!(proof.get(&key_a), Some(value_a));
|
||||
|
||||
// --- generate proof for an empty key at depth 64 ------------------------
|
||||
// this key has the same 48-bit prefix as A but is different from B
|
||||
let raw = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000011_u64;
|
||||
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
|
||||
let proof = smt.prove(key);
|
||||
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key, &value_a, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key, &EMPTY_WORD, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key), Some(EMPTY_WORD));
|
||||
assert_eq!(proof.get(&key_b), None);
|
||||
|
||||
// the same proof should verify against any key with the same 64-bit prefix
|
||||
let key2 = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw)]);
|
||||
assert!(proof.verify_membership(&key2, &EMPTY_WORD, &smt.root()));
|
||||
assert_eq!(proof.get(&key2), Some(EMPTY_WORD));
|
||||
|
||||
// but verifying if against a key with the same 63-bit prefix (or smaller) should fail
|
||||
let raw3 = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000010_u64;
|
||||
let key3 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw3)]);
|
||||
assert!(!proof.verify_membership(&key3, &EMPTY_WORD, &smt.root()));
|
||||
assert_eq!(proof.get(&key3), None);
|
||||
|
||||
// --- generate proof for an empty key at depth 48 ------------------------
|
||||
// this key has the same 32-prefix as A, B, C, and D, but is different from C
|
||||
let raw = 0b_01010101_01010101_11111111_11111111_00110101_10101010_11111100_00000000_u64;
|
||||
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
|
||||
let proof = smt.prove(key);
|
||||
assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root()));
|
||||
|
||||
assert!(!proof.verify_membership(&key, &value_a, &smt.root()));
|
||||
assert!(!proof.verify_membership(&key, &EMPTY_WORD, &smt_alt.root()));
|
||||
|
||||
assert_eq!(proof.get(&key), Some(EMPTY_WORD));
|
||||
assert_eq!(proof.get(&key_b), None);
|
||||
|
||||
// the same proof should verify against any key with the same 48-bit prefix
|
||||
let raw2 = 0b_01010101_01010101_11111111_11111111_00110101_10101010_01111100_00000000_u64;
|
||||
let key2 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw2)]);
|
||||
assert!(proof.verify_membership(&key2, &EMPTY_WORD, &smt.root()));
|
||||
assert_eq!(proof.get(&key2), Some(EMPTY_WORD));
|
||||
|
||||
// but verifying against a key with the same 47-bit prefix (or smaller) should fail
|
||||
let raw3 = 0b_01010101_01010101_11111111_11111111_00110101_10101011_11111100_00000000_u64;
|
||||
let key3 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw3)]);
|
||||
assert!(!proof.verify_membership(&key3, &EMPTY_WORD, &smt.root()));
|
||||
assert_eq!(proof.get(&key3), None);
|
||||
}
|
||||
|
||||
// ERROR TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn tsmt_node_not_available() {
|
||||
let mut smt = TieredSmt::default();
|
||||
|
||||
let raw = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let value = [ONE; WORD_SIZE];
|
||||
|
||||
// build an index which is just below the inserted leaf node
|
||||
let index = NodeIndex::make(17, raw >> 47);
|
||||
|
||||
// since we haven't inserted the node yet, we should be able to get node and path to this index
|
||||
assert!(smt.get_node(index).is_ok());
|
||||
assert!(smt.get_path(index).is_ok());
|
||||
|
||||
smt.insert(key, value);
|
||||
|
||||
// but once the node is inserted, everything under it should be unavailable
|
||||
assert!(smt.get_node(index).is_err());
|
||||
assert!(smt.get_path(index).is_err());
|
||||
|
||||
let index = NodeIndex::make(32, raw >> 32);
|
||||
assert!(smt.get_node(index).is_err());
|
||||
assert!(smt.get_path(index).is_err());
|
||||
|
||||
let index = NodeIndex::make(34, raw >> 30);
|
||||
assert!(smt.get_node(index).is_err());
|
||||
assert!(smt.get_path(index).is_err());
|
||||
|
||||
let index = NodeIndex::make(50, raw >> 14);
|
||||
assert!(smt.get_node(index).is_err());
|
||||
assert!(smt.get_path(index).is_err());
|
||||
|
||||
let index = NodeIndex::make(64, raw);
|
||||
assert!(smt.get_node(index).is_err());
|
||||
assert!(smt.get_path(index).is_err());
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
fn get_init_root() -> RpoDigest {
|
||||
EmptySubtreeRoots::empty_hashes(64)[0]
|
||||
}
|
||||
|
||||
fn build_leaf_node(key: RpoDigest, value: Word, depth: u8) -> RpoDigest {
|
||||
Rpo256::merge_in_domain(&[key, value.into()], depth.into())
|
||||
}
|
||||
|
||||
fn build_bottom_leaf_node(keys: &[RpoDigest], values: &[Word]) -> RpoDigest {
|
||||
assert_eq!(keys.len(), values.len());
|
||||
|
||||
let mut elements = Vec::with_capacity(keys.len());
|
||||
for (key, val) in keys.iter().zip(values.iter()) {
|
||||
elements.extend_from_slice(key.as_elements());
|
||||
elements.extend_from_slice(val.as_slice());
|
||||
}
|
||||
|
||||
Rpo256::hash_elements(&elements)
|
||||
}
|
||||
|
||||
fn get_non_empty_nodes(store: &MerkleStore) -> Vec<InnerNodeInfo> {
|
||||
store
|
||||
.inner_nodes()
|
||||
.filter(|node| !is_empty_subtree(&node.value))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn is_empty_subtree(node: &RpoDigest) -> bool {
|
||||
EmptySubtreeRoots::empty_hashes(255).contains(node)
|
||||
}
|
||||
@@ -1,584 +0,0 @@
|
||||
use super::{get_key_prefix, BTreeMap, LeafNodeIndex, RpoDigest, StarkField, Vec, Word};
|
||||
use crate::utils::vec;
|
||||
use core::{
|
||||
cmp::{Ord, Ordering},
|
||||
ops::RangeBounds,
|
||||
};
|
||||
use winter_utils::collections::btree_map::Entry;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Depths at which leaves can exist in a tiered SMT.
|
||||
const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS;
|
||||
|
||||
/// Maximum node depth. This is also the bottom tier of the tree.
|
||||
const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH;
|
||||
|
||||
// VALUE STORE
|
||||
// ================================================================================================
|
||||
/// A store for key-value pairs for a Tiered Sparse Merkle tree.
|
||||
///
|
||||
/// The store is organized in a [BTreeMap] where keys are 64 most significant bits of a key, and
|
||||
/// the values are the corresponding key-value pairs (or a list of key-value pairs if more that
|
||||
/// a single key-value pair shares the same 64-bit prefix).
|
||||
///
|
||||
/// The store supports lookup by the full key (i.e. [RpoDigest]) as well as by the 64-bit key
|
||||
/// prefix.
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct ValueStore {
|
||||
values: BTreeMap<u64, StoreEntry>,
|
||||
}
|
||||
|
||||
impl ValueStore {
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a reference to the value stored under the specified key, or None if there is no
|
||||
/// value associated with the specified key.
|
||||
pub fn get(&self, key: &RpoDigest) -> Option<&Word> {
|
||||
let prefix = get_key_prefix(key);
|
||||
self.values.get(&prefix).and_then(|entry| entry.get(key))
|
||||
}
|
||||
|
||||
/// Returns the first key-value pair such that the key prefix is greater than or equal to the
|
||||
/// specified prefix.
|
||||
pub fn get_first(&self, prefix: u64) -> Option<&(RpoDigest, Word)> {
|
||||
self.range(prefix..).next()
|
||||
}
|
||||
|
||||
/// Returns the first key-value pair such that the key prefix is greater than or equal to the
|
||||
/// specified prefix and the key value is not equal to the exclude_key value.
|
||||
pub fn get_first_filtered(
|
||||
&self,
|
||||
prefix: u64,
|
||||
exclude_key: &RpoDigest,
|
||||
) -> Option<&(RpoDigest, Word)> {
|
||||
self.range(prefix..).find(|(key, _)| key != exclude_key)
|
||||
}
|
||||
|
||||
/// Returns a vector with key-value pairs for all keys with the specified 64-bit prefix, or
|
||||
/// None if no keys with the specified prefix are present in this store.
|
||||
pub fn get_all(&self, prefix: u64) -> Option<Vec<(RpoDigest, Word)>> {
|
||||
self.values.get(&prefix).map(|entry| match entry {
|
||||
StoreEntry::Single(kv_pair) => vec![*kv_pair],
|
||||
StoreEntry::List(kv_pairs) => kv_pairs.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns information about a sibling of a leaf node with the specified index, but only if
|
||||
/// this is the only sibling the leaf has in some subtree starting at the first tier.
|
||||
///
|
||||
/// For example, if `index` is an index at depth 32, and there is a leaf node at depth 32 with
|
||||
/// the same root at depth 16 as `index`, we say that this leaf is a lone sibling.
|
||||
///
|
||||
/// The returned tuple contains: they key-value pair of the sibling as well as the index of
|
||||
/// the node for the root of the common subtree in which both nodes are leaves.
|
||||
///
|
||||
/// This method assumes that the key-value pair for the specified index has already been
|
||||
/// removed from the store.
|
||||
pub fn get_lone_sibling(
|
||||
&self,
|
||||
index: LeafNodeIndex,
|
||||
) -> Option<(&RpoDigest, &Word, LeafNodeIndex)> {
|
||||
// iterate over tiers from top to bottom, looking at the tiers which are strictly above
|
||||
// the depth of the index. This implies that only tiers at depth 32 and 48 will be
|
||||
// considered. For each tier, check if the parent of the index at the higher tier
|
||||
// contains a single node. The fist tier (depth 16) is excluded because we cannot move
|
||||
// nodes at depth 16 to a higher tier. This implies that nodes at the first tier will
|
||||
// never have "lone siblings".
|
||||
for &tier_depth in TIER_DEPTHS.iter().filter(|&t| index.depth() > *t) {
|
||||
// compute the index of the root at a higher tier
|
||||
let mut parent_index = index;
|
||||
parent_index.move_up_to(tier_depth);
|
||||
|
||||
// find the lone sibling, if any; we need to handle the "last node" at a given tier
|
||||
// separately specify the bounds for the search correctly.
|
||||
let start_prefix = parent_index.value() << (MAX_DEPTH - tier_depth);
|
||||
let sibling = if start_prefix.leading_ones() as u8 == tier_depth {
|
||||
let mut iter = self.range(start_prefix..);
|
||||
iter.next().filter(|_| iter.next().is_none())
|
||||
} else {
|
||||
let end_prefix = (parent_index.value() + 1) << (MAX_DEPTH - tier_depth);
|
||||
let mut iter = self.range(start_prefix..end_prefix);
|
||||
iter.next().filter(|_| iter.next().is_none())
|
||||
};
|
||||
|
||||
if let Some((key, value)) = sibling {
|
||||
return Some((key, value, parent_index));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns an iterator over all key-value pairs in this store.
|
||||
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
self.values.iter().flat_map(|(_, entry)| entry.iter())
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts the specified key-value pair into this store and returns the value previously
|
||||
/// associated with the specified key.
|
||||
///
|
||||
/// If no value was previously associated with the specified key, None is returned.
|
||||
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
||||
let prefix = get_key_prefix(&key);
|
||||
match self.values.entry(prefix) {
|
||||
Entry::Occupied(mut entry) => entry.get_mut().insert(key, value),
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(StoreEntry::new(key, value));
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes the key-value pair for the specified key from this store and returns the value
|
||||
/// associated with this key.
|
||||
///
|
||||
/// If no value was associated with the specified key, None is returned.
|
||||
pub fn remove(&mut self, key: &RpoDigest) -> Option<Word> {
|
||||
let prefix = get_key_prefix(key);
|
||||
match self.values.entry(prefix) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
let (value, remove_entry) = entry.get_mut().remove(key);
|
||||
if remove_entry {
|
||||
entry.remove_entry();
|
||||
}
|
||||
value
|
||||
}
|
||||
Entry::Vacant(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over all key-value pairs contained in this store such that the most
|
||||
/// significant 64 bits of the key lay within the specified bounds.
|
||||
///
|
||||
/// The order of iteration is from the smallest to the largest key.
|
||||
fn range<R: RangeBounds<u64>>(&self, bounds: R) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
self.values.range(bounds).flat_map(|(_, entry)| entry.iter())
|
||||
}
|
||||
}
|
||||
|
||||
// VALUE NODE
|
||||
// ================================================================================================
|
||||
|
||||
/// An entry in the [ValueStore].
|
||||
///
|
||||
/// An entry can contain either a single key-value pair or a vector of key-value pairs sorted by
|
||||
/// key.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub enum StoreEntry {
|
||||
Single((RpoDigest, Word)),
|
||||
List(Vec<(RpoDigest, Word)>),
|
||||
}
|
||||
|
||||
impl StoreEntry {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns a new [StoreEntry] instantiated with a single key-value pair.
|
||||
pub fn new(key: RpoDigest, value: Word) -> Self {
|
||||
Self::Single((key, value))
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the value associated with the specified key, or None if this entry does not contain
|
||||
/// a value associated with the specified key.
|
||||
pub fn get(&self, key: &RpoDigest) -> Option<&Word> {
|
||||
match self {
|
||||
StoreEntry::Single(kv_pair) => {
|
||||
if kv_pair.0 == *key {
|
||||
Some(&kv_pair.1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
StoreEntry::List(kv_pairs) => {
|
||||
match kv_pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, key)) {
|
||||
Ok(pos) => Some(&kv_pairs[pos].1),
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an iterator over all key-value pairs in this entry.
|
||||
pub fn iter(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
EntryIterator { entry: self, pos: 0 }
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts the specified key-value pair into this entry and returns the value previously
|
||||
/// associated with the specified key, or None if no value was associated with the specified
|
||||
/// key.
|
||||
///
|
||||
/// If a new key is inserted, this will also transform a `SingleEntry` into a `ListEntry`.
|
||||
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
||||
match self {
|
||||
StoreEntry::Single(kv_pair) => {
|
||||
// if the key is already in this entry, update the value and return
|
||||
if kv_pair.0 == key {
|
||||
let old_value = kv_pair.1;
|
||||
kv_pair.1 = value;
|
||||
return Some(old_value);
|
||||
}
|
||||
|
||||
// transform the entry into a list entry, and make sure the key-value pairs
|
||||
// are sorted by key
|
||||
let mut pairs = vec![*kv_pair, (key, value)];
|
||||
pairs.sort_by(|a, b| cmp_digests(&a.0, &b.0));
|
||||
|
||||
*self = StoreEntry::List(pairs);
|
||||
None
|
||||
}
|
||||
StoreEntry::List(pairs) => {
|
||||
match pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, &key)) {
|
||||
Ok(pos) => {
|
||||
let old_value = pairs[pos].1;
|
||||
pairs[pos].1 = value;
|
||||
Some(old_value)
|
||||
}
|
||||
Err(pos) => {
|
||||
pairs.insert(pos, (key, value));
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes the key-value pair with the specified key from this entry, and returns the value
|
||||
/// of the removed pair. If the entry did not contain a key-value pair for the specified key,
|
||||
/// None is returned.
|
||||
///
|
||||
/// If the last last key-value pair was removed from the entry, the second tuple value will
|
||||
/// be set to true.
|
||||
pub fn remove(&mut self, key: &RpoDigest) -> (Option<Word>, bool) {
|
||||
match self {
|
||||
StoreEntry::Single(kv_pair) => {
|
||||
if kv_pair.0 == *key {
|
||||
(Some(kv_pair.1), true)
|
||||
} else {
|
||||
(None, false)
|
||||
}
|
||||
}
|
||||
StoreEntry::List(kv_pairs) => {
|
||||
match kv_pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, key)) {
|
||||
Ok(pos) => {
|
||||
let kv_pair = kv_pairs.remove(pos);
|
||||
if kv_pairs.len() == 1 {
|
||||
*self = StoreEntry::Single(kv_pairs[0]);
|
||||
}
|
||||
(Some(kv_pair.1), false)
|
||||
}
|
||||
Err(_) => (None, false),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A custom iterator over key-value pairs of a [StoreEntry].
|
||||
///
|
||||
/// For a `SingleEntry` this returns only one value, but for `ListEntry`, this iterates over the
|
||||
/// entire list of key-value pairs.
|
||||
pub struct EntryIterator<'a> {
|
||||
entry: &'a StoreEntry,
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for EntryIterator<'a> {
|
||||
type Item = &'a (RpoDigest, Word);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self.entry {
|
||||
StoreEntry::Single(kv_pair) => {
|
||||
if self.pos == 0 {
|
||||
self.pos = 1;
|
||||
Some(kv_pair)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
StoreEntry::List(kv_pairs) => {
|
||||
if self.pos >= kv_pairs.len() {
|
||||
None
|
||||
} else {
|
||||
let kv_pair = &kv_pairs[self.pos];
|
||||
self.pos += 1;
|
||||
Some(kv_pair)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Compares two digests element-by-element using their integer representations starting with the
|
||||
/// most significant element.
|
||||
fn cmp_digests(d1: &RpoDigest, d2: &RpoDigest) -> Ordering {
|
||||
let d1 = Word::from(d1);
|
||||
let d2 = Word::from(d2);
|
||||
|
||||
for (v1, v2) in d1.iter().zip(d2.iter()).rev() {
|
||||
let v1 = v1.as_int();
|
||||
let v2 = v2.as_int();
|
||||
if v1 != v2 {
|
||||
return v1.cmp(&v2);
|
||||
}
|
||||
}
|
||||
|
||||
Ordering::Equal
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{LeafNodeIndex, RpoDigest, StoreEntry, ValueStore};
|
||||
use crate::{Felt, ONE, WORD_SIZE, ZERO};
|
||||
|
||||
#[test]
|
||||
fn test_insert() {
|
||||
let mut store = ValueStore::default();
|
||||
|
||||
// insert the first key-value pair into the store
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE; WORD_SIZE];
|
||||
|
||||
assert!(store.insert(key_a, value_a).is_none());
|
||||
assert_eq!(store.values.len(), 1);
|
||||
|
||||
let entry = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry = StoreEntry::Single((key_a, value_a));
|
||||
assert_eq!(entry, &expected_entry);
|
||||
|
||||
// insert a key-value pair with a different key into the store; since the keys are
|
||||
// different, another entry is added to the values map
|
||||
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
|
||||
assert!(store.insert(key_b, value_b).is_none());
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 = StoreEntry::Single((key_a, value_a));
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
let expected_entry2 = StoreEntry::Single((key_b, value_b));
|
||||
assert_eq!(entry2, &expected_entry2);
|
||||
|
||||
// insert a key-value pair with the same 64-bit key prefix as the first key; this should
|
||||
// transform the first entry into a List entry
|
||||
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
|
||||
assert!(store.insert(key_c, value_c).is_none());
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 = StoreEntry::List(vec![(key_c, value_c), (key_a, value_a)]);
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
let expected_entry2 = StoreEntry::Single((key_b, value_b));
|
||||
assert_eq!(entry2, &expected_entry2);
|
||||
|
||||
// replace values for keys a and b
|
||||
let value_a2 = [ONE, ONE, ONE, ZERO];
|
||||
let value_b2 = [ZERO, ZERO, ZERO, ONE];
|
||||
|
||||
assert_eq!(store.insert(key_a, value_a2), Some(value_a));
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
assert_eq!(store.insert(key_b, value_b2), Some(value_b));
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 = StoreEntry::List(vec![(key_c, value_c), (key_a, value_a2)]);
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
let expected_entry2 = StoreEntry::Single((key_b, value_b2));
|
||||
assert_eq!(entry2, &expected_entry2);
|
||||
|
||||
// insert one more key-value pair with the same 64-bit key-prefix as the first key
|
||||
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_d = [ZERO, ONE, ZERO, ZERO];
|
||||
|
||||
assert!(store.insert(key_d, value_d).is_none());
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 =
|
||||
StoreEntry::List(vec![(key_c, value_c), (key_a, value_a2), (key_d, value_d)]);
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
let expected_entry2 = StoreEntry::Single((key_b, value_b2));
|
||||
assert_eq!(entry2, &expected_entry2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove() {
|
||||
// populate the value store
|
||||
let mut store = ValueStore::default();
|
||||
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE; WORD_SIZE];
|
||||
store.insert(key_a, value_a);
|
||||
|
||||
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
store.insert(key_b, value_b);
|
||||
|
||||
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
store.insert(key_c, value_c);
|
||||
|
||||
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_d = [ZERO, ONE, ZERO, ZERO];
|
||||
store.insert(key_d, value_d);
|
||||
|
||||
assert_eq!(store.values.len(), 2);
|
||||
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 =
|
||||
StoreEntry::List(vec![(key_c, value_c), (key_a, value_a), (key_d, value_d)]);
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
let entry2 = store.values.get(&raw_b).unwrap();
|
||||
let expected_entry2 = StoreEntry::Single((key_b, value_b));
|
||||
assert_eq!(entry2, &expected_entry2);
|
||||
|
||||
// remove non-existent keys
|
||||
let key_e = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_a)]);
|
||||
assert!(store.remove(&key_e).is_none());
|
||||
|
||||
let raw_f = 0b_11111110_11111111_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_f = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_f)]);
|
||||
assert!(store.remove(&key_f).is_none());
|
||||
|
||||
// remove keys from the list entry
|
||||
assert_eq!(store.remove(&key_c).unwrap(), value_c);
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 = StoreEntry::List(vec![(key_a, value_a), (key_d, value_d)]);
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
assert_eq!(store.remove(&key_a).unwrap(), value_a);
|
||||
let entry1 = store.values.get(&raw_a).unwrap();
|
||||
let expected_entry1 = StoreEntry::Single((key_d, value_d));
|
||||
assert_eq!(entry1, &expected_entry1);
|
||||
|
||||
assert_eq!(store.remove(&key_d).unwrap(), value_d);
|
||||
assert!(store.values.get(&raw_a).is_none());
|
||||
assert_eq!(store.values.len(), 1);
|
||||
|
||||
// remove a key from a single entry
|
||||
assert_eq!(store.remove(&key_b).unwrap(), value_b);
|
||||
assert!(store.values.get(&raw_b).is_none());
|
||||
assert_eq!(store.values.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_range() {
|
||||
// populate the value store
|
||||
let mut store = ValueStore::default();
|
||||
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE; WORD_SIZE];
|
||||
store.insert(key_a, value_a);
|
||||
|
||||
let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
store.insert(key_b, value_b);
|
||||
|
||||
let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]);
|
||||
let value_c = [ONE, ONE, ZERO, ZERO];
|
||||
store.insert(key_c, value_c);
|
||||
|
||||
let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_d = [ZERO, ONE, ZERO, ZERO];
|
||||
store.insert(key_d, value_d);
|
||||
|
||||
let raw_e = 0b_10101000_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_e = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_e)]);
|
||||
let value_e = [ZERO, ZERO, ZERO, ONE];
|
||||
store.insert(key_e, value_e);
|
||||
|
||||
// check the entire range
|
||||
let mut iter = store.range(..u64::MAX);
|
||||
assert_eq!(iter.next(), Some(&(key_e, value_e)));
|
||||
assert_eq!(iter.next(), Some(&(key_c, value_c)));
|
||||
assert_eq!(iter.next(), Some(&(key_a, value_a)));
|
||||
assert_eq!(iter.next(), Some(&(key_d, value_d)));
|
||||
assert_eq!(iter.next(), Some(&(key_b, value_b)));
|
||||
assert_eq!(iter.next(), None);
|
||||
|
||||
// check all but e
|
||||
let mut iter = store.range(raw_a..u64::MAX);
|
||||
assert_eq!(iter.next(), Some(&(key_c, value_c)));
|
||||
assert_eq!(iter.next(), Some(&(key_a, value_a)));
|
||||
assert_eq!(iter.next(), Some(&(key_d, value_d)));
|
||||
assert_eq!(iter.next(), Some(&(key_b, value_b)));
|
||||
assert_eq!(iter.next(), None);
|
||||
|
||||
// check all but e and b
|
||||
let mut iter = store.range(raw_a..raw_b);
|
||||
assert_eq!(iter.next(), Some(&(key_c, value_c)));
|
||||
assert_eq!(iter.next(), Some(&(key_a, value_a)));
|
||||
assert_eq!(iter.next(), Some(&(key_d, value_d)));
|
||||
assert_eq!(iter.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_lone_sibling() {
|
||||
// populate the value store
|
||||
let mut store = ValueStore::default();
|
||||
|
||||
let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]);
|
||||
let value_a = [ONE; WORD_SIZE];
|
||||
store.insert(key_a, value_a);
|
||||
|
||||
let raw_b = 0b_11111111_11111111_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]);
|
||||
let value_b = [ONE, ZERO, ONE, ZERO];
|
||||
store.insert(key_b, value_b);
|
||||
|
||||
// check sibling node for `a`
|
||||
let index = LeafNodeIndex::make(32, 0b_10101010_10101010_00011111_11111110);
|
||||
let parent_index = LeafNodeIndex::make(16, 0b_10101010_10101010);
|
||||
assert_eq!(store.get_lone_sibling(index), Some((&key_a, &value_a, parent_index)));
|
||||
|
||||
// check sibling node for `b`
|
||||
let index = LeafNodeIndex::make(32, 0b_11111111_11111111_00011111_11111111);
|
||||
let parent_index = LeafNodeIndex::make(16, 0b_11111111_11111111);
|
||||
assert_eq!(store.get_lone_sibling(index), Some((&key_b, &value_b, parent_index)));
|
||||
|
||||
// check some other sibling for some other index
|
||||
let index = LeafNodeIndex::make(32, 0b_11101010_10101010);
|
||||
assert_eq!(store.get_lone_sibling(index), None);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,21 @@
|
||||
//! Pseudo-random element generation.
|
||||
|
||||
pub use winter_crypto::{RandomCoin, RandomCoinError};
|
||||
use rand::RngCore;
|
||||
pub use winter_crypto::{DefaultRandomCoin as WinterRandomCoin, RandomCoin, RandomCoinError};
|
||||
pub use winter_utils::Randomizable;
|
||||
|
||||
use crate::{Felt, FieldElement, Word, ZERO};
|
||||
|
||||
mod rpo;
|
||||
pub use rpo::RpoRandomCoin;
|
||||
|
||||
/// Pseudo-random element generator.
|
||||
///
|
||||
/// An instance can be used to draw, uniformly at random, base field elements as well as [Word]s.
|
||||
pub trait FeltRng: RngCore {
|
||||
/// Draw, uniformly at random, a base field element.
|
||||
fn draw_element(&mut self) -> Felt;
|
||||
|
||||
/// Draw, uniformly at random, a [Word].
|
||||
fn draw_word(&mut self) -> Word;
|
||||
}
|
||||
|
||||
292
src/rand/rpo.rs
Normal file
292
src/rand/rpo.rs
Normal file
@@ -0,0 +1,292 @@
|
||||
use super::{Felt, FeltRng, FieldElement, RandomCoin, RandomCoinError, RngCore, Word, ZERO};
|
||||
use crate::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
|
||||
};
|
||||
use alloc::{string::ToString, vec::Vec};
|
||||
use rand_core::impls;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
const STATE_WIDTH: usize = Rpo256::STATE_WIDTH;
|
||||
const RATE_START: usize = Rpo256::RATE_RANGE.start;
|
||||
const RATE_END: usize = Rpo256::RATE_RANGE.end;
|
||||
const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start) / 2;
|
||||
|
||||
// RPO RANDOM COIN
|
||||
// ================================================================================================
|
||||
/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
|
||||
/// described in <https://eprint.iacr.org/2011/499.pdf>.
|
||||
///
|
||||
/// The simplification is related to the following facts:
|
||||
/// 1. A call to the reseed method implies one and only one call to the permutation function.
|
||||
/// This is possible because in our case we never reseed with more than 4 field elements.
|
||||
/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
|
||||
/// material.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct RpoRandomCoin {
|
||||
state: [Felt; STATE_WIDTH],
|
||||
current: usize,
|
||||
}
|
||||
|
||||
impl RpoRandomCoin {
|
||||
/// Returns a new [RpoRandomCoin] initialize with the specified seed.
|
||||
pub fn new(seed: Word) -> Self {
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
for i in 0..HALF_RATE_WIDTH {
|
||||
state[RATE_START + i] += seed[i];
|
||||
}
|
||||
|
||||
// Absorb
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
|
||||
RpoRandomCoin { state, current: RATE_START }
|
||||
}
|
||||
|
||||
/// Returns an [RpoRandomCoin] instantiated from the provided components.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if `current` is smaller than 4 or greater than or equal to 12.
|
||||
pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
|
||||
assert!(
|
||||
(RATE_START..RATE_END).contains(¤t),
|
||||
"current value outside of valid range"
|
||||
);
|
||||
Self { state, current }
|
||||
}
|
||||
|
||||
/// Returns components of this random coin.
|
||||
pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
|
||||
(self.state, self.current)
|
||||
}
|
||||
|
||||
/// Fills `dest` with random data.
|
||||
pub fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
<Self as RngCore>::fill_bytes(self, dest)
|
||||
}
|
||||
|
||||
fn draw_basefield(&mut self) -> Felt {
|
||||
if self.current == RATE_END {
|
||||
Rpo256::apply_permutation(&mut self.state);
|
||||
self.current = RATE_START;
|
||||
}
|
||||
|
||||
self.current += 1;
|
||||
self.state[self.current - 1]
|
||||
}
|
||||
}
|
||||
|
||||
// RANDOM COIN IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl RandomCoin for RpoRandomCoin {
|
||||
type BaseField = Felt;
|
||||
type Hasher = Rpo256;
|
||||
|
||||
fn new(seed: &[Self::BaseField]) -> Self {
|
||||
let digest: Word = Rpo256::hash_elements(seed).into();
|
||||
Self::new(digest)
|
||||
}
|
||||
|
||||
fn reseed(&mut self, data: RpoDigest) {
|
||||
// Reset buffer
|
||||
self.current = RATE_START;
|
||||
|
||||
// Add the new seed material to the first half of the rate portion of the RPO state
|
||||
let data: Word = data.into();
|
||||
|
||||
self.state[RATE_START] += data[0];
|
||||
self.state[RATE_START + 1] += data[1];
|
||||
self.state[RATE_START + 2] += data[2];
|
||||
self.state[RATE_START + 3] += data[3];
|
||||
|
||||
// Absorb
|
||||
Rpo256::apply_permutation(&mut self.state);
|
||||
}
|
||||
|
||||
fn check_leading_zeros(&self, value: u64) -> u32 {
|
||||
let value = Felt::new(value);
|
||||
let mut state_tmp = self.state;
|
||||
|
||||
state_tmp[RATE_START] += value;
|
||||
|
||||
Rpo256::apply_permutation(&mut state_tmp);
|
||||
|
||||
let first_rate_element = state_tmp[RATE_START].as_int();
|
||||
first_rate_element.trailing_zeros()
|
||||
}
|
||||
|
||||
fn draw<E: FieldElement<BaseField = Felt>>(&mut self) -> Result<E, RandomCoinError> {
|
||||
let ext_degree = E::EXTENSION_DEGREE;
|
||||
let mut result = vec![ZERO; ext_degree];
|
||||
for r in result.iter_mut().take(ext_degree) {
|
||||
*r = self.draw_basefield();
|
||||
}
|
||||
|
||||
let result = E::slice_from_base_elements(&result);
|
||||
Ok(result[0])
|
||||
}
|
||||
|
||||
fn draw_integers(
|
||||
&mut self,
|
||||
num_values: usize,
|
||||
domain_size: usize,
|
||||
nonce: u64,
|
||||
) -> Result<Vec<usize>, RandomCoinError> {
|
||||
assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
|
||||
assert!(num_values < domain_size, "number of values must be smaller than domain size");
|
||||
|
||||
// absorb the nonce
|
||||
let nonce = Felt::new(nonce);
|
||||
self.state[RATE_START] += nonce;
|
||||
Rpo256::apply_permutation(&mut self.state);
|
||||
|
||||
// reset the buffer
|
||||
self.current = RATE_START;
|
||||
|
||||
// determine how many bits are needed to represent valid values in the domain
|
||||
let v_mask = (domain_size - 1) as u64;
|
||||
|
||||
// draw values from PRNG until we get as many unique values as specified by num_queries
|
||||
let mut values = Vec::new();
|
||||
for _ in 0..1000 {
|
||||
// get the next pseudo-random field element
|
||||
let value = self.draw_basefield().as_int();
|
||||
|
||||
// use the mask to get a value within the range
|
||||
let value = (value & v_mask) as usize;
|
||||
|
||||
values.push(value);
|
||||
if values.len() == num_values {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if values.len() < num_values {
|
||||
return Err(RandomCoinError::FailedToDrawIntegers(num_values, values.len(), 1000));
|
||||
}
|
||||
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
|
||||
// FELT RNG IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl FeltRng for RpoRandomCoin {
|
||||
fn draw_element(&mut self) -> Felt {
|
||||
self.draw_basefield()
|
||||
}
|
||||
|
||||
fn draw_word(&mut self) -> Word {
|
||||
let mut output = [ZERO; 4];
|
||||
for o in output.iter_mut() {
|
||||
*o = self.draw_basefield();
|
||||
}
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
// RNGCORE IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl RngCore for RpoRandomCoin {
|
||||
fn next_u32(&mut self) -> u32 {
|
||||
self.draw_basefield().as_int() as u32
|
||||
}
|
||||
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
impls::next_u64_via_u32(self)
|
||||
}
|
||||
|
||||
fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
impls::fill_bytes_via_next(self, dest)
|
||||
}
|
||||
|
||||
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
|
||||
self.fill_bytes(dest);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl Serializable for RpoRandomCoin {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
self.state.iter().for_each(|v| v.write_into(target));
|
||||
// casting to u8 is OK because `current` is always between 4 and 12.
|
||||
target.write_u8(self.current as u8);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RpoRandomCoin {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let state = [
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
];
|
||||
let current = source.read_u8()? as usize;
|
||||
if !(RATE_START..RATE_END).contains(¤t) {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"current value outside of valid range".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(Self { state, current })
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Deserializable, FeltRng, RpoRandomCoin, Serializable, ZERO};
|
||||
use crate::ONE;
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_felt() {
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let output = rpocoin.draw_element();
|
||||
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let expected = rpocoin.draw_basefield();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_word() {
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let output = rpocoin.draw_word();
|
||||
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let mut expected = [ZERO; 4];
|
||||
for o in expected.iter_mut() {
|
||||
*o = rpocoin.draw_basefield();
|
||||
}
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_serialization() {
|
||||
let coin1 = RpoRandomCoin::from_parts([ONE; 12], 5);
|
||||
|
||||
let bytes = coin1.to_bytes();
|
||||
let coin2 = RpoRandomCoin::read_from_bytes(&bytes).unwrap();
|
||||
assert_eq!(coin1, coin2);
|
||||
}
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
/// A trait for computing the difference between two objects.
|
||||
pub trait Diff<K: Ord + Clone, V: Clone> {
|
||||
/// The type that describes the difference between two objects.
|
||||
type DiffType;
|
||||
|
||||
/// Returns a [Self::DiffType] object that represents the difference between this object and
|
||||
/// other.
|
||||
fn diff(&self, other: &Self) -> Self::DiffType;
|
||||
}
|
||||
|
||||
/// A trait for applying the difference between two objects.
|
||||
pub trait ApplyDiff<K: Ord + Clone, V: Clone> {
|
||||
/// The type that describes the difference between two objects.
|
||||
type DiffType;
|
||||
|
||||
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
|
||||
fn apply(&mut self, diff: Self::DiffType);
|
||||
}
|
||||
|
||||
/// A trait for applying the difference between two objects with the possibility of failure.
|
||||
pub trait TryApplyDiff<K: Ord + Clone, V: Clone> {
|
||||
/// The type that describes the difference between two objects.
|
||||
type DiffType;
|
||||
|
||||
/// An error type that can be returned if the changes cannot be applied.
|
||||
type Error;
|
||||
|
||||
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
|
||||
/// Returns an error if the changes cannot be applied.
|
||||
fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), Self::Error>;
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user