mirror of
https://github.com/arnaucube/miden-crypto.git
synced 2026-01-11 16:41:29 +01:00
Compare commits
75 Commits
km/mkdocs-
...
rpo-dsa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d2a6739605 | ||
|
|
cae87a2790 | ||
|
|
335c50f54d | ||
|
|
b151773b0d | ||
|
|
1867f842d3 | ||
|
|
e1072ecc7f | ||
|
|
063ad49afd | ||
|
|
a27f9ad828 | ||
|
|
50dd6bda19 | ||
|
|
ee20a49953 | ||
|
|
0d75e3593b | ||
|
|
689cc93ed1 | ||
|
|
7970d3a736 | ||
|
|
a734dace1e | ||
|
|
940cc04670 | ||
|
|
e82baa35bb | ||
|
|
876d1bf97a | ||
|
|
8adc0ab418 | ||
|
|
c2eb38c236 | ||
|
|
a924ac6b81 | ||
|
|
e214608c85 | ||
|
|
c44ccd9dec | ||
|
|
e34900c7d8 | ||
|
|
2b184cd4ca | ||
|
|
913384600d | ||
|
|
ae807a47ae | ||
|
|
f4a9d5b027 | ||
|
|
ee42d87121 | ||
|
|
b1cb2b6ec3 | ||
|
|
e4a9a2ac00 | ||
|
|
c5077b1683 | ||
|
|
2e74028fd4 | ||
|
|
8bf6ef890d | ||
|
|
e2aeb25e01 | ||
|
|
790846cc73 | ||
|
|
4cb6bed428 | ||
|
|
a12e62ff22 | ||
|
|
9aa4987858 | ||
|
|
70a0a1e970 | ||
|
|
025fbb66a9 | ||
|
|
5ee5e8554b | ||
|
|
ac3c6976bd | ||
|
|
374a10f340 | ||
|
|
ad0f472708 | ||
|
|
8bb893345b | ||
|
|
d92fae7f82 | ||
|
|
b171575776 | ||
|
|
dfdd5f722f | ||
|
|
9f63b50510 | ||
|
|
d6ab367d32 | ||
|
|
b06cfa3c03 | ||
|
|
8556c8fc43 | ||
|
|
78ac70120d | ||
|
|
ccde10af13 | ||
|
|
f967211b5a | ||
|
|
d58c717956 | ||
|
|
c0743adac9 | ||
|
|
f72add58cd | ||
|
|
63f97e5621 | ||
|
|
43fe7a1072 | ||
|
|
bb42388827 | ||
|
|
2a0ae70645 | ||
|
|
da67f8c7e5 | ||
|
|
9454e1a8ae | ||
|
|
4bf087daf8 | ||
|
|
b4dc373925 | ||
|
|
4885f885a4 | ||
|
|
5a2e917dd5 | ||
|
|
2be17b74fb | ||
|
|
b35e99c390 | ||
|
|
4c8a9809ed | ||
|
|
ce9b45fe77 | ||
|
|
56d014898d | ||
|
|
8e81ccdb68 | ||
|
|
999a64fca6 |
3
.config/nextest.toml
Normal file
3
.config/nextest.toml
Normal file
@@ -0,0 +1,3 @@
|
||||
[profile.default]
|
||||
failure-output = "immediate-final"
|
||||
fail-fast = false
|
||||
25
.github/workflows/build.yml
vendored
Normal file
25
.github/workflows/build.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
# Runs build related jobs.
|
||||
|
||||
name: build
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, next]
|
||||
pull_request:
|
||||
types: [opened, reopened, synchronize]
|
||||
|
||||
jobs:
|
||||
no-std:
|
||||
name: Build for no-std
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable, nightly]
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
- name: Build for no-std
|
||||
run: |
|
||||
rustup update --no-self-update ${{ matrix.toolchain }}
|
||||
rustup target add wasm32-unknown-unknown
|
||||
make build-no-std
|
||||
23
.github/workflows/changelog.yml
vendored
Normal file
23
.github/workflows/changelog.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
# Runs changelog related jobs.
|
||||
# CI job heavily inspired by: https://github.com/tarides/changelog-check-action
|
||||
|
||||
name: changelog
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, reopened, synchronize, labeled, unlabeled]
|
||||
|
||||
jobs:
|
||||
changelog:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@main
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Check for changes in changelog
|
||||
env:
|
||||
BASE_REF: ${{ github.event.pull_request.base.ref }}
|
||||
NO_CHANGELOG_LABEL: ${{ contains(github.event.pull_request.labels.*.name, 'no changelog') }}
|
||||
run: ./scripts/check-changelog.sh "${{ inputs.changelog }}"
|
||||
shell: bash
|
||||
133
.github/workflows/ci.yml
vendored
133
.github/workflows/ci.yml
vendored
@@ -1,133 +0,0 @@
|
||||
name: CI
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, reopened, synchronize]
|
||||
|
||||
jobs:
|
||||
rustfmt:
|
||||
name: rustfmt ${{matrix.toolchain}} on ${{matrix.os}}
|
||||
runs-on: ${{matrix.os}}-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [nightly]
|
||||
os: [ubuntu]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install minimal Rust with rustfmt
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
components: rustfmt
|
||||
override: true
|
||||
- name: fmt
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: fmt
|
||||
args: --all -- --check
|
||||
|
||||
clippy:
|
||||
name: clippy ${{matrix.toolchain}} on ${{matrix.os}}
|
||||
runs-on: ${{matrix.os}}-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [nightly]
|
||||
os: [ubuntu]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install minimal Rust with clippy
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
components: clippy
|
||||
override: true
|
||||
- name: Clippy
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: clippy
|
||||
args: --all-targets -- -D clippy::all -D warnings
|
||||
- name: Clippy all features
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: clippy
|
||||
args: --all-targets --all-features -- -D clippy::all -D warnings
|
||||
|
||||
test:
|
||||
name: test ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.features}}
|
||||
runs-on: ${{matrix.os}}-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable, nightly]
|
||||
os: [ubuntu]
|
||||
features: ["--features default,serde", --no-default-features]
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
override: true
|
||||
- name: Test
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: ${{matrix.features}}
|
||||
|
||||
no-std:
|
||||
name: build ${{matrix.toolchain}} no-std for wasm32-unknown-unknown
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable, nightly]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
override: true
|
||||
- run: rustup target add wasm32-unknown-unknown
|
||||
- name: Build
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: build
|
||||
args: --no-default-features --target wasm32-unknown-unknown
|
||||
|
||||
docs:
|
||||
name: Verify the docs on ${{matrix.toolchain}}
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- name: Install rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: ${{matrix.toolchain}}
|
||||
override: true
|
||||
- name: Check docs
|
||||
uses: actions-rs/cargo@v1
|
||||
env:
|
||||
RUSTDOCFLAGS: -D warnings
|
||||
with:
|
||||
command: doc
|
||||
args: --verbose --all-features --keep-going
|
||||
53
.github/workflows/lint.yml
vendored
Normal file
53
.github/workflows/lint.yml
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
# Runs linting related jobs.
|
||||
|
||||
name: lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, next]
|
||||
pull_request:
|
||||
types: [opened, reopened, synchronize]
|
||||
|
||||
jobs:
|
||||
clippy:
|
||||
name: clippy nightly on ubuntu-latest
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
- name: Clippy
|
||||
run: |
|
||||
rustup update --no-self-update nightly
|
||||
rustup +nightly component add clippy
|
||||
make clippy
|
||||
|
||||
rustfmt:
|
||||
name: rustfmt check nightly on ubuntu-latest
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
- name: Rustfmt
|
||||
run: |
|
||||
rustup update --no-self-update nightly
|
||||
rustup +nightly component add rustfmt
|
||||
make format-check
|
||||
|
||||
doc:
|
||||
name: doc stable on ubuntu-latest
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
- name: Build docs
|
||||
run: |
|
||||
rustup update --no-self-update
|
||||
make doc
|
||||
|
||||
version:
|
||||
name: check rust version consistency
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
with:
|
||||
profile: minimal
|
||||
override: true
|
||||
- name: check rust versions
|
||||
run: ./scripts/check-rust-version.sh
|
||||
28
.github/workflows/test.yml
vendored
Normal file
28
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
# Runs test related jobs.
|
||||
|
||||
name: test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, next]
|
||||
pull_request:
|
||||
types: [opened, reopened, synchronize]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: test ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.args}}
|
||||
runs-on: ${{matrix.os}}-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable, nightly]
|
||||
os: [ubuntu]
|
||||
args: [default, no-std]
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
- uses: taiki-e/install-action@nextest
|
||||
- name: Perform tests
|
||||
run: |
|
||||
rustup update --no-self-update ${{matrix.toolchain}}
|
||||
make test-${{matrix.args}}
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -2,10 +2,6 @@
|
||||
# will have compiled files and executables
|
||||
/target/
|
||||
|
||||
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
Cargo.lock
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +0,0 @@
|
||||
[submodule "PQClean"]
|
||||
path = PQClean
|
||||
url = https://github.com/PQClean/PQClean.git
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v3.2.0
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.6.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
@@ -15,29 +15,20 @@ repos:
|
||||
- id: check-executables-have-shebangs
|
||||
- id: check-merge-conflict
|
||||
- id: detect-private-key
|
||||
- repo: https://github.com/hackaugusto/pre-commit-cargo
|
||||
rev: v1.0.0
|
||||
- repo: local
|
||||
hooks:
|
||||
# Allows cargo fmt to modify the source code prior to the commit
|
||||
- id: cargo
|
||||
name: Cargo fmt
|
||||
args: ["+stable", "fmt", "--all"]
|
||||
- id: lint
|
||||
name: Make lint
|
||||
stages: [commit]
|
||||
# Requires code to be properly formatted prior to pushing upstream
|
||||
- id: cargo
|
||||
name: Cargo fmt --check
|
||||
args: ["+stable", "fmt", "--all", "--check"]
|
||||
stages: [push, manual]
|
||||
- id: cargo
|
||||
name: Cargo check --all-targets
|
||||
args: ["+stable", "check", "--all-targets"]
|
||||
- id: cargo
|
||||
name: Cargo check --all-targets --no-default-features
|
||||
args: ["+stable", "check", "--all-targets", "--no-default-features"]
|
||||
- id: cargo
|
||||
name: Cargo check --all-targets --features default,std,serde
|
||||
args: ["+stable", "check", "--all-targets", "--features", "default,std,serde"]
|
||||
# Unlike fmt, clippy will not be automatically applied
|
||||
- id: cargo
|
||||
name: Cargo clippy
|
||||
args: ["+nightly", "clippy", "--workspace", "--", "--deny", "clippy::all", "--deny", "warnings"]
|
||||
language: rust
|
||||
entry: make lint
|
||||
- id: doc
|
||||
name: Make doc
|
||||
stages: [commit]
|
||||
language: rust
|
||||
entry: make doc
|
||||
- id: check
|
||||
name: Make check
|
||||
stages: [commit]
|
||||
language: rust
|
||||
entry: make check
|
||||
|
||||
128
CHANGELOG.md
128
CHANGELOG.md
@@ -1,46 +1,120 @@
|
||||
## 0.13.0 (2024-11-24)
|
||||
|
||||
- Fixed a bug in the implementation of `draw_integers` for `RpoRandomCoin` (#343).
|
||||
- [BREAKING] Refactor error messages and use `thiserror` to derive errors (#344).
|
||||
- [BREAKING] Updated Winterfell dependency to v0.11 (#346).
|
||||
- Added RPO-STARK based DSA (#349).
|
||||
- Added benchmarks for DSA implementations (#354).
|
||||
- Implemented deterministic RPO-STARK based DSA (#358).
|
||||
|
||||
## 0.12.0 (2024-10-30)
|
||||
|
||||
- [BREAKING] Updated Winterfell dependency to v0.10 (#338).
|
||||
- Added parallel implementation of `Smt::with_entries()` with significantly better performance when the `concurrent` feature is enabled (#341).
|
||||
|
||||
## 0.11.0 (2024-10-17)
|
||||
|
||||
- [BREAKING]: renamed `Mmr::open()` into `Mmr::open_at()` and `Mmr::peaks()` into `Mmr::peaks_at()` (#234).
|
||||
- Added `Mmr::open()` and `Mmr::peaks()` which rely on `Mmr::open_at()` and `Mmr::peaks()` respectively (#234).
|
||||
- Standardized CI and Makefile across Miden repos (#323).
|
||||
- Added `Smt::compute_mutations()` and `Smt::apply_mutations()` for validation-checked insertions (#327).
|
||||
- Changed padding rule for RPO/RPX hash functions (#318).
|
||||
- [BREAKING] Changed return value of the `Mmr::verify()` and `MerklePath::verify()` from `bool` to `Result<>` (#335).
|
||||
- Added `is_empty()` functions to the `SimpleSmt` and `Smt` structures. Added `EMPTY_ROOT` constant to the `SparseMerkleTree` trait (#337).
|
||||
|
||||
## 0.10.3 (2024-09-25)
|
||||
|
||||
- Implement `get_size_hint` for `Smt` (#331).
|
||||
|
||||
## 0.10.2 (2024-09-25)
|
||||
|
||||
- Implement `get_size_hint` for `RpoDigest` and `RpxDigest` and expose constants for their serialized size (#330).
|
||||
|
||||
## 0.10.1 (2024-09-13)
|
||||
|
||||
- Added `Serializable` and `Deserializable` implementations for `PartialMmr` and `InOrderIndex` (#329).
|
||||
|
||||
## 0.10.0 (2024-08-06)
|
||||
|
||||
- Added more `RpoDigest` and `RpxDigest` conversions (#311).
|
||||
- [BREAKING] Migrated to Winterfell v0.9 (#315).
|
||||
- Fixed encoding of Falcon secret key (#319).
|
||||
|
||||
## 0.9.3 (2024-04-24)
|
||||
|
||||
- Added `RpxRandomCoin` struct (#307).
|
||||
|
||||
## 0.9.2 (2024-04-21)
|
||||
|
||||
- Implemented serialization for the `Smt` struct (#304).
|
||||
- Fixed a bug in Falcon signature generation (#305).
|
||||
|
||||
## 0.9.1 (2024-04-02)
|
||||
|
||||
- Added `num_leaves()` method to `SimpleSmt` (#302).
|
||||
|
||||
## 0.9.0 (2024-03-24)
|
||||
|
||||
- [BREAKING] Removed deprecated re-exports from liballoc/libstd (#290).
|
||||
- [BREAKING] Refactored RpoFalcon512 signature to work with pure Rust (#285).
|
||||
- [BREAKING] Added `RngCore` as supertrait for `FeltRng` (#299).
|
||||
|
||||
# 0.8.4 (2024-03-17)
|
||||
|
||||
- Re-added unintentionally removed re-exported liballoc macros (`vec` and `format` macros).
|
||||
|
||||
# 0.8.3 (2024-03-17)
|
||||
|
||||
- Re-added unintentionally removed re-exported liballoc macros (#292).
|
||||
|
||||
# 0.8.2 (2024-03-17)
|
||||
|
||||
- Updated `no-std` approach to be in sync with winterfell v0.8.3 release (#290).
|
||||
|
||||
## 0.8.1 (2024-02-21)
|
||||
* Fixed clippy warnings (#280)
|
||||
|
||||
- Fixed clippy warnings (#280)
|
||||
|
||||
## 0.8.0 (2024-02-14)
|
||||
|
||||
* Implemented the `PartialMmr` data structure (#195).
|
||||
* Implemented RPX hash function (#201).
|
||||
* Added `FeltRng` and `RpoRandomCoin` (#237).
|
||||
* Accelerated RPO/RPX hash functions using AVX512 instructions (#234).
|
||||
* Added `inner_nodes()` method to `PartialMmr` (#238).
|
||||
* Improved `PartialMmr::apply_delta()` (#242).
|
||||
* Refactored `SimpleSmt` struct (#245).
|
||||
* Replaced `TieredSmt` struct with `Smt` struct (#254, #277).
|
||||
* Updated Winterfell dependency to v0.8 (#275).
|
||||
- Implemented the `PartialMmr` data structure (#195).
|
||||
- Implemented RPX hash function (#201).
|
||||
- Added `FeltRng` and `RpoRandomCoin` (#237).
|
||||
- Accelerated RPO/RPX hash functions using AVX512 instructions (#234).
|
||||
- Added `inner_nodes()` method to `PartialMmr` (#238).
|
||||
- Improved `PartialMmr::apply_delta()` (#242).
|
||||
- Refactored `SimpleSmt` struct (#245).
|
||||
- Replaced `TieredSmt` struct with `Smt` struct (#254, #277).
|
||||
- Updated Winterfell dependency to v0.8 (#275).
|
||||
|
||||
## 0.7.1 (2023-10-10)
|
||||
|
||||
* Fixed RPO Falcon signature build on Windows.
|
||||
- Fixed RPO Falcon signature build on Windows.
|
||||
|
||||
## 0.7.0 (2023-10-05)
|
||||
|
||||
* Replaced `MerklePathSet` with `PartialMerkleTree` (#165).
|
||||
* Implemented clearing of nodes in `TieredSmt` (#173).
|
||||
* Added ability to generate inclusion proofs for `TieredSmt` (#174).
|
||||
* Implemented Falcon DSA (#179).
|
||||
* Added conditional `serde`` support for various structs (#180).
|
||||
* Implemented benchmarking for `TieredSmt` (#182).
|
||||
* Added more leaf traversal methods for `MerkleStore` (#185).
|
||||
* Added SVE acceleration for RPO hash function (#189).
|
||||
- Replaced `MerklePathSet` with `PartialMerkleTree` (#165).
|
||||
- Implemented clearing of nodes in `TieredSmt` (#173).
|
||||
- Added ability to generate inclusion proofs for `TieredSmt` (#174).
|
||||
- Implemented Falcon DSA (#179).
|
||||
- Added conditional `serde`` support for various structs (#180).
|
||||
- Implemented benchmarking for `TieredSmt` (#182).
|
||||
- Added more leaf traversal methods for `MerkleStore` (#185).
|
||||
- Added SVE acceleration for RPO hash function (#189).
|
||||
|
||||
## 0.6.0 (2023-06-25)
|
||||
|
||||
* [BREAKING] Added support for recording capabilities for `MerkleStore` (#162).
|
||||
* [BREAKING] Refactored Merkle struct APIs to use `RpoDigest` instead of `Word` (#157).
|
||||
* Added initial implementation of `PartialMerkleTree` (#156).
|
||||
- [BREAKING] Added support for recording capabilities for `MerkleStore` (#162).
|
||||
- [BREAKING] Refactored Merkle struct APIs to use `RpoDigest` instead of `Word` (#157).
|
||||
- Added initial implementation of `PartialMerkleTree` (#156).
|
||||
|
||||
## 0.5.0 (2023-05-26)
|
||||
|
||||
* Implemented `TieredSmt` (#152, #153).
|
||||
* Implemented ability to extract a subset of a `MerkleStore` (#151).
|
||||
* Cleaned up `SimpleSmt` interface (#149).
|
||||
* Decoupled hashing and padding of peaks in `Mmr` (#148).
|
||||
* Added `inner_nodes()` to `MerkleStore` (#146).
|
||||
- Implemented `TieredSmt` (#152, #153).
|
||||
- Implemented ability to extract a subset of a `MerkleStore` (#151).
|
||||
- Cleaned up `SimpleSmt` interface (#149).
|
||||
- Decoupled hashing and padding of peaks in `Mmr` (#148).
|
||||
- Added `inner_nodes()` to `MerkleStore` (#146).
|
||||
|
||||
## 0.4.0 (2023-04-21)
|
||||
|
||||
|
||||
1329
Cargo.lock
generated
Normal file
1329
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
75
Cargo.toml
75
Cargo.toml
@@ -1,16 +1,16 @@
|
||||
[package]
|
||||
name = "miden-crypto"
|
||||
version = "0.8.1"
|
||||
version = "0.14.0"
|
||||
description = "Miden Cryptographic primitives"
|
||||
authors = ["miden contributors"]
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
repository = "https://github.com/0xPolygonMiden/crypto"
|
||||
documentation = "https://docs.rs/miden-crypto/0.8.1"
|
||||
documentation = "https://docs.rs/miden-crypto/0.14.0"
|
||||
categories = ["cryptography", "no-std"]
|
||||
keywords = ["miden", "crypto", "hash", "merkle"]
|
||||
edition = "2021"
|
||||
rust-version = "1.75"
|
||||
rust-version = "1.82"
|
||||
|
||||
[[bin]]
|
||||
name = "miden-crypto"
|
||||
@@ -19,6 +19,10 @@ bench = false
|
||||
doctest = false
|
||||
required-features = ["executable"]
|
||||
|
||||
[[bench]]
|
||||
name = "dsa"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "hash"
|
||||
harness = false
|
||||
@@ -27,39 +31,68 @@ harness = false
|
||||
name = "smt"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "smt-subtree"
|
||||
harness = false
|
||||
required-features = ["internal"]
|
||||
|
||||
[[bench]]
|
||||
name = "merkle"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "smt-with-entries"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "store"
|
||||
harness = false
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
executable = ["dep:clap", "dep:rand_utils", "std"]
|
||||
serde = ["dep:serde", "serde?/alloc", "winter_math/serde"]
|
||||
concurrent = ["dep:rayon"]
|
||||
default = ["std", "concurrent"]
|
||||
executable = ["dep:clap", "dep:rand-utils", "std"]
|
||||
internal = []
|
||||
serde = ["dep:serde", "serde?/alloc", "winter-math/serde"]
|
||||
std = [
|
||||
"blake3/std",
|
||||
"dep:cc",
|
||||
"dep:libc",
|
||||
"winter_crypto/std",
|
||||
"winter_math/std",
|
||||
"winter_utils/std",
|
||||
"rand/std",
|
||||
"rand/std_rng",
|
||||
"winter-crypto/std",
|
||||
"winter-math/std",
|
||||
"winter-utils/std",
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
blake3 = { version = "1.5", default-features = false }
|
||||
clap = { version = "4.5", features = ["derive"], optional = true }
|
||||
libc = { version = "0.2", default-features = false, optional = true }
|
||||
rand_utils = { version = "0.8", package = "winter-rand-utils", optional = true }
|
||||
serde = { version = "1.0", features = ["derive"], default-features = false, optional = true }
|
||||
winter_crypto = { version = "0.8", package = "winter-crypto", default-features = false }
|
||||
winter_math = { version = "0.8", package = "winter-math", default-features = false }
|
||||
winter_utils = { version = "0.8", package = "winter-utils", default-features = false }
|
||||
clap = { version = "4.5", optional = true, features = ["derive"] }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
num = { version = "0.4", default-features = false, features = ["alloc", "libm"] }
|
||||
num-complex = { version = "0.4", default-features = false }
|
||||
rand = { version = "0.8", default-features = false }
|
||||
rand_chacha = { version = "0.3", default-features = false }
|
||||
rand_core = { version = "0.6", default-features = false }
|
||||
rand-utils = {git = 'https://github.com/Al-Kindi-0/winterfell', package = "winter-rand-utils" , branch = 'al-zk', optional = true }
|
||||
rayon = { version = "1.10", optional = true }
|
||||
serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] }
|
||||
sha3 = { version = "0.10", default-features = false }
|
||||
thiserror = { version = "2.0", default-features = false }
|
||||
winter-air = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
|
||||
winter-crypto = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
|
||||
winter-prover = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
|
||||
winter-verifier = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
|
||||
winter-math = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
|
||||
winter-utils = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
|
||||
|
||||
[dev-dependencies]
|
||||
seq-macro = { version = "0.3" }
|
||||
assert_matches = { version = "1.5", default-features = false }
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
proptest = "1.4"
|
||||
rand_utils = { version = "0.8", package = "winter-rand-utils" }
|
||||
hex = { version = "0.4", default-features = false, features = ["alloc"] }
|
||||
proptest = "1.5"
|
||||
rand-utils = {git = 'https://github.com/Al-Kindi-0/winterfell', package = "winter-rand-utils" , branch = 'al-zk' }
|
||||
seq-macro = { version = "0.3" }
|
||||
|
||||
[build-dependencies]
|
||||
cc = { version = "1.0", features = ["parallel"], optional = true }
|
||||
cc = { version = "1.2", optional = true, features = ["parallel"] }
|
||||
glob = "0.3"
|
||||
|
||||
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Polygon Miden
|
||||
Copyright (c) 2024 Polygon Miden
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
86
Makefile
Normal file
86
Makefile
Normal file
@@ -0,0 +1,86 @@
|
||||
.DEFAULT_GOAL := help
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
# -- variables --------------------------------------------------------------------------------------
|
||||
|
||||
WARNINGS=RUSTDOCFLAGS="-D warnings"
|
||||
DEBUG_OVERFLOW_INFO=RUSTFLAGS="-C debug-assertions -C overflow-checks -C debuginfo=2"
|
||||
|
||||
# -- linting --------------------------------------------------------------------------------------
|
||||
|
||||
.PHONY: clippy
|
||||
clippy: ## Run Clippy with configs
|
||||
$(WARNINGS) cargo +nightly clippy --workspace --all-targets --all-features
|
||||
|
||||
|
||||
.PHONY: fix
|
||||
fix: ## Run Fix with configs
|
||||
cargo +nightly fix --allow-staged --allow-dirty --all-targets --all-features
|
||||
|
||||
|
||||
.PHONY: format
|
||||
format: ## Run Format using nightly toolchain
|
||||
cargo +nightly fmt --all
|
||||
|
||||
|
||||
.PHONY: format-check
|
||||
format-check: ## Run Format using nightly toolchain but only in check mode
|
||||
cargo +nightly fmt --all --check
|
||||
|
||||
|
||||
.PHONY: lint
|
||||
lint: format fix clippy ## Run all linting tasks at once (Clippy, fixing, formatting)
|
||||
|
||||
# --- docs ----------------------------------------------------------------------------------------
|
||||
|
||||
.PHONY: doc
|
||||
doc: ## Generate and check documentation
|
||||
$(WARNINGS) cargo doc --all-features --keep-going --release
|
||||
|
||||
# --- testing -------------------------------------------------------------------------------------
|
||||
|
||||
.PHONY: test-default
|
||||
test-default: ## Run tests with default features
|
||||
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --all-features
|
||||
|
||||
|
||||
.PHONY: test-no-std
|
||||
test-no-std: ## Run tests with `no-default-features` (std)
|
||||
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --no-default-features
|
||||
|
||||
|
||||
.PHONY: test
|
||||
test: test-default test-no-std ## Run all tests
|
||||
|
||||
# --- checking ------------------------------------------------------------------------------------
|
||||
|
||||
.PHONY: check
|
||||
check: ## Check all targets and features for errors without code generation
|
||||
cargo check --all-targets --all-features
|
||||
|
||||
# --- building ------------------------------------------------------------------------------------
|
||||
|
||||
.PHONY: build
|
||||
build: ## Build with default features enabled
|
||||
cargo build --release
|
||||
|
||||
.PHONY: build-no-std
|
||||
build-no-std: ## Build without the standard library
|
||||
cargo build --release --no-default-features --target wasm32-unknown-unknown
|
||||
|
||||
.PHONY: build-avx2
|
||||
build-avx2: ## Build with avx2 support
|
||||
RUSTFLAGS="-C target-feature=+avx2" cargo build --release
|
||||
|
||||
.PHONY: build-sve
|
||||
build-sve: ## Build with sve support
|
||||
RUSTFLAGS="-C target-feature=+sve" cargo build --release
|
||||
|
||||
# --- benchmarking --------------------------------------------------------------------------------
|
||||
|
||||
.PHONY: bench-tx
|
||||
bench-tx: ## Run crypto benchmarks
|
||||
cargo bench --features="concurrent"
|
||||
1
PQClean
1
PQClean
Submodule PQClean deleted from c3abebf4ab
81
README.md
81
README.md
@@ -1,78 +1,109 @@
|
||||
# Miden Crypto
|
||||
|
||||
[](https://github.com/0xPolygonMiden/crypto/blob/main/LICENSE)
|
||||
[](https://github.com/0xPolygonMiden/crypto/actions/workflows/test.yml)
|
||||
[](https://github.com/0xPolygonMiden/crypto/actions/workflows/build.yml)
|
||||
[](https://www.rust-lang.org/tools/install)
|
||||
[](https://crates.io/crates/miden-crypto)
|
||||
|
||||
This crate contains cryptographic primitives used in Polygon Miden.
|
||||
|
||||
## Hash
|
||||
|
||||
[Hash module](./src/hash) provides a set of cryptographic hash functions which are used by the Miden VM and the Miden rollup. Currently, these functions are:
|
||||
|
||||
* [BLAKE3](https://github.com/BLAKE3-team/BLAKE3) hash function with 256-bit, 192-bit, or 160-bit output. The 192-bit and 160-bit outputs are obtained by truncating the 256-bit output of the standard BLAKE3.
|
||||
* [RPO](https://eprint.iacr.org/2022/1577) hash function with 256-bit output. This hash function is an algebraic hash function suitable for recursive STARKs.
|
||||
* [RPX](https://eprint.iacr.org/2023/1045) hash function with 256-bit output. Similar to RPO, this hash function is suitable for recursive STARKs but it is about 2x faster as compared to RPO.
|
||||
- [BLAKE3](https://github.com/BLAKE3-team/BLAKE3) hash function with 256-bit, 192-bit, or 160-bit output. The 192-bit and 160-bit outputs are obtained by truncating the 256-bit output of the standard BLAKE3.
|
||||
- [RPO](https://eprint.iacr.org/2022/1577) hash function with 256-bit output. This hash function is an algebraic hash function suitable for recursive STARKs.
|
||||
- [RPX](https://eprint.iacr.org/2023/1045) hash function with 256-bit output. Similar to RPO, this hash function is suitable for recursive STARKs but it is about 2x faster as compared to RPO.
|
||||
|
||||
For performance benchmarks of these hash functions and their comparison to other popular hash functions please see [here](./benches/).
|
||||
|
||||
## Merkle
|
||||
|
||||
[Merkle module](./src/merkle/) provides a set of data structures related to Merkle trees. All these data structures are implemented using the RPO hash function described above. The data structures are:
|
||||
|
||||
* `MerkleStore`: a collection of Merkle trees of different heights designed to efficiently store trees with common subtrees. When instantiated with `RecordingMap`, a Merkle store records all accesses to the original data.
|
||||
* `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64.
|
||||
* `Mmr`: a Merkle mountain range structure designed to function as an append-only log.
|
||||
* `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64.
|
||||
* `PartialMmr`: a partial view of a Merkle mountain range structure.
|
||||
* `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values.
|
||||
* `Smt`: a Sparse Merkle tree (with compaction at depth 64), mapping 4-element keys to 4-element values.
|
||||
- `MerkleStore`: a collection of Merkle trees of different heights designed to efficiently store trees with common subtrees. When instantiated with `RecordingMap`, a Merkle store records all accesses to the original data.
|
||||
- `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64.
|
||||
- `Mmr`: a Merkle mountain range structure designed to function as an append-only log.
|
||||
- `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64.
|
||||
- `PartialMmr`: a partial view of a Merkle mountain range structure.
|
||||
- `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values.
|
||||
- `Smt`: a Sparse Merkle tree (with compaction at depth 64), mapping 4-element keys to 4-element values.
|
||||
|
||||
The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state.
|
||||
|
||||
## Signatures
|
||||
|
||||
[DSA module](./src/dsa) provides a set of digital signature schemes supported by default in the Miden VM. Currently, these schemes are:
|
||||
|
||||
* `RPO Falcon512`: a variant of the [Falcon](https://falcon-sign.info/) signature scheme. This variant differs from the standard in that instead of using SHAKE256 hash function in the *hash-to-point* algorithm we use RPO256. This makes the signature more efficient to verify in Miden VM.
|
||||
- `RPO Falcon512`: a variant of the [Falcon](https://falcon-sign.info/) signature scheme. This variant differs from the standard in that instead of using SHAKE256 hash function in the _hash-to-point_ algorithm we use RPO256. This makes the signature more efficient to verify in Miden VM.
|
||||
|
||||
For the above signatures, key generation and signing is available only in the `std` context (see [crate features](#crate-features) below), while signature verification is available in `no_std` context as well.
|
||||
For the above signatures, key generation, signing, and signature verification are available for both `std` and `no_std` contexts (see [crate features](#crate-features) below). However, in `no_std` context, the user is responsible for supplying the key generation and signing procedures with a random number generator.
|
||||
|
||||
## Pseudo-Random Element Generator
|
||||
|
||||
[Pseudo random element generator module](./src/rand/) provides a set of traits and data structures that facilitate generating pseudo-random elements in the context of Miden VM and Miden rollup. The module currently includes:
|
||||
|
||||
* `FeltRng`: a trait for generating random field elements and random 4 field elements.
|
||||
* `RpoRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait.
|
||||
- `FeltRng`: a trait for generating random field elements and random 4 field elements.
|
||||
- `RpoRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait using RPO hash function.
|
||||
- `RpxRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait using RPX hash function.
|
||||
|
||||
## Make commands
|
||||
|
||||
We use `make` to automate building, testing, and other processes. In most cases, `make` commands are wrappers around `cargo` commands with specific arguments. You can view the list of available commands in the [Makefile](Makefile), or run the following command:
|
||||
|
||||
```shell
|
||||
make
|
||||
```
|
||||
|
||||
## Crate features
|
||||
|
||||
This crate can be compiled with the following features:
|
||||
|
||||
* `std` - enabled by default and relies on the Rust standard library.
|
||||
* `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly.
|
||||
- `concurrent`- enabled by default; enables multi-threaded implementation of `Smt::with_entries()` which significantly improves performance on multi-core CPUs.
|
||||
- `std` - enabled by default and relies on the Rust standard library.
|
||||
- `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly.
|
||||
|
||||
Both of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/) to support heap-allocated collections.
|
||||
All of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/) to support heap-allocated collections.
|
||||
|
||||
To compile with `no_std`, disable default features via `--no-default-features` flag.
|
||||
To compile with `no_std`, disable default features via `--no-default-features` flag or using the following command:
|
||||
|
||||
```shell
|
||||
make build-no-std
|
||||
```
|
||||
|
||||
### AVX2 acceleration
|
||||
|
||||
On platforms with [AVX2](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable AVX2 acceleration, the code needs to be compiled with the `avx2` target feature enabled. For example:
|
||||
|
||||
```shell
|
||||
RUSTFLAGS="-C target-feature=+avx2" cargo build --release
|
||||
make build-avx2
|
||||
```
|
||||
|
||||
### SVE acceleration
|
||||
On platforms with [SVE](https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` target feature enabled. For example:
|
||||
|
||||
On platforms with [SVE](<https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)>) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` target feature enabled. For example:
|
||||
|
||||
```shell
|
||||
RUSTFLAGS="-C target-feature=+sve" cargo build --release
|
||||
make build-sve
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
You can use cargo defaults to test the library:
|
||||
The best way to test the library is using our [Makefile](Makefile), this will enable you to use our pre-defined optimized testing commands:
|
||||
|
||||
```shell
|
||||
cargo test
|
||||
make test
|
||||
```
|
||||
|
||||
However, some of the functions are heavy and might take a while for the tests to complete. In order to test in release mode, we have to replicate the test conditions of the development mode so all debug assertions can be verified.
|
||||
For example, some of the functions are heavy and might take a while for the tests to complete if using simply `cargo test`. In order to test in release and optimized mode, we have to replicate the test conditions of the development mode so all debug assertions can be verified.
|
||||
|
||||
We do that by enabling some special [flags](https://doc.rust-lang.org/cargo/reference/profiles.html) for the compilation.
|
||||
We do that by enabling some special [flags](https://doc.rust-lang.org/cargo/reference/profiles.html) for the compilation (which we have set as a default in our [Makefile](Makefile)):
|
||||
|
||||
```shell
|
||||
RUSTFLAGS="-C debug-assertions -C overflow-checks -C debuginfo=2" cargo test --release
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This project is [MIT licensed](./LICENSE).
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# Miden VM Hash Functions
|
||||
# Benchmarks
|
||||
|
||||
## Miden VM Hash Functions
|
||||
In the Miden VM, we make use of different hash functions. Some of these are "traditional" hash functions, like `BLAKE3`, which are optimized for out-of-STARK performance, while others are algebraic hash functions, like `Rescue Prime`, and are more optimized for a better performance inside the STARK. In what follows, we benchmark several such hash functions and compare against other constructions that are used by other proving systems. More precisely, we benchmark:
|
||||
|
||||
* **BLAKE3** as specified [here](https://github.com/BLAKE3-team/BLAKE3-specs/blob/master/blake3.pdf) and implemented [here](https://github.com/BLAKE3-team/BLAKE3) (with a wrapper exposed via this crate).
|
||||
@@ -8,13 +10,13 @@ In the Miden VM, we make use of different hash functions. Some of these are "tra
|
||||
* **Rescue Prime Optimized (RPO)** as specified [here](https://eprint.iacr.org/2022/1577) and implemented in this crate.
|
||||
* **Rescue Prime Extended (RPX)** a variant of the [xHash](https://eprint.iacr.org/2023/1045) hash function as implemented in this crate.
|
||||
|
||||
## Comparison and Instructions
|
||||
### Comparison and Instructions
|
||||
|
||||
### Comparison
|
||||
#### Comparison
|
||||
We benchmark the above hash functions using two scenarios. The first is a 2-to-1 $(a,b)\mapsto h(a,b)$ hashing where both $a$, $b$ and $h(a,b)$ are the digests corresponding to each of the hash functions.
|
||||
The second scenario is that of sequential hashing where we take a sequence of length $100$ field elements and hash these to produce a single digest. The digests are $4$ field elements in a prime field with modulus $2^{64} - 2^{32} + 1$ (i.e., 32 bytes) for Poseidon, Rescue Prime and RPO, and an array `[u8; 32]` for SHA3 and BLAKE3.
|
||||
|
||||
#### Scenario 1: 2-to-1 hashing `h(a,b)`
|
||||
##### Scenario 1: 2-to-1 hashing `h(a,b)`
|
||||
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
|
||||
| ------------------- | ------ | ------- | --------- | --------- | ------- | ------- |
|
||||
@@ -26,7 +28,7 @@ The second scenario is that of sequential hashing where we take a sequence of le
|
||||
| Intel Core i5-8279U | 68 ns | 536 ns | 2.0 µs | 13.6 µs | 8.5 µs | 4.4 µs |
|
||||
| Intel Xeon 8375C | 67 ns | | | | 8.2 µs | |
|
||||
|
||||
#### Scenario 2: Sequential hashing of 100 elements `h([a_0,...,a_99])`
|
||||
##### Scenario 2: Sequential hashing of 100 elements `h([a_0,...,a_99])`
|
||||
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
|
||||
| ------------------- | -------| ------- | --------- | --------- | ------- | ------- |
|
||||
@@ -42,7 +44,7 @@ Notes:
|
||||
- On Graviton 3, RPO256 and RPX256 are run with SVE acceleration enabled.
|
||||
- On AMD EPYC 9R14, RPO256 and RPX256 are run with AVX2 acceleration enabled.
|
||||
|
||||
### Instructions
|
||||
#### Instructions
|
||||
Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks for RPO and BLAKE3, clone the current repository, and from the root directory of the repo run the following:
|
||||
|
||||
```
|
||||
@@ -54,3 +56,47 @@ To run the benchmarks for Rescue Prime, Poseidon and SHA3, clone the following [
|
||||
```
|
||||
cargo bench hash
|
||||
```
|
||||
|
||||
## Miden VM DSA
|
||||
|
||||
We make use of the following digital signature algorithms (DSA) in the Miden VM:
|
||||
|
||||
* **RPO-Falcon512** as specified [here](https://falcon-sign.info/falcon.pdf) with the one difference being the use of the RPO hash function for the hash-to-point algorithm (Algorithm 3 in the previous reference) instead of SHAKE256.
|
||||
* **RPO-STARK** as specified [here](https://eprint.iacr.org/2024/1553), where the parameters are the ones for the unique-decoding regime (UDR) with the two differences:
|
||||
* We rely on Conjecture 1 in the [ethSTARK](https://eprint.iacr.org/2021/582) paper.
|
||||
* The number of FRI queries is $30$ and the grinding factor is $12$ bits. Thus using the previous point we can argue that the modified version achieves at least $102$ bits of average-case existential unforgeability security against $2^{113}$-query bound adversaries that can obtain up to $2^{64}$ signatures under the same public key.
|
||||
|
||||
|
||||
|
||||
### Comparison and Instructions
|
||||
|
||||
#### Comparison
|
||||
|
||||
|
||||
##### Key Generation
|
||||
|
||||
| DSA | RPO-Falcon512 | RPO-STARK |
|
||||
| ------------------- | :-----------: | :-------: |
|
||||
| Apple M1 Pro | 590 ms | 6 µs |
|
||||
| Intel Core i5-8279U | 585 ms | 10 µs |
|
||||
|
||||
##### Signature Generation
|
||||
|
||||
| DSA | RPO-Falcon512 | RPO-STARK |
|
||||
| ------------------- | :-----------: | :-------: |
|
||||
| Apple M1 Pro | 1.5 ms | 78 ms |
|
||||
| Intel Core i5-8279U | 1.8 ms | 130 ms |
|
||||
|
||||
##### Signature Verification
|
||||
|
||||
| DSA | RPO-Falcon512 | RPO-STARK |
|
||||
| ------------------- | :-----------: | :-------: |
|
||||
| Apple M1 Pro | 0.7 ms | 4.5 ms |
|
||||
| Intel Core i5-8279U | 1.2 ms | 7.9 ms |
|
||||
|
||||
#### Instructions
|
||||
Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks, clone the current repository, and from the root directory of the repo run the following:
|
||||
|
||||
```
|
||||
cargo bench --bench dsa
|
||||
```
|
||||
88
benches/dsa.rs
Normal file
88
benches/dsa.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
|
||||
use miden_crypto::dsa::{
|
||||
rpo_falcon512::SecretKey as FalconSecretKey, rpo_stark::SecretKey as RpoStarkSecretKey,
|
||||
};
|
||||
use rand_utils::rand_array;
|
||||
|
||||
fn key_gen_falcon(c: &mut Criterion) {
|
||||
c.bench_function("Falcon public key generation", |bench| {
|
||||
bench.iter_batched(|| FalconSecretKey::new(), |sk| sk.public_key(), BatchSize::SmallInput)
|
||||
});
|
||||
|
||||
c.bench_function("Falcon secret key generation", |bench| {
|
||||
bench.iter_batched(|| {}, |_| FalconSecretKey::new(), BatchSize::SmallInput)
|
||||
});
|
||||
}
|
||||
|
||||
fn key_gen_rpo_stark(c: &mut Criterion) {
|
||||
c.bench_function("RPO-STARK public key generation", |bench| {
|
||||
bench.iter_batched(
|
||||
|| RpoStarkSecretKey::random(),
|
||||
|sk| sk.public_key(),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
c.bench_function("RPO-STARK secret key generation", |bench| {
|
||||
bench.iter_batched(|| {}, |_| RpoStarkSecretKey::random(), BatchSize::SmallInput)
|
||||
});
|
||||
}
|
||||
|
||||
fn signature_gen_falcon(c: &mut Criterion) {
|
||||
c.bench_function("Falcon signature generation", |bench| {
|
||||
bench.iter_batched(
|
||||
|| (FalconSecretKey::new(), rand_array().into()),
|
||||
|(sk, msg)| sk.sign(msg),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn signature_gen_rpo_stark(c: &mut Criterion) {
|
||||
c.bench_function("RPO-STARK signature generation", |bench| {
|
||||
bench.iter_batched(
|
||||
|| (RpoStarkSecretKey::random(), rand_array().into()),
|
||||
|(sk, msg)| sk.sign(msg),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn signature_ver_falcon(c: &mut Criterion) {
|
||||
c.bench_function("Falcon signature verification", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
let sk = FalconSecretKey::new();
|
||||
let msg = rand_array().into();
|
||||
(sk.public_key(), msg, sk.sign(msg))
|
||||
},
|
||||
|(pk, msg, sig)| pk.verify(msg, &sig),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn signature_ver_rpo_stark(c: &mut Criterion) {
|
||||
c.bench_function("RPO-STARK signature verification", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
let sk = RpoStarkSecretKey::random();
|
||||
let msg = rand_array().into();
|
||||
(sk.public_key(), msg, sk.sign(msg))
|
||||
},
|
||||
|(pk, msg, sig)| pk.verify(msg, &sig),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
dsa_group,
|
||||
key_gen_falcon,
|
||||
key_gen_rpo_stark,
|
||||
signature_gen_falcon,
|
||||
signature_gen_rpo_stark,
|
||||
signature_ver_falcon,
|
||||
signature_ver_rpo_stark
|
||||
);
|
||||
criterion_main!(dsa_group);
|
||||
66
benches/merkle.rs
Normal file
66
benches/merkle.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
//! Benchmark for building a [`miden_crypto::merkle::MerkleTree`]. This is intended to be compared
|
||||
//! with the results from `benches/smt-subtree.rs`, as building a fully balanced Merkle tree with
|
||||
//! 256 leaves should indicate the *absolute best* performance we could *possibly* get for building
|
||||
//! a depth-8 sparse Merkle subtree, though practically speaking building a fully balanced Merkle
|
||||
//! tree will perform better than the sparse version. At the time of this writing (2024/11/24), this
|
||||
//! benchmark is about four times more efficient than the equivalent benchmark in
|
||||
//! `benches/smt-subtree.rs`.
|
||||
use std::{hint, mem, time::Duration};
|
||||
|
||||
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
|
||||
use miden_crypto::{merkle::MerkleTree, Felt, Word, ONE};
|
||||
use rand_utils::prng_array;
|
||||
|
||||
fn balanced_merkle_even(c: &mut Criterion) {
|
||||
c.bench_function("balanced-merkle-even", |b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
let entries: Vec<Word> =
|
||||
(0..256).map(|i| [Felt::new(i), ONE, ONE, Felt::new(i)]).collect();
|
||||
assert_eq!(entries.len(), 256);
|
||||
entries
|
||||
},
|
||||
|leaves| {
|
||||
let tree = MerkleTree::new(hint::black_box(leaves)).unwrap();
|
||||
assert_eq!(tree.depth(), 8);
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
fn balanced_merkle_rand(c: &mut Criterion) {
|
||||
let mut seed = [0u8; 32];
|
||||
c.bench_function("balanced-merkle-rand", |b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
let entries: Vec<Word> = (0..256).map(|_| generate_word(&mut seed)).collect();
|
||||
assert_eq!(entries.len(), 256);
|
||||
entries
|
||||
},
|
||||
|leaves| {
|
||||
let tree = MerkleTree::new(hint::black_box(leaves)).unwrap();
|
||||
assert_eq!(tree.depth(), 8);
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = smt_subtree_group;
|
||||
config = Criterion::default()
|
||||
.measurement_time(Duration::from_secs(20))
|
||||
.configure_from_args();
|
||||
targets = balanced_merkle_even, balanced_merkle_rand
|
||||
}
|
||||
criterion_main!(smt_subtree_group);
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn generate_word(seed: &mut [u8; 32]) -> Word {
|
||||
mem::swap(seed, &mut prng_array(*seed));
|
||||
let nums: [u64; 4] = prng_array(*seed);
|
||||
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
|
||||
}
|
||||
142
benches/smt-subtree.rs
Normal file
142
benches/smt-subtree.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use std::{fmt::Debug, hint, mem, time::Duration};
|
||||
|
||||
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
|
||||
use miden_crypto::{
|
||||
hash::rpo::RpoDigest,
|
||||
merkle::{build_subtree_for_bench, NodeIndex, SmtLeaf, SubtreeLeaf, SMT_DEPTH},
|
||||
Felt, Word, ONE,
|
||||
};
|
||||
use rand_utils::prng_array;
|
||||
use winter_utils::Randomizable;
|
||||
|
||||
const PAIR_COUNTS: [u64; 5] = [1, 64, 128, 192, 256];
|
||||
|
||||
fn smt_subtree_even(c: &mut Criterion) {
|
||||
let mut seed = [0u8; 32];
|
||||
|
||||
let mut group = c.benchmark_group("subtree8-even");
|
||||
|
||||
for pair_count in PAIR_COUNTS {
|
||||
let bench_id = BenchmarkId::from_parameter(pair_count);
|
||||
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
// Setup.
|
||||
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
|
||||
.map(|n| {
|
||||
// A single depth-8 subtree can have a maximum of 255 leaves.
|
||||
let leaf_index = ((n as f64 / pair_count as f64) * 255.0) as u64;
|
||||
let key = RpoDigest::new([
|
||||
generate_value(&mut seed),
|
||||
ONE,
|
||||
Felt::new(n),
|
||||
Felt::new(leaf_index),
|
||||
]);
|
||||
let value = generate_word(&mut seed);
|
||||
(key, value)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut leaves: Vec<_> = entries
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
let leaf = SmtLeaf::new_single(*key, *value);
|
||||
let col = NodeIndex::from(leaf.index()).value();
|
||||
let hash = leaf.hash();
|
||||
SubtreeLeaf { col, hash }
|
||||
})
|
||||
.collect();
|
||||
leaves.sort();
|
||||
leaves.dedup_by_key(|leaf| leaf.col);
|
||||
leaves
|
||||
},
|
||||
|leaves| {
|
||||
// Benchmarked function.
|
||||
let (subtree, _) = build_subtree_for_bench(
|
||||
hint::black_box(leaves),
|
||||
hint::black_box(SMT_DEPTH),
|
||||
hint::black_box(SMT_DEPTH),
|
||||
);
|
||||
assert!(!subtree.is_empty());
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn smt_subtree_random(c: &mut Criterion) {
|
||||
let mut seed = [0u8; 32];
|
||||
|
||||
let mut group = c.benchmark_group("subtree8-rand");
|
||||
|
||||
for pair_count in PAIR_COUNTS {
|
||||
let bench_id = BenchmarkId::from_parameter(pair_count);
|
||||
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
// Setup.
|
||||
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
|
||||
.map(|i| {
|
||||
let leaf_index: u8 = generate_value(&mut seed);
|
||||
let key = RpoDigest::new([
|
||||
ONE,
|
||||
ONE,
|
||||
Felt::new(i),
|
||||
Felt::new(leaf_index as u64),
|
||||
]);
|
||||
let value = generate_word(&mut seed);
|
||||
(key, value)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut leaves: Vec<_> = entries
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
let leaf = SmtLeaf::new_single(*key, *value);
|
||||
let col = NodeIndex::from(leaf.index()).value();
|
||||
let hash = leaf.hash();
|
||||
SubtreeLeaf { col, hash }
|
||||
})
|
||||
.collect();
|
||||
leaves.sort();
|
||||
leaves
|
||||
},
|
||||
|leaves| {
|
||||
let (subtree, _) = build_subtree_for_bench(
|
||||
hint::black_box(leaves),
|
||||
hint::black_box(SMT_DEPTH),
|
||||
hint::black_box(SMT_DEPTH),
|
||||
);
|
||||
assert!(!subtree.is_empty());
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = smt_subtree_group;
|
||||
config = Criterion::default()
|
||||
.measurement_time(Duration::from_secs(40))
|
||||
.sample_size(60)
|
||||
.configure_from_args();
|
||||
targets = smt_subtree_even, smt_subtree_random
|
||||
}
|
||||
criterion_main!(smt_subtree_group);
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn generate_value<T: Copy + Debug + Randomizable>(seed: &mut [u8; 32]) -> T {
|
||||
mem::swap(seed, &mut prng_array(*seed));
|
||||
let value: [T; 1] = rand_utils::prng_array(*seed);
|
||||
value[0]
|
||||
}
|
||||
|
||||
fn generate_word(seed: &mut [u8; 32]) -> Word {
|
||||
mem::swap(seed, &mut prng_array(*seed));
|
||||
let nums: [u64; 4] = prng_array(*seed);
|
||||
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
|
||||
}
|
||||
71
benches/smt-with-entries.rs
Normal file
71
benches/smt-with-entries.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use std::{fmt::Debug, hint, mem, time::Duration};
|
||||
|
||||
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
|
||||
use miden_crypto::{hash::rpo::RpoDigest, merkle::Smt, Felt, Word, ONE};
|
||||
use rand_utils::prng_array;
|
||||
use winter_utils::Randomizable;
|
||||
|
||||
// 2^0, 2^4, 2^8, 2^12, 2^16
|
||||
const PAIR_COUNTS: [u64; 6] = [1, 16, 256, 4096, 65536, 1_048_576];
|
||||
|
||||
fn smt_with_entries(c: &mut Criterion) {
|
||||
let mut seed = [0u8; 32];
|
||||
|
||||
let mut group = c.benchmark_group("smt-with-entries");
|
||||
|
||||
for pair_count in PAIR_COUNTS {
|
||||
let bench_id = BenchmarkId::from_parameter(pair_count);
|
||||
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
// Setup.
|
||||
prepare_entries(pair_count, &mut seed)
|
||||
},
|
||||
|entries| {
|
||||
// Benchmarked function.
|
||||
Smt::with_entries(hint::black_box(entries)).unwrap();
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = smt_with_entries_group;
|
||||
config = Criterion::default()
|
||||
//.measurement_time(Duration::from_secs(960))
|
||||
.measurement_time(Duration::from_secs(60))
|
||||
.sample_size(10)
|
||||
.configure_from_args();
|
||||
targets = smt_with_entries
|
||||
}
|
||||
criterion_main!(smt_with_entries_group);
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn prepare_entries(pair_count: u64, seed: &mut [u8; 32]) -> Vec<(RpoDigest, [Felt; 4])> {
|
||||
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
|
||||
.map(|i| {
|
||||
let count = pair_count as f64;
|
||||
let idx = ((i as f64 / count) * (count)) as u64;
|
||||
let key = RpoDigest::new([generate_value(seed), ONE, Felt::new(i), Felt::new(idx)]);
|
||||
let value = generate_word(seed);
|
||||
(key, value)
|
||||
})
|
||||
.collect();
|
||||
entries
|
||||
}
|
||||
|
||||
fn generate_value<T: Copy + Debug + Randomizable>(seed: &mut [u8; 32]) -> T {
|
||||
mem::swap(seed, &mut prng_array(*seed));
|
||||
let value: [T; 1] = rand_utils::prng_array(*seed);
|
||||
value[0]
|
||||
}
|
||||
|
||||
fn generate_word(seed: &mut [u8; 32]) -> Word {
|
||||
mem::swap(seed, &mut prng_array(*seed));
|
||||
let nums: [u64; 4] = prng_array(*seed);
|
||||
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
|
||||
}
|
||||
31
build.rs
31
build.rs
@@ -1,39 +1,8 @@
|
||||
fn main() {
|
||||
#[cfg(feature = "std")]
|
||||
compile_rpo_falcon();
|
||||
|
||||
#[cfg(target_feature = "sve")]
|
||||
compile_arch_arm64_sve();
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
fn compile_rpo_falcon() {
|
||||
use std::path::PathBuf;
|
||||
|
||||
const RPO_FALCON_PATH: &str = "src/dsa/rpo_falcon512/falcon_c";
|
||||
|
||||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/falcon.h");
|
||||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/falcon.c");
|
||||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/rpo.h");
|
||||
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/rpo.c");
|
||||
|
||||
let target_dir: PathBuf = ["PQClean", "crypto_sign", "falcon-512", "clean"].iter().collect();
|
||||
let common_dir: PathBuf = ["PQClean", "common"].iter().collect();
|
||||
|
||||
let scheme_files = glob::glob(target_dir.join("*.c").to_str().unwrap()).unwrap();
|
||||
let common_files = glob::glob(common_dir.join("*.c").to_str().unwrap()).unwrap();
|
||||
|
||||
cc::Build::new()
|
||||
.include(&common_dir)
|
||||
.include(target_dir)
|
||||
.files(scheme_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned()))
|
||||
.files(common_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned()))
|
||||
.file(format!("{RPO_FALCON_PATH}/falcon.c"))
|
||||
.file(format!("{RPO_FALCON_PATH}/rpo.c"))
|
||||
.flag("-O3")
|
||||
.compile("rpo_falcon512");
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "sve")]
|
||||
fn compile_arch_arm64_sve() {
|
||||
const RPO_SVE_PATH: &str = "arch/arm64-sve/rpo";
|
||||
|
||||
5
rust-toolchain.toml
Normal file
5
rust-toolchain.toml
Normal file
@@ -0,0 +1,5 @@
|
||||
[toolchain]
|
||||
channel = "1.82"
|
||||
components = ["rustfmt", "rust-src", "clippy"]
|
||||
targets = ["wasm32-unknown-unknown"]
|
||||
profile = "minimal"
|
||||
24
rustfmt.toml
24
rustfmt.toml
@@ -2,20 +2,22 @@ edition = "2021"
|
||||
array_width = 80
|
||||
attr_fn_like_width = 80
|
||||
chain_width = 80
|
||||
#condense_wildcard_suffixes = true
|
||||
#enum_discrim_align_threshold = 40
|
||||
comment_width = 100
|
||||
condense_wildcard_suffixes = true
|
||||
fn_call_width = 80
|
||||
#fn_single_line = true
|
||||
#format_code_in_doc_comments = true
|
||||
#format_macro_matchers = true
|
||||
#format_strings = true
|
||||
#group_imports = "StdExternalCrate"
|
||||
#hex_literal_case = "Lower"
|
||||
#imports_granularity = "Crate"
|
||||
format_code_in_doc_comments = true
|
||||
format_macro_matchers = true
|
||||
group_imports = "StdExternalCrate"
|
||||
hex_literal_case = "Lower"
|
||||
imports_granularity = "Crate"
|
||||
match_block_trailing_comma = true
|
||||
newline_style = "Unix"
|
||||
#normalize_doc_attributes = true
|
||||
#reorder_impl_items = true
|
||||
reorder_imports = true
|
||||
reorder_modules = true
|
||||
single_line_if_else_max_width = 60
|
||||
single_line_let_else_max_width = 60
|
||||
struct_lit_width = 40
|
||||
struct_variant_width = 40
|
||||
use_field_init_shorthand = true
|
||||
use_try_shorthand = true
|
||||
wrap_comments = true
|
||||
|
||||
21
scripts/check-changelog.sh
Executable file
21
scripts/check-changelog.sh
Executable file
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
set -uo pipefail
|
||||
|
||||
CHANGELOG_FILE="${1:-CHANGELOG.md}"
|
||||
|
||||
if [ "${NO_CHANGELOG_LABEL}" = "true" ]; then
|
||||
# 'no changelog' set, so finish successfully
|
||||
echo "\"no changelog\" label has been set"
|
||||
exit 0
|
||||
else
|
||||
# a changelog check is required
|
||||
# fail if the diff is empty
|
||||
if git diff --exit-code "origin/${BASE_REF}" -- "${CHANGELOG_FILE}"; then
|
||||
>&2 echo "Changes should come with an entry in the \"CHANGELOG.md\" file. This behavior
|
||||
can be overridden by using the \"no changelog\" label, which is used for changes
|
||||
that are trivial / explicitely stated not to require a changelog entry."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "The \"CHANGELOG.md\" file has been updated."
|
||||
fi
|
||||
15
scripts/check-rust-version.sh
Executable file
15
scripts/check-rust-version.sh
Executable file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Get rust-toolchain.toml file channel
|
||||
TOOLCHAIN_VERSION=$(grep 'channel' rust-toolchain.toml | sed -E 's/.*"(.*)".*/\1/')
|
||||
|
||||
# Get workspace Cargo.toml file rust-version
|
||||
CARGO_VERSION=$(grep 'rust-version' Cargo.toml | sed -E 's/.*"(.*)".*/\1/')
|
||||
|
||||
# Check version match
|
||||
if [ "$CARGO_VERSION" != "$TOOLCHAIN_VERSION" ]; then
|
||||
echo "Mismatch in Cargo.toml: Expected $TOOLCHAIN_VERSION, found $CARGO_VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Rust versions match ✅"
|
||||
@@ -1,3 +1,5 @@
|
||||
//! Digital signature schemes supported by default in the Miden VM.
|
||||
|
||||
pub mod rpo_falcon512;
|
||||
|
||||
pub mod rpo_stark;
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
use core::fmt;
|
||||
|
||||
use super::{LOG_N, MODULUS, PK_LEN};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum FalconError {
|
||||
KeyGenerationFailed,
|
||||
PubKeyDecodingExtraData,
|
||||
PubKeyDecodingInvalidCoefficient(u32),
|
||||
PubKeyDecodingInvalidLength(usize),
|
||||
PubKeyDecodingInvalidTag(u8),
|
||||
SigDecodingTooBigHighBits(u32),
|
||||
SigDecodingInvalidRemainder,
|
||||
SigDecodingNonZeroUnusedBitsLastByte,
|
||||
SigDecodingMinusZero,
|
||||
SigDecodingIncorrectEncodingAlgorithm,
|
||||
SigDecodingNotSupportedDegree(u8),
|
||||
SigGenerationFailed,
|
||||
}
|
||||
|
||||
impl fmt::Display for FalconError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use FalconError::*;
|
||||
match self {
|
||||
KeyGenerationFailed => write!(f, "Failed to generate a private-public key pair"),
|
||||
PubKeyDecodingExtraData => {
|
||||
write!(f, "Failed to decode public key: input not fully consumed")
|
||||
}
|
||||
PubKeyDecodingInvalidCoefficient(val) => {
|
||||
write!(f, "Failed to decode public key: coefficient {val} is greater than or equal to the field modulus {MODULUS}")
|
||||
}
|
||||
PubKeyDecodingInvalidLength(len) => {
|
||||
write!(f, "Failed to decode public key: expected {PK_LEN} bytes but received {len}")
|
||||
}
|
||||
PubKeyDecodingInvalidTag(byte) => {
|
||||
write!(f, "Failed to decode public key: expected the first byte to be {LOG_N} but was {byte}")
|
||||
}
|
||||
SigDecodingTooBigHighBits(m) => {
|
||||
write!(f, "Failed to decode signature: high bits {m} exceed 2048")
|
||||
}
|
||||
SigDecodingInvalidRemainder => {
|
||||
write!(f, "Failed to decode signature: incorrect remaining data")
|
||||
}
|
||||
SigDecodingNonZeroUnusedBitsLastByte => {
|
||||
write!(f, "Failed to decode signature: Non-zero unused bits in the last byte")
|
||||
}
|
||||
SigDecodingMinusZero => write!(f, "Failed to decode signature: -0 is forbidden"),
|
||||
SigDecodingIncorrectEncodingAlgorithm => write!(f, "Failed to decode signature: not supported encoding algorithm"),
|
||||
SigDecodingNotSupportedDegree(log_n) => write!(f, "Failed to decode signature: only supported irreducible polynomial degree is 512, 2^{log_n} was provided"),
|
||||
SigGenerationFailed => write!(f, "Failed to generate a signature"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for FalconError {}
|
||||
@@ -1,402 +0,0 @@
|
||||
/*
|
||||
* Wrapper for implementing the PQClean API.
|
||||
*/
|
||||
|
||||
#include <string.h>
|
||||
#include "randombytes.h"
|
||||
#include "falcon.h"
|
||||
#include "inner.h"
|
||||
#include "rpo.h"
|
||||
|
||||
#define NONCELEN 40
|
||||
|
||||
/*
|
||||
* Encoding formats (nnnn = log of degree, 9 for Falcon-512, 10 for Falcon-1024)
|
||||
*
|
||||
* private key:
|
||||
* header byte: 0101nnnn
|
||||
* private f (6 or 5 bits by element, depending on degree)
|
||||
* private g (6 or 5 bits by element, depending on degree)
|
||||
* private F (8 bits by element)
|
||||
*
|
||||
* public key:
|
||||
* header byte: 0000nnnn
|
||||
* public h (14 bits by element)
|
||||
*
|
||||
* signature:
|
||||
* header byte: 0011nnnn
|
||||
* nonce 40 bytes
|
||||
* value (12 bits by element)
|
||||
*
|
||||
* message + signature:
|
||||
* signature length (2 bytes, big-endian)
|
||||
* nonce 40 bytes
|
||||
* message
|
||||
* header byte: 0010nnnn
|
||||
* value (12 bits by element)
|
||||
* (signature length is 1+len(value), not counting the nonce)
|
||||
*/
|
||||
|
||||
/* see falcon.h */
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
uint8_t *pk,
|
||||
uint8_t *sk,
|
||||
unsigned char *seed
|
||||
) {
|
||||
union
|
||||
{
|
||||
uint8_t b[FALCON_KEYGEN_TEMP_9];
|
||||
uint64_t dummy_u64;
|
||||
fpr dummy_fpr;
|
||||
} tmp;
|
||||
int8_t f[512], g[512], F[512];
|
||||
uint16_t h[512];
|
||||
inner_shake256_context rng;
|
||||
size_t u, v;
|
||||
|
||||
/*
|
||||
* Generate key pair.
|
||||
*/
|
||||
inner_shake256_init(&rng);
|
||||
inner_shake256_inject(&rng, seed, sizeof seed);
|
||||
inner_shake256_flip(&rng);
|
||||
PQCLEAN_FALCON512_CLEAN_keygen(&rng, f, g, F, NULL, h, 9, tmp.b);
|
||||
inner_shake256_ctx_release(&rng);
|
||||
|
||||
/*
|
||||
* Encode private key.
|
||||
*/
|
||||
sk[0] = 0x50 + 9;
|
||||
u = 1;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode(
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u,
|
||||
f, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9]);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode(
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u,
|
||||
g, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9]);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode(
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u,
|
||||
F, 9, PQCLEAN_FALCON512_CLEAN_max_FG_bits[9]);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
if (u != PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
/*
|
||||
* Encode public key.
|
||||
*/
|
||||
pk[0] = 0x00 + 9;
|
||||
v = PQCLEAN_FALCON512_CLEAN_modq_encode(
|
||||
pk + 1, PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1,
|
||||
h, 9);
|
||||
if (v != PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
||||
uint8_t *pk,
|
||||
uint8_t *sk
|
||||
) {
|
||||
unsigned char seed[48];
|
||||
|
||||
/*
|
||||
* Generate a random seed.
|
||||
*/
|
||||
randombytes(seed, sizeof seed);
|
||||
|
||||
return PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(pk, sk, seed);
|
||||
}
|
||||
|
||||
/*
|
||||
* Compute the signature. nonce[] receives the nonce and must have length
|
||||
* NONCELEN bytes. sigbuf[] receives the signature value (without nonce
|
||||
* or header byte), with *sigbuflen providing the maximum value length and
|
||||
* receiving the actual value length.
|
||||
*
|
||||
* If a signature could be computed but not encoded because it would
|
||||
* exceed the output buffer size, then a new signature is computed. If
|
||||
* the provided buffer size is too low, this could loop indefinitely, so
|
||||
* the caller must provide a size that can accommodate signatures with a
|
||||
* large enough probability.
|
||||
*
|
||||
* Return value: 0 on success, -1 on error.
|
||||
*/
|
||||
static int do_sign(
|
||||
uint8_t *nonce,
|
||||
uint8_t *sigbuf,
|
||||
size_t *sigbuflen,
|
||||
const uint8_t *m,
|
||||
size_t mlen,
|
||||
const uint8_t *sk
|
||||
) {
|
||||
union
|
||||
{
|
||||
uint8_t b[72 * 512];
|
||||
uint64_t dummy_u64;
|
||||
fpr dummy_fpr;
|
||||
} tmp;
|
||||
int8_t f[512], g[512], F[512], G[512];
|
||||
struct
|
||||
{
|
||||
int16_t sig[512];
|
||||
uint16_t hm[512];
|
||||
} r;
|
||||
unsigned char seed[48];
|
||||
inner_shake256_context sc;
|
||||
rpo128_context rc;
|
||||
size_t u, v;
|
||||
|
||||
/*
|
||||
* Decode the private key.
|
||||
*/
|
||||
if (sk[0] != 0x50 + 9)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u = 1;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode(
|
||||
f, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9],
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode(
|
||||
g, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9],
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode(
|
||||
F, 9, PQCLEAN_FALCON512_CLEAN_max_FG_bits[9],
|
||||
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u);
|
||||
if (v == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
u += v;
|
||||
if (u != PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
if (!PQCLEAN_FALCON512_CLEAN_complete_private(G, f, g, F, 9, tmp.b))
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
/*
|
||||
* Create a random nonce (40 bytes).
|
||||
*/
|
||||
randombytes(nonce, NONCELEN);
|
||||
|
||||
/* ==== Start: Deviation from the reference implementation ================================= */
|
||||
|
||||
// Transform the nonce into 8 chunks each of size 5 bytes. We do this in order to be sure that
|
||||
// the conversion to field elements succeeds
|
||||
uint8_t buffer[64];
|
||||
memset(buffer, 0, 64);
|
||||
for (size_t i = 0; i < 8; i++)
|
||||
{
|
||||
buffer[8 * i] = nonce[5 * i];
|
||||
buffer[8 * i + 1] = nonce[5 * i + 1];
|
||||
buffer[8 * i + 2] = nonce[5 * i + 2];
|
||||
buffer[8 * i + 3] = nonce[5 * i + 3];
|
||||
buffer[8 * i + 4] = nonce[5 * i + 4];
|
||||
}
|
||||
|
||||
/*
|
||||
* Hash message nonce + message into a vector.
|
||||
*/
|
||||
rpo128_init(&rc);
|
||||
rpo128_absorb(&rc, buffer, NONCELEN + 24);
|
||||
rpo128_absorb(&rc, m, mlen);
|
||||
rpo128_finalize(&rc);
|
||||
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(&rc, r.hm, 9);
|
||||
rpo128_release(&rc);
|
||||
|
||||
/* ==== End: Deviation from the reference implementation =================================== */
|
||||
|
||||
/*
|
||||
* Initialize a RNG.
|
||||
*/
|
||||
randombytes(seed, sizeof seed);
|
||||
inner_shake256_init(&sc);
|
||||
inner_shake256_inject(&sc, seed, sizeof seed);
|
||||
inner_shake256_flip(&sc);
|
||||
|
||||
/*
|
||||
* Compute and return the signature. This loops until a signature
|
||||
* value is found that fits in the provided buffer.
|
||||
*/
|
||||
for (;;)
|
||||
{
|
||||
PQCLEAN_FALCON512_CLEAN_sign_dyn(r.sig, &sc, f, g, F, G, r.hm, 9, tmp.b);
|
||||
v = PQCLEAN_FALCON512_CLEAN_comp_encode(sigbuf, *sigbuflen, r.sig, 9);
|
||||
if (v != 0)
|
||||
{
|
||||
inner_shake256_ctx_release(&sc);
|
||||
*sigbuflen = v;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Verify a signature. The nonce has size NONCELEN bytes. sigbuf[]
|
||||
* (of size sigbuflen) contains the signature value, not including the
|
||||
* header byte or nonce. Return value is 0 on success, -1 on error.
|
||||
*/
|
||||
static int do_verify(
|
||||
const uint8_t *nonce,
|
||||
const uint8_t *sigbuf,
|
||||
size_t sigbuflen,
|
||||
const uint8_t *m,
|
||||
size_t mlen,
|
||||
const uint8_t *pk
|
||||
) {
|
||||
union
|
||||
{
|
||||
uint8_t b[2 * 512];
|
||||
uint64_t dummy_u64;
|
||||
fpr dummy_fpr;
|
||||
} tmp;
|
||||
uint16_t h[512], hm[512];
|
||||
int16_t sig[512];
|
||||
rpo128_context rc;
|
||||
|
||||
/*
|
||||
* Decode public key.
|
||||
*/
|
||||
if (pk[0] != 0x00 + 9)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
if (PQCLEAN_FALCON512_CLEAN_modq_decode(h, 9,
|
||||
pk + 1, PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1)
|
||||
!= PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
PQCLEAN_FALCON512_CLEAN_to_ntt_monty(h, 9);
|
||||
|
||||
/*
|
||||
* Decode signature.
|
||||
*/
|
||||
if (sigbuflen == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
if (PQCLEAN_FALCON512_CLEAN_comp_decode(sig, 9, sigbuf, sigbuflen) != sigbuflen)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
/* ==== Start: Deviation from the reference implementation ================================= */
|
||||
|
||||
/*
|
||||
* Hash nonce + message into a vector.
|
||||
*/
|
||||
|
||||
// Transform the nonce into 8 chunks each of size 5 bytes. We do this in order to be sure that
|
||||
// the conversion to field elements succeeds
|
||||
uint8_t buffer[64];
|
||||
memset(buffer, 0, 64);
|
||||
for (size_t i = 0; i < 8; i++)
|
||||
{
|
||||
buffer[8 * i] = nonce[5 * i];
|
||||
buffer[8 * i + 1] = nonce[5 * i + 1];
|
||||
buffer[8 * i + 2] = nonce[5 * i + 2];
|
||||
buffer[8 * i + 3] = nonce[5 * i + 3];
|
||||
buffer[8 * i + 4] = nonce[5 * i + 4];
|
||||
}
|
||||
|
||||
rpo128_init(&rc);
|
||||
rpo128_absorb(&rc, buffer, NONCELEN + 24);
|
||||
rpo128_absorb(&rc, m, mlen);
|
||||
rpo128_finalize(&rc);
|
||||
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(&rc, hm, 9);
|
||||
rpo128_release(&rc);
|
||||
|
||||
/* === End: Deviation from the reference implementation ==================================== */
|
||||
|
||||
/*
|
||||
* Verify signature.
|
||||
*/
|
||||
if (!PQCLEAN_FALCON512_CLEAN_verify_raw(hm, sig, h, 9, tmp.b))
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* see falcon.h */
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
uint8_t *sig,
|
||||
size_t *siglen,
|
||||
const uint8_t *m,
|
||||
size_t mlen,
|
||||
const uint8_t *sk
|
||||
) {
|
||||
/*
|
||||
* The PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES constant is used for
|
||||
* the signed message object (as produced by crypto_sign())
|
||||
* and includes a two-byte length value, so we take care here
|
||||
* to only generate signatures that are two bytes shorter than
|
||||
* the maximum. This is done to ensure that crypto_sign()
|
||||
* and crypto_sign_signature() produce the exact same signature
|
||||
* value, if used on the same message, with the same private key,
|
||||
* and using the same output from randombytes() (this is for
|
||||
* reproducibility of tests).
|
||||
*/
|
||||
size_t vlen;
|
||||
|
||||
vlen = PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES - NONCELEN - 3;
|
||||
if (do_sign(sig + 1, sig + 1 + NONCELEN, &vlen, m, mlen, sk) < 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
sig[0] = 0x30 + 9;
|
||||
*siglen = 1 + NONCELEN + vlen;
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* see falcon.h */
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
const uint8_t *sig,
|
||||
size_t siglen,
|
||||
const uint8_t *m,
|
||||
size_t mlen,
|
||||
const uint8_t *pk
|
||||
) {
|
||||
if (siglen < 1 + NONCELEN)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
if (sig[0] != 0x30 + 9)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
return do_verify(sig + 1, sig + 1 + NONCELEN, siglen - 1 - NONCELEN, m, mlen, pk);
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES 1281
|
||||
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES 897
|
||||
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES 666
|
||||
|
||||
/*
|
||||
* Generate a new key pair. Public key goes into pk[], private key in sk[].
|
||||
* Key sizes are exact (in bytes):
|
||||
* public (pk): PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES
|
||||
* private (sk): PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES
|
||||
*
|
||||
* Return value: 0 on success, -1 on error.
|
||||
*
|
||||
* Note: This implementation follows the reference implementation in PQClean
|
||||
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512
|
||||
* verbatim except for the sections that are marked otherwise.
|
||||
*/
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
||||
uint8_t *pk, uint8_t *sk);
|
||||
|
||||
/*
|
||||
* Generate a new key pair from seed. Public key goes into pk[], private key in sk[].
|
||||
* Key sizes are exact (in bytes):
|
||||
* public (pk): PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES
|
||||
* private (sk): PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES
|
||||
*
|
||||
* Return value: 0 on success, -1 on error.
|
||||
*/
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
uint8_t *pk, uint8_t *sk, unsigned char *seed);
|
||||
|
||||
/*
|
||||
* Compute a signature on a provided message (m, mlen), with a given
|
||||
* private key (sk). Signature is written in sig[], with length written
|
||||
* into *siglen. Signature length is variable; maximum signature length
|
||||
* (in bytes) is PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES.
|
||||
*
|
||||
* sig[], m[] and sk[] may overlap each other arbitrarily.
|
||||
*
|
||||
* Return value: 0 on success, -1 on error.
|
||||
*
|
||||
* Note: This implementation follows the reference implementation in PQClean
|
||||
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512
|
||||
* verbatim except for the sections that are marked otherwise.
|
||||
*/
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
uint8_t *sig, size_t *siglen,
|
||||
const uint8_t *m, size_t mlen, const uint8_t *sk);
|
||||
|
||||
/*
|
||||
* Verify a signature (sig, siglen) on a message (m, mlen) with a given
|
||||
* public key (pk).
|
||||
*
|
||||
* sig[], m[] and pk[] may overlap each other arbitrarily.
|
||||
*
|
||||
* Return value: 0 on success, -1 on error.
|
||||
*
|
||||
* Note: This implementation follows the reference implementation in PQClean
|
||||
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512
|
||||
* verbatim except for the sections that are marked otherwise.
|
||||
*/
|
||||
int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
const uint8_t *sig, size_t siglen,
|
||||
const uint8_t *m, size_t mlen, const uint8_t *pk);
|
||||
@@ -1,582 +0,0 @@
|
||||
/*
|
||||
* RPO implementation.
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
/* ================================================================================================
|
||||
* Modular Arithmetic
|
||||
*/
|
||||
|
||||
#define P 0xFFFFFFFF00000001
|
||||
#define M 12289
|
||||
|
||||
// From https://github.com/ncw/iprime/blob/master/mod_math_noasm.go
|
||||
static uint64_t add_mod_p(uint64_t a, uint64_t b)
|
||||
{
|
||||
a = P - a;
|
||||
uint64_t res = b - a;
|
||||
if (b < a)
|
||||
res += P;
|
||||
return res;
|
||||
}
|
||||
|
||||
static uint64_t sub_mod_p(uint64_t a, uint64_t b)
|
||||
{
|
||||
uint64_t r = a - b;
|
||||
if (a < b)
|
||||
r += P;
|
||||
return r;
|
||||
}
|
||||
|
||||
static uint64_t reduce_mod_p(uint64_t b, uint64_t a)
|
||||
{
|
||||
uint32_t d = b >> 32,
|
||||
c = b;
|
||||
if (a >= P)
|
||||
a -= P;
|
||||
a = sub_mod_p(a, c);
|
||||
a = sub_mod_p(a, d);
|
||||
a = add_mod_p(a, ((uint64_t)c) << 32);
|
||||
return a;
|
||||
}
|
||||
|
||||
static uint64_t mult_mod_p(uint64_t x, uint64_t y)
|
||||
{
|
||||
uint32_t a = x,
|
||||
b = x >> 32,
|
||||
c = y,
|
||||
d = y >> 32;
|
||||
|
||||
/* first synthesize the product using 32*32 -> 64 bit multiplies */
|
||||
x = b * (uint64_t)c; /* b*c */
|
||||
y = a * (uint64_t)d; /* a*d */
|
||||
uint64_t e = a * (uint64_t)c, /* a*c */
|
||||
f = b * (uint64_t)d, /* b*d */
|
||||
t;
|
||||
|
||||
x += y; /* b*c + a*d */
|
||||
/* carry? */
|
||||
if (x < y)
|
||||
f += 1LL << 32; /* carry into upper 32 bits - can't overflow */
|
||||
|
||||
t = x << 32;
|
||||
e += t; /* a*c + LSW(b*c + a*d) */
|
||||
/* carry? */
|
||||
if (e < t)
|
||||
f += 1; /* carry into upper 64 bits - can't overflow*/
|
||||
t = x >> 32;
|
||||
f += t; /* b*d + MSW(b*c + a*d) */
|
||||
/* can't overflow */
|
||||
|
||||
/* now reduce: (b*d + MSW(b*c + a*d), a*c + LSW(b*c + a*d)) */
|
||||
return reduce_mod_p(f, e);
|
||||
}
|
||||
|
||||
/* ================================================================================================
|
||||
* RPO128 Permutation
|
||||
*/
|
||||
|
||||
#define STATE_WIDTH 12
|
||||
#define NUM_ROUNDS 7
|
||||
|
||||
/*
|
||||
* MDS matrix
|
||||
*/
|
||||
static const uint64_t MDS[12][12] = {
|
||||
{ 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8 },
|
||||
{ 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21 },
|
||||
{ 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22 },
|
||||
{ 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6 },
|
||||
{ 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7 },
|
||||
{ 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9 },
|
||||
{ 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10 },
|
||||
{ 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13 },
|
||||
{ 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26 },
|
||||
{ 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8 },
|
||||
{ 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23 },
|
||||
{ 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7 },
|
||||
};
|
||||
|
||||
/*
|
||||
* Round constants.
|
||||
*/
|
||||
static const uint64_t ARK1[7][12] = {
|
||||
{
|
||||
5789762306288267392ULL,
|
||||
6522564764413701783ULL,
|
||||
17809893479458208203ULL,
|
||||
107145243989736508ULL,
|
||||
6388978042437517382ULL,
|
||||
15844067734406016715ULL,
|
||||
9975000513555218239ULL,
|
||||
3344984123768313364ULL,
|
||||
9959189626657347191ULL,
|
||||
12960773468763563665ULL,
|
||||
9602914297752488475ULL,
|
||||
16657542370200465908ULL,
|
||||
},
|
||||
{
|
||||
12987190162843096997ULL,
|
||||
653957632802705281ULL,
|
||||
4441654670647621225ULL,
|
||||
4038207883745915761ULL,
|
||||
5613464648874830118ULL,
|
||||
13222989726778338773ULL,
|
||||
3037761201230264149ULL,
|
||||
16683759727265180203ULL,
|
||||
8337364536491240715ULL,
|
||||
3227397518293416448ULL,
|
||||
8110510111539674682ULL,
|
||||
2872078294163232137ULL,
|
||||
},
|
||||
{
|
||||
18072785500942327487ULL,
|
||||
6200974112677013481ULL,
|
||||
17682092219085884187ULL,
|
||||
10599526828986756440ULL,
|
||||
975003873302957338ULL,
|
||||
8264241093196931281ULL,
|
||||
10065763900435475170ULL,
|
||||
2181131744534710197ULL,
|
||||
6317303992309418647ULL,
|
||||
1401440938888741532ULL,
|
||||
8884468225181997494ULL,
|
||||
13066900325715521532ULL,
|
||||
},
|
||||
{
|
||||
5674685213610121970ULL,
|
||||
5759084860419474071ULL,
|
||||
13943282657648897737ULL,
|
||||
1352748651966375394ULL,
|
||||
17110913224029905221ULL,
|
||||
1003883795902368422ULL,
|
||||
4141870621881018291ULL,
|
||||
8121410972417424656ULL,
|
||||
14300518605864919529ULL,
|
||||
13712227150607670181ULL,
|
||||
17021852944633065291ULL,
|
||||
6252096473787587650ULL,
|
||||
},
|
||||
{
|
||||
4887609836208846458ULL,
|
||||
3027115137917284492ULL,
|
||||
9595098600469470675ULL,
|
||||
10528569829048484079ULL,
|
||||
7864689113198939815ULL,
|
||||
17533723827845969040ULL,
|
||||
5781638039037710951ULL,
|
||||
17024078752430719006ULL,
|
||||
109659393484013511ULL,
|
||||
7158933660534805869ULL,
|
||||
2955076958026921730ULL,
|
||||
7433723648458773977ULL,
|
||||
},
|
||||
{
|
||||
16308865189192447297ULL,
|
||||
11977192855656444890ULL,
|
||||
12532242556065780287ULL,
|
||||
14594890931430968898ULL,
|
||||
7291784239689209784ULL,
|
||||
5514718540551361949ULL,
|
||||
10025733853830934803ULL,
|
||||
7293794580341021693ULL,
|
||||
6728552937464861756ULL,
|
||||
6332385040983343262ULL,
|
||||
13277683694236792804ULL,
|
||||
2600778905124452676ULL,
|
||||
},
|
||||
{
|
||||
7123075680859040534ULL,
|
||||
1034205548717903090ULL,
|
||||
7717824418247931797ULL,
|
||||
3019070937878604058ULL,
|
||||
11403792746066867460ULL,
|
||||
10280580802233112374ULL,
|
||||
337153209462421218ULL,
|
||||
13333398568519923717ULL,
|
||||
3596153696935337464ULL,
|
||||
8104208463525993784ULL,
|
||||
14345062289456085693ULL,
|
||||
17036731477169661256ULL,
|
||||
}};
|
||||
|
||||
const uint64_t ARK2[7][12] = {
|
||||
{
|
||||
6077062762357204287ULL,
|
||||
15277620170502011191ULL,
|
||||
5358738125714196705ULL,
|
||||
14233283787297595718ULL,
|
||||
13792579614346651365ULL,
|
||||
11614812331536767105ULL,
|
||||
14871063686742261166ULL,
|
||||
10148237148793043499ULL,
|
||||
4457428952329675767ULL,
|
||||
15590786458219172475ULL,
|
||||
10063319113072092615ULL,
|
||||
14200078843431360086ULL,
|
||||
},
|
||||
{
|
||||
6202948458916099932ULL,
|
||||
17690140365333231091ULL,
|
||||
3595001575307484651ULL,
|
||||
373995945117666487ULL,
|
||||
1235734395091296013ULL,
|
||||
14172757457833931602ULL,
|
||||
707573103686350224ULL,
|
||||
15453217512188187135ULL,
|
||||
219777875004506018ULL,
|
||||
17876696346199469008ULL,
|
||||
17731621626449383378ULL,
|
||||
2897136237748376248ULL,
|
||||
},
|
||||
{
|
||||
8023374565629191455ULL,
|
||||
15013690343205953430ULL,
|
||||
4485500052507912973ULL,
|
||||
12489737547229155153ULL,
|
||||
9500452585969030576ULL,
|
||||
2054001340201038870ULL,
|
||||
12420704059284934186ULL,
|
||||
355990932618543755ULL,
|
||||
9071225051243523860ULL,
|
||||
12766199826003448536ULL,
|
||||
9045979173463556963ULL,
|
||||
12934431667190679898ULL,
|
||||
},
|
||||
{
|
||||
18389244934624494276ULL,
|
||||
16731736864863925227ULL,
|
||||
4440209734760478192ULL,
|
||||
17208448209698888938ULL,
|
||||
8739495587021565984ULL,
|
||||
17000774922218161967ULL,
|
||||
13533282547195532087ULL,
|
||||
525402848358706231ULL,
|
||||
16987541523062161972ULL,
|
||||
5466806524462797102ULL,
|
||||
14512769585918244983ULL,
|
||||
10973956031244051118ULL,
|
||||
},
|
||||
{
|
||||
6982293561042362913ULL,
|
||||
14065426295947720331ULL,
|
||||
16451845770444974180ULL,
|
||||
7139138592091306727ULL,
|
||||
9012006439959783127ULL,
|
||||
14619614108529063361ULL,
|
||||
1394813199588124371ULL,
|
||||
4635111139507788575ULL,
|
||||
16217473952264203365ULL,
|
||||
10782018226466330683ULL,
|
||||
6844229992533662050ULL,
|
||||
7446486531695178711ULL,
|
||||
},
|
||||
{
|
||||
3736792340494631448ULL,
|
||||
577852220195055341ULL,
|
||||
6689998335515779805ULL,
|
||||
13886063479078013492ULL,
|
||||
14358505101923202168ULL,
|
||||
7744142531772274164ULL,
|
||||
16135070735728404443ULL,
|
||||
12290902521256031137ULL,
|
||||
12059913662657709804ULL,
|
||||
16456018495793751911ULL,
|
||||
4571485474751953524ULL,
|
||||
17200392109565783176ULL,
|
||||
},
|
||||
{
|
||||
17130398059294018733ULL,
|
||||
519782857322261988ULL,
|
||||
9625384390925085478ULL,
|
||||
1664893052631119222ULL,
|
||||
7629576092524553570ULL,
|
||||
3485239601103661425ULL,
|
||||
9755891797164033838ULL,
|
||||
15218148195153269027ULL,
|
||||
16460604813734957368ULL,
|
||||
9643968136937729763ULL,
|
||||
3611348709641382851ULL,
|
||||
18256379591337759196ULL,
|
||||
},
|
||||
};
|
||||
|
||||
static void apply_sbox(uint64_t *const state)
|
||||
{
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
uint64_t t2 = mult_mod_p(*(state + i), *(state + i));
|
||||
uint64_t t4 = mult_mod_p(t2, t2);
|
||||
|
||||
*(state + i) = mult_mod_p(*(state + i), mult_mod_p(t2, t4));
|
||||
}
|
||||
}
|
||||
|
||||
static void apply_mds(uint64_t *state)
|
||||
{
|
||||
uint64_t res[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
res[i] = 0;
|
||||
}
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
for (uint64_t j = 0; j < STATE_WIDTH; j++)
|
||||
{
|
||||
res[i] = add_mod_p(res[i], mult_mod_p(MDS[i][j], *(state + j)));
|
||||
}
|
||||
}
|
||||
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
*(state + i) = res[i];
|
||||
}
|
||||
}
|
||||
|
||||
static void apply_constants(uint64_t *const state, const uint64_t *ark)
|
||||
{
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
*(state + i) = add_mod_p(*(state + i), *(ark + i));
|
||||
}
|
||||
}
|
||||
|
||||
static void exp_acc(const uint64_t m, const uint64_t *base, const uint64_t *tail, uint64_t *const res)
|
||||
{
|
||||
for (uint64_t i = 0; i < m; i++)
|
||||
{
|
||||
for (uint64_t j = 0; j < STATE_WIDTH; j++)
|
||||
{
|
||||
if (i == 0)
|
||||
{
|
||||
*(res + j) = mult_mod_p(*(base + j), *(base + j));
|
||||
}
|
||||
else
|
||||
{
|
||||
*(res + j) = mult_mod_p(*(res + j), *(res + j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
*(res + i) = mult_mod_p(*(res + i), *(tail + i));
|
||||
}
|
||||
}
|
||||
|
||||
static void apply_inv_sbox(uint64_t *const state)
|
||||
{
|
||||
uint64_t t1[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t1[i] = 0;
|
||||
}
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t1[i] = mult_mod_p(*(state + i), *(state + i));
|
||||
}
|
||||
|
||||
uint64_t t2[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t2[i] = 0;
|
||||
}
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t2[i] = mult_mod_p(t1[i], t1[i]);
|
||||
}
|
||||
|
||||
uint64_t t3[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t3[i] = 0;
|
||||
}
|
||||
exp_acc(3, t2, t2, t3);
|
||||
|
||||
uint64_t t4[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t4[i] = 0;
|
||||
}
|
||||
exp_acc(6, t3, t3, t4);
|
||||
|
||||
uint64_t tmp[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
tmp[i] = 0;
|
||||
}
|
||||
exp_acc(12, t4, t4, tmp);
|
||||
|
||||
uint64_t t5[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t5[i] = 0;
|
||||
}
|
||||
exp_acc(6, tmp, t3, t5);
|
||||
|
||||
uint64_t t6[STATE_WIDTH];
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
t6[i] = 0;
|
||||
}
|
||||
exp_acc(31, t5, t5, t6);
|
||||
|
||||
for (uint64_t i = 0; i < STATE_WIDTH; i++)
|
||||
{
|
||||
uint64_t a = mult_mod_p(mult_mod_p(t6[i], t6[i]), t5[i]);
|
||||
a = mult_mod_p(a, a);
|
||||
a = mult_mod_p(a, a);
|
||||
uint64_t b = mult_mod_p(mult_mod_p(t1[i], t2[i]), *(state + i));
|
||||
|
||||
*(state + i) = mult_mod_p(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
static void apply_round(uint64_t *const state, const uint64_t round)
|
||||
{
|
||||
apply_mds(state);
|
||||
apply_constants(state, ARK1[round]);
|
||||
apply_sbox(state);
|
||||
|
||||
apply_mds(state);
|
||||
apply_constants(state, ARK2[round]);
|
||||
apply_inv_sbox(state);
|
||||
}
|
||||
|
||||
static void apply_permutation(uint64_t *state)
|
||||
{
|
||||
for (uint64_t i = 0; i < NUM_ROUNDS; i++)
|
||||
{
|
||||
apply_round(state, i);
|
||||
}
|
||||
}
|
||||
|
||||
/* ================================================================================================
|
||||
* RPO128 implementation. This is supposed to substitute SHAKE256 in the hash-to-point algorithm.
|
||||
*/
|
||||
|
||||
#include "rpo.h"
|
||||
|
||||
void rpo128_init(rpo128_context *rc)
|
||||
{
|
||||
rc->dptr = 32;
|
||||
|
||||
memset(rc->st.A, 0, sizeof rc->st.A);
|
||||
}
|
||||
|
||||
void rpo128_absorb(rpo128_context *rc, const uint8_t *in, size_t len)
|
||||
{
|
||||
size_t dptr;
|
||||
|
||||
dptr = (size_t)rc->dptr;
|
||||
while (len > 0)
|
||||
{
|
||||
size_t clen, u;
|
||||
|
||||
/* 136 * 8 = 1088 bit for the rate portion in the case of SHAKE256
|
||||
* For RPO, this is 64 * 8 = 512 bits
|
||||
* The capacity for SHAKE256 is at the end while for RPO128 it is at the beginning
|
||||
*/
|
||||
clen = 96 - dptr;
|
||||
if (clen > len)
|
||||
{
|
||||
clen = len;
|
||||
}
|
||||
|
||||
for (u = 0; u < clen; u++)
|
||||
{
|
||||
rc->st.dbuf[dptr + u] = in[u];
|
||||
}
|
||||
|
||||
dptr += clen;
|
||||
in += clen;
|
||||
len -= clen;
|
||||
if (dptr == 96)
|
||||
{
|
||||
apply_permutation(rc->st.A);
|
||||
dptr = 32;
|
||||
}
|
||||
}
|
||||
rc->dptr = dptr;
|
||||
}
|
||||
|
||||
void rpo128_finalize(rpo128_context *rc)
|
||||
{
|
||||
// Set dptr to the end of the buffer, so that first call to extract will call the permutation.
|
||||
rc->dptr = 96;
|
||||
}
|
||||
|
||||
void rpo128_squeeze(rpo128_context *rc, uint8_t *out, size_t len)
|
||||
{
|
||||
size_t dptr;
|
||||
|
||||
dptr = (size_t)rc->dptr;
|
||||
while (len > 0)
|
||||
{
|
||||
size_t clen;
|
||||
|
||||
if (dptr == 96)
|
||||
{
|
||||
apply_permutation(rc->st.A);
|
||||
dptr = 32;
|
||||
}
|
||||
clen = 96 - dptr;
|
||||
if (clen > len)
|
||||
{
|
||||
clen = len;
|
||||
}
|
||||
len -= clen;
|
||||
|
||||
memcpy(out, rc->st.dbuf + dptr, clen);
|
||||
dptr += clen;
|
||||
out += clen;
|
||||
}
|
||||
rc->dptr = dptr;
|
||||
}
|
||||
|
||||
void rpo128_release(rpo128_context *rc)
|
||||
{
|
||||
memset(rc->st.A, 0, sizeof rc->st.A);
|
||||
rc->dptr = 32;
|
||||
}
|
||||
|
||||
/* ================================================================================================
|
||||
* Hash-to-Point algorithm implementation based on RPO128
|
||||
*/
|
||||
|
||||
void PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(rpo128_context *rc, uint16_t *x, unsigned logn)
|
||||
{
|
||||
/*
|
||||
* This implementation avoids the rejection sampling step needed in the
|
||||
* per-the-spec implementation. It uses a remark in https://falcon-sign.info/falcon.pdf
|
||||
* page 31, which argues that the current variant is secure for the parameters set by NIST.
|
||||
* Avoiding the rejection-sampling step leads to an implementation that is constant-time.
|
||||
* TODO: Check that the current implementation is indeed constant-time.
|
||||
*/
|
||||
size_t n;
|
||||
|
||||
n = (size_t)1 << logn;
|
||||
while (n > 0)
|
||||
{
|
||||
uint8_t buf[8];
|
||||
uint64_t w;
|
||||
|
||||
rpo128_squeeze(rc, (void *)buf, sizeof buf);
|
||||
w = ((uint64_t)(buf[7]) << 56) |
|
||||
((uint64_t)(buf[6]) << 48) |
|
||||
((uint64_t)(buf[5]) << 40) |
|
||||
((uint64_t)(buf[4]) << 32) |
|
||||
((uint64_t)(buf[3]) << 24) |
|
||||
((uint64_t)(buf[2]) << 16) |
|
||||
((uint64_t)(buf[1]) << 8) |
|
||||
((uint64_t)(buf[0]));
|
||||
|
||||
w %= M;
|
||||
|
||||
*x++ = (uint16_t)w;
|
||||
n--;
|
||||
}
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
/* ================================================================================================
|
||||
* RPO hashing algorithm related structs and methods.
|
||||
*/
|
||||
|
||||
/*
|
||||
* RPO128 context.
|
||||
*
|
||||
* This structure is used by the hashing API. It is composed of an internal state that can be
|
||||
* viewed as either:
|
||||
* 1. 12 field elements in the Miden VM.
|
||||
* 2. 96 bytes.
|
||||
*
|
||||
* The first view is used for the internal state in the context of the RPO hashing algorithm. The
|
||||
* second view is used for the buffer used to absorb the data to be hashed.
|
||||
*
|
||||
* The pointer to the buffer is updated as the data is absorbed.
|
||||
*
|
||||
* 'rpo128_context' must be initialized with rpo128_init() before first use.
|
||||
*/
|
||||
typedef struct
|
||||
{
|
||||
union
|
||||
{
|
||||
uint64_t A[12];
|
||||
uint8_t dbuf[96];
|
||||
} st;
|
||||
uint64_t dptr;
|
||||
} rpo128_context;
|
||||
|
||||
/*
|
||||
* Initializes an RPO state
|
||||
*/
|
||||
void rpo128_init(rpo128_context *rc);
|
||||
|
||||
/*
|
||||
* Absorbs an array of bytes of length 'len' into the state.
|
||||
*/
|
||||
void rpo128_absorb(rpo128_context *rc, const uint8_t *in, size_t len);
|
||||
|
||||
/*
|
||||
* Squeezes an array of bytes of length 'len' from the state.
|
||||
*/
|
||||
void rpo128_squeeze(rpo128_context *rc, uint8_t *out, size_t len);
|
||||
|
||||
/*
|
||||
* Finalizes the state in preparation for squeezing.
|
||||
*
|
||||
* This function should be called after all the data has been absorbed.
|
||||
*
|
||||
* Note that the current implementation does not perform any sort of padding for domain separation
|
||||
* purposes. The reason being that, for our purposes, we always perform the following sequence:
|
||||
* 1. Absorb a Nonce (which is always 40 bytes packed as 8 field elements).
|
||||
* 2. Absorb the message (which is always 4 field elements).
|
||||
* 3. Call finalize.
|
||||
* 4. Squeeze the output.
|
||||
* 5. Call release.
|
||||
*/
|
||||
void rpo128_finalize(rpo128_context *rc);
|
||||
|
||||
/*
|
||||
* Releases the state.
|
||||
*
|
||||
* This function should be called after the squeeze operation is finished.
|
||||
*/
|
||||
void rpo128_release(rpo128_context *rc);
|
||||
|
||||
/* ================================================================================================
|
||||
* Hash-to-Point algorithm for signature generation and signature verification.
|
||||
*/
|
||||
|
||||
/*
|
||||
* Hash-to-Point algorithm.
|
||||
*
|
||||
* This function generates a point in Z_q[x]/(phi) from a given message.
|
||||
*
|
||||
* It takes a finalized rpo128_context as input and it generates the coefficients of the polynomial
|
||||
* representing the point. The coefficients are stored in the array 'x'. The number of coefficients
|
||||
* is given by 'logn', which must in our case is 512.
|
||||
*/
|
||||
void PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(rpo128_context *rc, uint16_t *x, unsigned logn);
|
||||
@@ -1,190 +0,0 @@
|
||||
use libc::c_int;
|
||||
|
||||
// C IMPLEMENTATION INTERFACE
|
||||
// ================================================================================================
|
||||
|
||||
#[link(name = "rpo_falcon512", kind = "static")]
|
||||
extern "C" {
|
||||
/// Generate a new key pair. Public key goes into pk[], private key in sk[].
|
||||
/// Key sizes are exact (in bytes):
|
||||
/// - public (pk): 897
|
||||
/// - private (sk): 1281
|
||||
///
|
||||
/// Return value: 0 on success, -1 on error.
|
||||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(pk: *mut u8, sk: *mut u8) -> c_int;
|
||||
|
||||
/// Generate a new key pair from seed. Public key goes into pk[], private key in sk[].
|
||||
/// Key sizes are exact (in bytes):
|
||||
/// - public (pk): 897
|
||||
/// - private (sk): 1281
|
||||
///
|
||||
/// Return value: 0 on success, -1 on error.
|
||||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
pk: *mut u8,
|
||||
sk: *mut u8,
|
||||
seed: *const u8,
|
||||
) -> c_int;
|
||||
|
||||
/// Compute a signature on a provided message (m, mlen), with a given private key (sk).
|
||||
/// Signature is written in sig[], with length written into *siglen. Signature length is
|
||||
/// variable; maximum signature length (in bytes) is 666.
|
||||
///
|
||||
/// sig[], m[] and sk[] may overlap each other arbitrarily.
|
||||
///
|
||||
/// Return value: 0 on success, -1 on error.
|
||||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
sig: *mut u8,
|
||||
siglen: *mut usize,
|
||||
m: *const u8,
|
||||
mlen: usize,
|
||||
sk: *const u8,
|
||||
) -> c_int;
|
||||
|
||||
// TEST HELPERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Verify a signature (sig, siglen) on a message (m, mlen) with a given public key (pk).
|
||||
///
|
||||
/// sig[], m[] and pk[] may overlap each other arbitrarily.
|
||||
///
|
||||
/// Return value: 0 on success, -1 on error.
|
||||
#[cfg(test)]
|
||||
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
sig: *const u8,
|
||||
siglen: usize,
|
||||
m: *const u8,
|
||||
mlen: usize,
|
||||
pk: *const u8,
|
||||
) -> c_int;
|
||||
|
||||
/// Hash-to-Point algorithm.
|
||||
///
|
||||
/// This function generates a point in Z_q[x]/(phi) from a given message.
|
||||
///
|
||||
/// It takes a finalized rpo128_context as input and it generates the coefficients of the polynomial
|
||||
/// representing the point. The coefficients are stored in the array 'x'. The number of coefficients
|
||||
/// is given by 'logn', which must in our case is 512.
|
||||
#[cfg(test)]
|
||||
pub fn PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(
|
||||
rc: *mut Rpo128Context,
|
||||
x: *mut u16,
|
||||
logn: usize,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn rpo128_init(sc: *mut Rpo128Context);
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn rpo128_absorb(
|
||||
sc: *mut Rpo128Context,
|
||||
data: *const ::std::os::raw::c_void,
|
||||
len: libc::size_t,
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn rpo128_finalize(sc: *mut Rpo128Context);
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[cfg(test)]
|
||||
pub struct Rpo128Context {
|
||||
pub content: [u64; 13usize],
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use rand_utils::{rand_array, rand_value, rand_vector};
|
||||
|
||||
use super::*;
|
||||
use crate::dsa::rpo_falcon512::{NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
|
||||
|
||||
#[test]
|
||||
fn falcon_ffi() {
|
||||
unsafe {
|
||||
//let mut rng = rand::thread_rng();
|
||||
|
||||
// --- generate a key pair from a seed ----------------------------
|
||||
|
||||
let mut pk = [0u8; PK_LEN];
|
||||
let mut sk = [0u8; SK_LEN];
|
||||
let seed: [u8; NONCE_LEN] = rand_array();
|
||||
|
||||
assert_eq!(
|
||||
0,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
pk.as_mut_ptr(),
|
||||
sk.as_mut_ptr(),
|
||||
seed.as_ptr()
|
||||
)
|
||||
);
|
||||
|
||||
// --- sign a message and make sure it verifies -------------------
|
||||
|
||||
let mlen: usize = rand_value::<u16>() as usize;
|
||||
let msg: Vec<u8> = rand_vector(mlen);
|
||||
let mut detached_sig = [0u8; NONCE_LEN + SIG_LEN];
|
||||
let mut siglen = 0;
|
||||
|
||||
assert_eq!(
|
||||
0,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
detached_sig.as_mut_ptr(),
|
||||
&mut siglen as *mut usize,
|
||||
msg.as_ptr(),
|
||||
msg.len(),
|
||||
sk.as_ptr()
|
||||
)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
0,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
detached_sig.as_ptr(),
|
||||
siglen,
|
||||
msg.as_ptr(),
|
||||
msg.len(),
|
||||
pk.as_ptr()
|
||||
)
|
||||
);
|
||||
|
||||
// --- check verification of different signature ------------------
|
||||
|
||||
assert_eq!(
|
||||
-1,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
detached_sig.as_ptr(),
|
||||
siglen,
|
||||
msg.as_ptr(),
|
||||
msg.len() - 1,
|
||||
pk.as_ptr()
|
||||
)
|
||||
);
|
||||
|
||||
// --- check verification against a different pub key -------------
|
||||
|
||||
let mut pk_alt = [0u8; PK_LEN];
|
||||
let mut sk_alt = [0u8; SK_LEN];
|
||||
assert_eq!(
|
||||
0,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
||||
pk_alt.as_mut_ptr(),
|
||||
sk_alt.as_mut_ptr()
|
||||
)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
-1,
|
||||
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
|
||||
detached_sig.as_ptr(),
|
||||
siglen,
|
||||
msg.as_ptr(),
|
||||
msg.len(),
|
||||
pk_alt.as_ptr()
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
70
src/dsa/rpo_falcon512/hash_to_point.rs
Normal file
70
src/dsa/rpo_falcon512/hash_to_point.rs
Normal file
@@ -0,0 +1,70 @@
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use num::Zero;
|
||||
|
||||
use super::{math::FalconFelt, Nonce, Polynomial, Rpo256, Word, MODULUS, N, ZERO};
|
||||
|
||||
// HASH-TO-POINT FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
|
||||
/// nonce using RPO256.
|
||||
pub fn hash_to_point_rpo256(message: Word, nonce: &Nonce) -> Polynomial<FalconFelt> {
|
||||
let mut state = [ZERO; Rpo256::STATE_WIDTH];
|
||||
|
||||
// absorb the nonce into the state
|
||||
let nonce_elements = nonce.to_elements();
|
||||
for (&n, s) in nonce_elements.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
||||
*s = n;
|
||||
}
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
|
||||
// absorb message into the state
|
||||
for (&m, s) in message.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
||||
*s = m;
|
||||
}
|
||||
|
||||
// squeeze the coefficients of the polynomial
|
||||
let mut i = 0;
|
||||
let mut res = [FalconFelt::zero(); N];
|
||||
for _ in 0..64 {
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
for a in &state[Rpo256::RATE_RANGE] {
|
||||
res[i] = FalconFelt::new((a.as_int() % MODULUS as u64) as i16);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Polynomial::new(res.to_vec())
|
||||
}
|
||||
|
||||
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
|
||||
/// nonce using SHAKE256. This is the hash-to-point algorithm used in the reference implementation.
|
||||
#[allow(dead_code)]
|
||||
pub fn hash_to_point_shake256(message: &[u8], nonce: &Nonce) -> Polynomial<FalconFelt> {
|
||||
use sha3::{
|
||||
digest::{ExtendableOutput, Update, XofReader},
|
||||
Shake256,
|
||||
};
|
||||
|
||||
let mut data = vec![];
|
||||
data.extend_from_slice(nonce.as_bytes());
|
||||
data.extend_from_slice(message);
|
||||
const K: u32 = (1u32 << 16) / MODULUS as u32;
|
||||
|
||||
let mut hasher = Shake256::default();
|
||||
hasher.update(&data);
|
||||
let mut reader = hasher.finalize_xof();
|
||||
|
||||
let mut coefficients: Vec<FalconFelt> = Vec::with_capacity(N);
|
||||
while coefficients.len() != N {
|
||||
let mut randomness = [0u8; 2];
|
||||
reader.read(&mut randomness);
|
||||
let t = ((randomness[0] as u32) << 8) | (randomness[1] as u32);
|
||||
if t < K * MODULUS as u32 {
|
||||
coefficients.push(FalconFelt::new((t % MODULUS as u32) as i16));
|
||||
}
|
||||
}
|
||||
|
||||
Polynomial { coefficients }
|
||||
}
|
||||
@@ -1,232 +0,0 @@
|
||||
#[cfg(feature = "std")]
|
||||
use super::{ffi, NonceBytes, NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
|
||||
use super::{
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, FalconError, Polynomial,
|
||||
PublicKeyBytes, Rpo256, SecretKeyBytes, Serializable, Signature, Word,
|
||||
};
|
||||
|
||||
// PUBLIC KEY
|
||||
// ================================================================================================
|
||||
|
||||
/// A public key for verifying signatures.
|
||||
///
|
||||
/// The public key is a [Word] (i.e., 4 field elements) that is the hash of the coefficients of
|
||||
/// the polynomial representing the raw bytes of the expanded public key.
|
||||
///
|
||||
/// For Falcon-512, the first byte of the expanded public key is always equal to log2(512) i.e., 9.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct PublicKey(Word);
|
||||
|
||||
impl PublicKey {
|
||||
/// Returns a new [PublicKey] which is a commitment to the provided expanded public key.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the decoding of the public key fails.
|
||||
pub fn new(pk: PublicKeyBytes) -> Result<Self, FalconError> {
|
||||
let h = Polynomial::from_pub_key(&pk)?;
|
||||
let pk_felts = h.to_elements();
|
||||
let pk_digest = Rpo256::hash_elements(&pk_felts).into();
|
||||
Ok(Self(pk_digest))
|
||||
}
|
||||
|
||||
/// Verifies the provided signature against provided message and this public key.
|
||||
pub fn verify(&self, message: Word, signature: &Signature) -> bool {
|
||||
signature.verify(message, self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PublicKey> for Word {
|
||||
fn from(key: PublicKey) -> Self {
|
||||
key.0
|
||||
}
|
||||
}
|
||||
|
||||
// KEY PAIR
|
||||
// ================================================================================================
|
||||
|
||||
/// A key pair (public and secret keys) for signing messages.
|
||||
///
|
||||
/// The secret key is a byte array of length [PK_LEN].
|
||||
/// The public key is a byte array of length [SK_LEN].
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct KeyPair {
|
||||
public_key: PublicKeyBytes,
|
||||
secret_key: SecretKeyBytes,
|
||||
}
|
||||
|
||||
#[allow(clippy::new_without_default)]
|
||||
impl KeyPair {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Generates a (public_key, secret_key) key pair from OS-provided randomness.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if key generation fails.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn new() -> Result<Self, FalconError> {
|
||||
let mut public_key = [0u8; PK_LEN];
|
||||
let mut secret_key = [0u8; SK_LEN];
|
||||
|
||||
let res = unsafe {
|
||||
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
|
||||
public_key.as_mut_ptr(),
|
||||
secret_key.as_mut_ptr(),
|
||||
)
|
||||
};
|
||||
|
||||
if res == 0 {
|
||||
Ok(Self { public_key, secret_key })
|
||||
} else {
|
||||
Err(FalconError::KeyGenerationFailed)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a (public_key, secret_key) key pair from the provided seed.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if key generation fails.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn from_seed(seed: &NonceBytes) -> Result<Self, FalconError> {
|
||||
let mut public_key = [0u8; PK_LEN];
|
||||
let mut secret_key = [0u8; SK_LEN];
|
||||
|
||||
let res = unsafe {
|
||||
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
|
||||
public_key.as_mut_ptr(),
|
||||
secret_key.as_mut_ptr(),
|
||||
seed.as_ptr(),
|
||||
)
|
||||
};
|
||||
|
||||
if res == 0 {
|
||||
Ok(Self { public_key, secret_key })
|
||||
} else {
|
||||
Err(FalconError::KeyGenerationFailed)
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the public key corresponding to this key pair.
|
||||
pub fn public_key(&self) -> PublicKey {
|
||||
// TODO: memoize public key commitment as computing it requires quite a bit of hashing.
|
||||
// expect() is fine here because we assume that the key pair was constructed correctly.
|
||||
PublicKey::new(self.public_key).expect("invalid key pair")
|
||||
}
|
||||
|
||||
/// Returns the expanded public key corresponding to this key pair.
|
||||
pub fn expanded_public_key(&self) -> PublicKeyBytes {
|
||||
self.public_key
|
||||
}
|
||||
|
||||
// SIGNATURE GENERATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Signs a message with a secret key and a seed.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error of signature generation fails.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn sign(&self, message: Word) -> Result<Signature, FalconError> {
|
||||
let msg = message.iter().flat_map(|e| e.as_int().to_le_bytes()).collect::<Vec<_>>();
|
||||
let msg_len = msg.len();
|
||||
let mut sig = [0_u8; SIG_LEN + NONCE_LEN];
|
||||
let mut sig_len: usize = 0;
|
||||
|
||||
let res = unsafe {
|
||||
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
|
||||
sig.as_mut_ptr(),
|
||||
&mut sig_len as *mut usize,
|
||||
msg.as_ptr(),
|
||||
msg_len,
|
||||
self.secret_key.as_ptr(),
|
||||
)
|
||||
};
|
||||
|
||||
if res == 0 {
|
||||
Ok(Signature {
|
||||
sig,
|
||||
pk: self.public_key,
|
||||
pk_polynomial: Default::default(),
|
||||
sig_polynomial: Default::default(),
|
||||
})
|
||||
} else {
|
||||
Err(FalconError::SigGenerationFailed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for KeyPair {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.public_key);
|
||||
target.write_bytes(&self.secret_key);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for KeyPair {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let public_key: PublicKeyBytes = source.read_array()?;
|
||||
let secret_key: SecretKeyBytes = source.read_array()?;
|
||||
Ok(Self { public_key, secret_key })
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use rand_utils::{rand_array, rand_vector};
|
||||
|
||||
use super::{super::Felt, KeyPair, NonceBytes, Word};
|
||||
|
||||
#[test]
|
||||
fn test_falcon_verification() {
|
||||
// generate random keys
|
||||
let keys = KeyPair::new().unwrap();
|
||||
let pk = keys.public_key();
|
||||
|
||||
// sign a random message
|
||||
let message: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
let signature = keys.sign(message);
|
||||
|
||||
// make sure the signature verifies correctly
|
||||
assert!(pk.verify(message, signature.as_ref().unwrap()));
|
||||
|
||||
// a signature should not verify against a wrong message
|
||||
let message2: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
assert!(!pk.verify(message2, signature.as_ref().unwrap()));
|
||||
|
||||
// a signature should not verify against a wrong public key
|
||||
let keys2 = KeyPair::new().unwrap();
|
||||
assert!(!keys2.public_key().verify(message, signature.as_ref().unwrap()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_falcon_verification_from_seed() {
|
||||
// generate keys from a random seed
|
||||
let seed: NonceBytes = rand_array();
|
||||
let keys = KeyPair::from_seed(&seed).unwrap();
|
||||
let pk = keys.public_key();
|
||||
|
||||
// sign a random message
|
||||
let message: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
let signature = keys.sign(message);
|
||||
|
||||
// make sure the signature verifies correctly
|
||||
assert!(pk.verify(message, signature.as_ref().unwrap()));
|
||||
|
||||
// a signature should not verify against a wrong message
|
||||
let message2: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
|
||||
assert!(!pk.verify(message2, signature.as_ref().unwrap()));
|
||||
|
||||
// a signature should not verify against a wrong public key
|
||||
let keys2 = KeyPair::new().unwrap();
|
||||
assert!(!keys2.public_key().verify(message, signature.as_ref().unwrap()))
|
||||
}
|
||||
}
|
||||
55
src/dsa/rpo_falcon512/keys/mod.rs
Normal file
55
src/dsa/rpo_falcon512/keys/mod.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use super::{
|
||||
math::{FalconFelt, Polynomial},
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, Felt, Serializable, Signature,
|
||||
Word,
|
||||
};
|
||||
|
||||
mod public_key;
|
||||
pub use public_key::{PubKeyPoly, PublicKey};
|
||||
|
||||
mod secret_key;
|
||||
pub use secret_key::SecretKey;
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
use winter_math::FieldElement;
|
||||
use winter_utils::{Deserializable, Serializable};
|
||||
|
||||
use crate::{dsa::rpo_falcon512::SecretKey, Word, ONE};
|
||||
|
||||
#[test]
|
||||
fn test_falcon_verification() {
|
||||
let seed = [0_u8; 32];
|
||||
let mut rng = ChaCha20Rng::from_seed(seed);
|
||||
|
||||
// generate random keys
|
||||
let sk = SecretKey::with_rng(&mut rng);
|
||||
let pk = sk.public_key();
|
||||
|
||||
// test secret key serialization/deserialization
|
||||
let mut buffer = vec![];
|
||||
sk.write_into(&mut buffer);
|
||||
let sk_deserialized = SecretKey::read_from_bytes(&buffer).unwrap();
|
||||
assert_eq!(sk.short_lattice_basis(), sk_deserialized.short_lattice_basis());
|
||||
|
||||
// sign a random message
|
||||
let message: Word = [ONE; 4];
|
||||
let signature = sk.sign_with_rng(message, &mut rng);
|
||||
|
||||
// make sure the signature verifies correctly
|
||||
assert!(pk.verify(message, &signature));
|
||||
|
||||
// a signature should not verify against a wrong message
|
||||
let message2: Word = [ONE.double(); 4];
|
||||
assert!(!pk.verify(message2, &signature));
|
||||
|
||||
// a signature should not verify against a wrong public key
|
||||
let sk2 = SecretKey::with_rng(&mut rng);
|
||||
assert!(!sk2.public_key().verify(message, &signature))
|
||||
}
|
||||
}
|
||||
139
src/dsa/rpo_falcon512/keys/public_key.rs
Normal file
139
src/dsa/rpo_falcon512/keys/public_key.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use alloc::string::ToString;
|
||||
use core::ops::Deref;
|
||||
|
||||
use num::Zero;
|
||||
|
||||
use super::{
|
||||
super::{Rpo256, LOG_N, N, PK_LEN},
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, FalconFelt, Felt, Polynomial,
|
||||
Serializable, Signature, Word,
|
||||
};
|
||||
use crate::dsa::rpo_falcon512::FALCON_ENCODING_BITS;
|
||||
|
||||
// PUBLIC KEY
|
||||
// ================================================================================================
|
||||
|
||||
/// A public key for verifying signatures.
|
||||
///
|
||||
/// The public key is a [Word] (i.e., 4 field elements) that is the hash of the coefficients of
|
||||
/// the polynomial representing the raw bytes of the expanded public key. The hash is computed
|
||||
/// using Rpo256.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct PublicKey(Word);
|
||||
|
||||
impl PublicKey {
|
||||
/// Returns a new [PublicKey] which is a commitment to the provided expanded public key.
|
||||
pub fn new(pub_key: Word) -> Self {
|
||||
Self(pub_key)
|
||||
}
|
||||
|
||||
/// Verifies the provided signature against provided message and this public key.
|
||||
pub fn verify(&self, message: Word, signature: &Signature) -> bool {
|
||||
signature.verify(message, self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PubKeyPoly> for PublicKey {
|
||||
fn from(pk_poly: PubKeyPoly) -> Self {
|
||||
let pk_felts: Polynomial<Felt> = pk_poly.0.into();
|
||||
let pk_digest = Rpo256::hash_elements(&pk_felts.coefficients).into();
|
||||
Self(pk_digest)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PublicKey> for Word {
|
||||
fn from(key: PublicKey) -> Self {
|
||||
key.0
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC KEY POLYNOMIAL
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PubKeyPoly(pub Polynomial<FalconFelt>);
|
||||
|
||||
impl Deref for PubKeyPoly {
|
||||
type Target = Polynomial<FalconFelt>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Polynomial<FalconFelt>> for PubKeyPoly {
|
||||
fn from(pk_poly: Polynomial<FalconFelt>) -> Self {
|
||||
Self(pk_poly)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for &PubKeyPoly {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
let mut buf = [0_u8; PK_LEN];
|
||||
buf[0] = LOG_N;
|
||||
|
||||
let mut acc = 0_u32;
|
||||
let mut acc_len: u32 = 0;
|
||||
|
||||
let mut input_pos = 1;
|
||||
for c in self.0.coefficients.iter() {
|
||||
let c = c.value();
|
||||
acc = (acc << FALCON_ENCODING_BITS) | c as u32;
|
||||
acc_len += FALCON_ENCODING_BITS;
|
||||
while acc_len >= 8 {
|
||||
acc_len -= 8;
|
||||
buf[input_pos] = (acc >> acc_len) as u8;
|
||||
input_pos += 1;
|
||||
}
|
||||
}
|
||||
if acc_len > 0 {
|
||||
buf[input_pos] = (acc >> (8 - acc_len)) as u8;
|
||||
}
|
||||
|
||||
target.write(buf);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for PubKeyPoly {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let buf = source.read_array::<PK_LEN>()?;
|
||||
|
||||
if buf[0] != LOG_N {
|
||||
return Err(DeserializationError::InvalidValue(format!(
|
||||
"Failed to decode public key: expected the first byte to be {LOG_N} but was {}",
|
||||
buf[0]
|
||||
)));
|
||||
}
|
||||
|
||||
let mut acc = 0_u32;
|
||||
let mut acc_len = 0;
|
||||
|
||||
let mut output = [FalconFelt::zero(); N];
|
||||
let mut output_idx = 0;
|
||||
|
||||
for &byte in buf.iter().skip(1) {
|
||||
acc = (acc << 8) | (byte as u32);
|
||||
acc_len += 8;
|
||||
|
||||
if acc_len >= FALCON_ENCODING_BITS {
|
||||
acc_len -= FALCON_ENCODING_BITS;
|
||||
let w = (acc >> acc_len) & 0x3fff;
|
||||
let element = w.try_into().map_err(|err| {
|
||||
DeserializationError::InvalidValue(format!(
|
||||
"Failed to decode public key: {err}"
|
||||
))
|
||||
})?;
|
||||
output[output_idx] = element;
|
||||
output_idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (acc & ((1u32 << acc_len) - 1)) == 0 {
|
||||
Ok(Polynomial::new(output.to_vec()).into())
|
||||
} else {
|
||||
Err(DeserializationError::InvalidValue(
|
||||
"Failed to decode public key: input not fully consumed".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
401
src/dsa/rpo_falcon512/keys/secret_key.rs
Normal file
401
src/dsa/rpo_falcon512/keys/secret_key.rs
Normal file
@@ -0,0 +1,401 @@
|
||||
use alloc::{string::ToString, vec::Vec};
|
||||
|
||||
use num::Complex;
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num::Float;
|
||||
use num_complex::Complex64;
|
||||
use rand::Rng;
|
||||
|
||||
use super::{
|
||||
super::{
|
||||
math::{ffldl, ffsampling, gram, normalize_tree, FalconFelt, FastFft, LdlTree, Polynomial},
|
||||
signature::SignaturePoly,
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, Nonce, Serializable,
|
||||
ShortLatticeBasis, Signature, Word, MODULUS, N, SIGMA, SIG_L2_BOUND,
|
||||
},
|
||||
PubKeyPoly, PublicKey,
|
||||
};
|
||||
use crate::dsa::rpo_falcon512::{
|
||||
hash_to_point::hash_to_point_rpo256, math::ntru_gen, SIG_NONCE_LEN, SK_LEN,
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
const WIDTH_BIG_POLY_COEFFICIENT: usize = 8;
|
||||
const WIDTH_SMALL_POLY_COEFFICIENT: usize = 6;
|
||||
|
||||
// SECRET KEY
|
||||
// ================================================================================================
|
||||
|
||||
/// Represents the secret key for Falcon DSA.
|
||||
///
|
||||
/// The secret key is a quadruple [[g, -f], [G, -F]] of polynomials with integer coefficients. Each
|
||||
/// polynomial is of degree at most N = 512 and computations with these polynomials is done modulo
|
||||
/// the monic irreducible polynomial ϕ = x^N + 1. The secret key is a basis for a lattice and has
|
||||
/// the property of being short with respect to a certain norm and an upper bound appropriate for
|
||||
/// a given security parameter. The public key on the other hand is another basis for the same
|
||||
/// lattice and can be described by a single polynomial h with integer coefficients modulo ϕ.
|
||||
/// The two keys are related by the following relation:
|
||||
///
|
||||
/// 1. h = g /f [mod ϕ][mod p]
|
||||
/// 2. f.G - g.F = p [mod ϕ]
|
||||
///
|
||||
/// where p = 12289 is the Falcon prime. Equation 2 is called the NTRU equation.
|
||||
/// The secret key is generated by first sampling a random pair (f, g) of polynomials using
|
||||
/// an appropriate distribution that yields short but not too short polynomials with integer
|
||||
/// coefficients modulo ϕ. The NTRU equation is then used to find a matching pair (F, G).
|
||||
/// The public key is then derived from the secret key using equation 1.
|
||||
///
|
||||
/// To allow for fast signature generation, the secret key is pre-processed into a more suitable
|
||||
/// form, called the LDL tree, and this allows for fast sampling of short vectors in the lattice
|
||||
/// using Fast Fourier sampling during signature generation (ffSampling algorithm 11 in [1]).
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SecretKey {
|
||||
secret_key: ShortLatticeBasis,
|
||||
tree: LdlTree,
|
||||
}
|
||||
|
||||
#[allow(clippy::new_without_default)]
|
||||
impl SecretKey {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Generates a secret key from OS-provided randomness.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn new() -> Self {
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
let mut rng = StdRng::from_entropy();
|
||||
Self::with_rng(&mut rng)
|
||||
}
|
||||
|
||||
/// Generates a secret_key using the provided random number generator `Rng`.
|
||||
pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
|
||||
let basis = ntru_gen(N, rng);
|
||||
Self::from_short_lattice_basis(basis)
|
||||
}
|
||||
|
||||
/// Given a short basis [[g, -f], [G, -F]], computes the normalized LDL tree i.e., Falcon tree.
|
||||
fn from_short_lattice_basis(basis: ShortLatticeBasis) -> SecretKey {
|
||||
// FFT each polynomial of the short basis.
|
||||
let basis_fft = to_complex_fft(&basis);
|
||||
// compute the Gram matrix.
|
||||
let gram_fft = gram(basis_fft);
|
||||
// construct the LDL tree of the Gram matrix.
|
||||
let mut tree = ffldl(gram_fft);
|
||||
// normalize the leaves of the LDL tree.
|
||||
normalize_tree(&mut tree, SIGMA);
|
||||
Self { secret_key: basis, tree }
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the polynomials of the short lattice basis of this secret key.
|
||||
pub fn short_lattice_basis(&self) -> &ShortLatticeBasis {
|
||||
&self.secret_key
|
||||
}
|
||||
|
||||
/// Returns the public key corresponding to this secret key.
|
||||
pub fn public_key(&self) -> PublicKey {
|
||||
self.compute_pub_key_poly().into()
|
||||
}
|
||||
|
||||
/// Returns the LDL tree associated to this secret key.
|
||||
pub fn tree(&self) -> &LdlTree {
|
||||
&self.tree
|
||||
}
|
||||
|
||||
// SIGNATURE GENERATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Signs a message with this secret key.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn sign(&self, message: Word) -> Signature {
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
let mut rng = StdRng::from_entropy();
|
||||
self.sign_with_rng(message, &mut rng)
|
||||
}
|
||||
|
||||
/// Signs a message with the secret key relying on the provided randomness generator.
|
||||
pub fn sign_with_rng<R: Rng>(&self, message: Word, rng: &mut R) -> Signature {
|
||||
let mut nonce_bytes = [0u8; SIG_NONCE_LEN];
|
||||
rng.fill_bytes(&mut nonce_bytes);
|
||||
let nonce = Nonce::new(nonce_bytes);
|
||||
|
||||
let h = self.compute_pub_key_poly();
|
||||
let c = hash_to_point_rpo256(message, &nonce);
|
||||
let s2 = self.sign_helper(c, rng);
|
||||
|
||||
Signature::new(nonce, h, s2)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Derives the public key corresponding to this secret key using h = g /f [mod ϕ][mod p].
|
||||
pub fn compute_pub_key_poly(&self) -> PubKeyPoly {
|
||||
let g: Polynomial<FalconFelt> = self.secret_key[0].clone().into();
|
||||
let g_fft = g.fft();
|
||||
let minus_f: Polynomial<FalconFelt> = self.secret_key[1].clone().into();
|
||||
let f = -minus_f;
|
||||
let f_fft = f.fft();
|
||||
let h_fft = g_fft.hadamard_div(&f_fft);
|
||||
h_fft.ifft().into()
|
||||
}
|
||||
|
||||
/// Signs a message polynomial with the secret key.
|
||||
///
|
||||
/// Takes a randomness generator implementing `Rng` and message polynomial representing `c`
|
||||
/// the hash-to-point of the message to be signed. It outputs a signature polynomial `s2`.
|
||||
fn sign_helper<R: Rng>(&self, c: Polynomial<FalconFelt>, rng: &mut R) -> SignaturePoly {
|
||||
let one_over_q = 1.0 / (MODULUS as f64);
|
||||
let c_over_q_fft = c.map(|cc| Complex::new(one_over_q * cc.value() as f64, 0.0)).fft();
|
||||
|
||||
// B = [[FFT(g), -FFT(f)], [FFT(G), -FFT(F)]]
|
||||
let [g_fft, minus_f_fft, big_g_fft, minus_big_f_fft] = to_complex_fft(&self.secret_key);
|
||||
let t0 = c_over_q_fft.hadamard_mul(&minus_big_f_fft);
|
||||
let t1 = -c_over_q_fft.hadamard_mul(&minus_f_fft);
|
||||
|
||||
loop {
|
||||
let bold_s = loop {
|
||||
let z = ffsampling(&(t0.clone(), t1.clone()), &self.tree, rng);
|
||||
let t0_min_z0 = t0.clone() - z.0;
|
||||
let t1_min_z1 = t1.clone() - z.1;
|
||||
|
||||
// s = (t-z) * B
|
||||
let s0 = t0_min_z0.hadamard_mul(&g_fft) + t1_min_z1.hadamard_mul(&big_g_fft);
|
||||
let s1 =
|
||||
t0_min_z0.hadamard_mul(&minus_f_fft) + t1_min_z1.hadamard_mul(&minus_big_f_fft);
|
||||
|
||||
// compute the norm of (s0||s1) and note that they are in FFT representation
|
||||
let length_squared: f64 =
|
||||
(s0.coefficients.iter().map(|a| (a * a.conj()).re).sum::<f64>()
|
||||
+ s1.coefficients.iter().map(|a| (a * a.conj()).re).sum::<f64>())
|
||||
/ (N as f64);
|
||||
|
||||
if length_squared > (SIG_L2_BOUND as f64) {
|
||||
continue;
|
||||
}
|
||||
|
||||
break [-s0, s1];
|
||||
};
|
||||
|
||||
let s2 = bold_s[1].ifft();
|
||||
let s2_coef: [i16; N] = s2
|
||||
.coefficients
|
||||
.iter()
|
||||
.map(|a| a.re.round() as i16)
|
||||
.collect::<Vec<i16>>()
|
||||
.try_into()
|
||||
.expect("The number of coefficients should be equal to N");
|
||||
|
||||
if let Ok(s2) = SignaturePoly::try_from(&s2_coef) {
|
||||
return s2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for SecretKey {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
let basis = &self.secret_key;
|
||||
|
||||
// header
|
||||
let n = basis[0].coefficients.len();
|
||||
let l = n.checked_ilog2().unwrap() as u8;
|
||||
let header: u8 = (5 << 4) | l;
|
||||
|
||||
let neg_f = &basis[1];
|
||||
let g = &basis[0];
|
||||
let neg_big_f = &basis[3];
|
||||
|
||||
let mut buffer = Vec::with_capacity(1281);
|
||||
buffer.push(header);
|
||||
|
||||
let f_i8: Vec<i8> = neg_f
|
||||
.coefficients
|
||||
.iter()
|
||||
.map(|&a| FalconFelt::new(-a).balanced_value() as i8)
|
||||
.collect();
|
||||
let f_i8_encoded = encode_i8(&f_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
|
||||
buffer.extend_from_slice(&f_i8_encoded);
|
||||
|
||||
let g_i8: Vec<i8> = g
|
||||
.coefficients
|
||||
.iter()
|
||||
.map(|&a| FalconFelt::new(a).balanced_value() as i8)
|
||||
.collect();
|
||||
let g_i8_encoded = encode_i8(&g_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
|
||||
buffer.extend_from_slice(&g_i8_encoded);
|
||||
|
||||
let big_f_i8: Vec<i8> = neg_big_f
|
||||
.coefficients
|
||||
.iter()
|
||||
.map(|&a| FalconFelt::new(-a).balanced_value() as i8)
|
||||
.collect();
|
||||
let big_f_i8_encoded = encode_i8(&big_f_i8, WIDTH_BIG_POLY_COEFFICIENT).unwrap();
|
||||
buffer.extend_from_slice(&big_f_i8_encoded);
|
||||
target.write_bytes(&buffer);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SecretKey {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let byte_vector: [u8; SK_LEN] = source.read_array()?;
|
||||
|
||||
// check length
|
||||
if byte_vector.len() < 2 {
|
||||
return Err(DeserializationError::InvalidValue("Invalid encoding length: Failed to decode as length is different from the one expected".to_string()));
|
||||
}
|
||||
|
||||
// read fields
|
||||
let header = byte_vector[0];
|
||||
|
||||
// check fixed bits in header
|
||||
if (header >> 4) != 5 {
|
||||
return Err(DeserializationError::InvalidValue("Invalid header format".to_string()));
|
||||
}
|
||||
|
||||
// check log n
|
||||
let logn = (header & 15) as usize;
|
||||
let n = 1 << logn;
|
||||
|
||||
// match against const variant generic parameter
|
||||
if n != N {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"Unsupported Falcon DSA variant".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if byte_vector.len() != SK_LEN {
|
||||
return Err(DeserializationError::InvalidValue("Invalid encoding length: Failed to decode as length is different from the one expected".to_string()));
|
||||
}
|
||||
|
||||
let chunk_size_f = ((n * WIDTH_SMALL_POLY_COEFFICIENT) + 7) >> 3;
|
||||
let chunk_size_g = ((n * WIDTH_SMALL_POLY_COEFFICIENT) + 7) >> 3;
|
||||
let chunk_size_big_f = ((n * WIDTH_BIG_POLY_COEFFICIENT) + 7) >> 3;
|
||||
|
||||
let f = decode_i8(&byte_vector[1..chunk_size_f + 1], WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
|
||||
let g = decode_i8(
|
||||
&byte_vector[chunk_size_f + 1..(chunk_size_f + chunk_size_g + 1)],
|
||||
WIDTH_SMALL_POLY_COEFFICIENT,
|
||||
)
|
||||
.unwrap();
|
||||
let big_f = decode_i8(
|
||||
&byte_vector[(chunk_size_f + chunk_size_g + 1)
|
||||
..(chunk_size_f + chunk_size_g + chunk_size_big_f + 1)],
|
||||
WIDTH_BIG_POLY_COEFFICIENT,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let f = Polynomial::new(f.iter().map(|&c| FalconFelt::new(c.into())).collect());
|
||||
let g = Polynomial::new(g.iter().map(|&c| FalconFelt::new(c.into())).collect());
|
||||
let big_f = Polynomial::new(big_f.iter().map(|&c| FalconFelt::new(c.into())).collect());
|
||||
|
||||
// big_g * f - g * big_f = p (mod X^n + 1)
|
||||
let big_g = g.fft().hadamard_div(&f.fft()).hadamard_mul(&big_f.fft()).ifft();
|
||||
let basis = [
|
||||
g.map(|f| f.balanced_value()),
|
||||
-f.map(|f| f.balanced_value()),
|
||||
big_g.map(|f| f.balanced_value()),
|
||||
-big_f.map(|f| f.balanced_value()),
|
||||
];
|
||||
Ok(Self::from_short_lattice_basis(basis))
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Computes the complex FFT of the secret key polynomials.
|
||||
fn to_complex_fft(basis: &[Polynomial<i16>; 4]) -> [Polynomial<Complex<f64>>; 4] {
|
||||
let [g, f, big_g, big_f] = basis.clone();
|
||||
let g_fft = g.map(|cc| Complex64::new(*cc as f64, 0.0)).fft();
|
||||
let minus_f_fft = f.map(|cc| -Complex64::new(*cc as f64, 0.0)).fft();
|
||||
let big_g_fft = big_g.map(|cc| Complex64::new(*cc as f64, 0.0)).fft();
|
||||
let minus_big_f_fft = big_f.map(|cc| -Complex64::new(*cc as f64, 0.0)).fft();
|
||||
[g_fft, minus_f_fft, big_g_fft, minus_big_f_fft]
|
||||
}
|
||||
|
||||
/// Encodes a sequence of signed integers such that each integer x satisfies |x| < 2^(bits-1)
|
||||
/// for a given parameter bits. bits can take either the value 6 or 8.
|
||||
pub fn encode_i8(x: &[i8], bits: usize) -> Option<Vec<u8>> {
|
||||
let maxv = (1 << (bits - 1)) - 1_usize;
|
||||
let maxv = maxv as i8;
|
||||
let minv = -maxv;
|
||||
|
||||
for &c in x {
|
||||
if c > maxv || c < minv {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let out_len = ((N * bits) + 7) >> 3;
|
||||
let mut buf = vec![0_u8; out_len];
|
||||
|
||||
let mut acc = 0_u32;
|
||||
let mut acc_len = 0;
|
||||
let mask = ((1_u16 << bits) - 1) as u8;
|
||||
|
||||
let mut input_pos = 0;
|
||||
for &c in x {
|
||||
acc = (acc << bits) | (c as u8 & mask) as u32;
|
||||
acc_len += bits;
|
||||
while acc_len >= 8 {
|
||||
acc_len -= 8;
|
||||
buf[input_pos] = (acc >> acc_len) as u8;
|
||||
input_pos += 1;
|
||||
}
|
||||
}
|
||||
if acc_len > 0 {
|
||||
buf[input_pos] = (acc >> (8 - acc_len)) as u8;
|
||||
}
|
||||
|
||||
Some(buf)
|
||||
}
|
||||
|
||||
/// Decodes a sequence of bytes into a sequence of signed integers such that each integer x
|
||||
/// satisfies |x| < 2^(bits-1) for a given parameter bits. bits can take either the value 6 or 8.
|
||||
pub fn decode_i8(buf: &[u8], bits: usize) -> Option<Vec<i8>> {
|
||||
let mut x = [0_i8; N];
|
||||
|
||||
let mut i = 0;
|
||||
let mut j = 0;
|
||||
let mut acc = 0_u32;
|
||||
let mut acc_len = 0;
|
||||
let mask = (1_u32 << bits) - 1;
|
||||
let a = (1 << bits) as u8;
|
||||
let b = ((1 << (bits - 1)) - 1) as u8;
|
||||
|
||||
while i < N {
|
||||
acc = (acc << 8) | (buf[j] as u32);
|
||||
j += 1;
|
||||
acc_len += 8;
|
||||
|
||||
while acc_len >= bits && i < N {
|
||||
acc_len -= bits;
|
||||
let w = (acc >> acc_len) & mask;
|
||||
|
||||
let w = w as u8;
|
||||
|
||||
let z = if w > b { w as i8 - a as i8 } else { w as i8 };
|
||||
|
||||
x[i] = z;
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (acc & ((1u32 << acc_len) - 1)) == 0 {
|
||||
Some(x.to_vec())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
124
src/dsa/rpo_falcon512/math/ffsampling.rs
Normal file
124
src/dsa/rpo_falcon512/math/ffsampling.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use alloc::boxed::Box;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num::Float;
|
||||
use num::{One, Zero};
|
||||
use num_complex::{Complex, Complex64};
|
||||
use rand::Rng;
|
||||
|
||||
use super::{fft::FastFft, polynomial::Polynomial, samplerz::sampler_z};
|
||||
|
||||
const SIGMIN: f64 = 1.2778336969128337;
|
||||
|
||||
/// Computes the Gram matrix. The argument must be a 2x2 matrix
|
||||
/// whose elements are equal-length vectors of complex numbers,
|
||||
/// representing polynomials in FFT domain.
|
||||
pub fn gram(b: [Polynomial<Complex64>; 4]) -> [Polynomial<Complex64>; 4] {
|
||||
const N: usize = 2;
|
||||
let mut g: [Polynomial<Complex<f64>>; 4] =
|
||||
[Polynomial::zero(), Polynomial::zero(), Polynomial::zero(), Polynomial::zero()];
|
||||
for i in 0..N {
|
||||
for j in 0..N {
|
||||
for k in 0..N {
|
||||
g[N * i + j] = g[N * i + j].clone()
|
||||
+ b[N * i + k].hadamard_mul(&b[N * j + k].map(|c| c.conj()));
|
||||
}
|
||||
}
|
||||
}
|
||||
g
|
||||
}
|
||||
|
||||
/// Computes the LDL decomposition of a 2x2 matrix G such that
|
||||
/// L D L* = G
|
||||
/// where D is diagonal, and L is lower-triangular. The elements of the matrices are in FFT domain.
|
||||
pub fn ldl(
|
||||
g: [Polynomial<Complex64>; 4],
|
||||
) -> ([Polynomial<Complex64>; 4], [Polynomial<Complex64>; 4]) {
|
||||
let zero = Polynomial::<Complex64>::one();
|
||||
let one = Polynomial::<Complex64>::zero();
|
||||
|
||||
let l10 = g[2].hadamard_div(&g[0]);
|
||||
let bc = l10.map(|c| c * c.conj());
|
||||
let abc = g[0].hadamard_mul(&bc);
|
||||
let d11 = g[3].clone() - abc;
|
||||
|
||||
let l = [one.clone(), zero.clone(), l10.clone(), one];
|
||||
let d = [g[0].clone(), zero.clone(), zero, d11];
|
||||
(l, d)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LdlTree {
|
||||
Branch(Polynomial<Complex64>, Box<LdlTree>, Box<LdlTree>),
|
||||
Leaf([Complex64; 2]),
|
||||
}
|
||||
|
||||
/// Computes the LDL Tree of G. Corresponds to Algorithm 9 of the specification [1, p.37].
|
||||
/// The argument is a 2x2 matrix of polynomials, given in FFT form.
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
pub fn ffldl(gram_matrix: [Polynomial<Complex64>; 4]) -> LdlTree {
|
||||
let n = gram_matrix[0].coefficients.len();
|
||||
let (l, d) = ldl(gram_matrix);
|
||||
|
||||
if n > 2 {
|
||||
let (d00, d01) = d[0].split_fft();
|
||||
let (d10, d11) = d[3].split_fft();
|
||||
let g0 = [d00.clone(), d01.clone(), d01.map(|c| c.conj()), d00];
|
||||
let g1 = [d10.clone(), d11.clone(), d11.map(|c| c.conj()), d10];
|
||||
LdlTree::Branch(l[2].clone(), Box::new(ffldl(g0)), Box::new(ffldl(g1)))
|
||||
} else {
|
||||
LdlTree::Branch(
|
||||
l[2].clone(),
|
||||
Box::new(LdlTree::Leaf(d[0].clone().coefficients.try_into().unwrap())),
|
||||
Box::new(LdlTree::Leaf(d[3].clone().coefficients.try_into().unwrap())),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalizes the leaves of an LDL tree using a given normalization value `sigma`.
|
||||
pub fn normalize_tree(tree: &mut LdlTree, sigma: f64) {
|
||||
match tree {
|
||||
LdlTree::Branch(_ell, left, right) => {
|
||||
normalize_tree(left, sigma);
|
||||
normalize_tree(right, sigma);
|
||||
},
|
||||
LdlTree::Leaf(vector) => {
|
||||
vector[0] = Complex::new(sigma / vector[0].re.sqrt(), 0.0);
|
||||
vector[1] = Complex64::zero();
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Samples short polynomials using a Falcon tree. Algorithm 11 from the spec [1, p.40].
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
pub fn ffsampling<R: Rng>(
|
||||
t: &(Polynomial<Complex64>, Polynomial<Complex64>),
|
||||
tree: &LdlTree,
|
||||
mut rng: &mut R,
|
||||
) -> (Polynomial<Complex64>, Polynomial<Complex64>) {
|
||||
match tree {
|
||||
LdlTree::Branch(ell, left, right) => {
|
||||
let bold_t1 = t.1.split_fft();
|
||||
let bold_z1 = ffsampling(&bold_t1, right, rng);
|
||||
let z1 = Polynomial::<Complex64>::merge_fft(&bold_z1.0, &bold_z1.1);
|
||||
|
||||
// t0' = t0 + (t1 - z1) * l
|
||||
let t0_prime = t.0.clone() + (t.1.clone() - z1.clone()).hadamard_mul(ell);
|
||||
|
||||
let bold_t0 = t0_prime.split_fft();
|
||||
let bold_z0 = ffsampling(&bold_t0, left, rng);
|
||||
let z0 = Polynomial::<Complex64>::merge_fft(&bold_z0.0, &bold_z0.1);
|
||||
|
||||
(z0, z1)
|
||||
},
|
||||
LdlTree::Leaf(value) => {
|
||||
let z0 = sampler_z(t.0.coefficients[0].re, value[0].re, SIGMIN, &mut rng);
|
||||
let z1 = sampler_z(t.1.coefficients[0].re, value[0].re, SIGMIN, &mut rng);
|
||||
(
|
||||
Polynomial::new(vec![Complex64::new(z0 as f64, 0.0)]),
|
||||
Polynomial::new(vec![Complex64::new(z1 as f64, 0.0)]),
|
||||
)
|
||||
},
|
||||
}
|
||||
}
|
||||
1919
src/dsa/rpo_falcon512/math/fft.rs
Normal file
1919
src/dsa/rpo_falcon512/math/fft.rs
Normal file
File diff suppressed because it is too large
Load Diff
174
src/dsa/rpo_falcon512/math/field.rs
Normal file
174
src/dsa/rpo_falcon512/math/field.rs
Normal file
@@ -0,0 +1,174 @@
|
||||
use alloc::string::String;
|
||||
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
|
||||
use num::{One, Zero};
|
||||
|
||||
use super::{fft::CyclotomicFourier, Inverse, MODULUS};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct FalconFelt(u32);
|
||||
|
||||
impl FalconFelt {
|
||||
pub const fn new(value: i16) -> Self {
|
||||
let gtz_bool = value >= 0;
|
||||
let gtz_int = gtz_bool as i16;
|
||||
let gtz_sign = gtz_int - ((!gtz_bool) as i16);
|
||||
let reduced = gtz_sign * (gtz_sign * value) % MODULUS;
|
||||
let canonical_representative = (reduced + MODULUS * (1 - gtz_int)) as u32;
|
||||
FalconFelt(canonical_representative)
|
||||
}
|
||||
|
||||
pub const fn value(&self) -> i16 {
|
||||
self.0 as i16
|
||||
}
|
||||
|
||||
pub fn balanced_value(&self) -> i16 {
|
||||
let value = self.value();
|
||||
let g = (value > ((MODULUS) / 2)) as i16;
|
||||
value - (MODULUS) * g
|
||||
}
|
||||
|
||||
pub const fn multiply(&self, other: Self) -> Self {
|
||||
FalconFelt((self.0 * other.0) % MODULUS as u32)
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for FalconFelt {
|
||||
type Output = Self;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let (s, _) = self.0.overflowing_add(rhs.0);
|
||||
let (d, n) = s.overflowing_sub(MODULUS as u32);
|
||||
let (r, _) = d.overflowing_add(MODULUS as u32 * (n as u32));
|
||||
FalconFelt(r)
|
||||
}
|
||||
}
|
||||
|
||||
impl AddAssign for FalconFelt {
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub for FalconFelt {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
self + -rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl SubAssign for FalconFelt {
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Neg for FalconFelt {
|
||||
type Output = FalconFelt;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
let is_nonzero = self.0 != 0;
|
||||
let r = MODULUS as u32 - self.0;
|
||||
FalconFelt(r * (is_nonzero as u32))
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul for FalconFelt {
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
FalconFelt((self.0 * rhs.0) % MODULUS as u32)
|
||||
}
|
||||
|
||||
type Output = Self;
|
||||
}
|
||||
|
||||
impl MulAssign for FalconFelt {
|
||||
fn mul_assign(&mut self, rhs: Self) {
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Div for FalconFelt {
|
||||
type Output = FalconFelt;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn div(self, rhs: Self) -> Self::Output {
|
||||
self * rhs.inverse_or_zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl DivAssign for FalconFelt {
|
||||
fn div_assign(&mut self, rhs: Self) {
|
||||
*self = *self / rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl Zero for FalconFelt {
|
||||
fn zero() -> Self {
|
||||
FalconFelt::new(0)
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
self.0 == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl One for FalconFelt {
|
||||
fn one() -> Self {
|
||||
FalconFelt::new(1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Inverse for FalconFelt {
|
||||
fn inverse_or_zero(self) -> Self {
|
||||
// q-2 = 0b10 11 11 11 11 11 11
|
||||
let two = self.multiply(self);
|
||||
let three = two.multiply(self);
|
||||
let six = three.multiply(three);
|
||||
let twelve = six.multiply(six);
|
||||
let fifteen = twelve.multiply(three);
|
||||
let thirty = fifteen.multiply(fifteen);
|
||||
let sixty = thirty.multiply(thirty);
|
||||
let sixty_three = sixty.multiply(three);
|
||||
|
||||
let sixty_three_sq = sixty_three.multiply(sixty_three);
|
||||
let sixty_three_qu = sixty_three_sq.multiply(sixty_three_sq);
|
||||
let sixty_three_oc = sixty_three_qu.multiply(sixty_three_qu);
|
||||
let sixty_three_hx = sixty_three_oc.multiply(sixty_three_oc);
|
||||
let sixty_three_tt = sixty_three_hx.multiply(sixty_three_hx);
|
||||
let sixty_three_sf = sixty_three_tt.multiply(sixty_three_tt);
|
||||
|
||||
let all_ones = sixty_three_sf.multiply(sixty_three);
|
||||
let two_e_twelve = all_ones.multiply(self);
|
||||
let two_e_thirteen = two_e_twelve.multiply(two_e_twelve);
|
||||
|
||||
two_e_thirteen.multiply(all_ones)
|
||||
}
|
||||
}
|
||||
|
||||
impl CyclotomicFourier for FalconFelt {
|
||||
fn primitive_root_of_unity(n: usize) -> Self {
|
||||
let log2n = n.ilog2();
|
||||
assert!(log2n <= 12);
|
||||
// and 1331 is a twelfth root of unity
|
||||
let mut a = FalconFelt::new(1331);
|
||||
let num_squarings = 12 - n.ilog2();
|
||||
for _ in 0..num_squarings {
|
||||
a *= a;
|
||||
}
|
||||
a
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for FalconFelt {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(value: u32) -> Result<Self, Self::Error> {
|
||||
if value >= MODULUS as u32 {
|
||||
Err(format!("value {value} is greater than or equal to the field modulus {MODULUS}"))
|
||||
} else {
|
||||
Ok(FalconFelt::new(value as i16))
|
||||
}
|
||||
}
|
||||
}
|
||||
322
src/dsa/rpo_falcon512/math/mod.rs
Normal file
322
src/dsa/rpo_falcon512/math/mod.rs
Normal file
@@ -0,0 +1,322 @@
|
||||
//! Contains different structs and methods related to the Falcon DSA.
|
||||
//!
|
||||
//! It uses and acknowledges the work in:
|
||||
//!
|
||||
//! 1. The [reference](https://falcon-sign.info/impl/README.txt.html) implementation by Thomas
|
||||
//! Pornin.
|
||||
//! 2. The [Rust](https://github.com/aszepieniec/falcon-rust) implementation by Alan Szepieniec.
|
||||
use alloc::{string::String, vec::Vec};
|
||||
use core::ops::MulAssign;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num::Float;
|
||||
use num::{BigInt, FromPrimitive, One, Zero};
|
||||
use num_complex::Complex64;
|
||||
use rand::Rng;
|
||||
|
||||
use super::MODULUS;
|
||||
|
||||
mod fft;
|
||||
pub use fft::{CyclotomicFourier, FastFft};
|
||||
|
||||
mod field;
|
||||
pub use field::FalconFelt;
|
||||
|
||||
mod ffsampling;
|
||||
pub use ffsampling::{ffldl, ffsampling, gram, normalize_tree, LdlTree};
|
||||
|
||||
mod samplerz;
|
||||
use self::samplerz::sampler_z;
|
||||
|
||||
mod polynomial;
|
||||
pub use polynomial::Polynomial;
|
||||
|
||||
pub trait Inverse: Copy + Zero + MulAssign + One {
|
||||
/// Gets the inverse of a, or zero if it is zero.
|
||||
fn inverse_or_zero(self) -> Self;
|
||||
|
||||
/// Gets the inverses of a batch of elements, and skip over any that are zero.
|
||||
fn batch_inverse_or_zero(batch: &[Self]) -> Vec<Self> {
|
||||
let mut acc = Self::one();
|
||||
let mut rp: Vec<Self> = Vec::with_capacity(batch.len());
|
||||
for batch_item in batch {
|
||||
if !batch_item.is_zero() {
|
||||
rp.push(acc);
|
||||
acc = *batch_item * acc;
|
||||
} else {
|
||||
rp.push(Self::zero());
|
||||
}
|
||||
}
|
||||
let mut inv = Self::inverse_or_zero(acc);
|
||||
for i in (0..batch.len()).rev() {
|
||||
if !batch[i].is_zero() {
|
||||
rp[i] *= inv;
|
||||
inv *= batch[i];
|
||||
}
|
||||
}
|
||||
rp
|
||||
}
|
||||
}
|
||||
|
||||
impl Inverse for Complex64 {
|
||||
fn inverse_or_zero(self) -> Self {
|
||||
let modulus = self.re * self.re + self.im * self.im;
|
||||
Complex64::new(self.re / modulus, -self.im / modulus)
|
||||
}
|
||||
fn batch_inverse_or_zero(batch: &[Self]) -> Vec<Self> {
|
||||
batch.iter().map(|&c| Complex64::new(1.0, 0.0) / c).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Inverse for f64 {
|
||||
fn inverse_or_zero(self) -> Self {
|
||||
1.0 / self
|
||||
}
|
||||
fn batch_inverse_or_zero(batch: &[Self]) -> Vec<Self> {
|
||||
batch.iter().map(|&c| 1.0 / c).collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Samples 4 small polynomials f, g, F, G such that f * G - g * F = q mod (X^n + 1).
|
||||
/// Algorithm 5 (NTRUgen) of the documentation [1, p.34].
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
pub(crate) fn ntru_gen<R: Rng>(n: usize, rng: &mut R) -> [Polynomial<i16>; 4] {
|
||||
loop {
|
||||
let f = gen_poly(n, rng);
|
||||
let g = gen_poly(n, rng);
|
||||
let f_ntt = f.map(|&i| FalconFelt::new(i)).fft();
|
||||
if f_ntt.coefficients.iter().any(|e| e.is_zero()) {
|
||||
continue;
|
||||
}
|
||||
let gamma = gram_schmidt_norm_squared(&f, &g);
|
||||
if gamma > 1.3689f64 * (MODULUS as f64) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some((capital_f, capital_g)) =
|
||||
ntru_solve(&f.map(|&i| i.into()), &g.map(|&i| i.into()))
|
||||
{
|
||||
return [
|
||||
g,
|
||||
-f,
|
||||
capital_g.map(|i| i.try_into().unwrap()),
|
||||
-capital_f.map(|i| i.try_into().unwrap()),
|
||||
];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Solves the NTRU equation. Given f, g in ZZ[X], find F, G in ZZ[X] such that:
|
||||
///
|
||||
/// f G - g F = q mod (X^n + 1)
|
||||
///
|
||||
/// Algorithm 6 of the specification [1, p.35].
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
fn ntru_solve(
|
||||
f: &Polynomial<BigInt>,
|
||||
g: &Polynomial<BigInt>,
|
||||
) -> Option<(Polynomial<BigInt>, Polynomial<BigInt>)> {
|
||||
let n = f.coefficients.len();
|
||||
if n == 1 {
|
||||
let (gcd, u, v) = xgcd(&f.coefficients[0], &g.coefficients[0]);
|
||||
if gcd != BigInt::one() {
|
||||
return None;
|
||||
}
|
||||
return Some((
|
||||
(Polynomial::new(vec![-v * BigInt::from_u32(MODULUS as u32).unwrap()])),
|
||||
Polynomial::new(vec![u * BigInt::from_u32(MODULUS as u32).unwrap()]),
|
||||
));
|
||||
}
|
||||
|
||||
let f_prime = f.field_norm();
|
||||
let g_prime = g.field_norm();
|
||||
|
||||
let (capital_f_prime, capital_g_prime) = ntru_solve(&f_prime, &g_prime)?;
|
||||
let capital_f_prime_xsq = capital_f_prime.lift_next_cyclotomic();
|
||||
let capital_g_prime_xsq = capital_g_prime.lift_next_cyclotomic();
|
||||
|
||||
let f_minx = f.galois_adjoint();
|
||||
let g_minx = g.galois_adjoint();
|
||||
|
||||
let mut capital_f = (capital_f_prime_xsq.karatsuba(&g_minx)).reduce_by_cyclotomic(n);
|
||||
let mut capital_g = (capital_g_prime_xsq.karatsuba(&f_minx)).reduce_by_cyclotomic(n);
|
||||
|
||||
match babai_reduce(f, g, &mut capital_f, &mut capital_g) {
|
||||
Ok(_) => Some((capital_f, capital_g)),
|
||||
Err(_e) => {
|
||||
#[cfg(test)]
|
||||
{
|
||||
panic!("{}", _e);
|
||||
}
|
||||
#[cfg(not(test))]
|
||||
{
|
||||
None
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a polynomial of degree at most n-1 whose coefficients are distributed according
|
||||
/// to a discrete Gaussian with mu = 0 and sigma = 1.17 * sqrt(Q / (2n)).
|
||||
fn gen_poly<R: Rng>(n: usize, rng: &mut R) -> Polynomial<i16> {
|
||||
let mu = 0.0;
|
||||
let sigma_star = 1.43300980528773;
|
||||
Polynomial {
|
||||
coefficients: (0..4096)
|
||||
.map(|_| sampler_z(mu, sigma_star, sigma_star - 0.001, rng))
|
||||
.collect::<Vec<i16>>()
|
||||
.chunks(4096 / n)
|
||||
.map(|ch| ch.iter().sum())
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the Gram-Schmidt norm of B = [[g, -f], [G, -F]] from f and g.
|
||||
/// Corresponds to line 9 in algorithm 5 of the spec [1, p.34]
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
fn gram_schmidt_norm_squared(f: &Polynomial<i16>, g: &Polynomial<i16>) -> f64 {
|
||||
let n = f.coefficients.len();
|
||||
let norm_f_squared = f.l2_norm_squared();
|
||||
let norm_g_squared = g.l2_norm_squared();
|
||||
let gamma1 = norm_f_squared + norm_g_squared;
|
||||
|
||||
let f_fft = f.map(|i| Complex64::new(*i as f64, 0.0)).fft();
|
||||
let g_fft = g.map(|i| Complex64::new(*i as f64, 0.0)).fft();
|
||||
let f_adj_fft = f_fft.map(|c| c.conj());
|
||||
let g_adj_fft = g_fft.map(|c| c.conj());
|
||||
let ffgg_fft = f_fft.hadamard_mul(&f_adj_fft) + g_fft.hadamard_mul(&g_adj_fft);
|
||||
let ffgg_fft_inverse = ffgg_fft.hadamard_inv();
|
||||
let qf_over_ffgg_fft = f_adj_fft.map(|c| c * (MODULUS as f64)).hadamard_mul(&ffgg_fft_inverse);
|
||||
let qg_over_ffgg_fft = g_adj_fft.map(|c| c * (MODULUS as f64)).hadamard_mul(&ffgg_fft_inverse);
|
||||
let norm_f_over_ffgg_squared =
|
||||
qf_over_ffgg_fft.coefficients.iter().map(|c| (c * c.conj()).re).sum::<f64>() / (n as f64);
|
||||
let norm_g_over_ffgg_squared =
|
||||
qg_over_ffgg_fft.coefficients.iter().map(|c| (c * c.conj()).re).sum::<f64>() / (n as f64);
|
||||
|
||||
let gamma2 = norm_f_over_ffgg_squared + norm_g_over_ffgg_squared;
|
||||
|
||||
f64::max(gamma1, gamma2)
|
||||
}
|
||||
|
||||
/// Reduces the vector (F,G) relative to (f,g). This method follows the python implementation [1].
|
||||
/// Note that this algorithm can end up in an infinite loop. (It's one of the things the author
|
||||
/// would like to fix.) When this happens, control returns an error (hence the return type) and
|
||||
/// generates another keypair with fresh randomness.
|
||||
///
|
||||
/// Algorithm 7 in the spec [2, p.35]
|
||||
///
|
||||
/// [1]: https://github.com/tprest/falcon.py
|
||||
///
|
||||
/// [2]: https://falcon-sign.info/falcon.pdf
|
||||
fn babai_reduce(
|
||||
f: &Polynomial<BigInt>,
|
||||
g: &Polynomial<BigInt>,
|
||||
capital_f: &mut Polynomial<BigInt>,
|
||||
capital_g: &mut Polynomial<BigInt>,
|
||||
) -> Result<(), String> {
|
||||
let bitsize = |bi: &BigInt| (bi.bits() + 7) & (u64::MAX ^ 7);
|
||||
let n = f.coefficients.len();
|
||||
let size = [
|
||||
f.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
|
||||
g.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
|
||||
53,
|
||||
]
|
||||
.into_iter()
|
||||
.max()
|
||||
.unwrap();
|
||||
let shift = (size as i64) - 53;
|
||||
let f_adjusted = f
|
||||
.map(|bi| Complex64::new(i64::try_from(bi >> shift).unwrap() as f64, 0.0))
|
||||
.fft();
|
||||
let g_adjusted = g
|
||||
.map(|bi| Complex64::new(i64::try_from(bi >> shift).unwrap() as f64, 0.0))
|
||||
.fft();
|
||||
|
||||
let f_star_adjusted = f_adjusted.map(|c| c.conj());
|
||||
let g_star_adjusted = g_adjusted.map(|c| c.conj());
|
||||
let denominator_fft =
|
||||
f_adjusted.hadamard_mul(&f_star_adjusted) + g_adjusted.hadamard_mul(&g_star_adjusted);
|
||||
|
||||
let mut counter = 0;
|
||||
loop {
|
||||
let capital_size = [
|
||||
capital_f.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
|
||||
capital_g.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
|
||||
53,
|
||||
]
|
||||
.into_iter()
|
||||
.max()
|
||||
.unwrap();
|
||||
|
||||
if capital_size < size {
|
||||
break;
|
||||
}
|
||||
let capital_shift = (capital_size as i64) - 53;
|
||||
let capital_f_adjusted = capital_f
|
||||
.map(|bi| Complex64::new(i64::try_from(bi >> capital_shift).unwrap() as f64, 0.0))
|
||||
.fft();
|
||||
let capital_g_adjusted = capital_g
|
||||
.map(|bi| Complex64::new(i64::try_from(bi >> capital_shift).unwrap() as f64, 0.0))
|
||||
.fft();
|
||||
|
||||
let numerator = capital_f_adjusted.hadamard_mul(&f_star_adjusted)
|
||||
+ capital_g_adjusted.hadamard_mul(&g_star_adjusted);
|
||||
let quotient = numerator.hadamard_div(&denominator_fft).ifft();
|
||||
|
||||
let k = quotient.map(|f| Into::<BigInt>::into(f.re.round() as i64));
|
||||
|
||||
if k.is_zero() {
|
||||
break;
|
||||
}
|
||||
let kf = (k.clone().karatsuba(f))
|
||||
.reduce_by_cyclotomic(n)
|
||||
.map(|bi| bi << (capital_size - size));
|
||||
let kg = (k.clone().karatsuba(g))
|
||||
.reduce_by_cyclotomic(n)
|
||||
.map(|bi| bi << (capital_size - size));
|
||||
*capital_f -= kf;
|
||||
*capital_g -= kg;
|
||||
|
||||
counter += 1;
|
||||
if counter > 1000 {
|
||||
// If we get here, that means that (with high likelihood) we are in an
|
||||
// infinite loop. We know it happens from time to time -- seldomly, but it
|
||||
// does. It would be nice to fix that! But in order to fix it we need to be
|
||||
// able to reproduce it, and for that we need test vectors. So print them
|
||||
// and hope that one day they circle back to the implementor.
|
||||
return Err(format!("Encountered infinite loop in babai_reduce of falcon-rust.\n\\
|
||||
Please help the developer(s) fix it! You can do this by sending them the inputs to the function that caused the behavior:\n\\
|
||||
f: {:?}\n\\
|
||||
g: {:?}\n\\
|
||||
capital_f: {:?}\n\\
|
||||
capital_g: {:?}\n", f.coefficients, g.coefficients, capital_f.coefficients, capital_g.coefficients));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Extended Euclidean algorithm for computing the greatest common divisor (g) and
|
||||
/// Bézout coefficients (u, v) for the relation
|
||||
///
|
||||
/// $$ u a + v b = g . $$
|
||||
///
|
||||
/// Implementation adapted from Wikipedia [1].
|
||||
///
|
||||
/// [1]: https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Pseudocode
|
||||
fn xgcd(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
|
||||
let (mut old_r, mut r) = (a.clone(), b.clone());
|
||||
let (mut old_s, mut s) = (BigInt::one(), BigInt::zero());
|
||||
let (mut old_t, mut t) = (BigInt::zero(), BigInt::one());
|
||||
|
||||
while r != BigInt::zero() {
|
||||
let quotient = old_r.clone() / r.clone();
|
||||
(old_r, r) = (r.clone(), old_r.clone() - quotient.clone() * r);
|
||||
(old_s, s) = (s.clone(), old_s.clone() - quotient.clone() * s);
|
||||
(old_t, t) = (t.clone(), old_t.clone() - quotient * t);
|
||||
}
|
||||
|
||||
(old_r, old_s, old_t)
|
||||
}
|
||||
622
src/dsa/rpo_falcon512/math/polynomial.rs
Normal file
622
src/dsa/rpo_falcon512/math/polynomial.rs
Normal file
@@ -0,0 +1,622 @@
|
||||
use alloc::vec::Vec;
|
||||
use core::{
|
||||
default::Default,
|
||||
fmt::Debug,
|
||||
ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign},
|
||||
};
|
||||
|
||||
use num::{One, Zero};
|
||||
|
||||
use super::{field::FalconFelt, Inverse};
|
||||
use crate::{
|
||||
dsa::rpo_falcon512::{MODULUS, N},
|
||||
Felt,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Polynomial<F> {
|
||||
pub coefficients: Vec<F>,
|
||||
}
|
||||
|
||||
impl<F> Polynomial<F>
|
||||
where
|
||||
F: Clone,
|
||||
{
|
||||
pub fn new(coefficients: Vec<F>) -> Self {
|
||||
Self { coefficients }
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
F: Mul<Output = F> + Sub<Output = F> + AddAssign + Zero + Div<Output = F> + Clone + Inverse,
|
||||
> Polynomial<F>
|
||||
{
|
||||
pub fn hadamard_mul(&self, other: &Self) -> Self {
|
||||
Polynomial::new(
|
||||
self.coefficients
|
||||
.iter()
|
||||
.zip(other.coefficients.iter())
|
||||
.map(|(a, b)| *a * *b)
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
pub fn hadamard_div(&self, other: &Self) -> Self {
|
||||
let other_coefficients_inverse = F::batch_inverse_or_zero(&other.coefficients);
|
||||
Polynomial::new(
|
||||
self.coefficients
|
||||
.iter()
|
||||
.zip(other_coefficients_inverse.iter())
|
||||
.map(|(a, b)| *a * *b)
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn hadamard_inv(&self) -> Self {
|
||||
let coefficients_inverse = F::batch_inverse_or_zero(&self.coefficients);
|
||||
Polynomial::new(coefficients_inverse)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Zero + PartialEq + Clone> Polynomial<F> {
|
||||
pub fn degree(&self) -> Option<usize> {
|
||||
if self.coefficients.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut max_index = self.coefficients.len() - 1;
|
||||
while self.coefficients[max_index] == F::zero() {
|
||||
if let Some(new_index) = max_index.checked_sub(1) {
|
||||
max_index = new_index;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
Some(max_index)
|
||||
}
|
||||
|
||||
pub fn lc(&self) -> F {
|
||||
match self.degree() {
|
||||
Some(non_negative_degree) => self.coefficients[non_negative_degree].clone(),
|
||||
None => F::zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The following implementations are specific to cyclotomic polynomial rings,
|
||||
/// i.e., F\[ X \] / <X^n + 1>, and are used extensively in Falcon.
|
||||
impl<
|
||||
F: One
|
||||
+ Zero
|
||||
+ Clone
|
||||
+ Neg<Output = F>
|
||||
+ MulAssign
|
||||
+ AddAssign
|
||||
+ Div<Output = F>
|
||||
+ Sub<Output = F>
|
||||
+ PartialEq,
|
||||
> Polynomial<F>
|
||||
{
|
||||
/// Reduce the polynomial by X^n + 1.
|
||||
pub fn reduce_by_cyclotomic(&self, n: usize) -> Self {
|
||||
let mut coefficients = vec![F::zero(); n];
|
||||
let mut sign = -F::one();
|
||||
for (i, c) in self.coefficients.iter().cloned().enumerate() {
|
||||
if i % n == 0 {
|
||||
sign *= -F::one();
|
||||
}
|
||||
coefficients[i % n] += sign.clone() * c;
|
||||
}
|
||||
Polynomial::new(coefficients)
|
||||
}
|
||||
|
||||
/// Computes the field norm of the polynomial as an element of the cyclotomic ring
|
||||
/// F\[ X \] / <X^n + 1 > relative to one of half the size, i.e., F\[ X \] / <X^(n/2) + 1> .
|
||||
///
|
||||
/// Corresponds to formula 3.25 in the spec [1, p.30].
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
pub fn field_norm(&self) -> Self {
|
||||
let n = self.coefficients.len();
|
||||
let mut f0_coefficients = vec![F::zero(); n / 2];
|
||||
let mut f1_coefficients = vec![F::zero(); n / 2];
|
||||
for i in 0..n / 2 {
|
||||
f0_coefficients[i] = self.coefficients[2 * i].clone();
|
||||
f1_coefficients[i] = self.coefficients[2 * i + 1].clone();
|
||||
}
|
||||
let f0 = Polynomial::new(f0_coefficients);
|
||||
let f1 = Polynomial::new(f1_coefficients);
|
||||
let f0_squared = (f0.clone() * f0).reduce_by_cyclotomic(n / 2);
|
||||
let f1_squared = (f1.clone() * f1).reduce_by_cyclotomic(n / 2);
|
||||
let x = Polynomial::new(vec![F::zero(), F::one()]);
|
||||
f0_squared - (x * f1_squared).reduce_by_cyclotomic(n / 2)
|
||||
}
|
||||
|
||||
/// Lifts an element from a cyclotomic polynomial ring to one of double the size.
|
||||
pub fn lift_next_cyclotomic(&self) -> Self {
|
||||
let n = self.coefficients.len();
|
||||
let mut coefficients = vec![F::zero(); n * 2];
|
||||
for i in 0..n {
|
||||
coefficients[2 * i] = self.coefficients[i].clone();
|
||||
}
|
||||
Self::new(coefficients)
|
||||
}
|
||||
|
||||
/// Computes the galois adjoint of the polynomial in the cyclotomic ring F\[ X \] / < X^n + 1 >
|
||||
/// , which corresponds to f(x^2).
|
||||
pub fn galois_adjoint(&self) -> Self {
|
||||
Self::new(
|
||||
self.coefficients
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, c)| if i % 2 == 0 { c.clone() } else { c.clone().neg() })
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Clone + Into<f64>> Polynomial<F> {
|
||||
pub(crate) fn l2_norm_squared(&self) -> f64 {
|
||||
self.coefficients
|
||||
.iter()
|
||||
.map(|i| Into::<f64>::into(i.clone()))
|
||||
.map(|i| i * i)
|
||||
.sum::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> PartialEq for Polynomial<F>
|
||||
where
|
||||
F: Zero + PartialEq + Clone + AddAssign,
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.is_zero() && other.is_zero() {
|
||||
true
|
||||
} else if self.is_zero() || other.is_zero() {
|
||||
false
|
||||
} else {
|
||||
let self_degree = self.degree().unwrap();
|
||||
let other_degree = other.degree().unwrap();
|
||||
self.coefficients[0..=self_degree] == other.coefficients[0..=other_degree]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Eq for Polynomial<F> where F: Zero + PartialEq + Clone + AddAssign {}
|
||||
|
||||
impl<F> Add for &Polynomial<F>
|
||||
where
|
||||
F: Add<Output = F> + AddAssign + Clone,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let coefficients = if self.coefficients.len() >= rhs.coefficients.len() {
|
||||
let mut coefficients = self.coefficients.clone();
|
||||
for (i, c) in rhs.coefficients.iter().enumerate() {
|
||||
coefficients[i] += c.clone();
|
||||
}
|
||||
coefficients
|
||||
} else {
|
||||
let mut coefficients = rhs.coefficients.clone();
|
||||
for (i, c) in self.coefficients.iter().enumerate() {
|
||||
coefficients[i] += c.clone();
|
||||
}
|
||||
coefficients
|
||||
};
|
||||
Self::Output { coefficients }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Add for Polynomial<F>
|
||||
where
|
||||
F: Add<Output = F> + AddAssign + Clone,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let coefficients = if self.coefficients.len() >= rhs.coefficients.len() {
|
||||
let mut coefficients = self.coefficients.clone();
|
||||
for (i, c) in rhs.coefficients.into_iter().enumerate() {
|
||||
coefficients[i] += c;
|
||||
}
|
||||
coefficients
|
||||
} else {
|
||||
let mut coefficients = rhs.coefficients.clone();
|
||||
for (i, c) in self.coefficients.into_iter().enumerate() {
|
||||
coefficients[i] += c;
|
||||
}
|
||||
coefficients
|
||||
};
|
||||
Self::Output { coefficients }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> AddAssign for Polynomial<F>
|
||||
where
|
||||
F: Add<Output = F> + AddAssign + Clone,
|
||||
{
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
if self.coefficients.len() >= rhs.coefficients.len() {
|
||||
for (i, c) in rhs.coefficients.into_iter().enumerate() {
|
||||
self.coefficients[i] += c;
|
||||
}
|
||||
} else {
|
||||
let mut coefficients = rhs.coefficients.clone();
|
||||
for (i, c) in self.coefficients.iter().enumerate() {
|
||||
coefficients[i] += c.clone();
|
||||
}
|
||||
self.coefficients = coefficients;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Sub for &Polynomial<F>
|
||||
where
|
||||
F: Sub<Output = F> + Clone + Neg<Output = F> + Add<Output = F> + AddAssign,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
self + &(-rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Sub for Polynomial<F>
|
||||
where
|
||||
F: Sub<Output = F> + Clone + Neg<Output = F> + Add<Output = F> + AddAssign,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
self + (-rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> SubAssign for Polynomial<F>
|
||||
where
|
||||
F: Add<Output = F> + Neg<Output = F> + AddAssign + Clone + Sub<Output = F>,
|
||||
{
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
self.coefficients = self.clone().sub(rhs).coefficients;
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Neg<Output = F> + Clone> Neg for &Polynomial<F> {
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
Self::Output {
|
||||
coefficients: self.coefficients.iter().cloned().map(|a| -a).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Neg<Output = F> + Clone> Neg for Polynomial<F> {
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
Self::Output {
|
||||
coefficients: self.coefficients.iter().cloned().map(|a| -a).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Mul for &Polynomial<F>
|
||||
where
|
||||
F: Add + AddAssign + Mul<Output = F> + Sub<Output = F> + Zero + PartialEq + Clone,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn mul(self, other: Self) -> Self::Output {
|
||||
if self.is_zero() || other.is_zero() {
|
||||
return Polynomial::<F>::zero();
|
||||
}
|
||||
let mut coefficients =
|
||||
vec![F::zero(); self.coefficients.len() + other.coefficients.len() - 1];
|
||||
for i in 0..self.coefficients.len() {
|
||||
for j in 0..other.coefficients.len() {
|
||||
coefficients[i + j] += self.coefficients[i].clone() * other.coefficients[j].clone();
|
||||
}
|
||||
}
|
||||
Polynomial { coefficients }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Mul for Polynomial<F>
|
||||
where
|
||||
F: Add + AddAssign + Mul<Output = F> + Zero + PartialEq + Clone,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, other: Self) -> Self::Output {
|
||||
if self.is_zero() || other.is_zero() {
|
||||
return Self::zero();
|
||||
}
|
||||
let mut coefficients =
|
||||
vec![F::zero(); self.coefficients.len() + other.coefficients.len() - 1];
|
||||
for i in 0..self.coefficients.len() {
|
||||
for j in 0..other.coefficients.len() {
|
||||
coefficients[i + j] += self.coefficients[i].clone() * other.coefficients[j].clone();
|
||||
}
|
||||
}
|
||||
Self { coefficients }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Add + Mul<Output = F> + Zero + Clone> Mul<F> for &Polynomial<F> {
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn mul(self, other: F) -> Self::Output {
|
||||
Polynomial {
|
||||
coefficients: self.coefficients.iter().cloned().map(|i| i * other.clone()).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Add + Mul<Output = F> + Zero + Clone> Mul<F> for Polynomial<F> {
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn mul(self, other: F) -> Self::Output {
|
||||
Polynomial {
|
||||
coefficients: self.coefficients.iter().cloned().map(|i| i * other.clone()).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Mul<Output = F> + Sub<Output = F> + AddAssign + Zero + Div<Output = F> + Clone>
|
||||
Polynomial<F>
|
||||
{
|
||||
/// Multiply two polynomials using Karatsuba's divide-and-conquer algorithm.
|
||||
pub fn karatsuba(&self, other: &Self) -> Self {
|
||||
Polynomial::new(vector_karatsuba(&self.coefficients, &other.coefficients))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> One for Polynomial<F>
|
||||
where
|
||||
F: Clone + One + PartialEq + Zero + AddAssign,
|
||||
{
|
||||
fn one() -> Self {
|
||||
Self { coefficients: vec![F::one()] }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Zero for Polynomial<F>
|
||||
where
|
||||
F: Zero + PartialEq + Clone + AddAssign,
|
||||
{
|
||||
fn zero() -> Self {
|
||||
Self { coefficients: vec![] }
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
self.degree().is_none()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Zero + Clone> Polynomial<F> {
|
||||
pub fn shift(&self, shamt: usize) -> Self {
|
||||
Self {
|
||||
coefficients: [vec![F::zero(); shamt], self.coefficients.clone()].concat(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn constant(f: F) -> Self {
|
||||
Self { coefficients: vec![f] }
|
||||
}
|
||||
|
||||
pub fn map<G: Clone, C: FnMut(&F) -> G>(&self, closure: C) -> Polynomial<G> {
|
||||
Polynomial::<G>::new(self.coefficients.iter().map(closure).collect())
|
||||
}
|
||||
|
||||
pub fn fold<G, C: FnMut(G, &F) -> G + Clone>(&self, mut initial_value: G, closure: C) -> G {
|
||||
for c in self.coefficients.iter() {
|
||||
initial_value = (closure.clone())(initial_value, c);
|
||||
}
|
||||
initial_value
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Div<Polynomial<F>> for Polynomial<F>
|
||||
where
|
||||
F: Zero
|
||||
+ One
|
||||
+ PartialEq
|
||||
+ AddAssign
|
||||
+ Clone
|
||||
+ Mul<Output = F>
|
||||
+ MulAssign
|
||||
+ Div<Output = F>
|
||||
+ Neg<Output = F>
|
||||
+ Sub<Output = F>,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn div(self, denominator: Self) -> Self::Output {
|
||||
if denominator.is_zero() {
|
||||
panic!();
|
||||
}
|
||||
if self.is_zero() {
|
||||
Self::zero();
|
||||
}
|
||||
let mut remainder = self.clone();
|
||||
let mut quotient = Polynomial::<F>::zero();
|
||||
while remainder.degree().unwrap() >= denominator.degree().unwrap() {
|
||||
let shift = remainder.degree().unwrap() - denominator.degree().unwrap();
|
||||
let quotient_coefficient = remainder.lc() / denominator.lc();
|
||||
let monomial = Self::constant(quotient_coefficient).shift(shift);
|
||||
quotient += monomial.clone();
|
||||
remainder -= monomial * denominator.clone();
|
||||
if remainder.is_zero() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
quotient
|
||||
}
|
||||
}
|
||||
|
||||
fn vector_karatsuba<
|
||||
F: Zero + AddAssign + Mul<Output = F> + Sub<Output = F> + Div<Output = F> + Clone,
|
||||
>(
|
||||
left: &[F],
|
||||
right: &[F],
|
||||
) -> Vec<F> {
|
||||
let n = left.len();
|
||||
if n <= 8 {
|
||||
let mut product = vec![F::zero(); left.len() + right.len() - 1];
|
||||
for (i, l) in left.iter().enumerate() {
|
||||
for (j, r) in right.iter().enumerate() {
|
||||
product[i + j] += l.clone() * r.clone();
|
||||
}
|
||||
}
|
||||
return product;
|
||||
}
|
||||
let n_over_2 = n / 2;
|
||||
let mut product = vec![F::zero(); 2 * n - 1];
|
||||
let left_lo = &left[0..n_over_2];
|
||||
let right_lo = &right[0..n_over_2];
|
||||
let left_hi = &left[n_over_2..];
|
||||
let right_hi = &right[n_over_2..];
|
||||
let left_sum: Vec<F> =
|
||||
left_lo.iter().zip(left_hi).map(|(a, b)| a.clone() + b.clone()).collect();
|
||||
let right_sum: Vec<F> =
|
||||
right_lo.iter().zip(right_hi).map(|(a, b)| a.clone() + b.clone()).collect();
|
||||
|
||||
let prod_lo = vector_karatsuba(left_lo, right_lo);
|
||||
let prod_hi = vector_karatsuba(left_hi, right_hi);
|
||||
let prod_mid: Vec<F> = vector_karatsuba(&left_sum, &right_sum)
|
||||
.iter()
|
||||
.zip(prod_lo.iter().zip(prod_hi.iter()))
|
||||
.map(|(s, (l, h))| s.clone() - (l.clone() + h.clone()))
|
||||
.collect();
|
||||
|
||||
for (i, l) in prod_lo.into_iter().enumerate() {
|
||||
product[i] = l;
|
||||
}
|
||||
for (i, m) in prod_mid.into_iter().enumerate() {
|
||||
product[i + n_over_2] += m;
|
||||
}
|
||||
for (i, h) in prod_hi.into_iter().enumerate() {
|
||||
product[i + n] += h
|
||||
}
|
||||
product
|
||||
}
|
||||
|
||||
impl From<Polynomial<FalconFelt>> for Polynomial<Felt> {
|
||||
fn from(item: Polynomial<FalconFelt>) -> Self {
|
||||
let res: Vec<Felt> =
|
||||
item.coefficients.iter().map(|a| Felt::from(a.value() as u16)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Polynomial<FalconFelt>> for Polynomial<Felt> {
|
||||
fn from(item: &Polynomial<FalconFelt>) -> Self {
|
||||
let res: Vec<Felt> =
|
||||
item.coefficients.iter().map(|a| Felt::from(a.value() as u16)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Polynomial<i16>> for Polynomial<FalconFelt> {
|
||||
fn from(item: Polynomial<i16>) -> Self {
|
||||
let res: Vec<FalconFelt> = item.coefficients.iter().map(|&a| FalconFelt::new(a)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Polynomial<i16>> for Polynomial<FalconFelt> {
|
||||
fn from(item: &Polynomial<i16>) -> Self {
|
||||
let res: Vec<FalconFelt> = item.coefficients.iter().map(|&a| FalconFelt::new(a)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<i16>> for Polynomial<FalconFelt> {
|
||||
fn from(item: Vec<i16>) -> Self {
|
||||
let res: Vec<FalconFelt> = item.iter().map(|&a| FalconFelt::new(a)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Vec<i16>> for Polynomial<FalconFelt> {
|
||||
fn from(item: &Vec<i16>) -> Self {
|
||||
let res: Vec<FalconFelt> = item.iter().map(|&a| FalconFelt::new(a)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl Polynomial<FalconFelt> {
|
||||
pub fn norm_squared(&self) -> u64 {
|
||||
self.coefficients
|
||||
.iter()
|
||||
.map(|&i| i.balanced_value() as i64)
|
||||
.map(|i| (i * i) as u64)
|
||||
.sum::<u64>()
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the coefficients of this polynomial as field elements.
|
||||
pub fn to_elements(&self) -> Vec<Felt> {
|
||||
self.coefficients.iter().map(|&a| Felt::from(a.value() as u16)).collect()
|
||||
}
|
||||
|
||||
// POLYNOMIAL OPERATIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Multiplies two polynomials over Z_p\[x\] without reducing modulo p. Given that the degrees
|
||||
/// of the input polynomials are less than 512 and their coefficients are less than the modulus
|
||||
/// q equal to 12289, the resulting product polynomial is guaranteed to have coefficients less
|
||||
/// than the Miden prime.
|
||||
///
|
||||
/// Note that this multiplication is not over Z_p\[x\]/(phi).
|
||||
pub fn mul_modulo_p(a: &Self, b: &Self) -> [u64; 1024] {
|
||||
let mut c = [0; 2 * N];
|
||||
for i in 0..N {
|
||||
for j in 0..N {
|
||||
c[i + j] += a.coefficients[i].value() as u64 * b.coefficients[j].value() as u64;
|
||||
}
|
||||
}
|
||||
|
||||
c
|
||||
}
|
||||
|
||||
/// Reduces a polynomial, that is the product of two polynomials over Z_p\[x\], modulo
|
||||
/// the irreducible polynomial phi. This results in an element in Z_p\[x\]/(phi).
|
||||
pub fn reduce_negacyclic(a: &[u64; 1024]) -> Self {
|
||||
let mut c = [FalconFelt::zero(); N];
|
||||
let modulus = MODULUS as u16;
|
||||
for i in 0..N {
|
||||
let ai = a[N + i] % modulus as u64;
|
||||
let neg_ai = (modulus - ai as u16) % modulus;
|
||||
|
||||
let bi = (a[i] % modulus as u64) as u16;
|
||||
c[i] = FalconFelt::new(((neg_ai + bi) % modulus) as i16);
|
||||
}
|
||||
|
||||
Self::new(c.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{FalconFelt, Polynomial, N};
|
||||
|
||||
#[test]
|
||||
fn test_negacyclic_reduction() {
|
||||
let coef1: [u8; N] = rand_utils::rand_array();
|
||||
let coef2: [u8; N] = rand_utils::rand_array();
|
||||
|
||||
let poly1 = Polynomial::new(coef1.iter().map(|&a| FalconFelt::new(a as i16)).collect());
|
||||
let poly2 = Polynomial::new(coef2.iter().map(|&a| FalconFelt::new(a as i16)).collect());
|
||||
let prod = poly1.clone() * poly2.clone();
|
||||
|
||||
assert_eq!(
|
||||
prod.reduce_by_cyclotomic(N),
|
||||
Polynomial::reduce_negacyclic(&Polynomial::mul_modulo_p(&poly1, &poly2))
|
||||
);
|
||||
}
|
||||
}
|
||||
299
src/dsa/rpo_falcon512/math/samplerz.rs
Normal file
299
src/dsa/rpo_falcon512/math/samplerz.rs
Normal file
@@ -0,0 +1,299 @@
|
||||
use core::f64::consts::LN_2;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num::Float;
|
||||
use rand::Rng;
|
||||
|
||||
/// Samples an integer from {0, ..., 18} according to the distribution χ, which is close to
|
||||
/// the half-Gaussian distribution on the natural numbers with mean 0 and standard deviation
|
||||
/// equal to sigma_max.
|
||||
fn base_sampler(bytes: [u8; 9]) -> i16 {
|
||||
const RCDT: [u128; 18] = [
|
||||
3024686241123004913666,
|
||||
1564742784480091954050,
|
||||
636254429462080897535,
|
||||
199560484645026482916,
|
||||
47667343854657281903,
|
||||
8595902006365044063,
|
||||
1163297957344668388,
|
||||
117656387352093658,
|
||||
8867391802663976,
|
||||
496969357462633,
|
||||
20680885154299,
|
||||
638331848991,
|
||||
14602316184,
|
||||
247426747,
|
||||
3104126,
|
||||
28824,
|
||||
198,
|
||||
1,
|
||||
];
|
||||
let u = u128::from_be_bytes([vec![0u8; 7], bytes.to_vec()].concat().try_into().unwrap());
|
||||
RCDT.into_iter().filter(|r| u < *r).count() as i16
|
||||
}
|
||||
|
||||
/// Computes an integer approximation of 2^63 * ccs * exp(-x).
|
||||
fn approx_exp(x: f64, ccs: f64) -> u64 {
|
||||
// The constants C are used to approximate exp(-x); these
|
||||
// constants are taken from FACCT (up to a scaling factor
|
||||
// of 2^63):
|
||||
// https://eprint.iacr.org/2018/1234
|
||||
// https://github.com/raykzhao/gaussian
|
||||
const C: [u64; 13] = [
|
||||
0x00000004741183a3u64,
|
||||
0x00000036548cfc06u64,
|
||||
0x0000024fdcbf140au64,
|
||||
0x0000171d939de045u64,
|
||||
0x0000d00cf58f6f84u64,
|
||||
0x000680681cf796e3u64,
|
||||
0x002d82d8305b0feau64,
|
||||
0x011111110e066fd0u64,
|
||||
0x0555555555070f00u64,
|
||||
0x155555555581ff00u64,
|
||||
0x400000000002b400u64,
|
||||
0x7fffffffffff4800u64,
|
||||
0x8000000000000000u64,
|
||||
];
|
||||
|
||||
let mut z: u64;
|
||||
let mut y: u64;
|
||||
let twoe63 = 1u64 << 63;
|
||||
|
||||
y = C[0];
|
||||
z = f64::floor(x * (twoe63 as f64)) as u64;
|
||||
for cu in C.iter().skip(1) {
|
||||
let zy = (z as u128) * (y as u128);
|
||||
y = cu - ((zy >> 63) as u64);
|
||||
}
|
||||
|
||||
z = f64::floor((twoe63 as f64) * ccs) as u64;
|
||||
|
||||
(((z as u128) * (y as u128)) >> 63) as u64
|
||||
}
|
||||
|
||||
/// A random bool that is true with probability ≈ ccs · exp(-x).
|
||||
fn ber_exp(x: f64, ccs: f64, random_bytes: [u8; 7]) -> bool {
|
||||
// 0.69314718055994530941 = ln(2)
|
||||
let s = f64::floor(x / LN_2) as usize;
|
||||
let r = x - LN_2 * (s as f64);
|
||||
let shamt = usize::min(s, 63);
|
||||
let z = ((((approx_exp(r, ccs) as u128) << 1) - 1) >> shamt) as u64;
|
||||
let mut w = 0i16;
|
||||
for (index, i) in (0..64).step_by(8).rev().enumerate() {
|
||||
let byte = random_bytes[index];
|
||||
w = (byte as i16) - (((z >> i) & 0xff) as i16);
|
||||
if w != 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
w < 0
|
||||
}
|
||||
|
||||
/// Samples an integer from the Gaussian distribution with given mean (mu) and standard deviation
|
||||
/// (sigma).
|
||||
pub(crate) fn sampler_z<R: Rng>(mu: f64, sigma: f64, sigma_min: f64, rng: &mut R) -> i16 {
|
||||
const SIGMA_MAX: f64 = 1.8205;
|
||||
const INV_2SIGMA_MAX_SQ: f64 = 1f64 / (2f64 * SIGMA_MAX * SIGMA_MAX);
|
||||
let isigma = 1f64 / sigma;
|
||||
let dss = 0.5f64 * isigma * isigma;
|
||||
let s = f64::floor(mu);
|
||||
let r = mu - s;
|
||||
let ccs = sigma_min * isigma;
|
||||
loop {
|
||||
let z0 = base_sampler(rng.gen());
|
||||
let random_byte: u8 = rng.gen();
|
||||
let b = (random_byte & 1) as i16;
|
||||
let z = b + ((b << 1) - 1) * z0;
|
||||
let zf_min_r = (z as f64) - r;
|
||||
// x = ((z-r)^2)/(2*sigma^2) - ((z-b)^2)/(2*sigma0^2)
|
||||
let x = zf_min_r * zf_min_r * dss - (z0 * z0) as f64 * INV_2SIGMA_MAX_SQ;
|
||||
if ber_exp(x, ccs, rng.gen()) {
|
||||
return z + (s as i16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod test {
|
||||
use alloc::vec::Vec;
|
||||
use std::{thread::sleep, time::Duration};
|
||||
|
||||
use rand::RngCore;
|
||||
|
||||
use super::{approx_exp, ber_exp, sampler_z};
|
||||
|
||||
/// RNG used only for testing purposes, whereby the produced
|
||||
/// string of random bytes is equal to the one it is initialized
|
||||
/// with. Whatever you do, do not use this RNG in production.
|
||||
struct UnsafeBufferRng {
|
||||
buffer: Vec<u8>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl UnsafeBufferRng {
|
||||
fn new(buffer: &[u8]) -> Self {
|
||||
Self { buffer: buffer.to_vec(), index: 0 }
|
||||
}
|
||||
|
||||
fn next(&mut self) -> u8 {
|
||||
if self.buffer.len() <= self.index {
|
||||
// panic!("Ran out of buffer.");
|
||||
sleep(Duration::from_millis(10));
|
||||
0u8
|
||||
} else {
|
||||
let return_value = self.buffer[self.index];
|
||||
self.index += 1;
|
||||
return_value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RngCore for UnsafeBufferRng {
|
||||
fn next_u32(&mut self) -> u32 {
|
||||
// let bytes: [u8; 4] = (0..4)
|
||||
// .map(|_| self.next())
|
||||
// .collect_vec()
|
||||
// .try_into()
|
||||
// .unwrap();
|
||||
// u32::from_be_bytes(bytes)
|
||||
u32::from_le_bytes([self.next(), 0, 0, 0])
|
||||
}
|
||||
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
// let bytes: [u8; 8] = (0..8)
|
||||
// .map(|_| self.next())
|
||||
// .collect_vec()
|
||||
// .try_into()
|
||||
// .unwrap();
|
||||
// u64::from_be_bytes(bytes)
|
||||
u64::from_le_bytes([self.next(), 0, 0, 0, 0, 0, 0, 0])
|
||||
}
|
||||
|
||||
fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
for d in dest.iter_mut() {
|
||||
*d = self.next();
|
||||
}
|
||||
}
|
||||
|
||||
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
|
||||
for d in dest.iter_mut() {
|
||||
*d = self.next();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsafe_buffer_rng() {
|
||||
let seed_bytes = hex::decode("7FFECD162AE2").unwrap();
|
||||
let mut rng = UnsafeBufferRng::new(&seed_bytes);
|
||||
let generated_bytes: Vec<u8> = (0..seed_bytes.len()).map(|_| rng.next()).collect();
|
||||
assert_eq!(seed_bytes, generated_bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_approx_exp() {
|
||||
let precision = 1u64 << 14;
|
||||
// known answers were generated with the following sage script:
|
||||
//```sage
|
||||
// num_samples = 10
|
||||
// precision = 200
|
||||
// R = Reals(precision)
|
||||
//
|
||||
// print(f"let kats : [(f64, f64, u64);{num_samples}] = [")
|
||||
// for i in range(num_samples):
|
||||
// x = RDF.random_element(0.0, 0.693147180559945)
|
||||
// ccs = RDF.random_element(0.0, 1.0)
|
||||
// res = round(2^63 * R(ccs) * exp(R(-x)))
|
||||
// print(f"({x}, {ccs}, {res}),")
|
||||
// print("];")
|
||||
// ```
|
||||
let kats: [(f64, f64, u64); 10] = [
|
||||
(0.2314993926072656, 0.8148006314615972, 5962140072160879737),
|
||||
(0.2648875572812225, 0.12769669655309035, 903712282351034505),
|
||||
(0.11251957513682391, 0.9264611470305881, 7635725498677341553),
|
||||
(0.04353439307256617, 0.5306497137523327, 4685877322232397936),
|
||||
(0.41834495299784347, 0.879438856118578, 5338392138535350986),
|
||||
(0.32579398973228557, 0.16513412873289002, 1099603299296456803),
|
||||
(0.5939508073919817, 0.029776019144967303, 151637565622779016),
|
||||
(0.2932367999399056, 0.37123847662857923, 2553827649386670452),
|
||||
(0.5005699297417507, 0.31447208863888976, 1758235618083658825),
|
||||
(0.4876437338498085, 0.6159515298936868, 3488632981903743976),
|
||||
];
|
||||
for (x, ccs, answer) in kats {
|
||||
let difference = (answer as i128) - (approx_exp(x, ccs) as i128);
|
||||
assert!(
|
||||
(difference * difference) as u64 <= precision * precision,
|
||||
"answer: {answer} versus approximation: {}\ndifference: {} whereas precision: {}",
|
||||
approx_exp(x, ccs),
|
||||
difference,
|
||||
precision
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ber_exp() {
|
||||
let kats = [
|
||||
(
|
||||
1.268_314_048_020_498_4,
|
||||
0.749_990_853_267_664_9,
|
||||
hex::decode("ea000000000000").unwrap(),
|
||||
false,
|
||||
),
|
||||
(
|
||||
0.001_563_917_959_143_409_6,
|
||||
0.749_990_853_267_664_9,
|
||||
hex::decode("6c000000000000").unwrap(),
|
||||
true,
|
||||
),
|
||||
(
|
||||
0.017_921_215_753_999_235,
|
||||
0.749_990_853_267_664_9,
|
||||
hex::decode("c2000000000000").unwrap(),
|
||||
false,
|
||||
),
|
||||
(
|
||||
0.776_117_648_844_980_6,
|
||||
0.751_181_554_542_520_8,
|
||||
hex::decode("58000000000000").unwrap(),
|
||||
true,
|
||||
),
|
||||
];
|
||||
for (x, ccs, bytes, answer) in kats {
|
||||
assert_eq!(answer, ber_exp(x, ccs, bytes.try_into().unwrap()));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sampler_z() {
|
||||
let sigma_min = 1.277833697;
|
||||
// known answers from the doc, table 3.2, page 44
|
||||
// https://falcon-sign.info/falcon.pdf
|
||||
// The zeros were added to account for dropped bytes.
|
||||
let kats = [
|
||||
(-91.90471153063714,1.7037990414754918,hex::decode("0fc5442ff043d66e91d1ea000000000000cac64ea5450a22941edc6c").unwrap(),-92),
|
||||
(-8.322564895434937,1.7037990414754918,hex::decode("f4da0f8d8444d1a77265c2000000000000ef6f98bbbb4bee7db8d9b3").unwrap(),-8),
|
||||
(-19.096516109216804,1.7035823083824078,hex::decode("db47f6d7fb9b19f25c36d6000000000000b9334d477a8bc0be68145d").unwrap(),-20),
|
||||
(-11.335543982423326, 1.7035823083824078, hex::decode("ae41b4f5209665c74d00dc000000000000c1a8168a7bb516b3190cb42c1ded26cd52000000000000aed770eca7dd334e0547bcc3c163ce0b").unwrap(), -12),
|
||||
(7.9386734193997555, 1.6984647769450156, hex::decode("31054166c1012780c603ae0000000000009b833cec73f2f41ca5807c000000000000c89c92158834632f9b1555").unwrap(), 8),
|
||||
(-28.990850086867255, 1.6984647769450156, hex::decode("737e9d68a50a06dbbc6477").unwrap(), -30),
|
||||
(-9.071257914091655, 1.6980782114808988, hex::decode("a98ddd14bf0bf22061d632").unwrap(), -10),
|
||||
(-43.88754568839566, 1.6980782114808988, hex::decode("3cbf6818a68f7ab9991514").unwrap(), -41),
|
||||
(-58.17435547946095,1.7010983419195522,hex::decode("6f8633f5bfa5d26848668e0000000000003d5ddd46958e97630410587c").unwrap(),-61),
|
||||
(-43.58664906684732, 1.7010983419195522, hex::decode("272bc6c25f5c5ee53f83c40000000000003a361fbc7cc91dc783e20a").unwrap(), -46),
|
||||
(-34.70565203313315, 1.7009387219711465, hex::decode("45443c59574c2c3b07e2e1000000000000d9071e6d133dbe32754b0a").unwrap(), -34),
|
||||
(-44.36009577368896, 1.7009387219711465, hex::decode("6ac116ed60c258e2cbaeab000000000000728c4823e6da36e18d08da0000000000005d0cc104e21cc7fd1f5ca8000000000000d9dbb675266c928448059e").unwrap(), -44),
|
||||
(-21.783037079346236, 1.6958406126012802, hex::decode("68163bc1e2cbf3e18e7426").unwrap(), -23),
|
||||
(-39.68827784633828, 1.6958406126012802, hex::decode("d6a1b51d76222a705a0259").unwrap(), -40),
|
||||
(-18.488607061056847, 1.6955259305261838, hex::decode("f0523bfaa8a394bf4ea5c10000000000000f842366fde286d6a30803").unwrap(), -22),
|
||||
(-48.39610939101591, 1.6955259305261838, hex::decode("87bd87e63374cee62127fc0000000000006931104aab64f136a0485b").unwrap(), -50),
|
||||
];
|
||||
for (mu, sigma, random_bytes, answer) in kats {
|
||||
assert_eq!(
|
||||
sampler_z(mu, sigma, sigma_min, &mut UnsafeBufferRng::new(&random_bytes)),
|
||||
answer
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,33 +4,33 @@ use crate::{
|
||||
Felt, Word, ZERO,
|
||||
};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
mod ffi;
|
||||
|
||||
mod error;
|
||||
mod hash_to_point;
|
||||
mod keys;
|
||||
mod polynomial;
|
||||
mod math;
|
||||
mod signature;
|
||||
|
||||
pub use error::FalconError;
|
||||
pub use keys::{KeyPair, PublicKey};
|
||||
pub use polynomial::Polynomial;
|
||||
pub use signature::Signature;
|
||||
pub use self::{
|
||||
keys::{PubKeyPoly, PublicKey, SecretKey},
|
||||
math::Polynomial,
|
||||
signature::{Signature, SignatureHeader, SignaturePoly},
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
// The Falcon modulus.
|
||||
const MODULUS: u16 = 12289;
|
||||
const MODULUS_MINUS_1_OVER_TWO: u16 = 6144;
|
||||
// The Falcon modulus p.
|
||||
const MODULUS: i16 = 12289;
|
||||
|
||||
// Number of bits needed to encode an element in the Falcon field.
|
||||
const FALCON_ENCODING_BITS: u32 = 14;
|
||||
|
||||
// The Falcon parameters for Falcon-512. This is the degree of the polynomial `phi := x^N + 1`
|
||||
// defining the ring Z_p[x]/(phi).
|
||||
const N: usize = 512;
|
||||
const LOG_N: usize = 9;
|
||||
const LOG_N: u8 = 9;
|
||||
|
||||
/// Length of nonce used for key-pair generation.
|
||||
const NONCE_LEN: usize = 40;
|
||||
const SIG_NONCE_LEN: usize = 40;
|
||||
|
||||
/// Number of filed elements used to encode a nonce.
|
||||
const NONCE_ELEMENTS: usize = 8;
|
||||
@@ -42,16 +42,64 @@ pub const PK_LEN: usize = 897;
|
||||
pub const SK_LEN: usize = 1281;
|
||||
|
||||
/// Signature length as a u8 vector.
|
||||
const SIG_LEN: usize = 626;
|
||||
const SIG_POLY_BYTE_LEN: usize = 625;
|
||||
|
||||
/// Bound on the squared-norm of the signature.
|
||||
const SIG_L2_BOUND: u64 = 34034726;
|
||||
|
||||
/// Standard deviation of the Gaussian over the lattice.
|
||||
const SIGMA: f64 = 165.7366171829776;
|
||||
|
||||
// TYPE ALIASES
|
||||
// ================================================================================================
|
||||
|
||||
type SignatureBytes = [u8; NONCE_LEN + SIG_LEN];
|
||||
type PublicKeyBytes = [u8; PK_LEN];
|
||||
type SecretKeyBytes = [u8; SK_LEN];
|
||||
type NonceBytes = [u8; NONCE_LEN];
|
||||
type NonceElements = [Felt; NONCE_ELEMENTS];
|
||||
type ShortLatticeBasis = [Polynomial<i16>; 4];
|
||||
|
||||
// NONCE
|
||||
// ================================================================================================
|
||||
|
||||
/// Nonce of the Falcon signature.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Nonce([u8; SIG_NONCE_LEN]);
|
||||
|
||||
impl Nonce {
|
||||
/// Returns a new [Nonce] instantiated from the provided bytes.
|
||||
pub fn new(bytes: [u8; SIG_NONCE_LEN]) -> Self {
|
||||
Self(bytes)
|
||||
}
|
||||
|
||||
/// Returns the underlying bytes of this nonce.
|
||||
pub fn as_bytes(&self) -> &[u8; SIG_NONCE_LEN] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// Converts byte representation of the nonce into field element representation.
|
||||
///
|
||||
/// Nonce bytes are converted to field elements by taking consecutive 5 byte chunks
|
||||
/// of the nonce and interpreting them as field elements.
|
||||
pub fn to_elements(&self) -> [Felt; NONCE_ELEMENTS] {
|
||||
let mut buffer = [0_u8; 8];
|
||||
let mut result = [ZERO; 8];
|
||||
for (i, bytes) in self.0.chunks(5).enumerate() {
|
||||
buffer[..5].copy_from_slice(bytes);
|
||||
// we can safely (without overflow) create a new Felt from u64 value here since this
|
||||
// value contains at most 5 bytes
|
||||
result[i] = Felt::new(u64::from_le_bytes(buffer));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for &Nonce {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for Nonce {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let bytes = source.read()?;
|
||||
Ok(Self(bytes))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,279 +0,0 @@
|
||||
use core::ops::{Add, Mul, Sub};
|
||||
|
||||
use super::{FalconError, Felt, LOG_N, MODULUS, MODULUS_MINUS_1_OVER_TWO, N, PK_LEN};
|
||||
use crate::utils::collections::*;
|
||||
|
||||
// FALCON POLYNOMIAL
|
||||
// ================================================================================================
|
||||
|
||||
/// A polynomial over Z_p\[x\]/(phi) where phi := x^512 + 1
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub struct Polynomial([u16; N]);
|
||||
|
||||
impl Polynomial {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Constructs a new polynomial from a list of coefficients.
|
||||
///
|
||||
/// # Safety
|
||||
/// This constructor validates that the coefficients are in the valid range only in debug mode.
|
||||
pub unsafe fn new(data: [u16; N]) -> Self {
|
||||
for value in data {
|
||||
debug_assert!(value < MODULUS);
|
||||
}
|
||||
|
||||
Self(data)
|
||||
}
|
||||
|
||||
/// Decodes raw bytes representing a public key into a polynomial in Z_p\[x\]/(phi).
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The provided input is not exactly 897 bytes long.
|
||||
/// - The first byte of the input is not equal to log2(512) i.e., 9.
|
||||
/// - Any of the coefficients encoded in the provided input is greater than or equal to the
|
||||
/// Falcon field modulus.
|
||||
pub fn from_pub_key(input: &[u8]) -> Result<Self, FalconError> {
|
||||
if input.len() != PK_LEN {
|
||||
return Err(FalconError::PubKeyDecodingInvalidLength(input.len()));
|
||||
}
|
||||
|
||||
if input[0] != LOG_N as u8 {
|
||||
return Err(FalconError::PubKeyDecodingInvalidTag(input[0]));
|
||||
}
|
||||
|
||||
let mut acc = 0_u32;
|
||||
let mut acc_len = 0;
|
||||
|
||||
let mut output = [0_u16; N];
|
||||
let mut output_idx = 0;
|
||||
|
||||
for &byte in input.iter().skip(1) {
|
||||
acc = (acc << 8) | (byte as u32);
|
||||
acc_len += 8;
|
||||
|
||||
if acc_len >= 14 {
|
||||
acc_len -= 14;
|
||||
let w = (acc >> acc_len) & 0x3FFF;
|
||||
if w >= MODULUS as u32 {
|
||||
return Err(FalconError::PubKeyDecodingInvalidCoefficient(w));
|
||||
}
|
||||
output[output_idx] = w as u16;
|
||||
output_idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (acc & ((1u32 << acc_len) - 1)) == 0 {
|
||||
Ok(Self(output))
|
||||
} else {
|
||||
Err(FalconError::PubKeyDecodingExtraData)
|
||||
}
|
||||
}
|
||||
|
||||
/// Decodes the signature into the coefficients of a polynomial in Z_p\[x\]/(phi). It assumes
|
||||
/// that the signature has been encoded using the uncompressed format.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The signature has been encoded using a different algorithm than the reference compressed
|
||||
/// encoding algorithm.
|
||||
/// - The encoded signature polynomial is in Z_p\[x\]/(phi') where phi' = x^N' + 1 and N' != 512.
|
||||
/// - While decoding the high bits of a coefficient, the current accumulated value of its
|
||||
/// high bits is larger than 2048.
|
||||
/// - The decoded coefficient is -0.
|
||||
/// - The remaining unused bits in the last byte of `input` are non-zero.
|
||||
pub fn from_signature(input: &[u8]) -> Result<Self, FalconError> {
|
||||
let (encoding, log_n) = (input[0] >> 4, input[0] & 0b00001111);
|
||||
|
||||
if encoding != 0b0011 {
|
||||
return Err(FalconError::SigDecodingIncorrectEncodingAlgorithm);
|
||||
}
|
||||
if log_n != 0b1001 {
|
||||
return Err(FalconError::SigDecodingNotSupportedDegree(log_n));
|
||||
}
|
||||
|
||||
let input = &input[41..];
|
||||
let mut input_idx = 0;
|
||||
let mut acc = 0u32;
|
||||
let mut acc_len = 0;
|
||||
let mut output = [0_u16; N];
|
||||
|
||||
for e in output.iter_mut() {
|
||||
acc = (acc << 8) | (input[input_idx] as u32);
|
||||
input_idx += 1;
|
||||
let b = acc >> acc_len;
|
||||
let s = b & 128;
|
||||
let mut m = b & 127;
|
||||
|
||||
loop {
|
||||
if acc_len == 0 {
|
||||
acc = (acc << 8) | (input[input_idx] as u32);
|
||||
input_idx += 1;
|
||||
acc_len = 8;
|
||||
}
|
||||
acc_len -= 1;
|
||||
if ((acc >> acc_len) & 1) != 0 {
|
||||
break;
|
||||
}
|
||||
m += 128;
|
||||
if m >= 2048 {
|
||||
return Err(FalconError::SigDecodingTooBigHighBits(m));
|
||||
}
|
||||
}
|
||||
if s != 0 && m == 0 {
|
||||
return Err(FalconError::SigDecodingMinusZero);
|
||||
}
|
||||
|
||||
*e = if s != 0 { (MODULUS as u32 - m) as u16 } else { m as u16 };
|
||||
}
|
||||
|
||||
if (acc & ((1 << acc_len) - 1)) != 0 {
|
||||
return Err(FalconError::SigDecodingNonZeroUnusedBitsLastByte);
|
||||
}
|
||||
|
||||
Ok(Self(output))
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the coefficients of this polynomial as integers.
|
||||
pub fn inner(&self) -> [u16; N] {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// Returns the coefficients of this polynomial as field elements.
|
||||
pub fn to_elements(&self) -> Vec<Felt> {
|
||||
self.0.iter().map(|&a| Felt::from(a)).collect()
|
||||
}
|
||||
|
||||
// POLYNOMIAL OPERATIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Multiplies two polynomials over Z_p\[x\] without reducing modulo p. Given that the degrees
|
||||
/// of the input polynomials are less than 512 and their coefficients are less than the modulus
|
||||
/// q equal to 12289, the resulting product polynomial is guaranteed to have coefficients less
|
||||
/// than the Miden prime.
|
||||
///
|
||||
/// Note that this multiplication is not over Z_p\[x\]/(phi).
|
||||
pub fn mul_modulo_p(a: &Self, b: &Self) -> [u64; 1024] {
|
||||
let mut c = [0; 2 * N];
|
||||
for i in 0..N {
|
||||
for j in 0..N {
|
||||
c[i + j] += a.0[i] as u64 * b.0[j] as u64;
|
||||
}
|
||||
}
|
||||
|
||||
c
|
||||
}
|
||||
|
||||
/// Reduces a polynomial, that is the product of two polynomials over Z_p\[x\], modulo
|
||||
/// the irreducible polynomial phi. This results in an element in Z_p\[x\]/(phi).
|
||||
pub fn reduce_negacyclic(a: &[u64; 1024]) -> Self {
|
||||
let mut c = [0; N];
|
||||
for i in 0..N {
|
||||
let ai = a[N + i] % MODULUS as u64;
|
||||
let neg_ai = (MODULUS - ai as u16) % MODULUS;
|
||||
|
||||
let bi = (a[i] % MODULUS as u64) as u16;
|
||||
c[i] = (neg_ai + bi) % MODULUS;
|
||||
}
|
||||
|
||||
Self(c)
|
||||
}
|
||||
|
||||
/// Computes the norm squared of a polynomial in Z_p\[x\]/(phi) after normalizing its
|
||||
/// coefficients to be in the interval (-p/2, p/2].
|
||||
pub fn sq_norm(&self) -> u64 {
|
||||
let mut res = 0;
|
||||
for e in self.0 {
|
||||
if e > MODULUS_MINUS_1_OVER_TWO {
|
||||
res += (MODULUS - e) as u64 * (MODULUS - e) as u64
|
||||
} else {
|
||||
res += e as u64 * e as u64
|
||||
}
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a polynomial representing the zero polynomial i.e. default element.
|
||||
impl Default for Polynomial {
|
||||
fn default() -> Self {
|
||||
Self([0_u16; N])
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiplication over Z_p\[x\]/(phi)
|
||||
impl Mul for Polynomial {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, other: Self) -> <Self as Mul<Self>>::Output {
|
||||
let mut result = [0_u16; N];
|
||||
for j in 0..N {
|
||||
for k in 0..N {
|
||||
let i = (j + k) % N;
|
||||
let a = self.0[j] as usize;
|
||||
let b = other.0[k] as usize;
|
||||
let q = MODULUS as usize;
|
||||
let mut prod = a * b % q;
|
||||
if (N - 1) < (j + k) {
|
||||
prod = (q - prod) % q;
|
||||
}
|
||||
result[i] = ((result[i] as usize + prod) % q) as u16;
|
||||
}
|
||||
}
|
||||
|
||||
Polynomial(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Addition over Z_p\[x\]/(phi)
|
||||
impl Add for Polynomial {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, other: Self) -> <Self as Add<Self>>::Output {
|
||||
let mut res = self;
|
||||
res.0.iter_mut().zip(other.0.iter()).for_each(|(x, y)| *x = (*x + *y) % MODULUS);
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
/// Subtraction over Z_p\[x\]/(phi)
|
||||
impl Sub for Polynomial {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, other: Self) -> <Self as Add<Self>>::Output {
|
||||
let mut res = self;
|
||||
res.0
|
||||
.iter_mut()
|
||||
.zip(other.0.iter())
|
||||
.for_each(|(x, y)| *x = (*x + MODULUS - *y) % MODULUS);
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Polynomial, N};
|
||||
|
||||
#[test]
|
||||
fn test_negacyclic_reduction() {
|
||||
let coef1: [u16; N] = rand_utils::rand_array();
|
||||
let coef2: [u16; N] = rand_utils::rand_array();
|
||||
|
||||
let poly1 = Polynomial(coef1);
|
||||
let poly2 = Polynomial(coef2);
|
||||
|
||||
assert_eq!(
|
||||
poly1 * poly2,
|
||||
Polynomial::reduce_negacyclic(&Polynomial::mul_modulo_p(&poly1, &poly2))
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,286 +1,375 @@
|
||||
use core::cell::OnceCell;
|
||||
use alloc::{string::ToString, vec::Vec};
|
||||
use core::ops::Deref;
|
||||
|
||||
use num::Zero;
|
||||
|
||||
use super::{
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, Felt, NonceBytes, NonceElements,
|
||||
Polynomial, PublicKeyBytes, Rpo256, Serializable, SignatureBytes, Word, MODULUS, N,
|
||||
SIG_L2_BOUND, ZERO,
|
||||
hash_to_point::hash_to_point_rpo256,
|
||||
keys::PubKeyPoly,
|
||||
math::{FalconFelt, FastFft, Polynomial},
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, Felt, Nonce, Rpo256,
|
||||
Serializable, Word, LOG_N, MODULUS, N, SIG_L2_BOUND, SIG_POLY_BYTE_LEN,
|
||||
};
|
||||
use crate::utils::string::*;
|
||||
|
||||
// FALCON SIGNATURE
|
||||
// ================================================================================================
|
||||
|
||||
/// An RPO Falcon512 signature over a message.
|
||||
///
|
||||
/// The signature is a pair of polynomials (s1, s2) in (Z_p\[x\]/(phi))^2, where:
|
||||
/// The signature is a pair of polynomials (s1, s2) in (Z_p\[x\]/(phi))^2 a nonce `r`, and a public
|
||||
/// key polynomial `h` where:
|
||||
/// - p := 12289
|
||||
/// - phi := x^512 + 1
|
||||
/// - s1 = c - s2 * h
|
||||
/// - h is a polynomial representing the public key and c is a polynomial that is the hash-to-point
|
||||
/// of the message being signed.
|
||||
///
|
||||
/// The signature verifies if and only if:
|
||||
/// The signature verifies against a public key `pk` if and only if:
|
||||
/// 1. s1 = c - s2 * h
|
||||
/// 2. |s1|^2 + |s2|^2 <= SIG_L2_BOUND
|
||||
///
|
||||
/// where |.| is the norm.
|
||||
/// where |.| is the norm and:
|
||||
/// - c = HashToPoint(r || message)
|
||||
/// - pk = Rpo256::hash(h)
|
||||
///
|
||||
/// [Signature] also includes the extended public key which is serialized as:
|
||||
/// Here h is a polynomial representing the public key and pk is its digest using the Rpo256 hash
|
||||
/// function. c is a polynomial that is the hash-to-point of the message being signed.
|
||||
///
|
||||
/// The polynomial h is serialized as:
|
||||
/// 1. 1 byte representing the log2(512) i.e., 9.
|
||||
/// 2. 896 bytes for the public key. This is decoded into the `h` polynomial above.
|
||||
/// 2. 896 bytes for the public key itself.
|
||||
///
|
||||
/// The actual signature is serialized as:
|
||||
/// The signature is serialized as:
|
||||
/// 1. A header byte specifying the algorithm used to encode the coefficients of the `s2` polynomial
|
||||
/// together with the degree of the irreducible polynomial phi.
|
||||
/// The general format of this byte is 0b0cc1nnnn where:
|
||||
/// a. cc is either 01 when the compressed encoding algorithm is used and 10 when the
|
||||
/// uncompressed algorithm is used.
|
||||
/// b. nnnn is log2(N) where N is the degree of the irreducible polynomial phi.
|
||||
/// The current implementation works always with cc equal to 0b01 and nnnn equal to 0b1001 and
|
||||
/// thus the header byte is always equal to 0b00111001.
|
||||
/// together with the degree of the irreducible polynomial phi. For RPO Falcon512, the header
|
||||
/// byte is set to `10111001` which differentiates it from the standardized instantiation of the
|
||||
/// Falcon signature.
|
||||
/// 2. 40 bytes for the nonce.
|
||||
/// 3. 625 bytes encoding the `s2` polynomial above.
|
||||
/// 4. 625 bytes encoding the `s2` polynomial above.
|
||||
///
|
||||
/// The total size of the signature (including the extended public key) is 1563 bytes.
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Signature {
|
||||
pub(super) pk: PublicKeyBytes,
|
||||
pub(super) sig: SignatureBytes,
|
||||
|
||||
// Cached polynomial decoding for public key and signatures
|
||||
pub(super) pk_polynomial: OnceCell<Polynomial>,
|
||||
pub(super) sig_polynomial: OnceCell<Polynomial>,
|
||||
header: SignatureHeader,
|
||||
nonce: Nonce,
|
||||
s2: SignaturePoly,
|
||||
h: PubKeyPoly,
|
||||
}
|
||||
|
||||
impl Signature {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
pub fn new(nonce: Nonce, h: PubKeyPoly, s2: SignaturePoly) -> Signature {
|
||||
Self {
|
||||
header: SignatureHeader::default(),
|
||||
nonce,
|
||||
s2,
|
||||
h,
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the public key polynomial h.
|
||||
pub fn pub_key_poly(&self) -> Polynomial {
|
||||
*self.pk_polynomial.get_or_init(|| {
|
||||
// we assume that the signature was constructed with a valid public key, and thus
|
||||
// expect() is OK here.
|
||||
Polynomial::from_pub_key(&self.pk).expect("invalid public key")
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the nonce component of the signature represented as field elements.
|
||||
///
|
||||
/// Nonce bytes are converted to field elements by taking consecutive 5 byte chunks
|
||||
/// of the nonce and interpreting them as field elements.
|
||||
pub fn nonce(&self) -> NonceElements {
|
||||
// we assume that the signature was constructed with a valid signature, and thus
|
||||
// expect() is OK here.
|
||||
let nonce = self.sig[1..41].try_into().expect("invalid signature");
|
||||
decode_nonce(nonce)
|
||||
pub fn pk_poly(&self) -> &PubKeyPoly {
|
||||
&self.h
|
||||
}
|
||||
|
||||
// Returns the polynomial representation of the signature in Z_p[x]/(phi).
|
||||
pub fn sig_poly(&self) -> Polynomial {
|
||||
*self.sig_polynomial.get_or_init(|| {
|
||||
// we assume that the signature was constructed with a valid signature, and thus
|
||||
// expect() is OK here.
|
||||
Polynomial::from_signature(&self.sig).expect("invalid signature")
|
||||
})
|
||||
pub fn sig_poly(&self) -> &Polynomial<FalconFelt> {
|
||||
&self.s2
|
||||
}
|
||||
|
||||
// HASH-TO-POINT
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a polynomial in Z_p\[x\]/(phi) representing the hash of the provided message.
|
||||
pub fn hash_to_point(&self, message: Word) -> Polynomial {
|
||||
hash_to_point(message, &self.nonce())
|
||||
/// Returns the nonce component of the signature.
|
||||
pub fn nonce(&self) -> &Nonce {
|
||||
&self.nonce
|
||||
}
|
||||
|
||||
// SIGNATURE VERIFICATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if this signature is a valid signature for the specified message generated
|
||||
/// against key pair matching the specified public key commitment.
|
||||
/// against the secret key matching the specified public key commitment.
|
||||
pub fn verify(&self, message: Word, pubkey_com: Word) -> bool {
|
||||
// Make sure the expanded public key matches the provided public key commitment
|
||||
let h = self.pub_key_poly();
|
||||
let h_digest: Word = Rpo256::hash_elements(&h.to_elements()).into();
|
||||
// compute the hash of the public key polynomial
|
||||
let h_felt: Polynomial<Felt> = (&**self.pk_poly()).into();
|
||||
let h_digest: Word = Rpo256::hash_elements(&h_felt.coefficients).into();
|
||||
if h_digest != pubkey_com {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Make sure the signature is valid
|
||||
let s2 = self.sig_poly();
|
||||
let c = self.hash_to_point(message);
|
||||
|
||||
let s1 = c - s2 * h;
|
||||
|
||||
let sq_norm = s1.sq_norm() + s2.sq_norm();
|
||||
sq_norm <= SIG_L2_BOUND
|
||||
let c = hash_to_point_rpo256(message, &self.nonce);
|
||||
h_digest == pubkey_com && verify_helper(&c, &self.s2, self.pk_poly())
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for Signature {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.pk);
|
||||
target.write_bytes(&self.sig);
|
||||
target.write(&self.header);
|
||||
target.write(&self.nonce);
|
||||
target.write(&self.s2);
|
||||
target.write(&self.h);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for Signature {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let pk: PublicKeyBytes = source.read_array()?;
|
||||
let sig: SignatureBytes = source.read_array()?;
|
||||
let header = source.read()?;
|
||||
let nonce = source.read()?;
|
||||
let s2 = source.read()?;
|
||||
let h = source.read()?;
|
||||
|
||||
// make sure public key and signature can be decoded correctly
|
||||
let pk_polynomial = Polynomial::from_pub_key(&pk)
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?
|
||||
.into();
|
||||
let sig_polynomial = Polynomial::from_signature(&sig)
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?
|
||||
.into();
|
||||
Ok(Self { header, nonce, s2, h })
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self { pk, sig, pk_polynomial, sig_polynomial })
|
||||
// SIGNATURE HEADER
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SignatureHeader(u8);
|
||||
|
||||
impl Default for SignatureHeader {
|
||||
/// According to section 3.11.3 in the specification [1], the signature header has the format
|
||||
/// `0cc1nnnn` where:
|
||||
///
|
||||
/// 1. `cc` signifies the encoding method. `01` denotes using the compression encoding method
|
||||
/// and `10` denotes encoding using the uncompressed method.
|
||||
/// 2. `nnnn` encodes `LOG_N`.
|
||||
///
|
||||
/// For RPO Falcon 512 we use compression encoding and N = 512. Moreover, to differentiate the
|
||||
/// RPO Falcon variant from the reference variant using SHAKE256, we flip the first bit in the
|
||||
/// header. Thus, for RPO Falcon 512 the header is `10111001`
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
fn default() -> Self {
|
||||
Self(0b1011_1001)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for &SignatureHeader {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_u8(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SignatureHeader {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let header = source.read_u8()?;
|
||||
let (encoding, log_n) = (header >> 4, header & 0b00001111);
|
||||
if encoding != 0b1011 {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"Failed to decode signature: not supported encoding algorithm".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if log_n != LOG_N {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
format!("Failed to decode signature: only supported irreducible polynomial degree is 512, 2^{log_n} was provided")
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Self(header))
|
||||
}
|
||||
}
|
||||
|
||||
// SIGNATURE POLYNOMIAL
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SignaturePoly(pub Polynomial<FalconFelt>);
|
||||
|
||||
impl Deref for SignaturePoly {
|
||||
type Target = Polynomial<FalconFelt>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Polynomial<FalconFelt>> for SignaturePoly {
|
||||
fn from(pk_poly: Polynomial<FalconFelt>) -> Self {
|
||||
Self(pk_poly)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[i16; N]> for SignaturePoly {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(coefficients: &[i16; N]) -> Result<Self, Self::Error> {
|
||||
if are_coefficients_valid(coefficients) {
|
||||
Ok(Self(coefficients.to_vec().into()))
|
||||
} else {
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for &SignaturePoly {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
let sig_coeff: Vec<i16> = self.0.coefficients.iter().map(|a| a.balanced_value()).collect();
|
||||
let mut sk_bytes = vec![0_u8; SIG_POLY_BYTE_LEN];
|
||||
|
||||
let mut acc = 0;
|
||||
let mut acc_len = 0;
|
||||
let mut v = 0;
|
||||
let mut t;
|
||||
let mut w;
|
||||
|
||||
// For each coefficient of x:
|
||||
// - the sign is encoded on 1 bit
|
||||
// - the 7 lower bits are encoded naively (binary)
|
||||
// - the high bits are encoded in unary encoding
|
||||
//
|
||||
// Algorithm 17 p. 47 of the specification [1].
|
||||
//
|
||||
// [1]: https://falcon-sign.info/falcon.pdf
|
||||
for &c in sig_coeff.iter() {
|
||||
acc <<= 1;
|
||||
t = c;
|
||||
|
||||
if t < 0 {
|
||||
t = -t;
|
||||
acc |= 1;
|
||||
}
|
||||
w = t as u16;
|
||||
|
||||
acc <<= 7;
|
||||
let mask = 127_u32;
|
||||
acc |= (w as u32) & mask;
|
||||
w >>= 7;
|
||||
|
||||
acc_len += 8;
|
||||
|
||||
acc <<= w + 1;
|
||||
acc |= 1;
|
||||
acc_len += w + 1;
|
||||
|
||||
while acc_len >= 8 {
|
||||
acc_len -= 8;
|
||||
|
||||
sk_bytes[v] = (acc >> acc_len) as u8;
|
||||
v += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if acc_len > 0 {
|
||||
sk_bytes[v] = (acc << (8 - acc_len)) as u8;
|
||||
}
|
||||
target.write_bytes(&sk_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SignaturePoly {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let input = source.read_array::<SIG_POLY_BYTE_LEN>()?;
|
||||
|
||||
let mut input_idx = 0;
|
||||
let mut acc = 0u32;
|
||||
let mut acc_len = 0;
|
||||
let mut coefficients = [FalconFelt::zero(); N];
|
||||
|
||||
// Algorithm 18 p. 48 of the specification [1].
|
||||
//
|
||||
// [1]: https://falcon-sign.info/falcon.pdf
|
||||
for c in coefficients.iter_mut() {
|
||||
acc = (acc << 8) | (input[input_idx] as u32);
|
||||
input_idx += 1;
|
||||
let b = acc >> acc_len;
|
||||
let s = b & 128;
|
||||
let mut m = b & 127;
|
||||
|
||||
loop {
|
||||
if acc_len == 0 {
|
||||
acc = (acc << 8) | (input[input_idx] as u32);
|
||||
input_idx += 1;
|
||||
acc_len = 8;
|
||||
}
|
||||
acc_len -= 1;
|
||||
if ((acc >> acc_len) & 1) != 0 {
|
||||
break;
|
||||
}
|
||||
m += 128;
|
||||
if m >= 2048 {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"Failed to decode signature: high bits {m} exceed 2048".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
if s != 0 && m == 0 {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"Failed to decode signature: -0 is forbidden".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let felt = if s != 0 { (MODULUS as u32 - m) as u16 } else { m as u16 };
|
||||
*c = FalconFelt::new(felt as i16);
|
||||
}
|
||||
|
||||
if (acc & ((1 << acc_len) - 1)) != 0 {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"Failed to decode signature: Non-zero unused bits in the last byte".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(Polynomial::new(coefficients.to_vec()).into())
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
|
||||
/// nonce.
|
||||
fn hash_to_point(message: Word, nonce: &NonceElements) -> Polynomial {
|
||||
let mut state = [ZERO; Rpo256::STATE_WIDTH];
|
||||
/// Takes the hash-to-point polynomial `c` of a message, the signature polynomial over
|
||||
/// the message `s2` and a public key polynomial and returns `true` is the signature is a valid
|
||||
/// signature for the given parameters, otherwise it returns `false`.
|
||||
fn verify_helper(c: &Polynomial<FalconFelt>, s2: &SignaturePoly, h: &PubKeyPoly) -> bool {
|
||||
let h_fft = h.fft();
|
||||
let s2_fft = s2.fft();
|
||||
let c_fft = c.fft();
|
||||
|
||||
// absorb the nonce into the state
|
||||
for (&n, s) in nonce.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
||||
*s = n;
|
||||
}
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
// compute the signature polynomial s1 using s1 = c - s2 * h
|
||||
let s1_fft = c_fft - s2_fft.hadamard_mul(&h_fft);
|
||||
let s1 = s1_fft.ifft();
|
||||
|
||||
// absorb message into the state
|
||||
for (&m, s) in message.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
||||
*s = m;
|
||||
}
|
||||
// compute the norm squared of (s1, s2)
|
||||
let length_squared_s1 = s1.norm_squared();
|
||||
let length_squared_s2 = s2.norm_squared();
|
||||
let length_squared = length_squared_s1 + length_squared_s2;
|
||||
|
||||
// squeeze the coefficients of the polynomial
|
||||
let mut i = 0;
|
||||
let mut res = [0_u16; N];
|
||||
for _ in 0..64 {
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
for a in &state[Rpo256::RATE_RANGE] {
|
||||
res[i] = (a.as_int() % MODULUS as u64) as u16;
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// using the raw constructor is OK here because we reduce all coefficients by the modulus above
|
||||
unsafe { Polynomial::new(res) }
|
||||
length_squared < SIG_L2_BOUND
|
||||
}
|
||||
|
||||
/// Converts byte representation of the nonce into field element representation.
|
||||
fn decode_nonce(nonce: &NonceBytes) -> NonceElements {
|
||||
let mut buffer = [0_u8; 8];
|
||||
let mut result = [ZERO; 8];
|
||||
for (i, bytes) in nonce.chunks(5).enumerate() {
|
||||
buffer[..5].copy_from_slice(bytes);
|
||||
// we can safely (without overflow) create a new Felt from u64 value here since this value
|
||||
// contains at most 5 bytes
|
||||
result[i] = Felt::new(u64::from_le_bytes(buffer));
|
||||
/// Checks whether a set of coefficients is a valid one for a signature polynomial.
|
||||
fn are_coefficients_valid(x: &[i16]) -> bool {
|
||||
if x.len() != N {
|
||||
return false;
|
||||
}
|
||||
|
||||
result
|
||||
for &c in x {
|
||||
if !(-2047..=2047).contains(&c) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use libc::c_void;
|
||||
use rand_utils::rand_vector;
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
|
||||
use super::{
|
||||
super::{ffi::*, KeyPair},
|
||||
*,
|
||||
};
|
||||
|
||||
// Wrappers for unsafe functions
|
||||
impl Rpo128Context {
|
||||
/// Initializes the RPO state.
|
||||
pub fn init() -> Self {
|
||||
let mut ctx = Rpo128Context { content: [0u64; 13] };
|
||||
unsafe {
|
||||
rpo128_init(&mut ctx as *mut Rpo128Context);
|
||||
}
|
||||
ctx
|
||||
}
|
||||
|
||||
/// Absorbs data into the RPO state.
|
||||
pub fn absorb(&mut self, data: &[u8]) {
|
||||
unsafe {
|
||||
rpo128_absorb(
|
||||
self as *mut Rpo128Context,
|
||||
data.as_ptr() as *const c_void,
|
||||
data.len(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Finalizes the RPO state to prepare for squeezing.
|
||||
pub fn finalize(&mut self) {
|
||||
unsafe { rpo128_finalize(self as *mut Rpo128Context) }
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hash_to_point() {
|
||||
// Create a random message and transform it into a u8 vector
|
||||
let msg_felts: Word = rand_vector::<Felt>(4).try_into().unwrap();
|
||||
let msg_bytes = msg_felts.iter().flat_map(|e| e.as_int().to_le_bytes()).collect::<Vec<_>>();
|
||||
|
||||
// Create a nonce i.e. a [u8; 40] array and pack into a [Felt; 8] array.
|
||||
let nonce: [u8; 40] = rand_vector::<u8>(40).try_into().unwrap();
|
||||
|
||||
let mut buffer = [0_u8; 64];
|
||||
for i in 0..8 {
|
||||
buffer[8 * i] = nonce[5 * i];
|
||||
buffer[8 * i + 1] = nonce[5 * i + 1];
|
||||
buffer[8 * i + 2] = nonce[5 * i + 2];
|
||||
buffer[8 * i + 3] = nonce[5 * i + 3];
|
||||
buffer[8 * i + 4] = nonce[5 * i + 4];
|
||||
}
|
||||
|
||||
// Initialize the RPO state
|
||||
let mut rng = Rpo128Context::init();
|
||||
|
||||
// Absorb the nonce and message into the RPO state
|
||||
rng.absorb(&buffer);
|
||||
rng.absorb(&msg_bytes);
|
||||
rng.finalize();
|
||||
|
||||
// Generate the coefficients of the hash-to-point polynomial.
|
||||
let mut res: [u16; N] = [0; N];
|
||||
|
||||
unsafe {
|
||||
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(
|
||||
&mut rng as *mut Rpo128Context,
|
||||
res.as_mut_ptr(),
|
||||
9,
|
||||
);
|
||||
}
|
||||
|
||||
// Check that the coefficients are correct
|
||||
let nonce = decode_nonce(&nonce);
|
||||
assert_eq!(res, hash_to_point(msg_felts, &nonce).inner());
|
||||
}
|
||||
use super::{super::SecretKey, *};
|
||||
|
||||
#[test]
|
||||
fn test_serialization_round_trip() {
|
||||
let key = KeyPair::new().unwrap();
|
||||
let signature = key.sign(Word::default()).unwrap();
|
||||
let seed = [0_u8; 32];
|
||||
let mut rng = ChaCha20Rng::from_seed(seed);
|
||||
|
||||
let sk = SecretKey::with_rng(&mut rng);
|
||||
let signature = sk.sign_with_rng(Word::default(), &mut rng);
|
||||
let serialized = signature.to_bytes();
|
||||
let deserialized = Signature::read_from_bytes(&serialized).unwrap();
|
||||
assert_eq!(signature.sig_poly(), deserialized.sig_poly());
|
||||
assert_eq!(signature.pub_key_poly(), deserialized.pub_key_poly());
|
||||
}
|
||||
}
|
||||
|
||||
24
src/dsa/rpo_stark/mod.rs
Normal file
24
src/dsa/rpo_stark/mod.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
mod signature;
|
||||
pub use signature::{PublicKey, SecretKey, Signature};
|
||||
|
||||
mod stark;
|
||||
pub use stark::{PublicInputs, RescueAir};
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::SecretKey;
|
||||
use crate::Word;
|
||||
|
||||
#[test]
|
||||
fn test_signature() {
|
||||
let sk = SecretKey::new(Word::default());
|
||||
|
||||
let message = Word::default();
|
||||
let signature = sk.sign(message);
|
||||
let pk = sk.public_key();
|
||||
assert!(pk.verify(message, &signature))
|
||||
}
|
||||
}
|
||||
173
src/dsa/rpo_stark/signature/mod.rs
Normal file
173
src/dsa/rpo_stark/signature/mod.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
use rand::{distributions::Uniform, prelude::Distribution, Rng};
|
||||
use winter_air::{FieldExtension, ProofOptions};
|
||||
use winter_math::{fields::f64::BaseElement, FieldElement};
|
||||
use winter_prover::Proof;
|
||||
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
|
||||
use crate::{
|
||||
dsa::rpo_stark::stark::RpoSignatureScheme,
|
||||
hash::{rpo::Rpo256, DIGEST_SIZE},
|
||||
StarkField, Word, ZERO,
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Specifies the parameters of the STARK underlying the signature scheme. These parameters provide
|
||||
/// at least 102 bits of security under the conjectured security of the toy protocol in
|
||||
/// the ethSTARK paper [1].
|
||||
///
|
||||
/// [1]: https://eprint.iacr.org/2021/582
|
||||
pub const PROOF_OPTIONS: ProofOptions =
|
||||
ProofOptions::new(30, 8, 12, FieldExtension::Quadratic, 4, 7, true);
|
||||
|
||||
// PUBLIC KEY
|
||||
// ================================================================================================
|
||||
|
||||
/// A public key for verifying signatures.
|
||||
///
|
||||
/// The public key is a [Word] (i.e., 4 field elements) that is the hash of the secret key.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct PublicKey(Word);
|
||||
|
||||
impl PublicKey {
|
||||
/// Returns the [Word] defining the public key.
|
||||
pub fn inner(&self) -> Word {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl PublicKey {
|
||||
/// Verifies the provided signature against provided message and this public key.
|
||||
pub fn verify(&self, message: Word, signature: &Signature) -> bool {
|
||||
signature.verify(message, *self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for PublicKey {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
self.0.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for PublicKey {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let pk = <Word>::read_from(source)?;
|
||||
Ok(Self(pk))
|
||||
}
|
||||
}
|
||||
|
||||
// SECRET KEY
|
||||
// ================================================================================================
|
||||
|
||||
/// A secret key for generating signatures.
|
||||
///
|
||||
/// The secret key is a [Word] (i.e., 4 field elements).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct SecretKey(Word);
|
||||
|
||||
impl SecretKey {
|
||||
/// Generates a secret key from OS-provided randomness.
|
||||
pub fn new(word: Word) -> Self {
|
||||
Self(word)
|
||||
}
|
||||
|
||||
/// Generates a secret key from a [Word].
|
||||
#[cfg(feature = "std")]
|
||||
pub fn random() -> Self {
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
let mut rng = StdRng::from_entropy();
|
||||
Self::with_rng(&mut rng)
|
||||
}
|
||||
|
||||
/// Generates a secret_key using the provided random number generator `Rng`.
|
||||
pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
|
||||
let mut sk = [ZERO; 4];
|
||||
let uni_dist = Uniform::from(0..BaseElement::MODULUS);
|
||||
|
||||
for s in sk.iter_mut() {
|
||||
let sampled_integer = uni_dist.sample(rng);
|
||||
*s = BaseElement::new(sampled_integer);
|
||||
}
|
||||
|
||||
Self(sk)
|
||||
}
|
||||
|
||||
/// Computes the public key corresponding to this secret key.
|
||||
pub fn public_key(&self) -> PublicKey {
|
||||
let mut elements = [BaseElement::ZERO; 8];
|
||||
elements[..DIGEST_SIZE].copy_from_slice(&self.0);
|
||||
let pk = Rpo256::hash_elements(&elements);
|
||||
PublicKey(pk.into())
|
||||
}
|
||||
|
||||
/// Signs a message with this secret key.
|
||||
pub fn sign(&self, message: Word) -> Signature {
|
||||
let signature: RpoSignatureScheme<Rpo256> = RpoSignatureScheme::new(PROOF_OPTIONS);
|
||||
let proof = signature.sign(self.0, message);
|
||||
Signature { proof }
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for SecretKey {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
self.0.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SecretKey {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let sk = <Word>::read_from(source)?;
|
||||
Ok(Self(sk))
|
||||
}
|
||||
}
|
||||
|
||||
// SIGNATURE
|
||||
// ================================================================================================
|
||||
|
||||
/// An RPO STARK-based signature over a message.
|
||||
///
|
||||
/// The signature is a STARK proof of knowledge of a pre-image given an image where the map is
|
||||
/// the RPO permutation, the pre-image is the secret key and the image is the public key.
|
||||
/// The current implementation follows the description in [1] but relies on the conjectured security
|
||||
/// of the toy protocol in the ethSTARK paper [2], which gives us using the parameter set
|
||||
/// given in `PROOF_OPTIONS` a signature with $102$ bits of average-case existential unforgeability
|
||||
/// security against $2^{113}$-query bound adversaries that can obtain up to $2^{64}$ signatures
|
||||
/// under the same public key.
|
||||
///
|
||||
/// [1]: https://eprint.iacr.org/2024/1553
|
||||
/// [2]: https://eprint.iacr.org/2021/582
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Signature {
|
||||
proof: Proof,
|
||||
}
|
||||
|
||||
impl Signature {
|
||||
/// Returns the STARK proof constituting the signature.
|
||||
pub fn inner(&self) -> Proof {
|
||||
self.proof.clone()
|
||||
}
|
||||
|
||||
/// Returns true if this signature is a valid signature for the specified message generated
|
||||
/// against the secret key matching the specified public key.
|
||||
pub fn verify(&self, message: Word, pk: PublicKey) -> bool {
|
||||
let signature: RpoSignatureScheme<Rpo256> = RpoSignatureScheme::new(PROOF_OPTIONS);
|
||||
|
||||
let res = signature.verify(pk.inner(), message, self.proof.clone());
|
||||
res.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for Signature {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
self.proof.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for Signature {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let proof = Proof::read_from(source)?;
|
||||
Ok(Self { proof })
|
||||
}
|
||||
}
|
||||
198
src/dsa/rpo_stark/stark/air.rs
Normal file
198
src/dsa/rpo_stark/stark/air.rs
Normal file
@@ -0,0 +1,198 @@
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use winter_math::{fields::f64::BaseElement, FieldElement, ToElements};
|
||||
use winter_prover::{
|
||||
Air, AirContext, Assertion, EvaluationFrame, ProofOptions, TraceInfo,
|
||||
TransitionConstraintDegree,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
hash::{ARK1, ARK2, MDS, STATE_WIDTH},
|
||||
Word, ZERO,
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
pub const HASH_CYCLE_LEN: usize = 8;
|
||||
|
||||
// AIR
|
||||
// ================================================================================================
|
||||
|
||||
pub struct RescueAir {
|
||||
context: AirContext<BaseElement>,
|
||||
pub_key: Word,
|
||||
}
|
||||
|
||||
impl Air for RescueAir {
|
||||
type BaseField = BaseElement;
|
||||
type PublicInputs = PublicInputs;
|
||||
|
||||
type GkrProof = ();
|
||||
type GkrVerifier = ();
|
||||
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
fn new(trace_info: TraceInfo, pub_inputs: PublicInputs, options: ProofOptions) -> Self {
|
||||
let degrees = vec![
|
||||
// Apply RPO rounds.
|
||||
TransitionConstraintDegree::new(7),
|
||||
TransitionConstraintDegree::new(7),
|
||||
TransitionConstraintDegree::new(7),
|
||||
TransitionConstraintDegree::new(7),
|
||||
TransitionConstraintDegree::new(7),
|
||||
TransitionConstraintDegree::new(7),
|
||||
TransitionConstraintDegree::new(7),
|
||||
TransitionConstraintDegree::new(7),
|
||||
TransitionConstraintDegree::new(7),
|
||||
TransitionConstraintDegree::new(7),
|
||||
TransitionConstraintDegree::new(7),
|
||||
TransitionConstraintDegree::new(7),
|
||||
];
|
||||
assert_eq!(STATE_WIDTH, trace_info.width());
|
||||
let context = AirContext::new(trace_info, degrees, 12, options);
|
||||
let context = context.set_num_transition_exemptions(1);
|
||||
RescueAir { context, pub_key: pub_inputs.pub_key }
|
||||
}
|
||||
|
||||
fn context(&self) -> &AirContext<Self::BaseField> {
|
||||
&self.context
|
||||
}
|
||||
|
||||
fn evaluate_transition<E: FieldElement + From<Self::BaseField>>(
|
||||
&self,
|
||||
frame: &EvaluationFrame<E>,
|
||||
periodic_values: &[E],
|
||||
result: &mut [E],
|
||||
) {
|
||||
let current = frame.current();
|
||||
let next = frame.next();
|
||||
// expected state width is 12 field elements
|
||||
debug_assert_eq!(STATE_WIDTH, current.len());
|
||||
debug_assert_eq!(STATE_WIDTH, next.len());
|
||||
|
||||
enforce_rpo_round(frame, result, periodic_values);
|
||||
}
|
||||
|
||||
fn get_assertions(&self) -> Vec<Assertion<Self::BaseField>> {
|
||||
let initial_step = 0;
|
||||
let last_step = self.trace_length() - 1;
|
||||
vec![
|
||||
// Assert that the capacity as well as the second half of the rate portion of the state
|
||||
// are initialized to `ZERO`.The first half of the rate is unconstrained as it will
|
||||
// contain the secret key
|
||||
Assertion::single(0, initial_step, Self::BaseField::ZERO),
|
||||
Assertion::single(1, initial_step, Self::BaseField::ZERO),
|
||||
Assertion::single(2, initial_step, Self::BaseField::ZERO),
|
||||
Assertion::single(3, initial_step, Self::BaseField::ZERO),
|
||||
Assertion::single(8, initial_step, Self::BaseField::ZERO),
|
||||
Assertion::single(9, initial_step, Self::BaseField::ZERO),
|
||||
Assertion::single(10, initial_step, Self::BaseField::ZERO),
|
||||
Assertion::single(11, initial_step, Self::BaseField::ZERO),
|
||||
// Assert that the public key is the correct one
|
||||
Assertion::single(4, last_step, self.pub_key[0]),
|
||||
Assertion::single(5, last_step, self.pub_key[1]),
|
||||
Assertion::single(6, last_step, self.pub_key[2]),
|
||||
Assertion::single(7, last_step, self.pub_key[3]),
|
||||
]
|
||||
}
|
||||
|
||||
fn get_periodic_column_values(&self) -> Vec<Vec<Self::BaseField>> {
|
||||
get_round_constants()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PublicInputs {
|
||||
pub(crate) pub_key: Word,
|
||||
pub(crate) msg: Word,
|
||||
}
|
||||
|
||||
impl PublicInputs {
|
||||
pub fn new(pub_key: Word, msg: Word) -> Self {
|
||||
Self { pub_key, msg }
|
||||
}
|
||||
}
|
||||
|
||||
impl ToElements<BaseElement> for PublicInputs {
|
||||
fn to_elements(&self) -> Vec<BaseElement> {
|
||||
let mut res = self.pub_key.to_vec();
|
||||
res.extend_from_slice(self.msg.as_ref());
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER EVALUATORS
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
/// Enforces constraints for a single round of the Rescue Prime Optimized hash functions.
|
||||
pub fn enforce_rpo_round<E: FieldElement + From<BaseElement>>(
|
||||
frame: &EvaluationFrame<E>,
|
||||
result: &mut [E],
|
||||
ark: &[E],
|
||||
) {
|
||||
// compute the state that should result from applying the first 5 operations of the RPO round to
|
||||
// the current hash state.
|
||||
let mut step1 = [E::ZERO; STATE_WIDTH];
|
||||
step1.copy_from_slice(frame.current());
|
||||
|
||||
apply_mds(&mut step1);
|
||||
// add constants
|
||||
for i in 0..STATE_WIDTH {
|
||||
step1[i] += ark[i];
|
||||
}
|
||||
apply_sbox(&mut step1);
|
||||
apply_mds(&mut step1);
|
||||
// add constants
|
||||
for i in 0..STATE_WIDTH {
|
||||
step1[i] += ark[STATE_WIDTH + i];
|
||||
}
|
||||
|
||||
// compute the state that should result from applying the inverse of the last operation of the
|
||||
// RPO round to the next step of the computation.
|
||||
let mut step2 = [E::ZERO; STATE_WIDTH];
|
||||
step2.copy_from_slice(frame.next());
|
||||
apply_sbox(&mut step2);
|
||||
|
||||
// make sure that the results are equal.
|
||||
for i in 0..STATE_WIDTH {
|
||||
result[i] = step2[i] - step1[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_sbox<E: FieldElement + From<BaseElement>>(state: &mut [E; STATE_WIDTH]) {
|
||||
state.iter_mut().for_each(|v| {
|
||||
let t2 = v.square();
|
||||
let t4 = t2.square();
|
||||
*v *= t2 * t4;
|
||||
});
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_mds<E: FieldElement + From<BaseElement>>(state: &mut [E; STATE_WIDTH]) {
|
||||
let mut result = [E::ZERO; STATE_WIDTH];
|
||||
result.iter_mut().zip(MDS).for_each(|(r, mds_row)| {
|
||||
state.iter().zip(mds_row).for_each(|(&s, m)| {
|
||||
*r += E::from(m) * s;
|
||||
});
|
||||
});
|
||||
*state = result
|
||||
}
|
||||
|
||||
/// Returns RPO round constants arranged in column-major form.
|
||||
pub fn get_round_constants() -> Vec<Vec<BaseElement>> {
|
||||
let mut constants = Vec::new();
|
||||
for _ in 0..(STATE_WIDTH * 2) {
|
||||
constants.push(vec![ZERO; HASH_CYCLE_LEN]);
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in 0..HASH_CYCLE_LEN - 1 {
|
||||
for j in 0..STATE_WIDTH {
|
||||
constants[j][i] = ARK1[i][j];
|
||||
constants[j + STATE_WIDTH][i] = ARK2[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
constants
|
||||
}
|
||||
98
src/dsa/rpo_stark/stark/mod.rs
Normal file
98
src/dsa/rpo_stark/stark/mod.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
use alloc::vec::Vec;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use prover::RpoSignatureProver;
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
use winter_crypto::{ElementHasher, SaltedMerkleTree};
|
||||
use winter_math::fields::f64::BaseElement;
|
||||
use winter_prover::{Proof, ProofOptions, Prover};
|
||||
use winter_utils::Serializable;
|
||||
use winter_verifier::{verify, AcceptableOptions, VerifierError};
|
||||
|
||||
use crate::{
|
||||
hash::{rpo::Rpo256, DIGEST_SIZE},
|
||||
rand::RpoRandomCoin,
|
||||
};
|
||||
|
||||
mod air;
|
||||
pub use air::{PublicInputs, RescueAir};
|
||||
mod prover;
|
||||
|
||||
/// Represents an abstract STARK-based signature scheme with knowledge of RPO pre-image as
|
||||
/// the hard relation.
|
||||
pub struct RpoSignatureScheme<H: ElementHasher> {
|
||||
options: ProofOptions,
|
||||
_h: PhantomData<H>,
|
||||
}
|
||||
|
||||
impl<H: ElementHasher<BaseField = BaseElement> + Sync> RpoSignatureScheme<H> {
|
||||
pub fn new(options: ProofOptions) -> Self {
|
||||
RpoSignatureScheme { options, _h: PhantomData }
|
||||
}
|
||||
|
||||
pub fn sign(&self, sk: [BaseElement; DIGEST_SIZE], msg: [BaseElement; DIGEST_SIZE]) -> Proof {
|
||||
// create a prover
|
||||
let prover = RpoSignatureProver::<H>::new(msg, self.options.clone());
|
||||
|
||||
// generate execution trace
|
||||
let trace = prover.build_trace(sk);
|
||||
|
||||
// generate the initial seed for the PRNG used for zero-knowledge
|
||||
let seed: [u8; 32] = generate_seed(sk, msg);
|
||||
|
||||
// generate the proof
|
||||
prover.prove(trace, Some(seed)).expect("failed to generate the signature")
|
||||
}
|
||||
|
||||
pub fn verify(
|
||||
&self,
|
||||
pub_key: [BaseElement; DIGEST_SIZE],
|
||||
msg: [BaseElement; DIGEST_SIZE],
|
||||
proof: Proof,
|
||||
) -> Result<(), VerifierError> {
|
||||
// we make sure that the parameters used in generating the proof match the expected ones
|
||||
if *proof.options() != self.options {
|
||||
return Err(VerifierError::UnacceptableProofOptions);
|
||||
}
|
||||
let pub_inputs = PublicInputs { pub_key, msg };
|
||||
let acceptable_options = AcceptableOptions::OptionSet(vec![proof.options().clone()]);
|
||||
verify::<RescueAir, Rpo256, RpoRandomCoin, SaltedMerkleTree<Rpo256, ChaCha20Rng>>(
|
||||
proof,
|
||||
pub_inputs,
|
||||
&acceptable_options,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Deterministically generates a seed for seeding the PRNG used for zero-knowledge.
|
||||
///
|
||||
/// This uses the argument described in [RFC 6979](https://datatracker.ietf.org/doc/html/rfc6979#section-3.5)
|
||||
/// § 3.5 where the concatenation of the private key and the hashed message, i.e., sk || H(m), is
|
||||
/// used in order to construct the initial seed of a PRNG.
|
||||
///
|
||||
/// Note that we hash in also a context string in order to domain separate between different
|
||||
/// instantiations of the signature scheme.
|
||||
#[inline]
|
||||
pub fn generate_seed(sk: [BaseElement; DIGEST_SIZE], msg: [BaseElement; DIGEST_SIZE]) -> [u8; 32] {
|
||||
let context_bytes = "
|
||||
Seed for PRNG used for Zero-knowledge in RPO-STARK signature scheme:
|
||||
1. Version: Conjectured security
|
||||
2. FRI queries: 30
|
||||
3. Blowup factor: 8
|
||||
4. Grinding bits: 12
|
||||
5. Field extension degree: 2
|
||||
6. FRI folding factor: 4
|
||||
7. FRI remainder polynomial max degree: 7
|
||||
"
|
||||
.to_bytes();
|
||||
let sk_bytes = sk.to_bytes();
|
||||
let msg_bytes = msg.to_bytes();
|
||||
|
||||
let total_length = context_bytes.len() + sk_bytes.len() + msg_bytes.len();
|
||||
let mut buffer = Vec::with_capacity(total_length);
|
||||
buffer.extend_from_slice(&context_bytes);
|
||||
buffer.extend_from_slice(&sk_bytes);
|
||||
buffer.extend_from_slice(&msg_bytes);
|
||||
|
||||
blake3::hash(&buffer).into()
|
||||
}
|
||||
148
src/dsa/rpo_stark/stark/prover.rs
Normal file
148
src/dsa/rpo_stark/stark/prover.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
use winter_air::{
|
||||
AuxRandElements, ConstraintCompositionCoefficients, PartitionOptions, ZkParameters,
|
||||
};
|
||||
use winter_crypto::{ElementHasher, SaltedMerkleTree};
|
||||
use winter_math::{fields::f64::BaseElement, FieldElement};
|
||||
use winter_prover::{
|
||||
matrix::ColMatrix, CompositionPoly, CompositionPolyTrace, DefaultConstraintCommitment,
|
||||
DefaultConstraintEvaluator, DefaultTraceLde, ProofOptions, Prover, StarkDomain, Trace,
|
||||
TraceInfo, TracePolyTable, TraceTable,
|
||||
};
|
||||
|
||||
use super::air::{PublicInputs, RescueAir, HASH_CYCLE_LEN};
|
||||
use crate::{
|
||||
hash::{rpo::Rpo256, STATE_WIDTH},
|
||||
rand::RpoRandomCoin,
|
||||
Word, ZERO,
|
||||
};
|
||||
|
||||
// PROVER
|
||||
// ================================================================================================
|
||||
|
||||
/// A prover for the RPO STARK-based signature scheme.
|
||||
///
|
||||
/// The signature is based on the the one-wayness of the RPO hash function but it is generic over
|
||||
/// the hash function used for instantiating the random oracle for the BCS transform.
|
||||
pub(crate) struct RpoSignatureProver<H: ElementHasher + Sync> {
|
||||
message: Word,
|
||||
options: ProofOptions,
|
||||
_hasher: PhantomData<H>,
|
||||
}
|
||||
|
||||
impl<H: ElementHasher + Sync> RpoSignatureProver<H> {
|
||||
pub(crate) fn new(message: Word, options: ProofOptions) -> Self {
|
||||
Self { message, options, _hasher: PhantomData }
|
||||
}
|
||||
|
||||
pub(crate) fn build_trace(&self, sk: Word) -> TraceTable<BaseElement> {
|
||||
let mut trace = TraceTable::new(STATE_WIDTH, HASH_CYCLE_LEN);
|
||||
|
||||
trace.fill(
|
||||
|state| {
|
||||
// initialize first half of the rate portion of the state with the secret key
|
||||
state[0] = ZERO;
|
||||
state[1] = ZERO;
|
||||
state[2] = ZERO;
|
||||
state[3] = ZERO;
|
||||
state[4] = sk[0];
|
||||
state[5] = sk[1];
|
||||
state[6] = sk[2];
|
||||
state[7] = sk[3];
|
||||
state[8] = ZERO;
|
||||
state[9] = ZERO;
|
||||
state[10] = ZERO;
|
||||
state[11] = ZERO;
|
||||
},
|
||||
|step, state| {
|
||||
Rpo256::apply_round(
|
||||
state.try_into().expect("should not fail given the size of the array"),
|
||||
step,
|
||||
);
|
||||
},
|
||||
);
|
||||
trace
|
||||
}
|
||||
}
|
||||
|
||||
impl<H: ElementHasher> Prover for RpoSignatureProver<H>
|
||||
where
|
||||
H: ElementHasher<BaseField = BaseElement> + Sync,
|
||||
{
|
||||
type BaseField = BaseElement;
|
||||
type Air = RescueAir;
|
||||
type Trace = TraceTable<BaseElement>;
|
||||
type HashFn = Rpo256;
|
||||
type VC = SaltedMerkleTree<Self::HashFn, Self::ZkPrng>;
|
||||
type RandomCoin = RpoRandomCoin;
|
||||
type TraceLde<E: FieldElement<BaseField = Self::BaseField>> =
|
||||
DefaultTraceLde<E, Self::HashFn, Self::VC>;
|
||||
type ConstraintCommitment<E: FieldElement<BaseField = Self::BaseField>> =
|
||||
DefaultConstraintCommitment<E, Self::HashFn, Self::ZkPrng, Self::VC>;
|
||||
type ConstraintEvaluator<'a, E: FieldElement<BaseField = Self::BaseField>> =
|
||||
DefaultConstraintEvaluator<'a, Self::Air, E>;
|
||||
type ZkPrng = ChaCha20Rng;
|
||||
|
||||
fn get_pub_inputs(&self, trace: &Self::Trace) -> PublicInputs {
|
||||
let last_step = trace.length() - 1;
|
||||
// Note that the message is not part of the execution trace but is part of the public
|
||||
// inputs. This is explained in the reference description of the DSA and intuitively
|
||||
// it is done in order to make sure that the message is part of the Fiat-Shamir
|
||||
// transcript and hence binds the proof/signature to the message
|
||||
PublicInputs {
|
||||
pub_key: [
|
||||
trace.get(4, last_step),
|
||||
trace.get(5, last_step),
|
||||
trace.get(6, last_step),
|
||||
trace.get(7, last_step),
|
||||
],
|
||||
msg: self.message,
|
||||
}
|
||||
}
|
||||
|
||||
fn options(&self) -> &ProofOptions {
|
||||
&self.options
|
||||
}
|
||||
|
||||
fn new_trace_lde<E: FieldElement<BaseField = Self::BaseField>>(
|
||||
&self,
|
||||
trace_info: &TraceInfo,
|
||||
main_trace: &ColMatrix<Self::BaseField>,
|
||||
domain: &StarkDomain<Self::BaseField>,
|
||||
partition_option: PartitionOptions,
|
||||
zk_parameters: Option<ZkParameters>,
|
||||
prng: &mut Option<Self::ZkPrng>,
|
||||
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
|
||||
DefaultTraceLde::new(trace_info, main_trace, domain, partition_option, zk_parameters, prng)
|
||||
}
|
||||
|
||||
fn new_evaluator<'a, E: FieldElement<BaseField = Self::BaseField>>(
|
||||
&self,
|
||||
air: &'a Self::Air,
|
||||
aux_rand_elements: Option<AuxRandElements<E>>,
|
||||
composition_coefficients: ConstraintCompositionCoefficients<E>,
|
||||
) -> Self::ConstraintEvaluator<'a, E> {
|
||||
DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients)
|
||||
}
|
||||
|
||||
fn build_constraint_commitment<E: FieldElement<BaseField = Self::BaseField>>(
|
||||
&self,
|
||||
composition_poly_trace: CompositionPolyTrace<E>,
|
||||
num_constraint_composition_columns: usize,
|
||||
domain: &StarkDomain<Self::BaseField>,
|
||||
partition_options: PartitionOptions,
|
||||
zk_parameters: Option<ZkParameters>,
|
||||
prng: &mut Option<Self::ZkPrng>,
|
||||
) -> (Self::ConstraintCommitment<E>, CompositionPoly<E>) {
|
||||
DefaultConstraintCommitment::new(
|
||||
composition_poly_trace,
|
||||
num_constraint_composition_columns,
|
||||
domain,
|
||||
partition_options,
|
||||
zk_parameters,
|
||||
prng,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,13 @@
|
||||
use alloc::{string::String, vec::Vec};
|
||||
use core::{
|
||||
mem::{size_of, transmute, transmute_copy},
|
||||
ops::Deref,
|
||||
slice::from_raw_parts,
|
||||
slice::{self, from_raw_parts},
|
||||
};
|
||||
|
||||
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher};
|
||||
use crate::utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, string::*, ByteReader, ByteWriter, Deserializable,
|
||||
bytes_to_hex_string, hex_to_bytes, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
};
|
||||
|
||||
@@ -32,6 +33,14 @@ const DIGEST20_BYTES: usize = 20;
|
||||
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
|
||||
pub struct Blake3Digest<const N: usize>([u8; N]);
|
||||
|
||||
impl<const N: usize> Blake3Digest<N> {
|
||||
pub fn digests_as_bytes(digests: &[Blake3Digest<N>]) -> &[u8] {
|
||||
let p = digests.as_ptr();
|
||||
let len = digests.len() * N;
|
||||
unsafe { slice::from_raw_parts(p as *const u8, len) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Default for Blake3Digest<N> {
|
||||
fn default() -> Self {
|
||||
Self([0; N])
|
||||
@@ -113,6 +122,10 @@ impl Hasher for Blake3_256 {
|
||||
Self::hash(prepare_merge(values))
|
||||
}
|
||||
|
||||
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
|
||||
Blake3Digest(blake3::hash(Blake3Digest::digests_as_bytes(values)).into())
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
let mut hasher = blake3::Hasher::new();
|
||||
hasher.update(&seed.0);
|
||||
@@ -173,6 +186,11 @@ impl Hasher for Blake3_192 {
|
||||
Blake3Digest(*shrink_bytes(&blake3::hash(bytes).into()))
|
||||
}
|
||||
|
||||
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
|
||||
let bytes: Vec<u8> = values.iter().flat_map(|v| v.as_bytes()).collect();
|
||||
Blake3Digest(*shrink_bytes(&blake3::hash(&bytes).into()))
|
||||
}
|
||||
|
||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||
Self::hash(prepare_merge(values))
|
||||
}
|
||||
@@ -241,6 +259,11 @@ impl Hasher for Blake3_160 {
|
||||
Self::hash(prepare_merge(values))
|
||||
}
|
||||
|
||||
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
|
||||
let bytes: Vec<u8> = values.iter().flat_map(|v| v.as_bytes()).collect();
|
||||
Blake3Digest(*shrink_bytes(&blake3::hash(&bytes).into()))
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
let mut hasher = blake3::Hasher::new();
|
||||
hasher.update(&seed.0);
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use proptest::prelude::*;
|
||||
use rand_utils::rand_vector;
|
||||
|
||||
use super::*;
|
||||
use crate::utils::collections::*;
|
||||
|
||||
#[test]
|
||||
fn blake3_hash_elements() {
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
//! Cryptographic hash functions used by the Miden VM and the Miden rollup.
|
||||
|
||||
use super::{CubeExtension, Felt, FieldElement, StarkField, ONE, ZERO};
|
||||
use super::{CubeExtension, Felt, FieldElement, StarkField, ZERO};
|
||||
|
||||
pub mod blake;
|
||||
|
||||
mod rescue;
|
||||
pub(crate) use rescue::{ARK1, ARK2, DIGEST_SIZE, MDS, STATE_WIDTH};
|
||||
pub mod rpo {
|
||||
pub use super::rescue::{Rpo256, RpoDigest};
|
||||
pub use super::rescue::{Rpo256, RpoDigest, RpoDigestError};
|
||||
}
|
||||
|
||||
pub mod rpx {
|
||||
pub use super::rescue::{Rpx256, RpxDigest};
|
||||
pub use super::rescue::{Rpx256, RpxDigest, RpxDigestError};
|
||||
}
|
||||
|
||||
// RE-EXPORTS
|
||||
|
||||
@@ -4,40 +4,43 @@ use core::arch::x86_64::*;
|
||||
// https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs
|
||||
|
||||
// Preliminary notes:
|
||||
// 1. AVX does not support addition with carry but 128-bit (2-word) addition can be easily
|
||||
// emulated. The method recognizes that for a + b overflowed iff (a + b) < a:
|
||||
// i. res_lo = a_lo + b_lo
|
||||
// ii. carry_mask = res_lo < a_lo
|
||||
// iii. res_hi = a_hi + b_hi - carry_mask
|
||||
// 1. AVX does not support addition with carry but 128-bit (2-word) addition can be easily emulated.
|
||||
// The method recognizes that for a + b overflowed iff (a + b) < a:
|
||||
// 1. res_lo = a_lo + b_lo
|
||||
// 2. carry_mask = res_lo < a_lo
|
||||
// 3. res_hi = a_hi + b_hi - carry_mask
|
||||
//
|
||||
// Notice that carry_mask is subtracted, not added. This is because AVX comparison instructions
|
||||
// return -1 (all bits 1) for true and 0 for false.
|
||||
//
|
||||
// 2. AVX does not have unsigned 64-bit comparisons. Those can be emulated with signed comparisons
|
||||
// by recognizing that a <u b iff a + (1 << 63) <s b + (1 << 63), where the addition wraps around
|
||||
// and the comparisons are unsigned and signed respectively. The shift function adds/subtracts
|
||||
// 1 << 63 to enable this trick.
|
||||
// Example: addition with carry.
|
||||
// i. a_lo_s = shift(a_lo)
|
||||
// ii. res_lo_s = a_lo_s + b_lo
|
||||
// iii. carry_mask = res_lo_s <s a_lo_s
|
||||
// iv. res_lo = shift(res_lo_s)
|
||||
// v. res_hi = a_hi + b_hi - carry_mask
|
||||
// The suffix _s denotes a value that has been shifted by 1 << 63. The result of addition is
|
||||
// shifted if exactly one of the operands is shifted, as is the case on line ii. Line iii.
|
||||
// performs a signed comparison res_lo_s <s a_lo_s on shifted values to emulate unsigned
|
||||
// comparison res_lo <u a_lo on unshifted values. Finally, line iv. reverses the shift so the
|
||||
// result can be returned.
|
||||
// When performing a chain of calculations, we can often save instructions by letting the shift
|
||||
// propagate through and only undoing it when necessary. For example, to compute the addition of
|
||||
// three two-word (128-bit) numbers we can do:
|
||||
// i. a_lo_s = shift(a_lo)
|
||||
// ii. tmp_lo_s = a_lo_s + b_lo
|
||||
// iii. tmp_carry_mask = tmp_lo_s <s a_lo_s
|
||||
// iv. tmp_hi = a_hi + b_hi - tmp_carry_mask
|
||||
// v. res_lo_s = tmp_lo_s + c_lo
|
||||
// vi. res_carry_mask = res_lo_s <s tmp_lo_s
|
||||
// vii. res_lo = shift(res_lo_s)
|
||||
// viii. res_hi = tmp_hi + c_hi - res_carry_mask
|
||||
// and the comparisons are unsigned and signed respectively. The shift function adds/subtracts 1
|
||||
// << 63 to enable this trick. Addition with carry example:
|
||||
// 1. a_lo_s = shift(a_lo)
|
||||
// 2. res_lo_s = a_lo_s + b_lo
|
||||
// 3. carry_mask = res_lo_s <s a_lo_s
|
||||
// 4. res_lo = shift(res_lo_s)
|
||||
// 5. res_hi = a_hi + b_hi - carry_mask
|
||||
//
|
||||
// The suffix _s denotes a value that has been shifted by 1 << 63. The result of addition
|
||||
// is shifted if exactly one of the operands is shifted, as is the case on
|
||||
// line 2. Line 3. performs a signed comparison res_lo_s <s a_lo_s on shifted values to
|
||||
// emulate unsigned comparison res_lo <u a_lo on unshifted values. Finally, line 4. reverses the
|
||||
// shift so the result can be returned.
|
||||
//
|
||||
// When performing a chain of calculations, we can often save instructions by letting
|
||||
// the shift propagate through and only undoing it when necessary.
|
||||
// For example, to compute the addition of three two-word (128-bit) numbers we can do:
|
||||
// 1. a_lo_s = shift(a_lo)
|
||||
// 2. tmp_lo_s = a_lo_s + b_lo
|
||||
// 3. tmp_carry_mask = tmp_lo_s <s a_lo_s
|
||||
// 4. tmp_hi = a_hi + b_hi - tmp_carry_mask
|
||||
// 5. res_lo_s = tmp_lo_s + c_lo vi. res_carry_mask = res_lo_s <s tmp_lo_s
|
||||
// 6. res_carry_mask = res_lo_s <s tmp_lo_s
|
||||
// 7. res_lo = shift(res_lo_s)
|
||||
// 8. res_hi = tmp_hi + c_hi - res_carry_mask
|
||||
//
|
||||
// Notice that the above 3-value addition still only requires two calls to shift, just like our
|
||||
// 2-value addition.
|
||||
|
||||
@@ -60,10 +63,10 @@ pub fn branch_hint() {
|
||||
}
|
||||
|
||||
macro_rules! map3 {
|
||||
($f:ident::<$l:literal>, $v:ident) => {
|
||||
($f:ident:: < $l:literal > , $v:ident) => {
|
||||
($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2))
|
||||
};
|
||||
($f:ident::<$l:literal>, $v1:ident, $v2:ident) => {
|
||||
($f:ident:: < $l:literal > , $v1:ident, $v2:ident) => {
|
||||
($f::<$l>($v1.0, $v2.0), $f::<$l>($v1.1, $v2.1), $f::<$l>($v1.2, $v2.2))
|
||||
};
|
||||
($f:ident, $v:ident) => {
|
||||
@@ -72,11 +75,11 @@ macro_rules! map3 {
|
||||
($f:ident, $v0:ident, $v1:ident) => {
|
||||
($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2))
|
||||
};
|
||||
($f:ident, rep $v0:ident, $v1:ident) => {
|
||||
($f:ident,rep $v0:ident, $v1:ident) => {
|
||||
($f($v0, $v1.0), $f($v0, $v1.1), $f($v0, $v1.2))
|
||||
};
|
||||
|
||||
($f:ident, $v0:ident, rep $v1:ident) => {
|
||||
($f:ident, $v0:ident,rep $v1:ident) => {
|
||||
($f($v0.0, $v1), $f($v0.1, $v1), $f($v0.2, $v1))
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,26 +1,28 @@
|
||||
// FFT-BASED MDS MULTIPLICATION HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// This module contains helper functions as well as constants used to perform the vector-matrix
|
||||
/// multiplication step of the Rescue prime permutation. The special form of our MDS matrix
|
||||
/// i.e. being circular, allows us to reduce the vector-matrix multiplication to a Hadamard product
|
||||
/// of two vectors in "frequency domain". This follows from the simple fact that every circulant
|
||||
/// matrix has the columns of the discrete Fourier transform matrix as orthogonal eigenvectors.
|
||||
/// The implementation also avoids the use of 3-point FFTs, and 3-point iFFTs, and substitutes that
|
||||
/// with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain,
|
||||
/// divisions by 2 and repeated modular reductions. This is because of our explicit choice of
|
||||
/// an MDS matrix that has small powers of 2 entries in frequency domain.
|
||||
/// The following implementation has benefited greatly from the discussions and insights of
|
||||
/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is base on Nabaglo's Plonky2
|
||||
/// implementation.
|
||||
//! This module contains helper functions as well as constants used to perform the vector-matrix
|
||||
//! multiplication step of the Rescue prime permutation. The special form of our MDS matrix
|
||||
//! i.e. being circular, allows us to reduce the vector-matrix multiplication to a Hadamard product
|
||||
//! of two vectors in "frequency domain". This follows from the simple fact that every circulant
|
||||
//! matrix has the columns of the discrete Fourier transform matrix as orthogonal eigenvectors.
|
||||
//! The implementation also avoids the use of 3-point FFTs, and 3-point iFFTs, and substitutes that
|
||||
//! with explicit expressions. It also avoids, due to the form of our matrix in the frequency
|
||||
//! domain, divisions by 2 and repeated modular reductions. This is because of our explicit choice
|
||||
//! of an MDS matrix that has small powers of 2 entries in frequency domain.
|
||||
//! The following implementation has benefited greatly from the discussions and insights of
|
||||
//! Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is base on Nabaglo's Plonky2
|
||||
//! implementation.
|
||||
|
||||
// Rescue MDS matrix in frequency domain.
|
||||
//
|
||||
// More precisely, this is the output of the three 4-point (real) FFTs of the first column of
|
||||
// the MDS matrix i.e. just before the multiplication with the appropriate twiddle factors
|
||||
// and application of the final four 3-point FFT in order to get the full 12-point FFT.
|
||||
// The entries have been scaled appropriately in order to avoid divisions by 2 in iFFT2 and iFFT4.
|
||||
// The code to generate the matrix in frequency domain is based on an adaptation of a code, to generate
|
||||
// MDS matrices efficiently in original domain, that was developed by the Polygon Zero team.
|
||||
// The code to generate the matrix in frequency domain is based on an adaptation of a code, to
|
||||
// generate MDS matrices efficiently in original domain, that was developed by the Polygon Zero
|
||||
// team.
|
||||
const MDS_FREQ_BLOCK_ONE: [i64; 3] = [16, 8, 16];
|
||||
const MDS_FREQ_BLOCK_TWO: [(i64, i64); 3] = [(-1, 2), (-1, 1), (4, 8)];
|
||||
const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1];
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
use core::ops::Range;
|
||||
|
||||
use super::{
|
||||
CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO,
|
||||
};
|
||||
use super::{CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ZERO};
|
||||
|
||||
mod arch;
|
||||
pub use arch::optimized::{add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox};
|
||||
|
||||
mod mds;
|
||||
use mds::{apply_mds, MDS};
|
||||
pub(crate) use mds::{apply_mds, MDS};
|
||||
|
||||
mod rpo;
|
||||
pub use rpo::{Rpo256, RpoDigest};
|
||||
pub use rpo::{Rpo256, RpoDigest, RpoDigestError};
|
||||
|
||||
mod rpx;
|
||||
pub use rpx::{Rpx256, RpxDigest};
|
||||
pub use rpx::{Rpx256, RpxDigest, RpxDigestError};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
@@ -28,7 +26,7 @@ const NUM_ROUNDS: usize = 7;
|
||||
|
||||
/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
const STATE_WIDTH: usize = 12;
|
||||
pub(crate) const STATE_WIDTH: usize = 12;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11.
|
||||
const RATE_RANGE: Range<usize> = 4..12;
|
||||
@@ -44,8 +42,8 @@ const CAPACITY_RANGE: Range<usize> = 0..4;
|
||||
///
|
||||
/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the
|
||||
/// rate portion).
|
||||
const DIGEST_RANGE: Range<usize> = 4..8;
|
||||
const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start;
|
||||
pub(crate) const DIGEST_RANGE: Range<usize> = 4..8;
|
||||
pub(crate) const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start;
|
||||
|
||||
/// The number of bytes needed to encoded a digest
|
||||
const DIGEST_BYTES: usize = 32;
|
||||
@@ -146,7 +144,7 @@ fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) {
|
||||
///
|
||||
/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the
|
||||
/// first half of RPO round, and ARK2 contains constants for the second half of RPO round.
|
||||
const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
pub(crate) const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(5789762306288267392),
|
||||
Felt::new(6522564764413701783),
|
||||
@@ -247,7 +245,7 @@ const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
],
|
||||
];
|
||||
|
||||
const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
pub(crate) const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(6077062762357204287),
|
||||
Felt::new(15277620170502011191),
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
use core::{cmp::Ordering, fmt::Display, ops::Deref};
|
||||
use alloc::string::String;
|
||||
use core::{cmp::Ordering, fmt::Display, ops::Deref, slice};
|
||||
|
||||
use rand::{
|
||||
distributions::{Standard, Uniform},
|
||||
prelude::Distribution,
|
||||
};
|
||||
use thiserror::Error;
|
||||
|
||||
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
|
||||
use crate::{
|
||||
rand::Randomizable,
|
||||
utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, string::*, ByteReader, ByteWriter, Deserializable,
|
||||
bytes_to_hex_string, hex_to_bytes, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
},
|
||||
};
|
||||
@@ -18,6 +25,9 @@ use crate::{
|
||||
pub struct RpoDigest([Felt; DIGEST_SIZE]);
|
||||
|
||||
impl RpoDigest {
|
||||
/// The serialized size of the digest in bytes.
|
||||
pub const SERIALIZED_SIZE: usize = DIGEST_BYTES;
|
||||
|
||||
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
@@ -30,13 +40,19 @@ impl RpoDigest {
|
||||
<Self as Digest>::as_bytes(self)
|
||||
}
|
||||
|
||||
pub fn digests_as_elements<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
|
||||
pub fn digests_as_elements_iter<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
|
||||
where
|
||||
I: Iterator<Item = &'a Self>,
|
||||
{
|
||||
digests.flat_map(|d| d.0.iter())
|
||||
}
|
||||
|
||||
pub fn digests_as_elements(digests: &[Self]) -> &[Felt] {
|
||||
let p = digests.as_ptr();
|
||||
let len = digests.len() * DIGEST_SIZE;
|
||||
unsafe { slice::from_raw_parts(p as *const Felt, len) }
|
||||
}
|
||||
|
||||
/// Returns hexadecimal representation of this digest prefixed with `0x`.
|
||||
pub fn to_hex(&self) -> String {
|
||||
bytes_to_hex_string(self.as_bytes())
|
||||
@@ -114,29 +130,160 @@ impl Randomizable for RpoDigest {
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: FROM RPO DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
impl From<&RpoDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
value.0
|
||||
impl Distribution<RpoDigest> for Standard {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> RpoDigest {
|
||||
let mut res = [ZERO; DIGEST_SIZE];
|
||||
let uni_dist = Uniform::from(0..Felt::MODULUS);
|
||||
for r in res.iter_mut() {
|
||||
let sampled_integer = uni_dist.sample(rng);
|
||||
*r = Felt::new(sampled_integer);
|
||||
}
|
||||
RpoDigest::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
value.0
|
||||
// CONVERSIONS: FROM RPO DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RpoDigestError {
|
||||
#[error("failed to convert digest field element to {0}")]
|
||||
TypeConversion(&'static str),
|
||||
#[error("failed to convert to field element: {0}")]
|
||||
InvalidFieldElement(String),
|
||||
}
|
||||
|
||||
impl TryFrom<&RpoDigest> for [bool; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RpoDigest> for [bool; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
|
||||
fn to_bool(v: u64) -> Option<bool> {
|
||||
if v <= 1 {
|
||||
Some(v == 1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
Ok([
|
||||
to_bool(value.0[0].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
|
||||
to_bool(value.0[1].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
|
||||
to_bool(value.0[2].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
|
||||
to_bool(value.0[3].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&RpoDigest> for [u8; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RpoDigest> for [u8; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
|
||||
Ok([
|
||||
value.0[0]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
|
||||
value.0[1]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
|
||||
value.0[2]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
|
||||
value.0[3]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&RpoDigest> for [u16; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RpoDigest> for [u16; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
|
||||
Ok([
|
||||
value.0[0]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
|
||||
value.0[1]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
|
||||
value.0[2]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
|
||||
value.0[3]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&RpoDigest> for [u32; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RpoDigest> for [u32; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
|
||||
Ok([
|
||||
value.0[0]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
|
||||
value.0[1]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
|
||||
value.0[2]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
|
||||
value.0[3]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for [u64; DIGEST_SIZE] {
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
[
|
||||
value.0[0].as_int(),
|
||||
value.0[1].as_int(),
|
||||
value.0[2].as_int(),
|
||||
value.0[3].as_int(),
|
||||
]
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,9 +298,21 @@ impl From<RpoDigest> for [u64; DIGEST_SIZE] {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for [u8; DIGEST_BYTES] {
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
value.as_bytes()
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,13 +322,6 @@ impl From<RpoDigest> for [u8; DIGEST_BYTES] {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
value.to_hex()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
@@ -177,13 +329,83 @@ impl From<&RpoDigest> for String {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
value.to_hex()
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: TO RPO DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum RpoDigestError {
|
||||
/// The provided u64 integer does not fit in the field's moduli.
|
||||
InvalidInteger,
|
||||
impl From<&[bool; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: &[bool; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[bool; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: [bool; DIGEST_SIZE]) -> Self {
|
||||
[value[0] as u32, value[1] as u32, value[2] as u32, value[3] as u32].into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[u8; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: &[u8; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: [u8; DIGEST_SIZE]) -> Self {
|
||||
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[u16; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: &[u16; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u16; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: [u16; DIGEST_SIZE]) -> Self {
|
||||
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[u32; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: &[u32; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u32; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: [u32; DIGEST_SIZE]) -> Self {
|
||||
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u64; DIGEST_SIZE]> for RpoDigest {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u64; DIGEST_SIZE]> for RpoDigest {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
|
||||
Ok(Self([
|
||||
value[0].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
|
||||
value[1].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
|
||||
value[2].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
|
||||
value[3].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
|
||||
]))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[Felt; DIGEST_SIZE]> for RpoDigest {
|
||||
@@ -198,6 +420,14 @@ impl From<[Felt; DIGEST_SIZE]> for RpoDigest {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8; DIGEST_BYTES]> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
@@ -217,14 +447,6 @@ impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8; DIGEST_BYTES]> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
@@ -233,33 +455,12 @@ impl TryFrom<&[u8]> for RpoDigest {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u64; DIGEST_SIZE]> for RpoDigest {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
|
||||
Ok(Self([
|
||||
value[0].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
|
||||
value[1].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
|
||||
value[2].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
|
||||
value[3].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
|
||||
]))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u64; DIGEST_SIZE]> for RpoDigest {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
hex_to_bytes(value).and_then(|v| v.try_into())
|
||||
hex_to_bytes::<DIGEST_BYTES>(value).and_then(RpoDigest::try_from)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,6 +489,10 @@ impl Serializable for RpoDigest {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.as_bytes());
|
||||
}
|
||||
|
||||
fn get_size_hint(&self) -> usize {
|
||||
Self::SERIALIZED_SIZE
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RpoDigest {
|
||||
@@ -323,10 +528,12 @@ impl IntoIterator for RpoDigest {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use alloc::string::String;
|
||||
|
||||
use rand_utils::rand_value;
|
||||
|
||||
use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
|
||||
use crate::utils::{string::*, SliceReader};
|
||||
use crate::utils::SliceReader;
|
||||
|
||||
#[test]
|
||||
fn digest_serialization() {
|
||||
@@ -340,6 +547,7 @@ mod tests {
|
||||
let mut bytes = vec![];
|
||||
d1.write_into(&mut bytes);
|
||||
assert_eq!(DIGEST_BYTES, bytes.len());
|
||||
assert_eq!(bytes.len(), d1.get_size_hint());
|
||||
|
||||
let mut reader = SliceReader::new(&bytes);
|
||||
let d2 = RpoDigest::read_from(&mut reader).unwrap();
|
||||
@@ -371,44 +579,72 @@ mod tests {
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = digest.into();
|
||||
// BY VALUE
|
||||
// ----------------------------------------------------------------------------------------
|
||||
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = (&digest).into();
|
||||
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
let v: [Felt; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = digest.into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = (&digest).into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
// BY REF
|
||||
// ----------------------------------------------------------------------------------------
|
||||
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
|
||||
let v2: RpoDigest = (&v).into();
|
||||
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
|
||||
let v2: RpoDigest = (&v).into();
|
||||
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
|
||||
let v2: RpoDigest = (&v).into();
|
||||
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
|
||||
let v2: RpoDigest = (&v).into();
|
||||
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpoDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpoDigest = (&v).try_into().unwrap();
|
||||
let v: [Felt; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpoDigest = (&v).into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpoDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = (&digest).into();
|
||||
let v2: RpoDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ use super::{
|
||||
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
|
||||
apply_mds, apply_sbox, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ARK1,
|
||||
ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, DIGEST_SIZE, INPUT1_RANGE,
|
||||
INPUT2_RANGE, MDS, NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO,
|
||||
INPUT2_RANGE, MDS, NUM_ROUNDS, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO,
|
||||
};
|
||||
|
||||
mod digest;
|
||||
pub use digest::RpoDigest;
|
||||
pub use digest::{RpoDigest, RpoDigestError};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
@@ -19,12 +19,14 @@ mod tests;
|
||||
/// Implementation of the Rescue Prime Optimized hash function with 256-bit output.
|
||||
///
|
||||
/// The hash function is implemented according to the Rescue Prime Optimized
|
||||
/// [specifications](https://eprint.iacr.org/2022/1577)
|
||||
/// [specifications](https://eprint.iacr.org/2022/1577) while the padding rule follows the one
|
||||
/// described [here](https://eprint.iacr.org/2023/1045).
|
||||
///
|
||||
/// The parameters used to instantiate the function are:
|
||||
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
|
||||
/// * Field: 64-bit prime field with modulus p = 2^64 - 2^32 + 1.
|
||||
/// * State width: 12 field elements.
|
||||
/// * Capacity size: 4 field elements.
|
||||
/// * Rate size: r = 8 field elements.
|
||||
/// * Capacity size: c = 4 field elements.
|
||||
/// * Number of founds: 7.
|
||||
/// * S-Box degree: 7.
|
||||
///
|
||||
@@ -50,8 +52,23 @@ mod tests;
|
||||
///
|
||||
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
||||
/// to deserialize them into field elements and then hash them using
|
||||
/// [hash_elements()](Rpo256::hash_elements) function rather then hashing the serialized bytes
|
||||
/// [hash_elements()](Rpo256::hash_elements) function rather than hashing the serialized bytes
|
||||
/// using [hash()](Rpo256::hash) function.
|
||||
///
|
||||
/// ## Domain separation
|
||||
/// [merge_in_domain()](Rpo256::merge_in_domain) hashes two digests into one digest with some domain
|
||||
/// identifier and the current implementation sets the second capacity element to the value of
|
||||
/// this domain identifier. Using a similar argument to the one formulated for domain separation of
|
||||
/// the RPX hash function in Appendix C of its [specification](https://eprint.iacr.org/2023/1045),
|
||||
/// one sees that doing so degrades only pre-image resistance, from its initial bound of c.log_2(p),
|
||||
/// by as much as the log_2 of the size of the domain identifier space. Since pre-image resistance
|
||||
/// becomes the bottleneck for the security bound of the sponge in overwrite-mode only when it is
|
||||
/// lower than 2^128, we see that the target 128-bit security level is maintained as long as
|
||||
/// the size of the domain identifier space, including for padding, is less than 2^128.
|
||||
///
|
||||
/// ## Hashing of empty input
|
||||
/// The current implementation hashes empty input to the zero digest [0, 0, 0, 0]. This has
|
||||
/// the benefit of requiring no calls to the RPO permutation when hashing empty input.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Rpo256();
|
||||
|
||||
@@ -65,14 +82,16 @@ impl Hasher for Rpo256 {
|
||||
// initialize the state with zeroes
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
// set the capacity (first element) to a flag on whether or not the input length is evenly
|
||||
// divided by the rate. this will prevent collisions between padded and non-padded inputs,
|
||||
// and will rule out the need to perform an extra permutation in case of evenly divided
|
||||
// inputs.
|
||||
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0;
|
||||
if !is_rate_multiple {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
// determine the number of field elements needed to encode `bytes` when each field element
|
||||
// represents at most 7 bytes.
|
||||
let num_field_elem = bytes.len().div_ceil(BINARY_CHUNK_SIZE);
|
||||
|
||||
// set the first capacity element to `RATE_WIDTH + (num_field_elem % RATE_WIDTH)`. We do
|
||||
// this to achieve:
|
||||
// 1. Domain separating hashing of `[u8]` from hashing of `[Felt]`.
|
||||
// 2. Avoiding collisions at the `[Felt]` representation of the encoded bytes.
|
||||
state[CAPACITY_RANGE.start] =
|
||||
Felt::from((RATE_WIDTH + (num_field_elem % RATE_WIDTH)) as u8);
|
||||
|
||||
// initialize a buffer to receive the little-endian elements.
|
||||
let mut buf = [0_u8; 8];
|
||||
@@ -81,41 +100,49 @@ impl Hasher for Rpo256 {
|
||||
// into the state.
|
||||
//
|
||||
// every time the rate range is filled, a permutation is performed. if the final value of
|
||||
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
|
||||
// additional permutation must be performed.
|
||||
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
|
||||
// the last element of the iteration may or may not be a full chunk. if it's not, then
|
||||
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`.
|
||||
// this will avoid collisions.
|
||||
if chunk.len() == BINARY_CHUNK_SIZE {
|
||||
// `rate_pos` is not zero, then the chunks count wasn't enough to fill the state range,
|
||||
// and an additional permutation must be performed.
|
||||
let mut current_chunk_idx = 0_usize;
|
||||
// handle the case of an empty `bytes`
|
||||
let last_chunk_idx = if num_field_elem == 0 {
|
||||
current_chunk_idx
|
||||
} else {
|
||||
num_field_elem - 1
|
||||
};
|
||||
let rate_pos = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |rate_pos, chunk| {
|
||||
// copy the chunk into the buffer
|
||||
if current_chunk_idx != last_chunk_idx {
|
||||
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||
} else {
|
||||
// on the last iteration, we pad `buf` with a 1 followed by as many 0's as are
|
||||
// needed to fill it
|
||||
buf.fill(0);
|
||||
buf[..chunk.len()].copy_from_slice(chunk);
|
||||
buf[chunk.len()] = 1;
|
||||
}
|
||||
current_chunk_idx += 1;
|
||||
|
||||
// set the current rate element to the input. since we take at most 7 bytes, we are
|
||||
// guaranteed that the inputs data will fit into a single field element.
|
||||
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf));
|
||||
state[RATE_RANGE.start + rate_pos] = Felt::new(u64::from_le_bytes(buf));
|
||||
|
||||
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
||||
// counter to the beginning of the range.
|
||||
if i == RATE_WIDTH - 1 {
|
||||
if rate_pos == RATE_WIDTH - 1 {
|
||||
Self::apply_permutation(&mut state);
|
||||
0
|
||||
} else {
|
||||
i + 1
|
||||
rate_pos + 1
|
||||
}
|
||||
});
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we
|
||||
// don't need to apply any extra padding because the first capacity element contains a
|
||||
// flag indicating whether the input is evenly divisible by the rate.
|
||||
if i != 0 {
|
||||
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
// flag indicating the number of field elements constituting the last block when the latter
|
||||
// is not divisible by `RATE_WIDTH`.
|
||||
if rate_pos != 0 {
|
||||
state[RATE_RANGE.start + rate_pos..RATE_RANGE.end].fill(ZERO);
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
@@ -127,7 +154,7 @@ impl Hasher for Rpo256 {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = Self::Digest::digests_as_elements(values.iter());
|
||||
let it = Self::Digest::digests_as_elements_iter(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
@@ -137,29 +164,28 @@ impl Hasher for Rpo256 {
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
|
||||
Self::hash_elements(Self::Digest::digests_as_elements(values))
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
// initialize the state as follows:
|
||||
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element
|
||||
// and set the sixth rate element to 1.
|
||||
// - if the value doesn't fit into a single field element, split it into two field
|
||||
// elements, copy them into rate elements 5 and 6, and set the seventh rate element
|
||||
// to 1.
|
||||
// - set the first capacity element to 1
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element and
|
||||
// set the first capacity element to 5.
|
||||
// - if the value doesn't fit into a single field element, split it into two field elements,
|
||||
// copy them into rate elements 5 and 6 and set the first capacity element to 6.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||
if value < Felt::MODULUS {
|
||||
state[INPUT2_RANGE.start + 1] = ONE;
|
||||
state[CAPACITY_RANGE.start] = Felt::from(5_u8);
|
||||
} else {
|
||||
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
||||
state[INPUT2_RANGE.start + 2] = ONE;
|
||||
state[CAPACITY_RANGE.start] = Felt::from(6_u8);
|
||||
}
|
||||
|
||||
// common padding for both cases
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
// apply the RPO permutation and return the first four elements of the rate
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
@@ -173,11 +199,9 @@ impl ElementHasher for Rpo256 {
|
||||
let elements = E::slice_as_base_elements(elements);
|
||||
|
||||
// initialize state to all zeros, except for the first element of the capacity part, which
|
||||
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH.
|
||||
// is set to `elements.len() % RATE_WIDTH`.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
if elements.len() % RATE_WIDTH != 0 {
|
||||
state[CAPACITY_RANGE.start] = ONE;
|
||||
}
|
||||
state[CAPACITY_RANGE.start] = Self::BaseField::from((elements.len() % RATE_WIDTH) as u8);
|
||||
|
||||
// absorb elements into the state one by one until the rate portion of the state is filled
|
||||
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
||||
@@ -194,11 +218,8 @@ impl ElementHasher for Rpo256 {
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after
|
||||
// padding by appending a 1 followed by as many 0 as necessary to make the input length a
|
||||
// multiple of the RATE_WIDTH.
|
||||
// padding by as many 0 as necessary to make the input length a multiple of the RATE_WIDTH.
|
||||
if i > 0 {
|
||||
state[RATE_RANGE.start + i] = ONE;
|
||||
i += 1;
|
||||
while i != RATE_WIDTH {
|
||||
state[RATE_RANGE.start + i] = ZERO;
|
||||
i += 1;
|
||||
@@ -273,7 +294,7 @@ impl Rpo256 {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = RpoDigest::digests_as_elements(values.iter());
|
||||
let it = RpoDigest::digests_as_elements_iter(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
use alloc::{collections::BTreeSet, vec::Vec};
|
||||
|
||||
use proptest::prelude::*;
|
||||
use rand_utils::rand_value;
|
||||
|
||||
use super::{
|
||||
super::{apply_inv_sbox, apply_sbox, ALPHA, INV_ALPHA},
|
||||
Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, ONE, STATE_WIDTH, ZERO,
|
||||
Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, STATE_WIDTH, ZERO,
|
||||
};
|
||||
use crate::{
|
||||
hash::rescue::{BINARY_CHUNK_SIZE, CAPACITY_RANGE, RATE_WIDTH},
|
||||
Word, ONE,
|
||||
};
|
||||
use crate::{utils::collections::*, Word};
|
||||
|
||||
#[test]
|
||||
fn test_sbox() {
|
||||
@@ -57,7 +62,7 @@ fn merge_vs_merge_in_domain() {
|
||||
];
|
||||
let merge_result = Rpo256::merge(&digests);
|
||||
|
||||
// ------------- merge with domain = 0 ----------------------------------------------------------
|
||||
// ------------- merge with domain = 0 -------------
|
||||
|
||||
// set domain to ZERO. This should not change the result.
|
||||
let domain = ZERO;
|
||||
@@ -65,7 +70,7 @@ fn merge_vs_merge_in_domain() {
|
||||
let merge_in_domain_result = Rpo256::merge_in_domain(&digests, domain);
|
||||
assert_eq!(merge_result, merge_in_domain_result);
|
||||
|
||||
// ------------- merge with domain = 1 ----------------------------------------------------------
|
||||
// ------------- merge with domain = 1 -------------
|
||||
|
||||
// set domain to ONE. This should change the result.
|
||||
let domain = ONE;
|
||||
@@ -124,6 +129,27 @@ fn hash_padding() {
|
||||
assert_ne!(r1, r2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_padding_no_extra_permutation_call() {
|
||||
use crate::hash::rescue::DIGEST_RANGE;
|
||||
|
||||
// Implementation
|
||||
let num_bytes = BINARY_CHUNK_SIZE * RATE_WIDTH;
|
||||
let mut buffer = vec![0_u8; num_bytes];
|
||||
*buffer.last_mut().unwrap() = 97;
|
||||
let r1 = Rpo256::hash(&buffer);
|
||||
|
||||
// Expected
|
||||
let final_chunk = [0_u8, 0, 0, 0, 0, 0, 97, 1];
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
// padding when hashing bytes
|
||||
state[CAPACITY_RANGE.start] = Felt::from(RATE_WIDTH as u8);
|
||||
*state.last_mut().unwrap() = Felt::new(u64::from_le_bytes(final_chunk));
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
|
||||
assert_eq!(&r1[0..4], &state[DIGEST_RANGE]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_elements_padding() {
|
||||
let e1 = [Felt::new(rand_value()); 2];
|
||||
@@ -157,6 +183,24 @@ fn hash_elements() {
|
||||
assert_eq!(m_result, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_empty() {
|
||||
let elements: Vec<Felt> = vec![];
|
||||
|
||||
let zero_digest = RpoDigest::default();
|
||||
let h_result = Rpo256::hash_elements(&elements);
|
||||
assert_eq!(zero_digest, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_empty_bytes() {
|
||||
let bytes: Vec<u8> = vec![];
|
||||
|
||||
let zero_digest = RpoDigest::default();
|
||||
let h_result = Rpo256::hash(&bytes);
|
||||
assert_eq!(zero_digest, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_test_vectors() {
|
||||
let elements = [
|
||||
@@ -227,46 +271,46 @@ proptest! {
|
||||
|
||||
const EXPECTED: [Word; 19] = [
|
||||
[
|
||||
Felt::new(1502364727743950833),
|
||||
Felt::new(5880949717274681448),
|
||||
Felt::new(162790463902224431),
|
||||
Felt::new(6901340476773664264),
|
||||
Felt::new(18126731724905382595),
|
||||
Felt::new(7388557040857728717),
|
||||
Felt::new(14290750514634285295),
|
||||
Felt::new(7852282086160480146),
|
||||
],
|
||||
[
|
||||
Felt::new(7478710183745780580),
|
||||
Felt::new(3308077307559720969),
|
||||
Felt::new(3383561985796182409),
|
||||
Felt::new(17205078494700259815),
|
||||
Felt::new(10139303045932500183),
|
||||
Felt::new(2293916558361785533),
|
||||
Felt::new(15496361415980502047),
|
||||
Felt::new(17904948502382283940),
|
||||
],
|
||||
[
|
||||
Felt::new(17439912364295172999),
|
||||
Felt::new(17979156346142712171),
|
||||
Felt::new(8280795511427637894),
|
||||
Felt::new(9349844417834368814),
|
||||
Felt::new(17457546260239634015),
|
||||
Felt::new(803990662839494686),
|
||||
Felt::new(10386005777401424878),
|
||||
Felt::new(18168807883298448638),
|
||||
],
|
||||
[
|
||||
Felt::new(5105868198472766874),
|
||||
Felt::new(13090564195691924742),
|
||||
Felt::new(1058904296915798891),
|
||||
Felt::new(18379501748825152268),
|
||||
Felt::new(13072499238647455740),
|
||||
Felt::new(10174350003422057273),
|
||||
Felt::new(9201651627651151113),
|
||||
Felt::new(6872461887313298746),
|
||||
],
|
||||
[
|
||||
Felt::new(9133662113608941286),
|
||||
Felt::new(12096627591905525991),
|
||||
Felt::new(14963426595993304047),
|
||||
Felt::new(13290205840019973377),
|
||||
Felt::new(2903803350580990546),
|
||||
Felt::new(1838870750730563299),
|
||||
Felt::new(4258619137315479708),
|
||||
Felt::new(17334260395129062936),
|
||||
],
|
||||
[
|
||||
Felt::new(3134262397541159485),
|
||||
Felt::new(10106105871979362399),
|
||||
Felt::new(138768814855329459),
|
||||
Felt::new(15044809212457404677),
|
||||
Felt::new(8571221005243425262),
|
||||
Felt::new(3016595589318175865),
|
||||
Felt::new(13933674291329928438),
|
||||
Felt::new(678640375034313072),
|
||||
],
|
||||
[
|
||||
Felt::new(162696376578462826),
|
||||
Felt::new(4991300494838863586),
|
||||
Felt::new(660346084748120605),
|
||||
Felt::new(13179389528641752698),
|
||||
Felt::new(16314113978986502310),
|
||||
Felt::new(14587622368743051587),
|
||||
Felt::new(2808708361436818462),
|
||||
Felt::new(10660517522478329440),
|
||||
],
|
||||
[
|
||||
Felt::new(2242391899857912644),
|
||||
@@ -275,46 +319,46 @@ const EXPECTED: [Word; 19] = [
|
||||
Felt::new(5046143039268215739),
|
||||
],
|
||||
[
|
||||
Felt::new(9585630502158073976),
|
||||
Felt::new(1310051013427303477),
|
||||
Felt::new(7491921222636097758),
|
||||
Felt::new(9417501558995216762),
|
||||
Felt::new(5218076004221736204),
|
||||
Felt::new(17169400568680971304),
|
||||
Felt::new(8840075572473868990),
|
||||
Felt::new(12382372614369863623),
|
||||
],
|
||||
[
|
||||
Felt::new(1994394001720334744),
|
||||
Felt::new(10866209900885216467),
|
||||
Felt::new(13836092831163031683),
|
||||
Felt::new(10814636682252756697),
|
||||
Felt::new(9783834557155203486),
|
||||
Felt::new(12317263104955018849),
|
||||
Felt::new(3933748931816109604),
|
||||
Felt::new(1843043029836917214),
|
||||
],
|
||||
[
|
||||
Felt::new(17486854790732826405),
|
||||
Felt::new(17376549265955727562),
|
||||
Felt::new(2371059831956435003),
|
||||
Felt::new(17585704935858006533),
|
||||
Felt::new(14498234468286984551),
|
||||
Felt::new(16837257669834682387),
|
||||
Felt::new(6664141123711355107),
|
||||
Felt::new(4590460158294697186),
|
||||
],
|
||||
[
|
||||
Felt::new(11368277489137713825),
|
||||
Felt::new(3906270146963049287),
|
||||
Felt::new(10236262408213059745),
|
||||
Felt::new(78552867005814007),
|
||||
Felt::new(4661800562479916067),
|
||||
Felt::new(11794407552792839953),
|
||||
Felt::new(9037742258721863712),
|
||||
Felt::new(6287820818064278819),
|
||||
],
|
||||
[
|
||||
Felt::new(17899847381280262181),
|
||||
Felt::new(14717912805498651446),
|
||||
Felt::new(10769146203951775298),
|
||||
Felt::new(2774289833490417856),
|
||||
Felt::new(7752693085194633729),
|
||||
Felt::new(7379857372245835536),
|
||||
Felt::new(9270229380648024178),
|
||||
Felt::new(10638301488452560378),
|
||||
],
|
||||
[
|
||||
Felt::new(3794717687462954368),
|
||||
Felt::new(4386865643074822822),
|
||||
Felt::new(8854162840275334305),
|
||||
Felt::new(7129983987107225269),
|
||||
Felt::new(11542686762698783357),
|
||||
Felt::new(15570714990728449027),
|
||||
Felt::new(7518801014067819501),
|
||||
Felt::new(12706437751337583515),
|
||||
],
|
||||
[
|
||||
Felt::new(7244773535611633983),
|
||||
Felt::new(19359923075859320),
|
||||
Felt::new(10898655967774994333),
|
||||
Felt::new(9319339563065736480),
|
||||
Felt::new(9553923701032839042),
|
||||
Felt::new(7281190920209838818),
|
||||
Felt::new(2488477917448393955),
|
||||
Felt::new(5088955350303368837),
|
||||
],
|
||||
[
|
||||
Felt::new(4935426252518736883),
|
||||
@@ -323,21 +367,21 @@ const EXPECTED: [Word; 19] = [
|
||||
Felt::new(18159875708229758073),
|
||||
],
|
||||
[
|
||||
Felt::new(14871230873837295931),
|
||||
Felt::new(11225255908868362971),
|
||||
Felt::new(18100987641405432308),
|
||||
Felt::new(1559244340089644233),
|
||||
Felt::new(12795429638314178838),
|
||||
Felt::new(14360248269767567855),
|
||||
Felt::new(3819563852436765058),
|
||||
Felt::new(10859123583999067291),
|
||||
],
|
||||
[
|
||||
Felt::new(8348203744950016968),
|
||||
Felt::new(4041411241960726733),
|
||||
Felt::new(17584743399305468057),
|
||||
Felt::new(16836952610803537051),
|
||||
Felt::new(2695742617679420093),
|
||||
Felt::new(9151515850666059759),
|
||||
Felt::new(15855828029180595485),
|
||||
Felt::new(17190029785471463210),
|
||||
],
|
||||
[
|
||||
Felt::new(16139797453633030050),
|
||||
Felt::new(1090233424040889412),
|
||||
Felt::new(10770255347785669036),
|
||||
Felt::new(16982398877290254028),
|
||||
Felt::new(13205273108219124830),
|
||||
Felt::new(2524898486192849221),
|
||||
Felt::new(14618764355375283547),
|
||||
Felt::new(10615614265042186874),
|
||||
],
|
||||
];
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
use core::{cmp::Ordering, fmt::Display, ops::Deref};
|
||||
use alloc::string::String;
|
||||
use core::{cmp::Ordering, fmt::Display, ops::Deref, slice};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
|
||||
use crate::{
|
||||
rand::Randomizable,
|
||||
utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, string::*, ByteReader, ByteWriter, Deserializable,
|
||||
bytes_to_hex_string, hex_to_bytes, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
},
|
||||
};
|
||||
@@ -18,6 +21,9 @@ use crate::{
|
||||
pub struct RpxDigest([Felt; DIGEST_SIZE]);
|
||||
|
||||
impl RpxDigest {
|
||||
/// The serialized size of the digest in bytes.
|
||||
pub const SERIALIZED_SIZE: usize = DIGEST_BYTES;
|
||||
|
||||
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
@@ -30,13 +36,19 @@ impl RpxDigest {
|
||||
<Self as Digest>::as_bytes(self)
|
||||
}
|
||||
|
||||
pub fn digests_as_elements<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
|
||||
pub fn digests_as_elements_iter<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
|
||||
where
|
||||
I: Iterator<Item = &'a Self>,
|
||||
{
|
||||
digests.flat_map(|d| d.0.iter())
|
||||
}
|
||||
|
||||
pub fn digests_as_elements(digests: &[Self]) -> &[Felt] {
|
||||
let p = digests.as_ptr();
|
||||
let len = digests.len() * DIGEST_SIZE;
|
||||
unsafe { slice::from_raw_parts(p as *const Felt, len) }
|
||||
}
|
||||
|
||||
/// Returns hexadecimal representation of this digest prefixed with `0x`.
|
||||
pub fn to_hex(&self) -> String {
|
||||
bytes_to_hex_string(self.as_bytes())
|
||||
@@ -117,26 +129,145 @@ impl Randomizable for RpxDigest {
|
||||
// CONVERSIONS: FROM RPX DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
impl From<&RpxDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
value.0
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RpxDigestError {
|
||||
#[error("failed to convert digest field element to {0}")]
|
||||
TypeConversion(&'static str),
|
||||
#[error("failed to convert to field element: {0}")]
|
||||
InvalidFieldElement(String),
|
||||
}
|
||||
|
||||
impl TryFrom<&RpxDigest> for [bool; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
value.0
|
||||
impl TryFrom<RpxDigest> for [bool; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
|
||||
fn to_bool(v: u64) -> Option<bool> {
|
||||
if v <= 1 {
|
||||
Some(v == 1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
Ok([
|
||||
to_bool(value.0[0].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
|
||||
to_bool(value.0[1].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
|
||||
to_bool(value.0[2].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
|
||||
to_bool(value.0[3].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&RpxDigest> for [u8; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RpxDigest> for [u8; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
|
||||
Ok([
|
||||
value.0[0]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
|
||||
value.0[1]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
|
||||
value.0[2]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
|
||||
value.0[3]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&RpxDigest> for [u16; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RpxDigest> for [u16; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
|
||||
Ok([
|
||||
value.0[0]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
|
||||
value.0[1]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
|
||||
value.0[2]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
|
||||
value.0[3]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&RpxDigest> for [u32; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RpxDigest> for [u32; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
|
||||
Ok([
|
||||
value.0[0]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
|
||||
value.0[1]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
|
||||
value.0[2]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
|
||||
value.0[3]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for [u64; DIGEST_SIZE] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
[
|
||||
value.0[0].as_int(),
|
||||
value.0[1].as_int(),
|
||||
value.0[2].as_int(),
|
||||
value.0[3].as_int(),
|
||||
]
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,6 +282,18 @@ impl From<RpxDigest> for [u64; DIGEST_SIZE] {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for [u8; DIGEST_BYTES] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
value.as_bytes()
|
||||
@@ -163,13 +306,6 @@ impl From<RpxDigest> for [u8; DIGEST_BYTES] {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
value.to_hex()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
@@ -177,13 +313,83 @@ impl From<&RpxDigest> for String {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
value.to_hex()
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: TO RPX DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum RpxDigestError {
|
||||
/// The provided u64 integer does not fit in the field's moduli.
|
||||
InvalidInteger,
|
||||
impl From<&[bool; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: &[bool; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[bool; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: [bool; DIGEST_SIZE]) -> Self {
|
||||
[value[0] as u32, value[1] as u32, value[2] as u32, value[3] as u32].into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[u8; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: &[u8; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: [u8; DIGEST_SIZE]) -> Self {
|
||||
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[u16; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: &[u16; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u16; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: [u16; DIGEST_SIZE]) -> Self {
|
||||
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[u32; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: &[u32; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u32; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: [u32; DIGEST_SIZE]) -> Self {
|
||||
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u64; DIGEST_SIZE]> for RpxDigest {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u64; DIGEST_SIZE]> for RpxDigest {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
|
||||
Ok(Self([
|
||||
value[0].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
|
||||
value[1].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
|
||||
value[2].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
|
||||
value[3].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
|
||||
]))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[Felt; DIGEST_SIZE]> for RpxDigest {
|
||||
@@ -198,6 +404,14 @@ impl From<[Felt; DIGEST_SIZE]> for RpxDigest {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8; DIGEST_BYTES]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u8; DIGEST_BYTES]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
@@ -217,14 +431,6 @@ impl TryFrom<[u8; DIGEST_BYTES]> for RpxDigest {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8; DIGEST_BYTES]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
@@ -233,42 +439,12 @@ impl TryFrom<&[u8]> for RpxDigest {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u64; DIGEST_SIZE]> for RpxDigest {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
|
||||
Ok(Self([
|
||||
value[0].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
|
||||
value[1].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
|
||||
value[2].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
|
||||
value[3].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
|
||||
]))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u64; DIGEST_SIZE]> for RpxDigest {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
hex_to_bytes(value).and_then(|v| v.try_into())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<String> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: String) -> Result<Self, Self::Error> {
|
||||
value.as_str().try_into()
|
||||
hex_to_bytes::<DIGEST_BYTES>(value).and_then(RpxDigest::try_from)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,6 +457,15 @@ impl TryFrom<&String> for RpxDigest {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<String> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: String) -> Result<Self, Self::Error> {
|
||||
value.as_str().try_into()
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
@@ -288,6 +473,10 @@ impl Serializable for RpxDigest {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.as_bytes());
|
||||
}
|
||||
|
||||
fn get_size_hint(&self) -> usize {
|
||||
Self::SERIALIZED_SIZE
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RpxDigest {
|
||||
@@ -307,15 +496,28 @@ impl Deserializable for RpxDigest {
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
impl IntoIterator for RpxDigest {
|
||||
type Item = Felt;
|
||||
type IntoIter = <[Felt; 4] as IntoIterator>::IntoIter;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.0.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use alloc::string::String;
|
||||
|
||||
use rand_utils::rand_value;
|
||||
|
||||
use super::{Deserializable, Felt, RpxDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
|
||||
use crate::utils::{string::*, SliceReader};
|
||||
use crate::utils::SliceReader;
|
||||
|
||||
#[test]
|
||||
fn digest_serialization() {
|
||||
@@ -329,6 +531,7 @@ mod tests {
|
||||
let mut bytes = vec![];
|
||||
d1.write_into(&mut bytes);
|
||||
assert_eq!(DIGEST_BYTES, bytes.len());
|
||||
assert_eq!(bytes.len(), d1.get_size_hint());
|
||||
|
||||
let mut reader = SliceReader::new(&bytes);
|
||||
let d2 = RpxDigest::read_from(&mut reader).unwrap();
|
||||
@@ -336,7 +539,6 @@ mod tests {
|
||||
assert_eq!(d1, d2);
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn digest_encoding() {
|
||||
let digest = RpxDigest([
|
||||
@@ -361,44 +563,72 @@ mod tests {
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = digest.into();
|
||||
// BY VALUE
|
||||
// ----------------------------------------------------------------------------------------
|
||||
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = (&digest).into();
|
||||
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
let v: [Felt; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = digest.into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = (&digest).into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
// BY REF
|
||||
// ----------------------------------------------------------------------------------------
|
||||
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
|
||||
let v2: RpxDigest = (&v).into();
|
||||
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
|
||||
let v2: RpxDigest = (&v).into();
|
||||
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
|
||||
let v2: RpxDigest = (&v).into();
|
||||
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
|
||||
let v2: RpxDigest = (&v).into();
|
||||
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpxDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpxDigest = (&v).try_into().unwrap();
|
||||
let v: [Felt; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpxDigest = (&v).into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpxDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = (&digest).into();
|
||||
let v2: RpxDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,10 @@ use super::{
|
||||
};
|
||||
|
||||
mod digest;
|
||||
pub use digest::RpxDigest;
|
||||
pub use digest::{RpxDigest, RpxDigestError};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub type CubicExtElement = CubeExtension<Felt>;
|
||||
|
||||
@@ -26,8 +29,10 @@ pub type CubicExtElement = CubeExtension<Felt>;
|
||||
/// * Capacity size: 4 field elements.
|
||||
/// * S-Box degree: 7.
|
||||
/// * Rounds: There are 3 different types of rounds:
|
||||
/// - (FB): `apply_mds` → `add_constants` → `apply_sbox` → `apply_mds` → `add_constants` → `apply_inv_sbox`.
|
||||
/// - (E): `add_constants` → `ext_sbox` (which is raising to power 7 in the degree 3 extension field).
|
||||
/// - (FB): `apply_mds` → `add_constants` → `apply_sbox` → `apply_mds` → `add_constants` →
|
||||
/// `apply_inv_sbox`.
|
||||
/// - (E): `add_constants` → `ext_sbox` (which is raising to power 7 in the degree 3 extension
|
||||
/// field).
|
||||
/// - (M): `apply_mds` → `add_constants`.
|
||||
/// * Permutation: (FB) (E) (FB) (E) (FB) (E) (M).
|
||||
///
|
||||
@@ -53,8 +58,23 @@ pub type CubicExtElement = CubeExtension<Felt>;
|
||||
///
|
||||
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
||||
/// to deserialize them into field elements and then hash them using
|
||||
/// [hash_elements()](Rpx256::hash_elements) function rather then hashing the serialized bytes
|
||||
/// [hash_elements()](Rpx256::hash_elements) function rather than hashing the serialized bytes
|
||||
/// using [hash()](Rpx256::hash) function.
|
||||
///
|
||||
/// ## Domain separation
|
||||
/// [merge_in_domain()](Rpx256::merge_in_domain) hashes two digests into one digest with some domain
|
||||
/// identifier and the current implementation sets the second capacity element to the value of
|
||||
/// this domain identifier. Using a similar argument to the one formulated for domain separation
|
||||
/// in Appendix C of the [specifications](https://eprint.iacr.org/2023/1045), one sees that doing
|
||||
/// so degrades only pre-image resistance, from its initial bound of c.log_2(p), by as much as
|
||||
/// the log_2 of the size of the domain identifier space. Since pre-image resistance becomes
|
||||
/// the bottleneck for the security bound of the sponge in overwrite-mode only when it is
|
||||
/// lower than 2^128, we see that the target 128-bit security level is maintained as long as
|
||||
/// the size of the domain identifier space, including for padding, is less than 2^128.
|
||||
///
|
||||
/// ## Hashing of empty input
|
||||
/// The current implementation hashes empty input to the zero digest [0, 0, 0, 0]. This has
|
||||
/// the benefit of requiring no calls to the RPX permutation when hashing empty input.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Rpx256();
|
||||
|
||||
@@ -86,11 +106,18 @@ impl Hasher for Rpx256 {
|
||||
// into the state.
|
||||
//
|
||||
// every time the rate range is filled, a permutation is performed. if the final value of
|
||||
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
|
||||
// additional permutation must be performed.
|
||||
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
|
||||
// `rate_pos` is not zero, then the chunks count wasn't enough to fill the state range,
|
||||
// and an additional permutation must be performed.
|
||||
let mut current_chunk_idx = 0_usize;
|
||||
// handle the case of an empty `bytes`
|
||||
let last_chunk_idx = if num_field_elem == 0 {
|
||||
current_chunk_idx
|
||||
} else {
|
||||
num_field_elem - 1
|
||||
};
|
||||
let rate_pos = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |rate_pos, chunk| {
|
||||
// copy the chunk into the buffer
|
||||
if i != num_field_elem - 1 {
|
||||
if current_chunk_idx != last_chunk_idx {
|
||||
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||
} else {
|
||||
// on the last iteration, we pad `buf` with a 1 followed by as many 0's as are
|
||||
@@ -99,18 +126,19 @@ impl Hasher for Rpx256 {
|
||||
buf[..chunk.len()].copy_from_slice(chunk);
|
||||
buf[chunk.len()] = 1;
|
||||
}
|
||||
current_chunk_idx += 1;
|
||||
|
||||
// set the current rate element to the input. since we take at most 7 bytes, we are
|
||||
// guaranteed that the inputs data will fit into a single field element.
|
||||
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf));
|
||||
state[RATE_RANGE.start + rate_pos] = Felt::new(u64::from_le_bytes(buf));
|
||||
|
||||
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
||||
// counter to the beginning of the range.
|
||||
if i == RATE_WIDTH - 1 {
|
||||
if rate_pos == RATE_WIDTH - 1 {
|
||||
Self::apply_permutation(&mut state);
|
||||
0
|
||||
} else {
|
||||
i + 1
|
||||
rate_pos + 1
|
||||
}
|
||||
});
|
||||
|
||||
@@ -119,8 +147,8 @@ impl Hasher for Rpx256 {
|
||||
// don't need to apply any extra padding because the first capacity element contains a
|
||||
// flag indicating the number of field elements constituting the last block when the latter
|
||||
// is not divisible by `RATE_WIDTH`.
|
||||
if i != 0 {
|
||||
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
|
||||
if rate_pos != 0 {
|
||||
state[RATE_RANGE.start + rate_pos..RATE_RANGE.end].fill(ZERO);
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
@@ -132,7 +160,7 @@ impl Hasher for Rpx256 {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = Self::Digest::digests_as_elements(values.iter());
|
||||
let it = Self::Digest::digests_as_elements_iter(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
@@ -142,13 +170,17 @@ impl Hasher for Rpx256 {
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
|
||||
Self::hash_elements(Self::Digest::digests_as_elements(values))
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
// initialize the state as follows:
|
||||
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element and
|
||||
// set the first capacity element to 5.
|
||||
// - if the value doesn't fit into a single field element, split it into two field
|
||||
// elements, copy them into rate elements 5 and 6 and set the first capacity element to 6.
|
||||
// - if the value doesn't fit into a single field element, split it into two field elements,
|
||||
// copy them into rate elements 5 and 6 and set the first capacity element to 6.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||
@@ -159,7 +191,7 @@ impl Hasher for Rpx256 {
|
||||
state[CAPACITY_RANGE.start] = Felt::from(6_u8);
|
||||
}
|
||||
|
||||
// apply the RPX permutation and return the first four elements of the state
|
||||
// apply the RPX permutation and return the first four elements of the rate
|
||||
Self::apply_permutation(&mut state);
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
@@ -265,7 +297,7 @@ impl Rpx256 {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = RpxDigest::digests_as_elements(values.iter());
|
||||
let it = RpxDigest::digests_as_elements_iter(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
186
src/hash/rescue/rpx/tests.rs
Normal file
186
src/hash/rescue/rpx/tests.rs
Normal file
@@ -0,0 +1,186 @@
|
||||
use alloc::{collections::BTreeSet, vec::Vec};
|
||||
|
||||
use proptest::prelude::*;
|
||||
use rand_utils::rand_value;
|
||||
|
||||
use super::{Felt, Hasher, Rpx256, StarkField, ZERO};
|
||||
use crate::{hash::rescue::RpxDigest, ONE};
|
||||
|
||||
#[test]
|
||||
fn hash_elements_vs_merge() {
|
||||
let elements = [Felt::new(rand_value()); 8];
|
||||
|
||||
let digests: [RpxDigest; 2] = [
|
||||
RpxDigest::new(elements[..4].try_into().unwrap()),
|
||||
RpxDigest::new(elements[4..].try_into().unwrap()),
|
||||
];
|
||||
|
||||
let m_result = Rpx256::merge(&digests);
|
||||
let h_result = Rpx256::hash_elements(&elements);
|
||||
assert_eq!(m_result, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_vs_merge_in_domain() {
|
||||
let elements = [Felt::new(rand_value()); 8];
|
||||
|
||||
let digests: [RpxDigest; 2] = [
|
||||
RpxDigest::new(elements[..4].try_into().unwrap()),
|
||||
RpxDigest::new(elements[4..].try_into().unwrap()),
|
||||
];
|
||||
let merge_result = Rpx256::merge(&digests);
|
||||
|
||||
// ----- merge with domain = 0 ----------------------------------------------------------------
|
||||
|
||||
// set domain to ZERO. This should not change the result.
|
||||
let domain = ZERO;
|
||||
|
||||
let merge_in_domain_result = Rpx256::merge_in_domain(&digests, domain);
|
||||
assert_eq!(merge_result, merge_in_domain_result);
|
||||
|
||||
// ----- merge with domain = 1 ----------------------------------------------------------------
|
||||
|
||||
// set domain to ONE. This should change the result.
|
||||
let domain = ONE;
|
||||
|
||||
let merge_in_domain_result = Rpx256::merge_in_domain(&digests, domain);
|
||||
assert_ne!(merge_result, merge_in_domain_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_elements_vs_merge_with_int() {
|
||||
let tmp = [Felt::new(rand_value()); 4];
|
||||
let seed = RpxDigest::new(tmp);
|
||||
|
||||
// ----- value fits into a field element ------------------------------------------------------
|
||||
let val: Felt = Felt::new(rand_value());
|
||||
let m_result = Rpx256::merge_with_int(seed, val.as_int());
|
||||
|
||||
let mut elements = seed.as_elements().to_vec();
|
||||
elements.push(val);
|
||||
let h_result = Rpx256::hash_elements(&elements);
|
||||
|
||||
assert_eq!(m_result, h_result);
|
||||
|
||||
// ----- value does not fit into a field element ----------------------------------------------
|
||||
let val = Felt::MODULUS + 2;
|
||||
let m_result = Rpx256::merge_with_int(seed, val);
|
||||
|
||||
let mut elements = seed.as_elements().to_vec();
|
||||
elements.push(Felt::new(val));
|
||||
elements.push(ONE);
|
||||
let h_result = Rpx256::hash_elements(&elements);
|
||||
|
||||
assert_eq!(m_result, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_padding() {
|
||||
// adding a zero bytes at the end of a byte string should result in a different hash
|
||||
let r1 = Rpx256::hash(&[1_u8, 2, 3]);
|
||||
let r2 = Rpx256::hash(&[1_u8, 2, 3, 0]);
|
||||
assert_ne!(r1, r2);
|
||||
|
||||
// same as above but with bigger inputs
|
||||
let r1 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6]);
|
||||
let r2 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 0]);
|
||||
assert_ne!(r1, r2);
|
||||
|
||||
// same as above but with input splitting over two elements
|
||||
let r1 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7]);
|
||||
let r2 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0]);
|
||||
assert_ne!(r1, r2);
|
||||
|
||||
// same as above but with multiple zeros
|
||||
let r1 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0, 0]);
|
||||
let r2 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0]);
|
||||
assert_ne!(r1, r2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_elements_padding() {
|
||||
let e1 = [Felt::new(rand_value()); 2];
|
||||
let e2 = [e1[0], e1[1], ZERO];
|
||||
|
||||
let r1 = Rpx256::hash_elements(&e1);
|
||||
let r2 = Rpx256::hash_elements(&e2);
|
||||
assert_ne!(r1, r2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_elements() {
|
||||
let elements = [
|
||||
ZERO,
|
||||
ONE,
|
||||
Felt::new(2),
|
||||
Felt::new(3),
|
||||
Felt::new(4),
|
||||
Felt::new(5),
|
||||
Felt::new(6),
|
||||
Felt::new(7),
|
||||
];
|
||||
|
||||
let digests: [RpxDigest; 2] = [
|
||||
RpxDigest::new(elements[..4].try_into().unwrap()),
|
||||
RpxDigest::new(elements[4..8].try_into().unwrap()),
|
||||
];
|
||||
|
||||
let m_result = Rpx256::merge(&digests);
|
||||
let h_result = Rpx256::hash_elements(&elements);
|
||||
assert_eq!(m_result, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_empty() {
|
||||
let elements: Vec<Felt> = vec![];
|
||||
|
||||
let zero_digest = RpxDigest::default();
|
||||
let h_result = Rpx256::hash_elements(&elements);
|
||||
assert_eq!(zero_digest, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_empty_bytes() {
|
||||
let bytes: Vec<u8> = vec![];
|
||||
|
||||
let zero_digest = RpxDigest::default();
|
||||
let h_result = Rpx256::hash(&bytes);
|
||||
assert_eq!(zero_digest, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sponge_bytes_with_remainder_length_wont_panic() {
|
||||
// this test targets to assert that no panic will happen with the edge case of having an inputs
|
||||
// with length that is not divisible by the used binary chunk size. 113 is a non-negligible
|
||||
// input length that is prime; hence guaranteed to not be divisible by any choice of chunk
|
||||
// size.
|
||||
//
|
||||
// this is a preliminary test to the fuzzy-stress of proptest.
|
||||
Rpx256::hash(&[0; 113]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sponge_collision_for_wrapped_field_element() {
|
||||
let a = Rpx256::hash(&[0; 8]);
|
||||
let b = Rpx256::hash(&Felt::MODULUS.to_le_bytes());
|
||||
assert_ne!(a, b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sponge_zeroes_collision() {
|
||||
let mut zeroes = Vec::with_capacity(255);
|
||||
let mut set = BTreeSet::new();
|
||||
(0..255).for_each(|_| {
|
||||
let hash = Rpx256::hash(&zeroes);
|
||||
zeroes.push(0);
|
||||
// panic if a collision was found
|
||||
assert!(set.insert(hash));
|
||||
});
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn rpo256_wont_panic_with_arbitrary_input(ref bytes in any::<Vec<u8>>()) {
|
||||
Rpx256::hash(bytes);
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![no_std]
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[cfg_attr(test, macro_use)]
|
||||
#[macro_use]
|
||||
extern crate alloc;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
extern crate std;
|
||||
|
||||
pub mod dsa;
|
||||
pub mod hash;
|
||||
pub mod merkle;
|
||||
|
||||
49
src/main.rs
49
src/main.rs
@@ -35,6 +35,7 @@ pub fn benchmark_smt() {
|
||||
|
||||
let mut tree = construction(entries, tree_size).unwrap();
|
||||
insertion(&mut tree, tree_size).unwrap();
|
||||
batched_insertion(&mut tree, tree_size).unwrap();
|
||||
proof_generation(&mut tree, tree_size).unwrap();
|
||||
}
|
||||
|
||||
@@ -82,6 +83,54 @@ pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn batched_insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
|
||||
println!("Running a batched insertion benchmark:");
|
||||
|
||||
let new_pairs: Vec<(RpoDigest, Word)> = (0..1000)
|
||||
.map(|i| {
|
||||
let key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
|
||||
let value = [ONE, ONE, ONE, Felt::new(size + i)];
|
||||
(key, value)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let now = Instant::now();
|
||||
let mutations = tree.compute_mutations(new_pairs);
|
||||
let compute_elapsed = now.elapsed();
|
||||
|
||||
let now = Instant::now();
|
||||
tree.apply_mutations(mutations).unwrap();
|
||||
let apply_elapsed = now.elapsed();
|
||||
|
||||
println!(
|
||||
"An average batch computation time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds",
|
||||
size,
|
||||
compute_elapsed.as_secs_f32() * 1000f32,
|
||||
// Dividing by the number of iterations, 1000, and then multiplying by 1000 to get
|
||||
// milliseconds, cancels out.
|
||||
compute_elapsed.as_secs_f32(),
|
||||
);
|
||||
|
||||
println!(
|
||||
"An average batch application time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds",
|
||||
size,
|
||||
apply_elapsed.as_secs_f32() * 1000f32,
|
||||
// Dividing by the number of iterations, 1000, and then multiplying by 1000 to get
|
||||
// milliseconds, cancels out.
|
||||
apply_elapsed.as_secs_f32(),
|
||||
);
|
||||
|
||||
println!(
|
||||
"An average batch insertion time measured by a 1k-batch into an SMT with {} key-value pairs totals to {:.3} milliseconds",
|
||||
size,
|
||||
(compute_elapsed + apply_elapsed).as_secs_f32() * 1000f32,
|
||||
);
|
||||
|
||||
println!();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Runs the proof generation benchmark for the [`Smt`].
|
||||
pub fn proof_generation(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
|
||||
println!("Running a proof generation benchmark:");
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use core::slice;
|
||||
|
||||
use super::{Felt, RpoDigest, EMPTY_WORD};
|
||||
use super::{smt::InnerNode, Felt, RpoDigest, EMPTY_WORD};
|
||||
|
||||
// EMPTY NODES SUBTREES
|
||||
// ================================================================================================
|
||||
@@ -25,6 +25,17 @@ impl EmptySubtreeRoots {
|
||||
let pos = 255 - tree_depth + node_depth;
|
||||
&EMPTY_SUBTREES[pos as usize]
|
||||
}
|
||||
|
||||
/// Returns a sparse Merkle tree [`InnerNode`] with two empty children.
|
||||
///
|
||||
/// # Note
|
||||
/// `node_depth` is the depth of the **parent** to have empty children. That is, `node_depth`
|
||||
/// and the depth of the returned [`InnerNode`] are the same, and thus the empty hashes are for
|
||||
/// subtrees of depth `node_depth + 1`.
|
||||
pub(crate) const fn get_inner_node(tree_depth: u8, node_depth: u8) -> InnerNode {
|
||||
let &child = Self::entry(tree_depth, node_depth + 1);
|
||||
InnerNode { left: child, right: child }
|
||||
}
|
||||
}
|
||||
|
||||
const EMPTY_SUBTREES: [RpoDigest; 256] = [
|
||||
|
||||
@@ -1,65 +1,34 @@
|
||||
use core::fmt;
|
||||
use thiserror::Error;
|
||||
|
||||
use super::{smt::SmtLeafError, MerklePath, NodeIndex, RpoDigest};
|
||||
use crate::utils::collections::*;
|
||||
use super::{NodeIndex, RpoDigest};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[derive(Debug, Error)]
|
||||
pub enum MerkleError {
|
||||
ConflictingRoots(Vec<RpoDigest>),
|
||||
#[error("expected merkle root {expected_root} found {actual_root}")]
|
||||
ConflictingRoots {
|
||||
expected_root: RpoDigest,
|
||||
actual_root: RpoDigest,
|
||||
},
|
||||
#[error("provided merkle tree depth {0} is too small")]
|
||||
DepthTooSmall(u8),
|
||||
#[error("provided merkle tree depth {0} is too big")]
|
||||
DepthTooBig(u64),
|
||||
#[error("multiple values provided for merkle tree index {0}")]
|
||||
DuplicateValuesForIndex(u64),
|
||||
DuplicateValuesForKey(RpoDigest),
|
||||
InvalidIndex { depth: u8, value: u64 },
|
||||
InvalidDepth { expected: u8, provided: u8 },
|
||||
InvalidSubtreeDepth { subtree_depth: u8, tree_depth: u8 },
|
||||
InvalidPath(MerklePath),
|
||||
InvalidNumEntries(usize),
|
||||
NodeNotInSet(NodeIndex),
|
||||
NodeNotInStore(RpoDigest, NodeIndex),
|
||||
#[error("node index value {value} is not valid for depth {depth}")]
|
||||
InvalidNodeIndex { depth: u8, value: u64 },
|
||||
#[error("provided node index depth {provided} does not match expected depth {expected}")]
|
||||
InvalidNodeIndexDepth { expected: u8, provided: u8 },
|
||||
#[error("merkle subtree depth {subtree_depth} exceeds merkle tree depth {tree_depth}")]
|
||||
SubtreeDepthExceedsDepth { subtree_depth: u8, tree_depth: u8 },
|
||||
#[error("number of entries in the merkle tree exceeds the maximum of {0}")]
|
||||
TooManyEntries(usize),
|
||||
#[error("node index `{0}` not found in the tree")]
|
||||
NodeIndexNotFoundInTree(NodeIndex),
|
||||
#[error("node {0:?} with index `{1}` not found in the store")]
|
||||
NodeIndexNotFoundInStore(RpoDigest, NodeIndex),
|
||||
#[error("number of provided merkle tree leaves {0} is not a power of two")]
|
||||
NumLeavesNotPowerOfTwo(usize),
|
||||
#[error("root {0:?} is not in the store")]
|
||||
RootNotInStore(RpoDigest),
|
||||
SmtLeaf(SmtLeafError),
|
||||
}
|
||||
|
||||
impl fmt::Display for MerkleError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use MerkleError::*;
|
||||
match self {
|
||||
ConflictingRoots(roots) => write!(f, "the merkle paths roots do not match {roots:?}"),
|
||||
DepthTooSmall(depth) => write!(f, "the provided depth {depth} is too small"),
|
||||
DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"),
|
||||
DuplicateValuesForIndex(key) => write!(f, "multiple values provided for key {key}"),
|
||||
DuplicateValuesForKey(key) => write!(f, "multiple values provided for key {key}"),
|
||||
InvalidIndex { depth, value } => {
|
||||
write!(f, "the index value {value} is not valid for the depth {depth}")
|
||||
}
|
||||
InvalidDepth { expected, provided } => {
|
||||
write!(f, "the provided depth {provided} is not valid for {expected}")
|
||||
}
|
||||
InvalidSubtreeDepth { subtree_depth, tree_depth } => {
|
||||
write!(f, "tried inserting a subtree of depth {subtree_depth} into a tree of depth {tree_depth}")
|
||||
}
|
||||
InvalidPath(_path) => write!(f, "the provided path is not valid"),
|
||||
InvalidNumEntries(max) => write!(f, "number of entries exceeded the maximum: {max}"),
|
||||
NodeNotInSet(index) => write!(f, "the node with index ({index}) is not in the set"),
|
||||
NodeNotInStore(hash, index) => {
|
||||
write!(f, "the node {hash:?} with index ({index}) is not in the store")
|
||||
}
|
||||
NumLeavesNotPowerOfTwo(leaves) => {
|
||||
write!(f, "the leaves count {leaves} is not a power of 2")
|
||||
}
|
||||
RootNotInStore(root) => write!(f, "the root {:?} is not in the store", root),
|
||||
SmtLeaf(smt_leaf_error) => write!(f, "smt leaf error: {smt_leaf_error}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for MerkleError {}
|
||||
|
||||
impl From<SmtLeafError> for MerkleError {
|
||||
fn from(value: SmtLeafError) -> Self {
|
||||
Self::SmtLeaf(value)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ impl NodeIndex {
|
||||
/// Returns an error if the `value` is greater than or equal to 2^{depth}.
|
||||
pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
|
||||
if (64 - value.leading_zeros()) > depth as u32 {
|
||||
Err(MerkleError::InvalidIndex { depth, value })
|
||||
Err(MerkleError::InvalidNodeIndex { depth, value })
|
||||
} else {
|
||||
Ok(Self { depth, value })
|
||||
}
|
||||
@@ -97,6 +97,14 @@ impl NodeIndex {
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the parent of the current node. This is the same as [`Self::move_up()`], but returns
|
||||
/// a new value instead of mutating `self`.
|
||||
pub const fn parent(mut self) -> Self {
|
||||
self.depth = self.depth.saturating_sub(1);
|
||||
self.value >>= 1;
|
||||
self
|
||||
}
|
||||
|
||||
// PROVIDERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
@@ -182,6 +190,7 @@ impl Deserializable for NodeIndex {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use assert_matches::assert_matches;
|
||||
use proptest::prelude::*;
|
||||
|
||||
use super::*;
|
||||
@@ -190,19 +199,19 @@ mod tests {
|
||||
fn test_node_index_value_too_high() {
|
||||
assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
|
||||
let err = NodeIndex::new(0, 1).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 0, value: 1 });
|
||||
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
|
||||
|
||||
assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
|
||||
let err = NodeIndex::new(1, 2).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 1, value: 2 });
|
||||
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
|
||||
|
||||
assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
|
||||
let err = NodeIndex::new(2, 4).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 2, value: 4 });
|
||||
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
|
||||
|
||||
assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
|
||||
let err = NodeIndex::new(3, 8).unwrap_err();
|
||||
assert_eq!(err, MerkleError::InvalidIndex { depth: 3, value: 8 });
|
||||
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
use alloc::{string::String, vec::Vec};
|
||||
use core::{fmt, ops::Deref, slice};
|
||||
|
||||
use winter_math::log2;
|
||||
|
||||
use super::{InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Word};
|
||||
use crate::utils::{collections::*, string::*, uninit_vector, word_to_hex};
|
||||
use crate::utils::{uninit_vector, word_to_hex};
|
||||
|
||||
// MERKLE TREE
|
||||
// ================================================================================================
|
||||
@@ -69,7 +68,7 @@ impl MerkleTree {
|
||||
///
|
||||
/// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc.
|
||||
pub fn depth(&self) -> u8 {
|
||||
log2(self.nodes.len() / 2) as u8
|
||||
(self.nodes.len() / 2).ilog2() as u8
|
||||
}
|
||||
|
||||
/// Returns a node at the specified depth and index value.
|
||||
@@ -212,7 +211,7 @@ pub struct InnerNodeIterator<'a> {
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for InnerNodeIterator<'a> {
|
||||
impl Iterator for InnerNodeIterator<'_> {
|
||||
type Item = InnerNodeInfo;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use super::super::RpoDigest;
|
||||
use crate::utils::collections::*;
|
||||
|
||||
/// Container for the update data of a [super::PartialMmr]
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -1,35 +1,27 @@
|
||||
use core::fmt::{Display, Formatter};
|
||||
#[cfg(feature = "std")]
|
||||
use std::error::Error;
|
||||
use alloc::string::String;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::merkle::MerkleError;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
#[derive(Debug, Error)]
|
||||
pub enum MmrError {
|
||||
InvalidPosition(usize),
|
||||
InvalidPeaks,
|
||||
InvalidPeak,
|
||||
#[error("mmr does not contain position {0}")]
|
||||
PositionNotFound(usize),
|
||||
#[error("mmr peaks are invalid: {0}")]
|
||||
InvalidPeaks(String),
|
||||
#[error(
|
||||
"mmr peak does not match the computed merkle root of the provided authentication path"
|
||||
)]
|
||||
PeakPathMismatch,
|
||||
#[error("requested peak index is {peak_idx} but the number of peaks is {peaks_len}")]
|
||||
PeakOutOfBounds { peak_idx: usize, peaks_len: usize },
|
||||
#[error("invalid mmr update")]
|
||||
InvalidUpdate,
|
||||
UnknownPeak,
|
||||
MerkleError(MerkleError),
|
||||
#[error("mmr does not contain a peak with depth {0}")]
|
||||
UnknownPeak(u8),
|
||||
#[error("invalid merkle path")]
|
||||
InvalidMerklePath(#[source] MerkleError),
|
||||
#[error("merkle root computation failed")]
|
||||
MerkleRootComputationFailed(#[source] MerkleError),
|
||||
}
|
||||
|
||||
impl Display for MmrError {
|
||||
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), core::fmt::Error> {
|
||||
match self {
|
||||
MmrError::InvalidPosition(pos) => write!(fmt, "Mmr does not contain position {pos}"),
|
||||
MmrError::InvalidPeaks => write!(fmt, "Invalid peaks count"),
|
||||
MmrError::InvalidPeak => {
|
||||
write!(fmt, "Peak values does not match merkle path computed root")
|
||||
}
|
||||
MmrError::InvalidUpdate => write!(fmt, "Invalid mmr update"),
|
||||
MmrError::UnknownPeak => {
|
||||
write!(fmt, "Peak not in Mmr")
|
||||
}
|
||||
MmrError::MerkleError(err) => write!(fmt, "{}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl Error for MmrError {}
|
||||
|
||||
@@ -7,16 +7,17 @@
|
||||
//!
|
||||
//! Additionally the structure only supports adding leaves to the right-most tree, the one with the
|
||||
//! least number of leaves. The structure preserves the invariant that each tree has different
|
||||
//! depths, i.e. as part of adding adding a new element to the forest the trees with same depth are
|
||||
//! depths, i.e. as part of adding a new element to the forest the trees with same depth are
|
||||
//! merged, creating a new tree with depth d+1, this process is continued until the property is
|
||||
//! reestablished.
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use super::{
|
||||
super::{InnerNodeInfo, MerklePath},
|
||||
bit::TrueBitPositionIterator,
|
||||
leaf_to_corresponding_tree, nodes_in_forest, MmrDelta, MmrError, MmrPeaks, MmrProof, Rpo256,
|
||||
RpoDigest,
|
||||
};
|
||||
use crate::utils::collections::*;
|
||||
|
||||
// MMR
|
||||
// ===============================================================================================
|
||||
@@ -72,19 +73,36 @@ impl Mmr {
|
||||
// FUNCTIONALITY
|
||||
// ============================================================================================
|
||||
|
||||
/// Given a leaf position, returns the Merkle path to its corresponding peak. If the position
|
||||
/// is greater-or-equal than the tree size an error is returned.
|
||||
/// Returns an [MmrProof] for the leaf at the specified position.
|
||||
///
|
||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||
/// has position 0, the second position 1, and so on.
|
||||
pub fn open(&self, pos: usize, target_forest: usize) -> Result<MmrProof, MmrError> {
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified leaf position is out of bounds for this MMR.
|
||||
pub fn open(&self, pos: usize) -> Result<MmrProof, MmrError> {
|
||||
self.open_at(pos, self.forest)
|
||||
}
|
||||
|
||||
/// Returns an [MmrProof] for the leaf at the specified position using the state of the MMR
|
||||
/// at the specified `forest`.
|
||||
///
|
||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||
/// has position 0, the second position 1, and so on.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified leaf position is out of bounds for this MMR.
|
||||
/// - The specified `forest` value is not valid for this MMR.
|
||||
pub fn open_at(&self, pos: usize, forest: usize) -> Result<MmrProof, MmrError> {
|
||||
// find the target tree responsible for the MMR position
|
||||
let tree_bit =
|
||||
leaf_to_corresponding_tree(pos, target_forest).ok_or(MmrError::InvalidPosition(pos))?;
|
||||
leaf_to_corresponding_tree(pos, forest).ok_or(MmrError::PositionNotFound(pos))?;
|
||||
|
||||
// isolate the trees before the target
|
||||
let forest_before = target_forest & high_bitmask(tree_bit + 1);
|
||||
let forest_before = forest & high_bitmask(tree_bit + 1);
|
||||
let index_offset = nodes_in_forest(forest_before);
|
||||
|
||||
// update the value position from global to the target tree
|
||||
@@ -94,7 +112,7 @@ impl Mmr {
|
||||
let (_, path) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);
|
||||
|
||||
Ok(MmrProof {
|
||||
forest: target_forest,
|
||||
forest,
|
||||
position: pos,
|
||||
merkle_path: MerklePath::new(path),
|
||||
})
|
||||
@@ -108,7 +126,7 @@ impl Mmr {
|
||||
pub fn get(&self, pos: usize) -> Result<RpoDigest, MmrError> {
|
||||
// find the target tree responsible for the MMR position
|
||||
let tree_bit =
|
||||
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?;
|
||||
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::PositionNotFound(pos))?;
|
||||
|
||||
// isolate the trees before the target
|
||||
let forest_before = self.forest & high_bitmask(tree_bit + 1);
|
||||
@@ -145,10 +163,21 @@ impl Mmr {
|
||||
self.forest += 1;
|
||||
}
|
||||
|
||||
/// Returns an peaks of the MMR for the version specified by `forest`.
|
||||
pub fn peaks(&self, forest: usize) -> Result<MmrPeaks, MmrError> {
|
||||
/// Returns the current peaks of the MMR.
|
||||
pub fn peaks(&self) -> MmrPeaks {
|
||||
self.peaks_at(self.forest).expect("failed to get peaks at current forest")
|
||||
}
|
||||
|
||||
/// Returns the peaks of the MMR at the state specified by `forest`.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified `forest` value is not valid for this MMR.
|
||||
pub fn peaks_at(&self, forest: usize) -> Result<MmrPeaks, MmrError> {
|
||||
if forest > self.forest {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
return Err(MmrError::InvalidPeaks(format!(
|
||||
"requested forest {forest} exceeds current forest {}",
|
||||
self.forest
|
||||
)));
|
||||
}
|
||||
|
||||
let peaks: Vec<RpoDigest> = TrueBitPositionIterator::new(forest)
|
||||
@@ -173,7 +202,7 @@ impl Mmr {
|
||||
/// that have been merged together, followed by the new peaks of the [Mmr].
|
||||
pub fn get_delta(&self, from_forest: usize, to_forest: usize) -> Result<MmrDelta, MmrError> {
|
||||
if to_forest > self.forest || from_forest > to_forest {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
return Err(MmrError::InvalidPeaks(format!("to_forest {to_forest} exceeds the current forest {} or from_forest {from_forest} exceeds to_forest", self.forest)));
|
||||
}
|
||||
|
||||
if from_forest == to_forest {
|
||||
@@ -344,7 +373,7 @@ pub struct MmrNodes<'a> {
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for MmrNodes<'a> {
|
||||
impl Iterator for MmrNodes<'_> {
|
||||
type Item = InnerNodeInfo;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
@@ -377,7 +406,8 @@ impl<'a> Iterator for MmrNodes<'a> {
|
||||
// the next parent position is one above the position of the pair
|
||||
let parent = self.last_right << 1;
|
||||
|
||||
// the left node has been paired and the current parent yielded, removed it from the forest
|
||||
// the left node has been paired and the current parent yielded, removed it from the
|
||||
// forest
|
||||
self.forest ^= self.last_right;
|
||||
if self.forest & parent == 0 {
|
||||
// this iteration yielded the left parent node
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
//! leaves count.
|
||||
use core::num::NonZeroUsize;
|
||||
|
||||
use winter_utils::{Deserializable, Serializable};
|
||||
|
||||
// IN-ORDER INDEX
|
||||
// ================================================================================================
|
||||
|
||||
@@ -112,6 +114,21 @@ impl InOrderIndex {
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for InOrderIndex {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
target.write_usize(self.idx);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for InOrderIndex {
|
||||
fn read_from<R: winter_utils::ByteReader>(
|
||||
source: &mut R,
|
||||
) -> Result<Self, winter_utils::DeserializationError> {
|
||||
let idx = source.read_usize()?;
|
||||
Ok(InOrderIndex { idx })
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS FROM IN-ORDER INDEX
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
@@ -127,6 +144,7 @@ impl From<InOrderIndex> for u64 {
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use proptest::prelude::*;
|
||||
use winter_utils::{Deserializable, Serializable};
|
||||
|
||||
use super::InOrderIndex;
|
||||
|
||||
@@ -162,4 +180,12 @@ mod test {
|
||||
assert_eq!(left.sibling(), right);
|
||||
assert_eq!(left, right.sibling());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inorder_index_serialization() {
|
||||
let index = InOrderIndex::from_leaf_pos(5);
|
||||
let bytes = index.to_bytes();
|
||||
let index2 = InOrderIndex::read_from_bytes(&bytes).unwrap();
|
||||
assert_eq!(index, index2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,8 +10,6 @@ mod proof;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use super::{Felt, Rpo256, RpoDigest, Word};
|
||||
|
||||
// REEXPORTS
|
||||
// ================================================================================================
|
||||
pub use delta::MmrDelta;
|
||||
@@ -22,6 +20,8 @@ pub use partial::PartialMmr;
|
||||
pub use peaks::MmrPeaks;
|
||||
pub use proof::MmrProof;
|
||||
|
||||
use super::{Felt, Rpo256, RpoDigest, Word};
|
||||
|
||||
// UTILITIES
|
||||
// ===============================================================================================
|
||||
|
||||
@@ -42,8 +42,8 @@ const fn leaf_to_corresponding_tree(pos: usize, forest: usize) -> Option<u32> {
|
||||
// - this means the first tree owns from `0` up to the `2^k_0` first positions, where `k_0`
|
||||
// is the highest true bit position, the second tree from `2^k_0 + 1` up to `2^k_1` where
|
||||
// `k_1` is the second highest bit, so on.
|
||||
// - this means the highest bits work as a category marker, and the position is owned by
|
||||
// the first tree which doesn't share a high bit with the position
|
||||
// - this means the highest bits work as a category marker, and the position is owned by the
|
||||
// first tree which doesn't share a high bit with the position
|
||||
let before = forest & pos;
|
||||
let after = forest ^ before;
|
||||
let tree = after.ilog2();
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
use alloc::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
vec::Vec,
|
||||
};
|
||||
|
||||
use winter_utils::{Deserializable, Serializable};
|
||||
|
||||
use super::{MmrDelta, MmrProof, Rpo256, RpoDigest};
|
||||
use crate::{
|
||||
merkle::{
|
||||
use crate::merkle::{
|
||||
mmr::{leaf_to_corresponding_tree, nodes_in_forest},
|
||||
InOrderIndex, InnerNodeInfo, MerklePath, MmrError, MmrPeaks,
|
||||
},
|
||||
utils::{collections::*, vec},
|
||||
};
|
||||
|
||||
// TYPE ALIASES
|
||||
@@ -141,7 +145,7 @@ impl PartialMmr {
|
||||
/// in the underlying MMR.
|
||||
pub fn open(&self, pos: usize) -> Result<Option<MmrProof>, MmrError> {
|
||||
let tree_bit =
|
||||
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?;
|
||||
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::PositionNotFound(pos))?;
|
||||
let depth = tree_bit as usize;
|
||||
|
||||
let mut nodes = Vec::with_capacity(depth);
|
||||
@@ -182,7 +186,7 @@ impl PartialMmr {
|
||||
pub fn inner_nodes<'a, I: Iterator<Item = (usize, RpoDigest)> + 'a>(
|
||||
&'a self,
|
||||
mut leaves: I,
|
||||
) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
) -> impl Iterator<Item = InnerNodeInfo> + 'a {
|
||||
let stack = if let Some((pos, leaf)) = leaves.next() {
|
||||
let idx = InOrderIndex::from_leaf_pos(pos);
|
||||
vec![(idx, leaf)]
|
||||
@@ -294,7 +298,7 @@ impl PartialMmr {
|
||||
// invalid.
|
||||
let tree = 1 << path.depth();
|
||||
if tree & self.forest == 0 {
|
||||
return Err(MmrError::UnknownPeak);
|
||||
return Err(MmrError::UnknownPeak(path.depth()));
|
||||
};
|
||||
|
||||
if leaf_pos + 1 == self.forest
|
||||
@@ -315,9 +319,11 @@ impl PartialMmr {
|
||||
|
||||
// Compute the root of the authentication path, and check it matches the current version of
|
||||
// the PartialMmr.
|
||||
let computed = path.compute_root(path_idx as u64, leaf).map_err(MmrError::MerkleError)?;
|
||||
let computed = path
|
||||
.compute_root(path_idx as u64, leaf)
|
||||
.map_err(MmrError::MerkleRootComputationFailed)?;
|
||||
if self.peaks[peak_pos] != computed {
|
||||
return Err(MmrError::InvalidPeak);
|
||||
return Err(MmrError::PeakPathMismatch);
|
||||
}
|
||||
|
||||
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
|
||||
@@ -352,7 +358,10 @@ impl PartialMmr {
|
||||
/// inserted into the partial MMR.
|
||||
pub fn apply(&mut self, delta: MmrDelta) -> Result<Vec<(InOrderIndex, RpoDigest)>, MmrError> {
|
||||
if delta.forest < self.forest {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
return Err(MmrError::InvalidPeaks(format!(
|
||||
"forest of mmr delta {} is less than current forest {}",
|
||||
delta.forest, self.forest
|
||||
)));
|
||||
}
|
||||
|
||||
let mut inserted_nodes = Vec::new();
|
||||
@@ -535,7 +544,7 @@ pub struct InnerNodeIterator<'a, I: Iterator<Item = (usize, RpoDigest)>> {
|
||||
seen_nodes: BTreeSet<InOrderIndex>,
|
||||
}
|
||||
|
||||
impl<'a, I: Iterator<Item = (usize, RpoDigest)>> Iterator for InnerNodeIterator<'a, I> {
|
||||
impl<I: Iterator<Item = (usize, RpoDigest)>> Iterator for InnerNodeIterator<'_, I> {
|
||||
type Item = InnerNodeInfo;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
@@ -571,6 +580,28 @@ impl<'a, I: Iterator<Item = (usize, RpoDigest)>> Iterator for InnerNodeIterator<
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for PartialMmr {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
self.forest.write_into(target);
|
||||
self.peaks.write_into(target);
|
||||
self.nodes.write_into(target);
|
||||
target.write_bool(self.track_latest);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for PartialMmr {
|
||||
fn read_from<R: winter_utils::ByteReader>(
|
||||
source: &mut R,
|
||||
) -> Result<Self, winter_utils::DeserializationError> {
|
||||
let forest = usize::read_from(source)?;
|
||||
let peaks = Vec::<RpoDigest>::read_from(source)?;
|
||||
let nodes = NodeMap::read_from(source)?;
|
||||
let track_latest = source.read_bool()?;
|
||||
|
||||
Ok(Self { forest, peaks, nodes, track_latest })
|
||||
}
|
||||
}
|
||||
|
||||
// UTILS
|
||||
// ================================================================================================
|
||||
|
||||
@@ -612,14 +643,15 @@ fn forest_to_rightmost_index(forest: usize) -> InOrderIndex {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use alloc::{collections::BTreeSet, vec::Vec};
|
||||
|
||||
use winter_utils::{Deserializable, Serializable};
|
||||
|
||||
use super::{
|
||||
forest_to_rightmost_index, forest_to_root_index, InOrderIndex, MmrPeaks, PartialMmr,
|
||||
RpoDigest,
|
||||
};
|
||||
use crate::{
|
||||
merkle::{int_to_node, MerkleStore, Mmr, NodeIndex},
|
||||
utils::collections::*,
|
||||
};
|
||||
use crate::merkle::{int_to_node, MerkleStore, Mmr, NodeIndex};
|
||||
|
||||
const LEAVES: [RpoDigest; 7] = [
|
||||
int_to_node(0),
|
||||
@@ -689,18 +721,18 @@ mod tests {
|
||||
// build an MMR with 10 nodes (2 peaks) and a partial MMR based on it
|
||||
let mut mmr = Mmr::default();
|
||||
(0..10).for_each(|i| mmr.add(int_to_node(i)));
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks().into();
|
||||
|
||||
// add authentication path for position 1 and 8
|
||||
{
|
||||
let node = mmr.get(1).unwrap();
|
||||
let proof = mmr.open(1, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(1).unwrap();
|
||||
partial_mmr.track(1, node, &proof.merkle_path).unwrap();
|
||||
}
|
||||
|
||||
{
|
||||
let node = mmr.get(8).unwrap();
|
||||
let proof = mmr.open(8, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(8).unwrap();
|
||||
partial_mmr.track(8, node, &proof.merkle_path).unwrap();
|
||||
}
|
||||
|
||||
@@ -713,7 +745,7 @@ mod tests {
|
||||
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||
{
|
||||
let node = mmr.get(12).unwrap();
|
||||
let proof = mmr.open(12, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(12).unwrap();
|
||||
partial_mmr.track(12, node, &proof.merkle_path).unwrap();
|
||||
assert!(partial_mmr.track_latest);
|
||||
}
|
||||
@@ -738,7 +770,7 @@ mod tests {
|
||||
let nodes_delta = partial.apply(delta).unwrap();
|
||||
|
||||
// new peaks were computed correctly
|
||||
assert_eq!(mmr.peaks(mmr.forest()).unwrap(), partial.peaks());
|
||||
assert_eq!(mmr.peaks(), partial.peaks());
|
||||
|
||||
let mut expected_nodes = nodes_before;
|
||||
for (key, value) in nodes_delta {
|
||||
@@ -754,7 +786,7 @@ mod tests {
|
||||
let index_value: u64 = index.into();
|
||||
let pos = index_value / 2;
|
||||
let proof1 = partial.open(pos as usize).unwrap().unwrap();
|
||||
let proof2 = mmr.open(pos as usize, mmr.forest()).unwrap();
|
||||
let proof2 = mmr.open(pos as usize).unwrap();
|
||||
assert_eq!(proof1, proof2);
|
||||
}
|
||||
}
|
||||
@@ -763,16 +795,16 @@ mod tests {
|
||||
fn test_partial_mmr_inner_nodes_iterator() {
|
||||
// build the MMR
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let first_peak = mmr.peaks(mmr.forest).unwrap().peaks()[0];
|
||||
let first_peak = mmr.peaks().peaks()[0];
|
||||
|
||||
// -- test single tree ----------------------------
|
||||
|
||||
// get path and node for position 1
|
||||
let node1 = mmr.get(1).unwrap();
|
||||
let proof1 = mmr.open(1, mmr.forest()).unwrap();
|
||||
let proof1 = mmr.open(1).unwrap();
|
||||
|
||||
// create partial MMR and add authentication path to node at position 1
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks().into();
|
||||
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
|
||||
|
||||
// empty iterator should have no nodes
|
||||
@@ -790,13 +822,13 @@ mod tests {
|
||||
// -- test no duplicates --------------------------
|
||||
|
||||
// build the partial MMR
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks().into();
|
||||
|
||||
let node0 = mmr.get(0).unwrap();
|
||||
let proof0 = mmr.open(0, mmr.forest()).unwrap();
|
||||
let proof0 = mmr.open(0).unwrap();
|
||||
|
||||
let node2 = mmr.get(2).unwrap();
|
||||
let proof2 = mmr.open(2, mmr.forest()).unwrap();
|
||||
let proof2 = mmr.open(2).unwrap();
|
||||
|
||||
partial_mmr.track(0, node0, &proof0.merkle_path).unwrap();
|
||||
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
|
||||
@@ -827,10 +859,10 @@ mod tests {
|
||||
// -- test multiple trees -------------------------
|
||||
|
||||
// build the partial MMR
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into();
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks().into();
|
||||
|
||||
let node5 = mmr.get(5).unwrap();
|
||||
let proof5 = mmr.open(5, mmr.forest()).unwrap();
|
||||
let proof5 = mmr.open(5).unwrap();
|
||||
|
||||
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
|
||||
partial_mmr.track(5, node5, &proof5.merkle_path).unwrap();
|
||||
@@ -842,7 +874,7 @@ mod tests {
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
let index5 = NodeIndex::new(1, 1).unwrap();
|
||||
|
||||
let second_peak = mmr.peaks(mmr.forest).unwrap().peaks()[1];
|
||||
let second_peak = mmr.peaks().peaks()[1];
|
||||
|
||||
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||
let path5 = store.get_path(second_peak, index5).unwrap().path;
|
||||
@@ -861,8 +893,7 @@ mod tests {
|
||||
mmr.add(el);
|
||||
partial_mmr.add(el, false);
|
||||
|
||||
let mmr_peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(mmr_peaks, partial_mmr.peaks());
|
||||
assert_eq!(mmr.peaks(), partial_mmr.peaks());
|
||||
assert_eq!(mmr.forest(), partial_mmr.forest());
|
||||
}
|
||||
}
|
||||
@@ -878,12 +909,11 @@ mod tests {
|
||||
mmr.add(el);
|
||||
partial_mmr.add(el, true);
|
||||
|
||||
let mmr_peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||
assert_eq!(mmr_peaks, partial_mmr.peaks());
|
||||
assert_eq!(mmr.peaks(), partial_mmr.peaks());
|
||||
assert_eq!(mmr.forest(), partial_mmr.forest());
|
||||
|
||||
for pos in 0..i {
|
||||
let mmr_proof = mmr.open(pos as usize, mmr.forest()).unwrap();
|
||||
let mmr_proof = mmr.open(pos as usize).unwrap();
|
||||
let partialmmr_proof = partial_mmr.open(pos as usize).unwrap().unwrap();
|
||||
assert_eq!(mmr_proof, partialmmr_proof);
|
||||
}
|
||||
@@ -895,8 +925,8 @@ mod tests {
|
||||
let mut mmr = Mmr::from((0..7).map(int_to_node));
|
||||
|
||||
// derive a partial Mmr from it which tracks authentication path to leaf 5
|
||||
let mut partial_mmr = PartialMmr::from_peaks(mmr.peaks(mmr.forest()).unwrap());
|
||||
let path_to_5 = mmr.open(5, mmr.forest()).unwrap().merkle_path;
|
||||
let mut partial_mmr = PartialMmr::from_peaks(mmr.peaks());
|
||||
let path_to_5 = mmr.open(5).unwrap().merkle_path;
|
||||
let leaf_at_5 = mmr.get(5).unwrap();
|
||||
partial_mmr.track(5, leaf_at_5, &path_to_5).unwrap();
|
||||
|
||||
@@ -906,6 +936,17 @@ mod tests {
|
||||
partial_mmr.add(leaf_at_7, false);
|
||||
|
||||
// the openings should be the same
|
||||
assert_eq!(mmr.open(5, mmr.forest()).unwrap(), partial_mmr.open(5).unwrap().unwrap());
|
||||
assert_eq!(mmr.open(5).unwrap(), partial_mmr.open(5).unwrap().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_serialization() {
|
||||
let mmr = Mmr::from((0..7).map(int_to_node));
|
||||
let partial_mmr = PartialMmr::from_peaks(mmr.peaks());
|
||||
|
||||
let bytes = partial_mmr.to_bytes();
|
||||
let decoded = PartialMmr::read_from_bytes(&bytes).unwrap();
|
||||
|
||||
assert_eq!(partial_mmr, decoded);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use super::{super::ZERO, Felt, MmrError, MmrProof, Rpo256, RpoDigest, Word};
|
||||
use crate::utils::collections::*;
|
||||
|
||||
// MMR PEAKS
|
||||
// ================================================================================================
|
||||
@@ -18,12 +19,12 @@ pub struct MmrPeaks {
|
||||
///
|
||||
/// Examples:
|
||||
///
|
||||
/// - With 5 leaves, the binary `0b101`. The number of set bits is equal the number
|
||||
/// of peaks, in this case there are 2 peaks. The 0-indexed least-significant position of
|
||||
/// the bit determines the number of elements of a tree, so the rightmost tree has `2**0`
|
||||
/// elements and the left most has `2**2`.
|
||||
/// - With 12 leaves, the binary is `0b1100`, this case also has 2 peaks, the
|
||||
/// leftmost tree has `2**3=8` elements, and the right most has `2**2=4` elements.
|
||||
/// - With 5 leaves, the binary `0b101`. The number of set bits is equal the number of peaks,
|
||||
/// in this case there are 2 peaks. The 0-indexed least-significant position of the bit
|
||||
/// determines the number of elements of a tree, so the rightmost tree has `2**0` elements
|
||||
/// and the left most has `2**2`.
|
||||
/// - With 12 leaves, the binary is `0b1100`, this case also has 2 peaks, the leftmost tree has
|
||||
/// `2**3=8` elements, and the right most has `2**2=4` elements.
|
||||
num_leaves: usize,
|
||||
|
||||
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
|
||||
@@ -44,7 +45,11 @@ impl MmrPeaks {
|
||||
/// Returns an error if the number of leaves and the number of peaks are inconsistent.
|
||||
pub fn new(num_leaves: usize, peaks: Vec<RpoDigest>) -> Result<Self, MmrError> {
|
||||
if num_leaves.count_ones() as usize != peaks.len() {
|
||||
return Err(MmrError::InvalidPeaks);
|
||||
return Err(MmrError::InvalidPeaks(format!(
|
||||
"number of one bits in leaves is {} which does not equal peak length {}",
|
||||
num_leaves.count_ones(),
|
||||
peaks.len()
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Self { num_leaves, peaks })
|
||||
@@ -68,6 +73,17 @@ impl MmrPeaks {
|
||||
&self.peaks
|
||||
}
|
||||
|
||||
/// Returns the peak by the provided index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the provided peak index is greater or equal to the current number of
|
||||
/// peaks in the Mmr.
|
||||
pub fn get_peak(&self, peak_idx: usize) -> Result<&RpoDigest, MmrError> {
|
||||
self.peaks
|
||||
.get(peak_idx)
|
||||
.ok_or(MmrError::PeakOutOfBounds { peak_idx, peaks_len: self.peaks.len() })
|
||||
}
|
||||
|
||||
/// Converts this [MmrPeaks] into its components: number of leaves and a vector of peaks of
|
||||
/// the underlying MMR.
|
||||
pub fn into_parts(self) -> (usize, Vec<RpoDigest>) {
|
||||
@@ -83,9 +99,18 @@ impl MmrPeaks {
|
||||
Rpo256::hash_elements(&self.flatten_and_pad_peaks())
|
||||
}
|
||||
|
||||
pub fn verify(&self, value: RpoDigest, opening: MmrProof) -> bool {
|
||||
let root = &self.peaks[opening.peak_index()];
|
||||
opening.merkle_path.verify(opening.relative_pos() as u64, value, root)
|
||||
/// Verifies the Merkle opening proof.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - provided opening proof is invalid.
|
||||
/// - Mmr root value computed using the provided leaf value differs from the actual one.
|
||||
pub fn verify(&self, value: RpoDigest, opening: MmrProof) -> Result<(), MmrError> {
|
||||
let root = self.get_peak(opening.peak_index())?;
|
||||
opening
|
||||
.merkle_path
|
||||
.verify(opening.relative_pos() as u64, value, root)
|
||||
.map_err(MmrError::InvalidMerklePath)
|
||||
}
|
||||
|
||||
/// Flattens and pads the peaks to make hashing inside of the Miden VM easier.
|
||||
@@ -94,16 +119,15 @@ impl MmrPeaks {
|
||||
/// - Flatten the vector of Words into a vector of Felts.
|
||||
/// - Pad the peaks with ZERO to an even number of words, this removes the need to handle RPO
|
||||
/// padding.
|
||||
/// - Pad the peaks to a minimum length of 16 words, which reduces the constant cost of
|
||||
/// hashing.
|
||||
/// - Pad the peaks to a minimum length of 16 words, which reduces the constant cost of hashing.
|
||||
pub fn flatten_and_pad_peaks(&self) -> Vec<Felt> {
|
||||
let num_peaks = self.peaks.len();
|
||||
|
||||
// To achieve the padding rules above we calculate the length of the final vector.
|
||||
// This is calculated as the number of field elements. Each peak is 4 field elements.
|
||||
// The length is calculated as follows:
|
||||
// - If there are less than 16 peaks, the data is padded to 16 peaks and as such requires
|
||||
// 64 field elements.
|
||||
// - If there are less than 16 peaks, the data is padded to 16 peaks and as such requires 64
|
||||
// field elements.
|
||||
// - If there are more than 16 peaks and the number of peaks is odd, the data is padded to
|
||||
// an even number of peaks and as such requires `(num_peaks + 1) * 4` field elements.
|
||||
// - If there are more than 16 peaks and the number of peaks is even, the data is not padded
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use super::{
|
||||
super::{InnerNodeInfo, Rpo256, RpoDigest},
|
||||
bit::TrueBitPositionIterator,
|
||||
@@ -6,7 +8,6 @@ use super::{
|
||||
};
|
||||
use crate::{
|
||||
merkle::{int_to_node, InOrderIndex, MerklePath, MerkleTree, MmrProof, NodeIndex},
|
||||
utils::collections::*,
|
||||
Felt, Word,
|
||||
};
|
||||
|
||||
@@ -138,7 +139,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 1);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
let acc = mmr.peaks();
|
||||
assert_eq!(acc.num_leaves(), 1);
|
||||
assert_eq!(acc.peaks(), &[postorder[0]]);
|
||||
|
||||
@@ -147,7 +148,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 3);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
let acc = mmr.peaks();
|
||||
assert_eq!(acc.num_leaves(), 2);
|
||||
assert_eq!(acc.peaks(), &[postorder[2]]);
|
||||
|
||||
@@ -156,7 +157,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 4);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
let acc = mmr.peaks();
|
||||
assert_eq!(acc.num_leaves(), 3);
|
||||
assert_eq!(acc.peaks(), &[postorder[2], postorder[3]]);
|
||||
|
||||
@@ -165,7 +166,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 7);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
let acc = mmr.peaks();
|
||||
assert_eq!(acc.num_leaves(), 4);
|
||||
assert_eq!(acc.peaks(), &[postorder[6]]);
|
||||
|
||||
@@ -174,7 +175,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 8);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
let acc = mmr.peaks();
|
||||
assert_eq!(acc.num_leaves(), 5);
|
||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[7]]);
|
||||
|
||||
@@ -183,7 +184,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 10);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
let acc = mmr.peaks();
|
||||
assert_eq!(acc.num_leaves(), 6);
|
||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[9]]);
|
||||
|
||||
@@ -192,7 +193,7 @@ fn test_mmr_simple() {
|
||||
assert_eq!(mmr.nodes.len(), 11);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
let acc = mmr.peaks();
|
||||
assert_eq!(acc.num_leaves(), 7);
|
||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[9], postorder[10]]);
|
||||
}
|
||||
@@ -204,97 +205,73 @@ fn test_mmr_open() {
|
||||
let h23 = merge(LEAVES[2], LEAVES[3]);
|
||||
|
||||
// node at pos 7 is the root
|
||||
assert!(
|
||||
mmr.open(7, mmr.forest()).is_err(),
|
||||
"Element 7 is not in the tree, result should be None"
|
||||
);
|
||||
assert!(mmr.open(7).is_err(), "Element 7 is not in the tree, result should be None");
|
||||
|
||||
// node at pos 6 is the root
|
||||
let empty: MerklePath = MerklePath::new(vec![]);
|
||||
let opening = mmr
|
||||
.open(6, mmr.forest())
|
||||
.open(6)
|
||||
.expect("Element 6 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, empty);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 6);
|
||||
assert!(
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[6], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
mmr.peaks().verify(LEAVES[6], opening).unwrap();
|
||||
|
||||
// nodes 4,5 are depth 1
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[4]]);
|
||||
let opening = mmr
|
||||
.open(5, mmr.forest())
|
||||
.open(5)
|
||||
.expect("Element 5 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 5);
|
||||
assert!(
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[5], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
mmr.peaks().verify(LEAVES[5], opening).unwrap();
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[5]]);
|
||||
let opening = mmr
|
||||
.open(4, mmr.forest())
|
||||
.open(4)
|
||||
.expect("Element 4 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 4);
|
||||
assert!(
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[4], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
mmr.peaks().verify(LEAVES[4], opening).unwrap();
|
||||
|
||||
// nodes 0,1,2,3 are detph 2
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[2], h01]);
|
||||
let opening = mmr
|
||||
.open(3, mmr.forest())
|
||||
.open(3)
|
||||
.expect("Element 3 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 3);
|
||||
assert!(
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[3], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
mmr.peaks().verify(LEAVES[3], opening).unwrap();
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[3], h01]);
|
||||
let opening = mmr
|
||||
.open(2, mmr.forest())
|
||||
.open(2)
|
||||
.expect("Element 2 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 2);
|
||||
assert!(
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[2], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
mmr.peaks().verify(LEAVES[2], opening).unwrap();
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[0], h23]);
|
||||
let opening = mmr
|
||||
.open(1, mmr.forest())
|
||||
.open(1)
|
||||
.expect("Element 1 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 1);
|
||||
assert!(
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[1], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
mmr.peaks().verify(LEAVES[1], opening).unwrap();
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[1], h23]);
|
||||
let opening = mmr
|
||||
.open(0, mmr.forest())
|
||||
.open(0)
|
||||
.expect("Element 0 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 0);
|
||||
assert!(
|
||||
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[0], opening),
|
||||
"MmrProof should be valid for the current accumulator."
|
||||
);
|
||||
mmr.peaks().verify(LEAVES[0], opening).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -308,7 +285,7 @@ fn test_mmr_open_older_version() {
|
||||
// merkle path of a node is empty if there are no elements to pair with it
|
||||
for pos in (0..mmr.forest()).filter(is_even) {
|
||||
let forest = pos + 1;
|
||||
let proof = mmr.open(pos, forest).unwrap();
|
||||
let proof = mmr.open_at(pos, forest).unwrap();
|
||||
assert_eq!(proof.forest, forest);
|
||||
assert_eq!(proof.merkle_path.nodes(), []);
|
||||
assert_eq!(proof.position, pos);
|
||||
@@ -320,7 +297,7 @@ fn test_mmr_open_older_version() {
|
||||
for pos in 0..4 {
|
||||
let idx = NodeIndex::new(2, pos).unwrap();
|
||||
let path = mtree.get_path(idx).unwrap();
|
||||
let proof = mmr.open(pos as usize, forest).unwrap();
|
||||
let proof = mmr.open_at(pos as usize, forest).unwrap();
|
||||
assert_eq!(path, proof.merkle_path);
|
||||
}
|
||||
}
|
||||
@@ -331,7 +308,7 @@ fn test_mmr_open_older_version() {
|
||||
let path = mtree.get_path(idx).unwrap();
|
||||
// account for the bigger tree with 4 elements
|
||||
let mmr_pos = (pos + 4) as usize;
|
||||
let proof = mmr.open(mmr_pos, forest).unwrap();
|
||||
let proof = mmr.open_at(mmr_pos, forest).unwrap();
|
||||
assert_eq!(path, proof.merkle_path);
|
||||
}
|
||||
}
|
||||
@@ -357,49 +334,49 @@ fn test_mmr_open_eight() {
|
||||
let root = mtree.root();
|
||||
|
||||
let position = 0;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 1;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 2;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 3;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 4;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 5;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 6;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 7;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
@@ -415,47 +392,47 @@ fn test_mmr_open_seven() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
let position = 0;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[0]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 1;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[1]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 2;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(2, LEAVES[2]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 3;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(3, LEAVES[3]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 4;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 0u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[4]).unwrap(), mtree2.root());
|
||||
|
||||
let position = 5;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 1u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[5]).unwrap(), mtree2.root());
|
||||
|
||||
let position = 6;
|
||||
let proof = mmr.open(position, mmr.forest()).unwrap();
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path: MerklePath = [].as_ref().into();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[6]).unwrap(), LEAVES[6]);
|
||||
@@ -479,7 +456,7 @@ fn test_mmr_invariants() {
|
||||
let mut mmr = Mmr::new();
|
||||
for v in 1..=1028 {
|
||||
mmr.add(int_to_node(v));
|
||||
let accumulator = mmr.peaks(mmr.forest()).unwrap();
|
||||
let accumulator = mmr.peaks();
|
||||
assert_eq!(v as usize, mmr.forest(), "MMR leaf count must increase by one on every add");
|
||||
assert_eq!(
|
||||
v as usize,
|
||||
@@ -565,37 +542,37 @@ fn test_mmr_peaks() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
let forest = 0b0001;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
let acc = mmr.peaks_at(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[0]]);
|
||||
|
||||
let forest = 0b0010;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
let acc = mmr.peaks_at(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[2]]);
|
||||
|
||||
let forest = 0b0011;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
let acc = mmr.peaks_at(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[2], mmr.nodes[3]]);
|
||||
|
||||
let forest = 0b0100;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
let acc = mmr.peaks_at(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6]]);
|
||||
|
||||
let forest = 0b0101;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
let acc = mmr.peaks_at(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[7]]);
|
||||
|
||||
let forest = 0b0110;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
let acc = mmr.peaks_at(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9]]);
|
||||
|
||||
let forest = 0b0111;
|
||||
let acc = mmr.peaks(forest).unwrap();
|
||||
let acc = mmr.peaks_at(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9], mmr.nodes[10]]);
|
||||
}
|
||||
@@ -603,7 +580,7 @@ fn test_mmr_peaks() {
|
||||
#[test]
|
||||
fn test_mmr_hash_peaks() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||
let peaks = mmr.peaks();
|
||||
|
||||
let first_peak = Rpo256::merge(&[
|
||||
Rpo256::merge(&[LEAVES[0], LEAVES[1]]),
|
||||
@@ -657,7 +634,7 @@ fn test_mmr_peaks_hash_odd() {
|
||||
#[test]
|
||||
fn test_mmr_delta() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
let acc = mmr.peaks();
|
||||
|
||||
// original_forest can't have more elements
|
||||
assert!(
|
||||
@@ -757,7 +734,7 @@ fn test_mmr_delta_old_forest() {
|
||||
#[test]
|
||||
fn test_partial_mmr_simple() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let peaks = mmr.peaks(mmr.forest()).unwrap();
|
||||
let peaks = mmr.peaks();
|
||||
let mut partial: PartialMmr = peaks.clone().into();
|
||||
|
||||
// check initial state of the partial mmr
|
||||
@@ -768,7 +745,7 @@ fn test_partial_mmr_simple() {
|
||||
assert_eq!(partial.nodes.len(), 0);
|
||||
|
||||
// check state after adding tracking one element
|
||||
let proof1 = mmr.open(0, mmr.forest()).unwrap();
|
||||
let proof1 = mmr.open(0).unwrap();
|
||||
let el1 = mmr.get(proof1.position).unwrap();
|
||||
partial.track(proof1.position, el1, &proof1.merkle_path).unwrap();
|
||||
|
||||
@@ -780,7 +757,7 @@ fn test_partial_mmr_simple() {
|
||||
let idx = idx.parent();
|
||||
assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[1]);
|
||||
|
||||
let proof2 = mmr.open(1, mmr.forest()).unwrap();
|
||||
let proof2 = mmr.open(1).unwrap();
|
||||
let el2 = mmr.get(proof2.position).unwrap();
|
||||
partial.track(proof2.position, el2, &proof2.merkle_path).unwrap();
|
||||
|
||||
@@ -798,9 +775,9 @@ fn test_partial_mmr_update_single() {
|
||||
let mut full = Mmr::new();
|
||||
let zero = int_to_node(0);
|
||||
full.add(zero);
|
||||
let mut partial: PartialMmr = full.peaks(full.forest()).unwrap().into();
|
||||
let mut partial: PartialMmr = full.peaks().into();
|
||||
|
||||
let proof = full.open(0, full.forest()).unwrap();
|
||||
let proof = full.open(0).unwrap();
|
||||
partial.track(proof.position, zero, &proof.merkle_path).unwrap();
|
||||
|
||||
for i in 1..100 {
|
||||
@@ -810,9 +787,9 @@ fn test_partial_mmr_update_single() {
|
||||
partial.apply(delta).unwrap();
|
||||
|
||||
assert_eq!(partial.forest(), full.forest());
|
||||
assert_eq!(partial.peaks(), full.peaks(full.forest()).unwrap());
|
||||
assert_eq!(partial.peaks(), full.peaks());
|
||||
|
||||
let proof1 = full.open(i as usize, full.forest()).unwrap();
|
||||
let proof1 = full.open(i as usize).unwrap();
|
||||
partial.track(proof1.position, node, &proof1.merkle_path).unwrap();
|
||||
let proof2 = partial.open(proof1.position).unwrap().unwrap();
|
||||
assert_eq!(proof1.merkle_path, proof2.merkle_path);
|
||||
@@ -822,7 +799,7 @@ fn test_partial_mmr_update_single() {
|
||||
#[test]
|
||||
fn test_mmr_add_invalid_odd_leaf() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let acc = mmr.peaks(mmr.forest()).unwrap();
|
||||
let acc = mmr.peaks();
|
||||
let mut partial: PartialMmr = acc.clone().into();
|
||||
|
||||
let empty = MerklePath::new(Vec::new());
|
||||
@@ -837,6 +814,39 @@ fn test_mmr_add_invalid_odd_leaf() {
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
/// Tests that a proof whose peak count exceeds the peak count of the MMR returns an error.
|
||||
///
|
||||
/// Here we manipulate the proof to return a peak index of 1 while the MMR only has 1 peak (with
|
||||
/// index 0).
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mmr_proof_num_peaks_exceeds_current_num_peaks() {
|
||||
let mmr: Mmr = LEAVES[0..4].iter().cloned().into();
|
||||
let mut proof = mmr.open(3).unwrap();
|
||||
proof.forest = 5;
|
||||
proof.position = 4;
|
||||
mmr.peaks().verify(LEAVES[3], proof).unwrap();
|
||||
}
|
||||
|
||||
/// Tests that a proof whose peak count exceeds the peak count of the MMR returns an error.
|
||||
///
|
||||
/// We create an MmrProof for a leaf whose peak index to verify against is 1.
|
||||
/// Then we add another leaf which results in an Mmr with just one peak due to trees
|
||||
/// being merged. If we try to use the old proof against the new Mmr, we should get an error.
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mmr_old_proof_num_peaks_exceeds_current_num_peaks() {
|
||||
let leaves_len = 3;
|
||||
let mut mmr = Mmr::from(LEAVES[0..leaves_len].iter().cloned());
|
||||
|
||||
let leaf_idx = leaves_len - 1;
|
||||
let proof = mmr.open(leaf_idx).unwrap();
|
||||
assert!(mmr.peaks().verify(LEAVES[leaf_idx], proof.clone()).is_ok());
|
||||
|
||||
mmr.add(LEAVES[leaves_len]);
|
||||
mmr.peaks().verify(LEAVES[leaf_idx], proof).unwrap();
|
||||
}
|
||||
|
||||
mod property_tests {
|
||||
use proptest::prelude::*;
|
||||
|
||||
|
||||
@@ -21,9 +21,11 @@ mod path;
|
||||
pub use path::{MerklePath, RootPath, ValuePath};
|
||||
|
||||
mod smt;
|
||||
#[cfg(feature = "internal")]
|
||||
pub use smt::build_subtree_for_bench;
|
||||
pub use smt::{
|
||||
LeafIndex, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH,
|
||||
SMT_MAX_DEPTH, SMT_MIN_DEPTH,
|
||||
LeafIndex, MutationSet, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError,
|
||||
SubtreeLeaf, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
|
||||
};
|
||||
|
||||
mod mmr;
|
||||
@@ -44,9 +46,6 @@ pub use error::MerkleError;
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
use crate::utils::collections::*;
|
||||
|
||||
#[cfg(test)]
|
||||
const fn int_to_node(value: u64) -> RpoDigest {
|
||||
RpoDigest::new([Felt::new(value), ZERO, ZERO, ZERO])
|
||||
@@ -58,6 +57,6 @@ const fn int_to_leaf(value: u64) -> Word {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn digests_to_words(digests: &[RpoDigest]) -> Vec<Word> {
|
||||
fn digests_to_words(digests: &[RpoDigest]) -> alloc::vec::Vec<Word> {
|
||||
digests.iter().map(|d| d.into()).collect()
|
||||
}
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
use alloc::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
string::String,
|
||||
vec::Vec,
|
||||
};
|
||||
use core::fmt;
|
||||
|
||||
use super::{
|
||||
@@ -5,8 +10,7 @@ use super::{
|
||||
EMPTY_WORD,
|
||||
};
|
||||
use crate::utils::{
|
||||
collections::*, format, string::*, vec, word_to_hex, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, Serializable,
|
||||
word_to_hex, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -112,7 +116,7 @@ impl PartialMerkleTree {
|
||||
// depth of 63 because we consider passing in a vector of size 2^64 infeasible.
|
||||
let max = 2usize.pow(63);
|
||||
if layers.len() > max {
|
||||
return Err(MerkleError::InvalidNumEntries(max));
|
||||
return Err(MerkleError::TooManyEntries(max));
|
||||
}
|
||||
|
||||
// Get maximum depth
|
||||
@@ -143,11 +147,12 @@ impl PartialMerkleTree {
|
||||
let index = NodeIndex::new(depth, index_value)?;
|
||||
|
||||
// get hash of the current node
|
||||
let node = nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index))?;
|
||||
let node =
|
||||
nodes.get(&index).ok_or(MerkleError::NodeIndexNotFoundInTree(index))?;
|
||||
// get hash of the sibling node
|
||||
let sibling = nodes
|
||||
.get(&index.sibling())
|
||||
.ok_or(MerkleError::NodeNotInSet(index.sibling()))?;
|
||||
.ok_or(MerkleError::NodeIndexNotFoundInTree(index.sibling()))?;
|
||||
// get parent hash
|
||||
let parent = Rpo256::merge(&index.build_node(*node, *sibling));
|
||||
|
||||
@@ -180,7 +185,10 @@ impl PartialMerkleTree {
|
||||
/// # Errors
|
||||
/// Returns an error if the specified NodeIndex is not contained in the nodes map.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
self.nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index)).copied()
|
||||
self.nodes
|
||||
.get(&index)
|
||||
.ok_or(MerkleError::NodeIndexNotFoundInTree(index))
|
||||
.copied()
|
||||
}
|
||||
|
||||
/// Returns true if provided index contains in the leaves set, false otherwise.
|
||||
@@ -220,7 +228,7 @@ impl PartialMerkleTree {
|
||||
}
|
||||
|
||||
if !self.nodes.contains_key(&index) {
|
||||
return Err(MerkleError::NodeNotInSet(index));
|
||||
return Err(MerkleError::NodeIndexNotFoundInTree(index));
|
||||
}
|
||||
|
||||
let mut path = Vec::new();
|
||||
@@ -331,15 +339,16 @@ impl PartialMerkleTree {
|
||||
if self.root() == EMPTY_DIGEST {
|
||||
self.nodes.insert(ROOT_INDEX, root);
|
||||
} else if self.root() != root {
|
||||
return Err(MerkleError::ConflictingRoots([self.root(), root].to_vec()));
|
||||
return Err(MerkleError::ConflictingRoots {
|
||||
expected_root: self.root(),
|
||||
actual_root: root,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Updates value of the leaf at the specified index returning the old leaf value.
|
||||
/// By default the specified index is assumed to belong to the deepest layer. If the considered
|
||||
/// node does not belong to the tree, the first node on the way to the root will be changed.
|
||||
///
|
||||
/// By default the specified index is assumed to belong to the deepest layer. If the considered
|
||||
/// node does not belong to the tree, the first node on the way to the root will be changed.
|
||||
@@ -348,6 +357,7 @@ impl PartialMerkleTree {
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - No entry exists at the specified index.
|
||||
/// - The specified index is greater than the maximum number of nodes on the deepest layer.
|
||||
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<RpoDigest, MerkleError> {
|
||||
let mut node_index = NodeIndex::new(self.max_depth(), index)?;
|
||||
@@ -363,7 +373,7 @@ impl PartialMerkleTree {
|
||||
let old_value = self
|
||||
.nodes
|
||||
.insert(node_index, value.into())
|
||||
.ok_or(MerkleError::NodeNotInSet(node_index))?;
|
||||
.ok_or(MerkleError::NodeIndexNotFoundInTree(node_index))?;
|
||||
|
||||
// if the old value and new value are the same, there is nothing to update
|
||||
if value == *old_value {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use alloc::{collections::BTreeMap, vec::Vec};
|
||||
|
||||
use super::{
|
||||
super::{
|
||||
digests_to_words, int_to_node, DefaultMerkleStore as MerkleStore, MerkleTree, NodeIndex,
|
||||
@@ -5,7 +7,6 @@ use super::{
|
||||
},
|
||||
Deserializable, InnerNodeInfo, RpoDigest, Serializable, ValuePath,
|
||||
};
|
||||
use crate::utils::collections::*;
|
||||
|
||||
// TEST DATA
|
||||
// ================================================================================================
|
||||
@@ -294,7 +295,8 @@ fn leaves() {
|
||||
assert!(expected_leaves.eq(pmt.leaves()));
|
||||
}
|
||||
|
||||
/// Checks that nodes of the PMT returned by `inner_nodes()` function are equal to the expected ones.
|
||||
/// Checks that nodes of the PMT returned by `inner_nodes()` function are equal to the expected
|
||||
/// ones.
|
||||
#[test]
|
||||
fn test_inner_node_iterator() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use alloc::vec::Vec;
|
||||
use core::ops::{Deref, DerefMut};
|
||||
|
||||
use super::{InnerNodeInfo, MerkleError, NodeIndex, Rpo256, RpoDigest};
|
||||
use crate::{
|
||||
utils::{collections::*, ByteReader, Deserializable, DeserializationError, Serializable},
|
||||
utils::{ByteReader, Deserializable, DeserializationError, Serializable},
|
||||
Word,
|
||||
};
|
||||
|
||||
@@ -53,12 +54,20 @@ impl MerklePath {
|
||||
|
||||
/// Verifies the Merkle opening proof towards the provided root.
|
||||
///
|
||||
/// Returns `true` if `node` exists at `index` in a Merkle tree with `root`.
|
||||
pub fn verify(&self, index: u64, node: RpoDigest, root: &RpoDigest) -> bool {
|
||||
match self.compute_root(index, node) {
|
||||
Ok(computed_root) => root == &computed_root,
|
||||
Err(_) => false,
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - provided node index is invalid.
|
||||
/// - root calculated during the verification differs from the provided one.
|
||||
pub fn verify(&self, index: u64, node: RpoDigest, root: &RpoDigest) -> Result<(), MerkleError> {
|
||||
let computed_root = self.compute_root(index, node)?;
|
||||
if &computed_root != root {
|
||||
return Err(MerkleError::ConflictingRoots {
|
||||
expected_root: *root,
|
||||
actual_root: computed_root,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns an iterator over every inner node of this [MerklePath].
|
||||
@@ -128,7 +137,7 @@ impl FromIterator<RpoDigest> for MerklePath {
|
||||
|
||||
impl IntoIterator for MerklePath {
|
||||
type Item = RpoDigest;
|
||||
type IntoIter = vec::IntoIter<RpoDigest>;
|
||||
type IntoIter = alloc::vec::IntoIter<RpoDigest>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.nodes.into_iter()
|
||||
@@ -142,7 +151,7 @@ pub struct InnerNodeIterator<'a> {
|
||||
value: RpoDigest,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for InnerNodeIterator<'a> {
|
||||
impl Iterator for InnerNodeIterator<'_> {
|
||||
type Item = InnerNodeInfo;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
|
||||
@@ -1,86 +1,39 @@
|
||||
use core::fmt;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{
|
||||
hash::rpo::RpoDigest,
|
||||
merkle::{LeafIndex, SMT_DEPTH},
|
||||
utils::collections::*,
|
||||
Word,
|
||||
};
|
||||
|
||||
// SMT LEAF ERROR
|
||||
// =================================================================================================
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SmtLeafError {
|
||||
InconsistentKeys {
|
||||
entries: Vec<(RpoDigest, Word)>,
|
||||
key_1: RpoDigest,
|
||||
key_2: RpoDigest,
|
||||
},
|
||||
InvalidNumEntriesForMultiple(usize),
|
||||
SingleKeyInconsistentWithLeafIndex {
|
||||
#[error(
|
||||
"multiple leaf requires all keys to map to the same leaf index but key1 {key_1} and key2 {key_2} map to different indices"
|
||||
)]
|
||||
InconsistentMultipleLeafKeys { key_1: RpoDigest, key_2: RpoDigest },
|
||||
#[error("single leaf key {key} maps to {actual_leaf_index:?} but was expected to map to {expected_leaf_index:?}")]
|
||||
InconsistentSingleLeafIndices {
|
||||
key: RpoDigest,
|
||||
leaf_index: LeafIndex<SMT_DEPTH>,
|
||||
expected_leaf_index: LeafIndex<SMT_DEPTH>,
|
||||
actual_leaf_index: LeafIndex<SMT_DEPTH>,
|
||||
},
|
||||
MultipleKeysInconsistentWithLeafIndex {
|
||||
#[error("supplied leaf index {leaf_index_supplied:?} does not match {leaf_index_from_keys:?} for multiple leaf")]
|
||||
InconsistentMultipleLeafIndices {
|
||||
leaf_index_from_keys: LeafIndex<SMT_DEPTH>,
|
||||
leaf_index_supplied: LeafIndex<SMT_DEPTH>,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for SmtLeafError {}
|
||||
|
||||
impl fmt::Display for SmtLeafError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use SmtLeafError::*;
|
||||
match self {
|
||||
InvalidNumEntriesForMultiple(num_entries) => {
|
||||
write!(f, "Multiple leaf requires 2 or more entries. Got: {num_entries}")
|
||||
}
|
||||
InconsistentKeys { entries, key_1, key_2 } => {
|
||||
write!(f, "Multiple leaf requires all keys to map to the same leaf index. Offending keys: {key_1} and {key_2}. Entries: {entries:?}.")
|
||||
}
|
||||
SingleKeyInconsistentWithLeafIndex { key, leaf_index } => {
|
||||
write!(
|
||||
f,
|
||||
"Single key in leaf inconsistent with leaf index. Key: {key}, leaf index: {}",
|
||||
leaf_index.value()
|
||||
)
|
||||
}
|
||||
MultipleKeysInconsistentWithLeafIndex {
|
||||
leaf_index_from_keys,
|
||||
leaf_index_supplied,
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"Keys in entries map to leaf index {}, but leaf index {} was supplied",
|
||||
leaf_index_from_keys.value(),
|
||||
leaf_index_supplied.value()
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
#[error("multiple leaf requires at least two entries but only {0} were given")]
|
||||
MultipleLeafRequiresTwoEntries(usize),
|
||||
}
|
||||
|
||||
// SMT PROOF ERROR
|
||||
// =================================================================================================
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SmtProofError {
|
||||
InvalidPathLength(usize),
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for SmtProofError {}
|
||||
|
||||
impl fmt::Display for SmtProofError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
use SmtProofError::*;
|
||||
match self {
|
||||
InvalidPathLength(path_length) => {
|
||||
write!(f, "Invalid Merkle path length. Expected {SMT_DEPTH}, got {path_length}")
|
||||
}
|
||||
}
|
||||
}
|
||||
#[error("merkle path length {0} does not match SMT depth {SMT_DEPTH}")]
|
||||
InvalidMerklePathLength(usize),
|
||||
}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
use alloc::{string::ToString, vec::Vec};
|
||||
use core::cmp::Ordering;
|
||||
|
||||
use super::{Felt, LeafIndex, Rpo256, RpoDigest, SmtLeafError, Word, EMPTY_WORD, SMT_DEPTH};
|
||||
use crate::utils::{
|
||||
collections::*, string::*, vec, ByteReader, ByteWriter, Deserializable, DeserializationError,
|
||||
Serializable,
|
||||
};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
@@ -22,8 +20,8 @@ impl SmtLeaf {
|
||||
///
|
||||
/// # Errors
|
||||
/// - Returns an error if 2 keys in `entries` map to a different leaf index
|
||||
/// - Returns an error if 1 or more keys in `entries` map to a leaf index
|
||||
/// different from `leaf_index`
|
||||
/// - Returns an error if 1 or more keys in `entries` map to a leaf index different from
|
||||
/// `leaf_index`
|
||||
pub fn new(
|
||||
entries: Vec<(RpoDigest, Word)>,
|
||||
leaf_index: LeafIndex<SMT_DEPTH>,
|
||||
@@ -33,29 +31,31 @@ impl SmtLeaf {
|
||||
1 => {
|
||||
let (key, value) = entries[0];
|
||||
|
||||
if LeafIndex::<SMT_DEPTH>::from(key) != leaf_index {
|
||||
return Err(SmtLeafError::SingleKeyInconsistentWithLeafIndex {
|
||||
let computed_index = LeafIndex::<SMT_DEPTH>::from(key);
|
||||
if computed_index != leaf_index {
|
||||
return Err(SmtLeafError::InconsistentSingleLeafIndices {
|
||||
key,
|
||||
leaf_index,
|
||||
expected_leaf_index: leaf_index,
|
||||
actual_leaf_index: computed_index,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Self::new_single(key, value))
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
let leaf = Self::new_multiple(entries)?;
|
||||
|
||||
// `new_multiple()` checked that all keys map to the same leaf index. We still need
|
||||
// to ensure that that leaf index is `leaf_index`.
|
||||
if leaf.index() != leaf_index {
|
||||
Err(SmtLeafError::MultipleKeysInconsistentWithLeafIndex {
|
||||
Err(SmtLeafError::InconsistentMultipleLeafIndices {
|
||||
leaf_index_from_keys: leaf.index(),
|
||||
leaf_index_supplied: leaf_index,
|
||||
})
|
||||
} else {
|
||||
Ok(leaf)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ impl SmtLeaf {
|
||||
/// - Returns an error if 2 keys in `entries` map to a different leaf index
|
||||
pub fn new_multiple(entries: Vec<(RpoDigest, Word)>) -> Result<Self, SmtLeafError> {
|
||||
if entries.len() < 2 {
|
||||
return Err(SmtLeafError::InvalidNumEntriesForMultiple(entries.len()));
|
||||
return Err(SmtLeafError::MultipleLeafRequiresTwoEntries(entries.len()));
|
||||
}
|
||||
|
||||
// Check that all keys map to the same leaf index
|
||||
@@ -91,8 +91,7 @@ impl SmtLeaf {
|
||||
let next_leaf_index: LeafIndex<SMT_DEPTH> = next_key.into();
|
||||
|
||||
if next_leaf_index != first_leaf_index {
|
||||
return Err(SmtLeafError::InconsistentKeys {
|
||||
entries,
|
||||
return Err(SmtLeafError::InconsistentMultipleLeafKeys {
|
||||
key_1: first_key,
|
||||
key_2: next_key,
|
||||
});
|
||||
@@ -120,7 +119,7 @@ impl SmtLeaf {
|
||||
// Note: All keys are guaranteed to have the same leaf index
|
||||
let (first_key, _) = entries[0];
|
||||
first_key.into()
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,7 +130,7 @@ impl SmtLeaf {
|
||||
SmtLeaf::Single(_) => 1,
|
||||
SmtLeaf::Multiple(entries) => {
|
||||
entries.len().try_into().expect("shouldn't have more than 2^64 entries")
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,7 +142,7 @@ impl SmtLeaf {
|
||||
SmtLeaf::Multiple(kvs) => {
|
||||
let elements: Vec<Felt> = kvs.iter().copied().flat_map(kv_to_elements).collect();
|
||||
Rpo256::hash_elements(&elements)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -184,7 +183,8 @@ impl SmtLeaf {
|
||||
// HELPERS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the value associated with `key` in the leaf, or `None` if `key` maps to another leaf.
|
||||
/// Returns the value associated with `key` in the leaf, or `None` if `key` maps to another
|
||||
/// leaf.
|
||||
pub(super) fn get_value(&self, key: &RpoDigest) -> Option<Word> {
|
||||
// Ensure that `key` maps to this leaf
|
||||
if self.index() != key.into() {
|
||||
@@ -199,7 +199,7 @@ impl SmtLeaf {
|
||||
} else {
|
||||
Some(EMPTY_WORD)
|
||||
}
|
||||
}
|
||||
},
|
||||
SmtLeaf::Multiple(kv_pairs) => {
|
||||
for (key_in_leaf, value_in_leaf) in kv_pairs {
|
||||
if key == key_in_leaf {
|
||||
@@ -208,7 +208,7 @@ impl SmtLeaf {
|
||||
}
|
||||
|
||||
Some(EMPTY_WORD)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,7 +221,7 @@ impl SmtLeaf {
|
||||
SmtLeaf::Empty(_) => {
|
||||
*self = SmtLeaf::new_single(key, value);
|
||||
None
|
||||
}
|
||||
},
|
||||
SmtLeaf::Single(kv_pair) => {
|
||||
if kv_pair.0 == key {
|
||||
// the key is already in this leaf. Update the value and return the previous
|
||||
@@ -239,7 +239,7 @@ impl SmtLeaf {
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
},
|
||||
SmtLeaf::Multiple(kv_pairs) => {
|
||||
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
|
||||
Ok(pos) => {
|
||||
@@ -247,14 +247,14 @@ impl SmtLeaf {
|
||||
kv_pairs[pos].1 = value;
|
||||
|
||||
Some(old_value)
|
||||
}
|
||||
},
|
||||
Err(pos) => {
|
||||
kv_pairs.insert(pos, (key, value));
|
||||
|
||||
None
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,7 +279,7 @@ impl SmtLeaf {
|
||||
// another key is stored at leaf; nothing to update
|
||||
(None, false)
|
||||
}
|
||||
}
|
||||
},
|
||||
SmtLeaf::Multiple(kv_pairs) => {
|
||||
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
|
||||
Ok(pos) => {
|
||||
@@ -294,13 +294,13 @@ impl SmtLeaf {
|
||||
}
|
||||
|
||||
(Some(old_value), false)
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
// other keys are stored at leaf; nothing to update
|
||||
(None, false)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -351,7 +351,7 @@ impl Deserializable for SmtLeaf {
|
||||
// ================================================================================================
|
||||
|
||||
/// Converts a key-value tuple to an iterator of `Felt`s
|
||||
fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt> {
|
||||
pub(crate) fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt> {
|
||||
let key_elements = key.into_iter();
|
||||
let value_elements = value.into_iter();
|
||||
|
||||
@@ -360,7 +360,7 @@ fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt>
|
||||
|
||||
/// Compares two keys, compared element-by-element using their integer representations starting with
|
||||
/// the most significant element.
|
||||
fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering {
|
||||
pub(crate) fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering {
|
||||
for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() {
|
||||
let v1 = v1.as_int();
|
||||
let v2 = v2.as_int();
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
use alloc::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
string::ToString,
|
||||
vec::Vec,
|
||||
};
|
||||
|
||||
use super::{
|
||||
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath,
|
||||
NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
|
||||
MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
|
||||
};
|
||||
use crate::utils::collections::*;
|
||||
|
||||
mod error;
|
||||
pub use error::{SmtLeafError, SmtProofError};
|
||||
@@ -12,6 +17,7 @@ pub use leaf::SmtLeaf;
|
||||
|
||||
mod proof;
|
||||
pub use proof::SmtProof;
|
||||
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
@@ -27,8 +33,8 @@ pub const SMT_DEPTH: u8 = 64;
|
||||
/// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented
|
||||
/// by 4 field elements.
|
||||
///
|
||||
/// All leaves sit at depth 64. The most significant element of the key is used to identify the leaf to
|
||||
/// which the key maps.
|
||||
/// All leaves sit at depth 64. The most significant element of the key is used to identify the leaf
|
||||
/// to which the key maps.
|
||||
///
|
||||
/// A leaf is either empty, or holds one or more key-value pairs. An empty leaf hashes to the empty
|
||||
/// word. Otherwise, a leaf hashes to the hash of its key-value pairs, ordered by key first, value
|
||||
@@ -65,12 +71,51 @@ impl Smt {
|
||||
|
||||
/// Returns a new [Smt] instantiated with leaves set as specified by the provided entries.
|
||||
///
|
||||
/// If the `concurrent` feature is enabled, this function uses a parallel implementation to
|
||||
/// process the entries efficiently, otherwise it defaults to the sequential implementation.
|
||||
///
|
||||
/// All leaves omitted from the entries list are set to [Self::EMPTY_VALUE].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the provided entries contain multiple values for the same key.
|
||||
pub fn with_entries(
|
||||
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
#[cfg(feature = "concurrent")]
|
||||
{
|
||||
let mut seen_keys = BTreeSet::new();
|
||||
let entries: Vec<_> = entries
|
||||
.into_iter()
|
||||
.map(|(key, value)| {
|
||||
if seen_keys.insert(key) {
|
||||
Ok((key, value))
|
||||
} else {
|
||||
Err(MerkleError::DuplicateValuesForIndex(
|
||||
LeafIndex::<SMT_DEPTH>::from(key).value(),
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
if entries.is_empty() {
|
||||
return Ok(Self::default());
|
||||
}
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::with_entries_par(entries)
|
||||
}
|
||||
#[cfg(not(feature = "concurrent"))]
|
||||
{
|
||||
Self::with_entries_sequential(entries)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new [Smt] instantiated with leaves set as specified by the provided entries.
|
||||
///
|
||||
/// This sequential implementation processes entries one at a time to build the tree.
|
||||
/// All leaves omitted from the entries list are set to [Self::EMPTY_VALUE].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the provided entries contain multiple values for the same key.
|
||||
pub fn with_entries_sequential(
|
||||
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
// create an empty tree
|
||||
let mut tree = Self::new();
|
||||
@@ -95,6 +140,23 @@ impl Smt {
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Returns a new [`Smt`] instantiated from already computed leaves and nodes.
|
||||
///
|
||||
/// This function performs minimal consistency checking. It is the caller's responsibility to
|
||||
/// ensure the passed arguments are correct and consistent with each other.
|
||||
///
|
||||
/// # Panics
|
||||
/// With debug assertions on, this function panics if `root` does not match the root node in
|
||||
/// `inner_nodes`.
|
||||
pub fn from_raw_parts(
|
||||
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
|
||||
leaves: BTreeMap<u64, SmtLeaf>,
|
||||
root: RpoDigest,
|
||||
) -> Self {
|
||||
// Our particular implementation of `from_raw_parts()` never returns `Err`.
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
@@ -115,12 +177,7 @@ impl Smt {
|
||||
|
||||
/// Returns the value associated with `key`
|
||||
pub fn get_value(&self, key: &RpoDigest) -> Word {
|
||||
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
|
||||
|
||||
match self.leaves.get(&leaf_pos) {
|
||||
Some(leaf) => leaf.get_value(key).unwrap_or_default(),
|
||||
None => EMPTY_WORD,
|
||||
}
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::get_value(self, key)
|
||||
}
|
||||
|
||||
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
|
||||
@@ -129,6 +186,12 @@ impl Smt {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::open(self, key)
|
||||
}
|
||||
|
||||
/// Returns a boolean value indicating whether the SMT is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
debug_assert_eq!(self.leaves.is_empty(), self.root == Self::EMPTY_ROOT);
|
||||
self.root == Self::EMPTY_ROOT
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
@@ -166,6 +229,47 @@ impl Smt {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::insert(self, key, value)
|
||||
}
|
||||
|
||||
/// Computes what changes are necessary to insert the specified key-value pairs into this Merkle
|
||||
/// tree, allowing for validation before applying those changes.
|
||||
///
|
||||
/// This method returns a [`MutationSet`], which contains all the information for inserting
|
||||
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
|
||||
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
|
||||
/// [`Smt::apply_mutations()`] can be called in order to commit these changes to the Merkle
|
||||
/// tree, or [`drop()`] to discard them.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word};
|
||||
/// # use miden_crypto::merkle::{Smt, EmptySubtreeRoots, SMT_DEPTH};
|
||||
/// let mut smt = Smt::new();
|
||||
/// let pair = (RpoDigest::default(), Word::default());
|
||||
/// let mutations = smt.compute_mutations(vec![pair]);
|
||||
/// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
|
||||
/// smt.apply_mutations(mutations);
|
||||
/// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
|
||||
/// ```
|
||||
pub fn compute_mutations(
|
||||
&self,
|
||||
kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
|
||||
}
|
||||
|
||||
/// Apply the prospective mutations computed with [`Smt::compute_mutations()`] to this tree.
|
||||
///
|
||||
/// # Errors
|
||||
/// If `mutations` was computed on a tree with a different root than this one, returns
|
||||
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
|
||||
/// the `mutations` were computed against, and the second item is the actual current root of
|
||||
/// this tree.
|
||||
pub fn apply_mutations(
|
||||
&mut self,
|
||||
mutations: MutationSet<SMT_DEPTH, RpoDigest, Word>,
|
||||
) -> Result<(), MerkleError> {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations(self, mutations)
|
||||
}
|
||||
|
||||
// HELPERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
@@ -182,7 +286,7 @@ impl Smt {
|
||||
self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
|
||||
|
||||
None
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,6 +314,20 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
|
||||
type Opening = SmtProof;
|
||||
|
||||
const EMPTY_VALUE: Self::Value = EMPTY_WORD;
|
||||
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
|
||||
|
||||
fn from_raw_parts(
|
||||
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
|
||||
leaves: BTreeMap<u64, SmtLeaf>,
|
||||
root: RpoDigest,
|
||||
) -> Result<Self, MerkleError> {
|
||||
if cfg!(debug_assertions) {
|
||||
let root_node = inner_nodes.get(&NodeIndex::root()).unwrap();
|
||||
assert_eq!(root_node.hash(), root);
|
||||
}
|
||||
|
||||
Ok(Self { root, inner_nodes, leaves })
|
||||
}
|
||||
|
||||
fn root(&self) -> RpoDigest {
|
||||
self.root
|
||||
@@ -220,11 +338,10 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
|
||||
}
|
||||
|
||||
fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
|
||||
self.inner_nodes.get(&index).cloned().unwrap_or_else(|| {
|
||||
let node = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth() + 1);
|
||||
|
||||
InnerNode { left: *node, right: *node }
|
||||
})
|
||||
self.inner_nodes
|
||||
.get(&index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()))
|
||||
}
|
||||
|
||||
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
|
||||
@@ -244,6 +361,15 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_value(&self, key: &Self::Key) -> Self::Value {
|
||||
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
|
||||
|
||||
match self.leaves.get(&leaf_pos) {
|
||||
Some(leaf) => leaf.get_value(key).unwrap_or_default(),
|
||||
None => EMPTY_WORD,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_leaf(&self, key: &RpoDigest) -> Self::Leaf {
|
||||
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
|
||||
|
||||
@@ -257,6 +383,28 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
|
||||
leaf.hash()
|
||||
}
|
||||
|
||||
fn construct_prospective_leaf(
|
||||
&self,
|
||||
mut existing_leaf: SmtLeaf,
|
||||
key: &RpoDigest,
|
||||
value: &Word,
|
||||
) -> SmtLeaf {
|
||||
debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));
|
||||
|
||||
match existing_leaf {
|
||||
SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value),
|
||||
_ => {
|
||||
if *value != EMPTY_WORD {
|
||||
existing_leaf.insert(*key, *value);
|
||||
} else {
|
||||
existing_leaf.remove(*key);
|
||||
}
|
||||
|
||||
existing_leaf
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex<SMT_DEPTH> {
|
||||
let most_significant_felt = key[3];
|
||||
LeafIndex::new_max_depth(most_significant_felt.as_int())
|
||||
@@ -265,6 +413,23 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
|
||||
fn path_and_leaf_to_opening(path: MerklePath, leaf: SmtLeaf) -> SmtProof {
|
||||
SmtProof::new_unchecked(path, leaf)
|
||||
}
|
||||
|
||||
fn pairs_to_leaf(mut pairs: Vec<(RpoDigest, Word)>) -> SmtLeaf {
|
||||
assert!(!pairs.is_empty());
|
||||
|
||||
if pairs.len() > 1 {
|
||||
SmtLeaf::new_multiple(pairs).unwrap()
|
||||
} else {
|
||||
let (key, value) = pairs.pop().unwrap();
|
||||
// TODO: should we ever be constructing empty leaves from pairs?
|
||||
if value == Self::EMPTY_VALUE {
|
||||
let index = Self::key_to_leaf_index(&key);
|
||||
SmtLeaf::new_empty(index)
|
||||
} else {
|
||||
SmtLeaf::new_single(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Smt {
|
||||
@@ -294,3 +459,70 @@ impl From<&RpoDigest> for LeafIndex<SMT_DEPTH> {
|
||||
Word::from(value).into()
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for Smt {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
// Write the number of filled leaves for this Smt
|
||||
target.write_usize(self.entries().count());
|
||||
|
||||
// Write each (key, value) pair
|
||||
for (key, value) in self.entries() {
|
||||
target.write(key);
|
||||
target.write(value);
|
||||
}
|
||||
}
|
||||
|
||||
fn get_size_hint(&self) -> usize {
|
||||
let entries_count = self.entries().count();
|
||||
|
||||
// Each entry is the size of a digest plus a word.
|
||||
entries_count.get_size_hint()
|
||||
+ entries_count * (RpoDigest::SERIALIZED_SIZE + EMPTY_WORD.get_size_hint())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for Smt {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
// Read the number of filled leaves for this Smt
|
||||
let num_filled_leaves = source.read_usize()?;
|
||||
let mut entries = Vec::with_capacity(num_filled_leaves);
|
||||
|
||||
for _ in 0..num_filled_leaves {
|
||||
let key = source.read()?;
|
||||
let value = source.read()?;
|
||||
entries.push((key, value));
|
||||
}
|
||||
|
||||
Self::with_entries(entries)
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smt_serialization_deserialization() {
|
||||
// Smt for default types (empty map)
|
||||
let smt_default = Smt::default();
|
||||
let bytes = smt_default.to_bytes();
|
||||
assert_eq!(smt_default, Smt::read_from_bytes(&bytes).unwrap());
|
||||
assert_eq!(bytes.len(), smt_default.get_size_hint());
|
||||
|
||||
// Smt with values
|
||||
let smt_leaves_2: [(RpoDigest, Word); 2] = [
|
||||
(
|
||||
RpoDigest::new([Felt::new(101), Felt::new(102), Felt::new(103), Felt::new(104)]),
|
||||
[Felt::new(1_u64), Felt::new(2_u64), Felt::new(3_u64), Felt::new(4_u64)],
|
||||
),
|
||||
(
|
||||
RpoDigest::new([Felt::new(105), Felt::new(106), Felt::new(107), Felt::new(108)]),
|
||||
[Felt::new(5_u64), Felt::new(6_u64), Felt::new(7_u64), Felt::new(8_u64)],
|
||||
),
|
||||
];
|
||||
let smt = Smt::with_entries(smt_leaves_2).unwrap();
|
||||
|
||||
let bytes = smt.to_bytes();
|
||||
assert_eq!(smt, Smt::read_from_bytes(&bytes).unwrap());
|
||||
assert_eq!(bytes.len(), smt.get_size_hint());
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use alloc::string::ToString;
|
||||
|
||||
use super::{MerklePath, RpoDigest, SmtLeaf, SmtProofError, Word, SMT_DEPTH};
|
||||
use crate::utils::{
|
||||
string::*, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
|
||||
};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
|
||||
/// A proof which can be used to assert membership (or non-membership) of key-value pairs in a
|
||||
/// [`super::Smt`].
|
||||
@@ -25,7 +25,7 @@ impl SmtProof {
|
||||
pub fn new(path: MerklePath, leaf: SmtLeaf) -> Result<Self, SmtProofError> {
|
||||
let depth: usize = SMT_DEPTH.into();
|
||||
if path.len() != depth {
|
||||
return Err(SmtProofError::InvalidPathLength(path.len()));
|
||||
return Err(SmtProofError::InvalidMerklePathLength(path.len()));
|
||||
}
|
||||
|
||||
Ok(Self { path, leaf })
|
||||
@@ -58,7 +58,7 @@ impl SmtProof {
|
||||
|
||||
// make sure the Merkle path resolves to the correct root
|
||||
self.compute_root() == *root
|
||||
}
|
||||
},
|
||||
// If the key maps to a different leaf, the proof cannot verify membership of `value`
|
||||
None => false,
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
|
||||
use crate::{
|
||||
merkle::{EmptySubtreeRoots, MerkleStore},
|
||||
utils::{collections::*, Deserializable, Serializable},
|
||||
merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleStore},
|
||||
utils::{Deserializable, Serializable},
|
||||
Word, ONE, WORD_SIZE,
|
||||
};
|
||||
|
||||
@@ -256,6 +258,195 @@ fn test_smt_removal() {
|
||||
}
|
||||
}
|
||||
|
||||
/// This tests that we can correctly calculate prospective leaves -- that is, we can construct
|
||||
/// correct [`SmtLeaf`] values for a theoretical insertion on a Merkle tree without mutating or
|
||||
/// cloning the tree.
|
||||
#[test]
|
||||
fn test_prospective_hash() {
|
||||
let mut smt = Smt::default();
|
||||
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
|
||||
// Sort key_3 before key_1, to test non-append insertion.
|
||||
let key_3: RpoDigest =
|
||||
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
|
||||
|
||||
// insert key-value 1
|
||||
{
|
||||
let prospective =
|
||||
smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &value_1).hash();
|
||||
smt.insert(key_1, value_1);
|
||||
|
||||
let leaf = smt.get_leaf(&key_1);
|
||||
assert_eq!(
|
||||
prospective,
|
||||
leaf.hash(),
|
||||
"prospective hash for leaf {leaf:?} did not match actual hash",
|
||||
);
|
||||
}
|
||||
|
||||
// insert key-value 2
|
||||
{
|
||||
let prospective =
|
||||
smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &value_2).hash();
|
||||
smt.insert(key_2, value_2);
|
||||
|
||||
let leaf = smt.get_leaf(&key_2);
|
||||
assert_eq!(
|
||||
prospective,
|
||||
leaf.hash(),
|
||||
"prospective hash for leaf {leaf:?} did not match actual hash",
|
||||
);
|
||||
}
|
||||
|
||||
// insert key-value 3
|
||||
{
|
||||
let prospective =
|
||||
smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &value_3).hash();
|
||||
smt.insert(key_3, value_3);
|
||||
|
||||
let leaf = smt.get_leaf(&key_3);
|
||||
assert_eq!(
|
||||
prospective,
|
||||
leaf.hash(),
|
||||
"prospective hash for leaf {leaf:?} did not match actual hash",
|
||||
);
|
||||
}
|
||||
|
||||
// remove key 3
|
||||
{
|
||||
let old_leaf = smt.get_leaf(&key_3);
|
||||
let old_value_3 = smt.insert(key_3, EMPTY_WORD);
|
||||
assert_eq!(old_value_3, value_3);
|
||||
let prospective_leaf =
|
||||
smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &old_value_3);
|
||||
|
||||
assert_eq!(
|
||||
old_leaf.hash(),
|
||||
prospective_leaf.hash(),
|
||||
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
|
||||
\n original leaf: {old_leaf:?}\
|
||||
\n prospective leaf: {prospective_leaf:?}",
|
||||
);
|
||||
}
|
||||
|
||||
// remove key 2
|
||||
{
|
||||
let old_leaf = smt.get_leaf(&key_2);
|
||||
let old_value_2 = smt.insert(key_2, EMPTY_WORD);
|
||||
assert_eq!(old_value_2, value_2);
|
||||
let prospective_leaf =
|
||||
smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &old_value_2);
|
||||
|
||||
assert_eq!(
|
||||
old_leaf.hash(),
|
||||
prospective_leaf.hash(),
|
||||
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
|
||||
\n original leaf: {old_leaf:?}\
|
||||
\n prospective leaf: {prospective_leaf:?}",
|
||||
);
|
||||
}
|
||||
|
||||
// remove key 1
|
||||
{
|
||||
let old_leaf = smt.get_leaf(&key_1);
|
||||
let old_value_1 = smt.insert(key_1, EMPTY_WORD);
|
||||
assert_eq!(old_value_1, value_1);
|
||||
let prospective_leaf =
|
||||
smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &old_value_1);
|
||||
assert_eq!(
|
||||
old_leaf.hash(),
|
||||
prospective_leaf.hash(),
|
||||
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
|
||||
\n original leaf: {old_leaf:?}\
|
||||
\n prospective leaf: {prospective_leaf:?}",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// This tests that we can perform prospective changes correctly.
|
||||
#[test]
|
||||
fn test_prospective_insertion() {
|
||||
let mut smt = Smt::default();
|
||||
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
|
||||
// Sort key_3 before key_1, to test non-append insertion.
|
||||
let key_3: RpoDigest =
|
||||
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
|
||||
|
||||
let root_empty = smt.root();
|
||||
|
||||
let root_1 = {
|
||||
smt.insert(key_1, value_1);
|
||||
smt.root()
|
||||
};
|
||||
|
||||
let root_2 = {
|
||||
smt.insert(key_2, value_2);
|
||||
smt.root()
|
||||
};
|
||||
|
||||
let root_3 = {
|
||||
smt.insert(key_3, value_3);
|
||||
smt.root()
|
||||
};
|
||||
|
||||
// Test incremental updates.
|
||||
|
||||
let mut smt = Smt::default();
|
||||
|
||||
let mutations = smt.compute_mutations(vec![(key_1, value_1)]);
|
||||
assert_eq!(mutations.root(), root_1, "prospective root 1 did not match actual root 1");
|
||||
smt.apply_mutations(mutations).unwrap();
|
||||
assert_eq!(smt.root(), root_1, "mutations before and after apply did not match");
|
||||
|
||||
let mutations = smt.compute_mutations(vec![(key_2, value_2)]);
|
||||
assert_eq!(mutations.root(), root_2, "prospective root 2 did not match actual root 2");
|
||||
let mutations =
|
||||
smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_2, value_2), (key_3, value_3)]);
|
||||
assert_eq!(mutations.root(), root_3, "mutations before and after apply did not match");
|
||||
smt.apply_mutations(mutations).unwrap();
|
||||
|
||||
// Edge case: multiple values at the same key, where a later pair restores the original value.
|
||||
let mutations = smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_3, value_3)]);
|
||||
assert_eq!(mutations.root(), root_3);
|
||||
smt.apply_mutations(mutations).unwrap();
|
||||
assert_eq!(smt.root(), root_3);
|
||||
|
||||
// Test batch updates, and that the order doesn't matter.
|
||||
let pairs =
|
||||
vec![(key_3, value_2), (key_2, EMPTY_WORD), (key_1, EMPTY_WORD), (key_3, EMPTY_WORD)];
|
||||
let mutations = smt.compute_mutations(pairs);
|
||||
assert_eq!(
|
||||
mutations.root(),
|
||||
root_empty,
|
||||
"prospective root for batch removal did not match actual root",
|
||||
);
|
||||
smt.apply_mutations(mutations).unwrap();
|
||||
assert_eq!(smt.root(), root_empty, "mutations before and after apply did not match");
|
||||
|
||||
let pairs = vec![(key_3, value_3), (key_1, value_1), (key_2, value_2)];
|
||||
let mutations = smt.compute_mutations(pairs);
|
||||
assert_eq!(mutations.root(), root_3);
|
||||
smt.apply_mutations(mutations).unwrap();
|
||||
assert_eq!(smt.root(), root_3);
|
||||
}
|
||||
|
||||
/// Tests that 2 key-value pairs stored in the same leaf have the same path
|
||||
#[test]
|
||||
fn test_smt_path_to_keys_in_same_leaf_are_equal() {
|
||||
@@ -286,8 +477,7 @@ fn test_empty_leaf_hash() {
|
||||
#[test]
|
||||
fn test_smt_get_value() {
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), 2_u32.into()]);
|
||||
let key_2: RpoDigest = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
@@ -301,8 +491,7 @@ fn test_smt_get_value() {
|
||||
assert_eq!(value_2, returned_value_2);
|
||||
|
||||
// Check that a key with no inserted value returns the empty word
|
||||
let key_no_value =
|
||||
RpoDigest::from([42_u32.into(), 42_u32.into(), 42_u32.into(), 42_u32.into()]);
|
||||
let key_no_value = RpoDigest::from([42_u32, 42_u32, 42_u32, 42_u32]);
|
||||
|
||||
assert_eq!(EMPTY_WORD, smt.get_value(&key_no_value));
|
||||
}
|
||||
@@ -311,8 +500,7 @@ fn test_smt_get_value() {
|
||||
#[test]
|
||||
fn test_smt_entries() {
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), 2_u32.into()]);
|
||||
let key_2: RpoDigest = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
@@ -328,6 +516,16 @@ fn test_smt_entries() {
|
||||
assert!(entries.next().is_none());
|
||||
}
|
||||
|
||||
/// Tests that `EMPTY_ROOT` constant generated in the `Smt` equals to the root of the empty tree of
|
||||
/// depth 64
|
||||
#[test]
|
||||
fn test_smt_check_empty_root_constant() {
|
||||
// get the root of the empty tree of depth 64
|
||||
let empty_root_64_depth = EmptySubtreeRoots::empty_hashes(64)[0];
|
||||
|
||||
assert_eq!(empty_root_64_depth, Smt::EMPTY_ROOT);
|
||||
}
|
||||
|
||||
// SMT LEAF
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
@@ -346,7 +544,7 @@ fn test_empty_smt_leaf_serialization() {
|
||||
#[test]
|
||||
fn test_single_smt_leaf_serialization() {
|
||||
let single_leaf = SmtLeaf::new_single(
|
||||
RpoDigest::from([10_u32.into(), 11_u32.into(), 12_u32.into(), 13_u32.into()]),
|
||||
RpoDigest::from([10_u32, 11_u32, 12_u32, 13_u32]),
|
||||
[1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()],
|
||||
);
|
||||
|
||||
@@ -362,11 +560,11 @@ fn test_single_smt_leaf_serialization() {
|
||||
fn test_multiple_smt_leaf_serialization_success() {
|
||||
let multiple_leaf = SmtLeaf::new_multiple(vec![
|
||||
(
|
||||
RpoDigest::from([10_u32.into(), 11_u32.into(), 12_u32.into(), 13_u32.into()]),
|
||||
RpoDigest::from([10_u32, 11_u32, 12_u32, 13_u32]),
|
||||
[1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()],
|
||||
),
|
||||
(
|
||||
RpoDigest::from([100_u32.into(), 101_u32.into(), 102_u32.into(), 13_u32.into()]),
|
||||
RpoDigest::from([100_u32, 101_u32, 102_u32, 13_u32]),
|
||||
[11_u32.into(), 12_u32.into(), 13_u32.into(), 14_u32.into()],
|
||||
),
|
||||
])
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
use alloc::{collections::BTreeMap, vec::Vec};
|
||||
use core::mem;
|
||||
|
||||
use num::Integer;
|
||||
|
||||
use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
|
||||
use crate::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
utils::collections::*,
|
||||
Felt, Word, EMPTY_WORD,
|
||||
};
|
||||
|
||||
@@ -44,20 +48,34 @@ pub const SMT_MAX_DEPTH: u8 = 64;
|
||||
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
|
||||
pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
|
||||
/// The type for a key
|
||||
type Key: Clone;
|
||||
type Key: Clone + Ord;
|
||||
/// The type for a value
|
||||
type Value: Clone + PartialEq;
|
||||
/// The type for a leaf
|
||||
type Leaf;
|
||||
type Leaf: Clone;
|
||||
/// The type for an opening (i.e. a "proof") of a leaf
|
||||
type Opening;
|
||||
|
||||
/// The default value used to compute the hash of empty leaves
|
||||
const EMPTY_VALUE: Self::Value;
|
||||
|
||||
/// The root of the empty tree with provided DEPTH
|
||||
const EMPTY_ROOT: RpoDigest;
|
||||
|
||||
// PROVIDED METHODS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Creates a new sparse Merkle tree from an existing set of key-value pairs, in parallel.
|
||||
#[cfg(feature = "concurrent")]
|
||||
fn with_entries_par(entries: Vec<(Self::Key, Self::Value)>) -> Result<Self, MerkleError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let (inner_nodes, leaves) = Self::build_subtrees(entries);
|
||||
let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash();
|
||||
Self::from_raw_parts(inner_nodes, leaves, root)
|
||||
}
|
||||
|
||||
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
|
||||
/// path to the leaf, as well as the leaf itself.
|
||||
fn open(&self, key: &Self::Key) -> Self::Opening {
|
||||
@@ -139,9 +157,165 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
|
||||
self.set_root(node_hash);
|
||||
}
|
||||
|
||||
/// Computes what changes are necessary to insert the specified key-value pairs into this Merkle
|
||||
/// tree, allowing for validation before applying those changes.
|
||||
///
|
||||
/// This method returns a [`MutationSet`], which contains all the information for inserting
|
||||
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
|
||||
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
|
||||
/// [`SparseMerkleTree::apply_mutations()`] can be called in order to commit these changes to
|
||||
/// the Merkle tree, or [`drop()`] to discard them.
|
||||
fn compute_mutations(
|
||||
&self,
|
||||
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
|
||||
) -> MutationSet<DEPTH, Self::Key, Self::Value> {
|
||||
use NodeMutation::*;
|
||||
|
||||
let mut new_root = self.root();
|
||||
let mut new_pairs: BTreeMap<Self::Key, Self::Value> = Default::default();
|
||||
let mut node_mutations: BTreeMap<NodeIndex, NodeMutation> = Default::default();
|
||||
|
||||
for (key, value) in kv_pairs {
|
||||
// If the old value and the new value are the same, there is nothing to update.
|
||||
// For the unusual case that kv_pairs has multiple values at the same key, we'll have
|
||||
// to check the key-value pairs we've already seen to get the "effective" old value.
|
||||
let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
|
||||
if value == old_value {
|
||||
continue;
|
||||
}
|
||||
|
||||
let leaf_index = Self::key_to_leaf_index(&key);
|
||||
let mut node_index = NodeIndex::from(leaf_index);
|
||||
|
||||
// We need the current leaf's hash to calculate the new leaf, but in the rare case that
|
||||
// `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also
|
||||
// part of the "current leaf".
|
||||
let old_leaf = {
|
||||
let pairs_at_index = new_pairs
|
||||
.iter()
|
||||
.filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index);
|
||||
|
||||
pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| {
|
||||
// Most of the time `pairs_at_index` should only contain a single entry (or
|
||||
// none at all), as multi-leaves should be really rare.
|
||||
let existing_leaf = acc.clone();
|
||||
self.construct_prospective_leaf(existing_leaf, k, v)
|
||||
})
|
||||
};
|
||||
|
||||
let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value);
|
||||
|
||||
let mut new_child_hash = Self::hash_leaf(&new_leaf);
|
||||
|
||||
for node_depth in (0..node_index.depth()).rev() {
|
||||
// Whether the node we're replacing is the right child or the left child.
|
||||
let is_right = node_index.is_value_odd();
|
||||
node_index.move_up();
|
||||
|
||||
let old_node = node_mutations
|
||||
.get(&node_index)
|
||||
.map(|mutation| match mutation {
|
||||
Addition(node) => node.clone(),
|
||||
Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth),
|
||||
})
|
||||
.unwrap_or_else(|| self.get_inner_node(node_index));
|
||||
|
||||
let new_node = if is_right {
|
||||
InnerNode {
|
||||
left: old_node.left,
|
||||
right: new_child_hash,
|
||||
}
|
||||
} else {
|
||||
InnerNode {
|
||||
left: new_child_hash,
|
||||
right: old_node.right,
|
||||
}
|
||||
};
|
||||
|
||||
// The next iteration will operate on this new node's hash.
|
||||
new_child_hash = new_node.hash();
|
||||
|
||||
let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth);
|
||||
let is_removal = new_child_hash == equivalent_empty_hash;
|
||||
let new_entry = if is_removal { Removal } else { Addition(new_node) };
|
||||
node_mutations.insert(node_index, new_entry);
|
||||
}
|
||||
|
||||
// Once we're at depth 0, the last node we made is the new root.
|
||||
new_root = new_child_hash;
|
||||
// And then we're done with this pair; on to the next one.
|
||||
new_pairs.insert(key, value);
|
||||
}
|
||||
|
||||
MutationSet {
|
||||
old_root: self.root(),
|
||||
new_root,
|
||||
node_mutations,
|
||||
new_pairs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
|
||||
/// this tree.
|
||||
///
|
||||
/// # Errors
|
||||
/// If `mutations` was computed on a tree with a different root than this one, returns
|
||||
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
|
||||
/// the `mutations` were computed against, and the second item is the actual current root of
|
||||
/// this tree.
|
||||
fn apply_mutations(
|
||||
&mut self,
|
||||
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
|
||||
) -> Result<(), MerkleError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
use NodeMutation::*;
|
||||
let MutationSet {
|
||||
old_root,
|
||||
node_mutations,
|
||||
new_pairs,
|
||||
new_root,
|
||||
} = mutations;
|
||||
|
||||
// Guard against accidentally trying to apply mutations that were computed against a
|
||||
// different tree, including a stale version of this tree.
|
||||
if old_root != self.root() {
|
||||
return Err(MerkleError::ConflictingRoots {
|
||||
expected_root: self.root(),
|
||||
actual_root: old_root,
|
||||
});
|
||||
}
|
||||
|
||||
for (index, mutation) in node_mutations {
|
||||
match mutation {
|
||||
Removal => self.remove_inner_node(index),
|
||||
Addition(node) => self.insert_inner_node(index, node),
|
||||
}
|
||||
}
|
||||
|
||||
for (key, value) in new_pairs {
|
||||
self.insert_value(key, value);
|
||||
}
|
||||
|
||||
self.set_root(new_root);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// REQUIRED METHODS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Construct this type from already computed leaves and nodes. The caller ensures passed
|
||||
/// arguments are correct and consistent with each other.
|
||||
fn from_raw_parts(
|
||||
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
|
||||
leaves: BTreeMap<u64, Self::Leaf>,
|
||||
root: RpoDigest,
|
||||
) -> Result<Self, MerkleError>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
/// The root of the tree
|
||||
fn root(&self) -> RpoDigest;
|
||||
|
||||
@@ -160,27 +334,168 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
|
||||
/// Inserts a leaf node, and returns the value at the key if already exists
|
||||
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>;
|
||||
|
||||
/// Returns the value at the specified key. Recall that by definition, any key that hasn't been
|
||||
/// updated is associated with [`Self::EMPTY_VALUE`].
|
||||
fn get_value(&self, key: &Self::Key) -> Self::Value;
|
||||
|
||||
/// Returns the leaf at the specified index.
|
||||
fn get_leaf(&self, key: &Self::Key) -> Self::Leaf;
|
||||
|
||||
/// Returns the hash of a leaf
|
||||
fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest;
|
||||
|
||||
/// Returns what a leaf would look like if a key-value pair were inserted into the tree, without
|
||||
/// mutating the tree itself. The existing leaf can be empty.
|
||||
///
|
||||
/// To get a prospective leaf based on the current state of the tree, use `self.get_leaf(key)`
|
||||
/// as the argument for `existing_leaf`. The return value from this function can be chained back
|
||||
/// into this function as the first argument to continue making prospective changes.
|
||||
///
|
||||
/// # Invariants
|
||||
/// Because this method is for a prospective key-value insertion into a specific leaf,
|
||||
/// `existing_leaf` must have the same leaf index as `key` (as determined by
|
||||
/// [`SparseMerkleTree::key_to_leaf_index()`]), or the result will be meaningless.
|
||||
fn construct_prospective_leaf(
|
||||
&self,
|
||||
existing_leaf: Self::Leaf,
|
||||
key: &Self::Key,
|
||||
value: &Self::Value,
|
||||
) -> Self::Leaf;
|
||||
|
||||
/// Maps a key to a leaf index
|
||||
fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
|
||||
|
||||
/// Constructs a single leaf from an arbitrary amount of key-value pairs.
|
||||
/// Those pairs must all have the same leaf index.
|
||||
fn pairs_to_leaf(pairs: Vec<(Self::Key, Self::Value)>) -> Self::Leaf;
|
||||
|
||||
/// Maps a (MerklePath, Self::Leaf) to an opening.
|
||||
///
|
||||
/// The length `path` is guaranteed to be equal to `DEPTH`
|
||||
fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening;
|
||||
|
||||
/// Performs the initial transforms for constructing a [`SparseMerkleTree`] by composing
|
||||
/// subtrees. In other words, this function takes the key-value inputs to the tree, and produces
|
||||
/// the inputs to feed into [`build_subtree()`].
|
||||
///
|
||||
/// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If
|
||||
/// `pairs` is not correctly sorted, the returned computations will be incorrect.
|
||||
///
|
||||
/// # Panics
|
||||
/// With debug assertions on, this function panics if it detects that `pairs` is not correctly
|
||||
/// sorted. Without debug assertions, the returned computations will be incorrect.
|
||||
fn sorted_pairs_to_leaves(
|
||||
pairs: Vec<(Self::Key, Self::Value)>,
|
||||
) -> PairComputations<u64, Self::Leaf> {
|
||||
debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value()));
|
||||
|
||||
let mut accumulator: PairComputations<u64, Self::Leaf> = Default::default();
|
||||
let mut accumulated_leaves: Vec<SubtreeLeaf> = Vec::with_capacity(pairs.len() / 2);
|
||||
|
||||
// As we iterate, we'll keep track of the kv-pairs we've seen so far that correspond to a
|
||||
// single leaf. When we see a pair that's in a different leaf, we'll swap these pairs
|
||||
// out and store them in our accumulated leaves.
|
||||
let mut current_leaf_buffer: Vec<(Self::Key, Self::Value)> = Default::default();
|
||||
|
||||
let mut iter = pairs.into_iter().peekable();
|
||||
while let Some((key, value)) = iter.next() {
|
||||
let col = Self::key_to_leaf_index(&key).index.value();
|
||||
let peeked_col = iter.peek().map(|(key, _v)| {
|
||||
let index = Self::key_to_leaf_index(key);
|
||||
let next_col = index.index.value();
|
||||
// We panic if `pairs` is not sorted by column.
|
||||
debug_assert!(next_col >= col);
|
||||
next_col
|
||||
});
|
||||
current_leaf_buffer.push((key, value));
|
||||
|
||||
// If the next pair is the same column as this one, then we're done after adding this
|
||||
// pair to the buffer.
|
||||
if peeked_col == Some(col) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, the next pair is a different column, or there is no next pair. Either way
|
||||
// it's time to swap out our buffer.
|
||||
let leaf_pairs = mem::take(&mut current_leaf_buffer);
|
||||
let leaf = Self::pairs_to_leaf(leaf_pairs);
|
||||
let hash = Self::hash_leaf(&leaf);
|
||||
|
||||
accumulator.nodes.insert(col, leaf);
|
||||
accumulated_leaves.push(SubtreeLeaf { col, hash });
|
||||
|
||||
debug_assert!(current_leaf_buffer.is_empty());
|
||||
}
|
||||
|
||||
// TODO: determine is there is any notable performance difference between computing
|
||||
// subtree boundaries after the fact as an iterator adapter (like this), versus computing
|
||||
// subtree boundaries as we go. Either way this function is only used at the beginning of a
|
||||
// parallel construction, so it should not be a critical path.
|
||||
accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect();
|
||||
accumulator
|
||||
}
|
||||
|
||||
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
|
||||
///
|
||||
/// `entries` need not be sorted. This function will sort them.
|
||||
#[cfg(feature = "concurrent")]
|
||||
fn build_subtrees(
|
||||
mut entries: Vec<(Self::Key, Self::Value)>,
|
||||
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<u64, Self::Leaf>) {
|
||||
entries.sort_by_key(|item| {
|
||||
let index = Self::key_to_leaf_index(&item.0);
|
||||
index.value()
|
||||
});
|
||||
Self::build_subtrees_from_sorted_entries(entries)
|
||||
}
|
||||
|
||||
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
|
||||
///
|
||||
/// This function is mostly an implementation detail of
|
||||
/// [`SparseMerkleTree::with_entries_par()`].
|
||||
#[cfg(feature = "concurrent")]
|
||||
fn build_subtrees_from_sorted_entries(
|
||||
entries: Vec<(Self::Key, Self::Value)>,
|
||||
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<u64, Self::Leaf>) {
|
||||
use rayon::prelude::*;
|
||||
|
||||
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
|
||||
|
||||
let PairComputations {
|
||||
leaves: mut leaf_subtrees,
|
||||
nodes: initial_leaves,
|
||||
} = Self::sorted_pairs_to_leaves(entries);
|
||||
|
||||
for current_depth in (SUBTREE_DEPTH..=DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||
let (nodes, mut subtree_roots): (Vec<BTreeMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
|
||||
.into_par_iter()
|
||||
.map(|subtree| {
|
||||
debug_assert!(subtree.is_sorted());
|
||||
debug_assert!(!subtree.is_empty());
|
||||
|
||||
let (nodes, subtree_root) = build_subtree(subtree, DEPTH, current_depth);
|
||||
(nodes, subtree_root)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
||||
accumulated_nodes.extend(nodes.into_iter().flatten());
|
||||
|
||||
debug_assert!(!leaf_subtrees.is_empty());
|
||||
}
|
||||
(accumulated_nodes, initial_leaves)
|
||||
}
|
||||
}
|
||||
|
||||
// INNER NODE
|
||||
// ================================================================================================
|
||||
|
||||
/// This struct is public so functions returning it can be used in `benches/`, but is otherwise not
|
||||
/// part of the public API.
|
||||
#[doc(hidden)]
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub(crate) struct InnerNode {
|
||||
pub struct InnerNode {
|
||||
pub left: RpoDigest,
|
||||
pub right: RpoDigest,
|
||||
}
|
||||
@@ -234,7 +549,7 @@ impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
|
||||
|
||||
fn try_from(node_index: NodeIndex) -> Result<Self, Self::Error> {
|
||||
if node_index.depth() != DEPTH {
|
||||
return Err(MerkleError::InvalidDepth {
|
||||
return Err(MerkleError::InvalidNodeIndexDepth {
|
||||
expected: DEPTH,
|
||||
provided: node_index.depth(),
|
||||
});
|
||||
@@ -243,3 +558,244 @@ impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
|
||||
Self::new(node_index.value())
|
||||
}
|
||||
}
|
||||
|
||||
// MUTATIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// A change to an inner node of a [`SparseMerkleTree`] that hasn't yet been applied.
|
||||
/// [`MutationSet`] stores this type in relation to a [`NodeIndex`] to keep track of what changes
|
||||
/// need to occur at which node indices.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) enum NodeMutation {
|
||||
/// Corresponds to [`SparseMerkleTree::remove_inner_node()`].
|
||||
Removal,
|
||||
/// Corresponds to [`SparseMerkleTree::insert_inner_node()`].
|
||||
Addition(InnerNode),
|
||||
}
|
||||
|
||||
/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by
|
||||
/// `SparseMerkleTree::compute_mutations()`, and that can be applied with
|
||||
/// `SparseMerkleTree::apply_mutations()`.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct MutationSet<const DEPTH: u8, K, V> {
|
||||
/// The root of the Merkle tree this MutationSet is for, recorded at the time
|
||||
/// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying
|
||||
/// mutations to the wrong tree or applying stale mutations to a tree that has since changed.
|
||||
old_root: RpoDigest,
|
||||
/// The set of nodes that need to be removed or added. The "effective" node at an index is the
|
||||
/// Merkle tree's existing node at that index, with the [`NodeMutation`] in this map at that
|
||||
/// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a
|
||||
/// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`]
|
||||
/// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call.
|
||||
node_mutations: BTreeMap<NodeIndex, NodeMutation>,
|
||||
/// The set of top-level key-value pairs we're prospectively adding to the tree, including
|
||||
/// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling
|
||||
/// back to the existing value in the Merkle tree. Each entry corresponds to a
|
||||
/// [`SparseMerkleTree::insert_value()`] call.
|
||||
new_pairs: BTreeMap<K, V>,
|
||||
/// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with
|
||||
/// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call.
|
||||
new_root: RpoDigest,
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
|
||||
/// Queries the root that was calculated during `SparseMerkleTree::compute_mutations()`. See
|
||||
/// that method for more information.
|
||||
pub fn root(&self) -> RpoDigest {
|
||||
self.new_root
|
||||
}
|
||||
}
|
||||
|
||||
// SUBTREES
|
||||
// ================================================================================================
|
||||
/// A subtree is of depth 8.
|
||||
const SUBTREE_DEPTH: u8 = 8;
|
||||
|
||||
/// A depth-8 subtree contains 256 "columns" that can possibly be occupied.
|
||||
const COLS_PER_SUBTREE: u64 = u64::pow(2, SUBTREE_DEPTH as u32);
|
||||
|
||||
/// Helper struct for organizing the data we care about when computing Merkle subtrees.
|
||||
///
|
||||
/// Note that these represet "conceptual" leaves of some subtree, not necessarily
|
||||
/// the leaf type for the sparse Merkle tree.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
|
||||
pub struct SubtreeLeaf {
|
||||
/// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known.
|
||||
pub col: u64,
|
||||
/// The hash of the node this `SubtreeLeaf` represents.
|
||||
pub hash: RpoDigest,
|
||||
}
|
||||
|
||||
/// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`].
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct PairComputations<K, L> {
|
||||
/// Literal leaves to be added to the sparse Merkle tree's internal mapping.
|
||||
pub nodes: BTreeMap<K, L>,
|
||||
/// "Conceptual" leaves that will be used for computations.
|
||||
pub leaves: Vec<Vec<SubtreeLeaf>>,
|
||||
}
|
||||
|
||||
// Derive requires `L` to impl Default, even though we don't actually need that.
|
||||
impl<K, L> Default for PairComputations<K, L> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
nodes: Default::default(),
|
||||
leaves: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SubtreeLeavesIter<'s> {
|
||||
leaves: core::iter::Peekable<alloc::vec::Drain<'s, SubtreeLeaf>>,
|
||||
}
|
||||
impl<'s> SubtreeLeavesIter<'s> {
|
||||
fn from_leaves(leaves: &'s mut Vec<SubtreeLeaf>) -> Self {
|
||||
// TODO: determine if there is any notable performance difference between taking a Vec,
|
||||
// which many need flattening first, vs storing a `Box<dyn Iterator<Item = SubtreeLeaf>>`.
|
||||
// The latter may have self-referential properties that are impossible to express in purely
|
||||
// safe Rust Rust.
|
||||
Self { leaves: leaves.drain(..).peekable() }
|
||||
}
|
||||
}
|
||||
impl core::iter::Iterator for SubtreeLeavesIter<'_> {
|
||||
type Item = Vec<SubtreeLeaf>;
|
||||
|
||||
/// Each `next()` collects an entire subtree.
|
||||
fn next(&mut self) -> Option<Vec<SubtreeLeaf>> {
|
||||
let mut subtree: Vec<SubtreeLeaf> = Default::default();
|
||||
|
||||
let mut last_subtree_col = 0;
|
||||
|
||||
while let Some(leaf) = self.leaves.peek() {
|
||||
last_subtree_col = u64::max(1, last_subtree_col);
|
||||
let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE);
|
||||
let next_subtree_col = if is_exact_multiple {
|
||||
u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE)
|
||||
} else {
|
||||
last_subtree_col.next_multiple_of(COLS_PER_SUBTREE)
|
||||
};
|
||||
|
||||
last_subtree_col = leaf.col;
|
||||
if leaf.col < next_subtree_col {
|
||||
subtree.push(self.leaves.next().unwrap());
|
||||
} else if subtree.is_empty() {
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if subtree.is_empty() {
|
||||
debug_assert!(self.leaves.peek().is_none());
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(subtree)
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and
|
||||
/// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and
|
||||
/// `leaves` must not contain more than one depth-8 subtree's worth of leaves.
|
||||
///
|
||||
/// This function will then calculate the inner nodes above each leaf for 8 layers, as well as
|
||||
/// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into
|
||||
/// itself.
|
||||
///
|
||||
/// # Panics
|
||||
/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains
|
||||
/// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to
|
||||
/// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified
|
||||
/// maximum depth (`DEPTH`), or if `leaves` is not sorted.
|
||||
fn build_subtree(
|
||||
mut leaves: Vec<SubtreeLeaf>,
|
||||
tree_depth: u8,
|
||||
bottom_depth: u8,
|
||||
) -> (BTreeMap<NodeIndex, InnerNode>, SubtreeLeaf) {
|
||||
debug_assert!(bottom_depth <= tree_depth);
|
||||
debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH));
|
||||
debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32));
|
||||
let subtree_root = bottom_depth - SUBTREE_DEPTH;
|
||||
let mut inner_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
|
||||
let mut next_leaves: Vec<SubtreeLeaf> = Vec::with_capacity(leaves.len() / 2);
|
||||
for next_depth in (subtree_root..bottom_depth).rev() {
|
||||
debug_assert!(next_depth <= bottom_depth);
|
||||
// `next_depth` is the stuff we're making.
|
||||
// `current_depth` is the stuff we have.
|
||||
let current_depth = next_depth + 1;
|
||||
let mut iter = leaves.drain(..).peekable();
|
||||
while let Some(first) = iter.next() {
|
||||
// On non-continuous iterations, including the first iteration, `first_column` may
|
||||
// be a left or right node. On subsequent continuous iterations, we will always call
|
||||
// `iter.next()` twice.
|
||||
// On non-continuous iterations (including the very first iteration), this column
|
||||
// could be either on the left or the right. If the next iteration is not
|
||||
// discontinuous with our right node, then the next iteration's
|
||||
let is_right = first.col.is_odd();
|
||||
let (left, right) = if is_right {
|
||||
// Discontinuous iteration: we have no left node, so it must be empty.
|
||||
let left = SubtreeLeaf {
|
||||
col: first.col - 1,
|
||||
hash: *EmptySubtreeRoots::entry(tree_depth, current_depth),
|
||||
};
|
||||
let right = first;
|
||||
(left, right)
|
||||
} else {
|
||||
let left = first;
|
||||
let right_col = first.col + 1;
|
||||
let right = match iter.peek().copied() {
|
||||
Some(SubtreeLeaf { col, .. }) if col == right_col => {
|
||||
// Our inputs must be sorted.
|
||||
debug_assert!(left.col <= col);
|
||||
// The next leaf in the iterator is our sibling. Use it and consume it!
|
||||
iter.next().unwrap()
|
||||
},
|
||||
// Otherwise, the leaves don't contain our sibling, so our sibling must be
|
||||
// empty.
|
||||
_ => SubtreeLeaf {
|
||||
col: right_col,
|
||||
hash: *EmptySubtreeRoots::entry(tree_depth, current_depth),
|
||||
},
|
||||
};
|
||||
(left, right)
|
||||
};
|
||||
let index = NodeIndex::new_unchecked(current_depth, left.col).parent();
|
||||
let node = InnerNode { left: left.hash, right: right.hash };
|
||||
let hash = node.hash();
|
||||
let &equivalent_empty_hash = EmptySubtreeRoots::entry(tree_depth, next_depth);
|
||||
// If this hash is empty, then it doesn't become a new inner node, nor does it count
|
||||
// as a leaf for the next depth.
|
||||
if hash != equivalent_empty_hash {
|
||||
inner_nodes.insert(index, node);
|
||||
next_leaves.push(SubtreeLeaf { col: index.value(), hash });
|
||||
}
|
||||
}
|
||||
// Stop borrowing `leaves`, so we can swap it.
|
||||
// The iterator is empty at this point anyway.
|
||||
drop(iter);
|
||||
// After each depth, consider the stuff we just made the new "leaves", and empty the
|
||||
// other collection.
|
||||
mem::swap(&mut leaves, &mut next_leaves);
|
||||
}
|
||||
debug_assert_eq!(leaves.len(), 1);
|
||||
let root = leaves.pop().unwrap();
|
||||
(inner_nodes, root)
|
||||
}
|
||||
|
||||
#[cfg(feature = "internal")]
|
||||
pub fn build_subtree_for_bench(
|
||||
leaves: Vec<SubtreeLeaf>,
|
||||
tree_depth: u8,
|
||||
bottom_depth: u8,
|
||||
) -> (BTreeMap<NodeIndex, InnerNode>, SubtreeLeaf) {
|
||||
build_subtree(leaves, tree_depth, bottom_depth)
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
use alloc::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
vec::Vec,
|
||||
};
|
||||
|
||||
use super::{
|
||||
super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError,
|
||||
MerklePath, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, SMT_MAX_DEPTH,
|
||||
SMT_MIN_DEPTH,
|
||||
MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
|
||||
SMT_MAX_DEPTH, SMT_MIN_DEPTH,
|
||||
};
|
||||
use crate::utils::collections::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
@@ -80,7 +84,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
|
||||
|
||||
for (idx, (key, value)) in entries.into_iter().enumerate() {
|
||||
if idx >= max_num_entries {
|
||||
return Err(MerkleError::InvalidNumEntries(max_num_entries));
|
||||
return Err(MerkleError::TooManyEntries(max_num_entries));
|
||||
}
|
||||
|
||||
let old_value = tree.insert(LeafIndex::<DEPTH>::new(key)?, value);
|
||||
@@ -96,6 +100,23 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Returns a new [`SimpleSmt`] instantiated from already computed leaves and nodes.
|
||||
///
|
||||
/// This function performs minimal consistency checking. It is the caller's responsibility to
|
||||
/// ensure the passed arguments are correct and consistent with each other.
|
||||
///
|
||||
/// # Panics
|
||||
/// With debug assertions on, this function panics if `root` does not match the root node in
|
||||
/// `inner_nodes`.
|
||||
pub fn from_raw_parts(
|
||||
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
|
||||
leaves: BTreeMap<u64, Word>,
|
||||
root: RpoDigest,
|
||||
) -> Self {
|
||||
// Our particular implementation of `from_raw_parts()` never returns `Err`.
|
||||
<Self as SparseMerkleTree<DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
|
||||
}
|
||||
|
||||
/// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
|
||||
/// starting at index 0.
|
||||
pub fn with_contiguous_leaves(
|
||||
@@ -122,6 +143,11 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
|
||||
<Self as SparseMerkleTree<DEPTH>>::root(self)
|
||||
}
|
||||
|
||||
/// Returns the number of non-empty leaves in this tree.
|
||||
pub fn num_leaves(&self) -> usize {
|
||||
self.leaves.len()
|
||||
}
|
||||
|
||||
/// Returns the leaf at the specified index.
|
||||
pub fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
|
||||
<Self as SparseMerkleTree<DEPTH>>::get_leaf(self, key)
|
||||
@@ -152,6 +178,12 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
|
||||
<Self as SparseMerkleTree<DEPTH>>::open(self, key)
|
||||
}
|
||||
|
||||
/// Returns a boolean value indicating whether the SMT is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
debug_assert_eq!(self.leaves.is_empty(), self.root == Self::EMPTY_ROOT);
|
||||
self.root == Self::EMPTY_ROOT
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
@@ -182,6 +214,48 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
|
||||
<Self as SparseMerkleTree<DEPTH>>::insert(self, key, value)
|
||||
}
|
||||
|
||||
/// Computes what changes are necessary to insert the specified key-value pairs into this
|
||||
/// Merkle tree, allowing for validation before applying those changes.
|
||||
///
|
||||
/// This method returns a [`MutationSet`], which contains all the information for inserting
|
||||
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
|
||||
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
|
||||
/// [`SimpleSmt::apply_mutations()`] can be called in order to commit these changes to the
|
||||
/// Merkle tree, or [`drop()`] to discard them.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word};
|
||||
/// # use miden_crypto::merkle::{LeafIndex, SimpleSmt, EmptySubtreeRoots, SMT_DEPTH};
|
||||
/// let mut smt: SimpleSmt<3> = SimpleSmt::new().unwrap();
|
||||
/// let pair = (LeafIndex::default(), Word::default());
|
||||
/// let mutations = smt.compute_mutations(vec![pair]);
|
||||
/// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(3, 0));
|
||||
/// smt.apply_mutations(mutations);
|
||||
/// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(3, 0));
|
||||
/// ```
|
||||
pub fn compute_mutations(
|
||||
&self,
|
||||
kv_pairs: impl IntoIterator<Item = (LeafIndex<DEPTH>, Word)>,
|
||||
) -> MutationSet<DEPTH, LeafIndex<DEPTH>, Word> {
|
||||
<Self as SparseMerkleTree<DEPTH>>::compute_mutations(self, kv_pairs)
|
||||
}
|
||||
|
||||
/// Apply the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this
|
||||
/// tree.
|
||||
///
|
||||
/// # Errors
|
||||
/// If `mutations` was computed on a tree with a different root than this one, returns
|
||||
/// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
|
||||
/// root hash the `mutations` were computed against, and the second item is the actual
|
||||
/// current root of this tree.
|
||||
pub fn apply_mutations(
|
||||
&mut self,
|
||||
mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
|
||||
) -> Result<(), MerkleError> {
|
||||
<Self as SparseMerkleTree<DEPTH>>::apply_mutations(self, mutations)
|
||||
}
|
||||
|
||||
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
|
||||
/// computed as `DEPTH - SUBTREE_DEPTH`.
|
||||
///
|
||||
@@ -192,7 +266,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
|
||||
subtree: SimpleSmt<SUBTREE_DEPTH>,
|
||||
) -> Result<RpoDigest, MerkleError> {
|
||||
if SUBTREE_DEPTH > DEPTH {
|
||||
return Err(MerkleError::InvalidSubtreeDepth {
|
||||
return Err(MerkleError::SubtreeDepthExceedsDepth {
|
||||
subtree_depth: SUBTREE_DEPTH,
|
||||
tree_depth: DEPTH,
|
||||
});
|
||||
@@ -250,6 +324,20 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
|
||||
type Opening = ValuePath;
|
||||
|
||||
const EMPTY_VALUE: Self::Value = EMPTY_WORD;
|
||||
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0);
|
||||
|
||||
fn from_raw_parts(
|
||||
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
|
||||
leaves: BTreeMap<u64, Word>,
|
||||
root: RpoDigest,
|
||||
) -> Result<Self, MerkleError> {
|
||||
if cfg!(debug_assertions) {
|
||||
let root_node = inner_nodes.get(&NodeIndex::root()).unwrap();
|
||||
assert_eq!(root_node.hash(), root);
|
||||
}
|
||||
|
||||
Ok(Self { root, inner_nodes, leaves })
|
||||
}
|
||||
|
||||
fn root(&self) -> RpoDigest {
|
||||
self.root
|
||||
@@ -260,11 +348,10 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
|
||||
}
|
||||
|
||||
fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
|
||||
self.inner_nodes.get(&index).cloned().unwrap_or_else(|| {
|
||||
let node = EmptySubtreeRoots::entry(DEPTH, index.depth() + 1);
|
||||
|
||||
InnerNode { left: *node, right: *node }
|
||||
})
|
||||
self.inner_nodes
|
||||
.get(&index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(DEPTH, index.depth()))
|
||||
}
|
||||
|
||||
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
|
||||
@@ -276,17 +363,22 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
|
||||
}
|
||||
|
||||
fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> {
|
||||
if value == Self::EMPTY_VALUE {
|
||||
self.leaves.remove(&key.value())
|
||||
} else {
|
||||
self.leaves.insert(key.value(), value)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_value(&self, key: &LeafIndex<DEPTH>) -> Word {
|
||||
self.get_leaf(key)
|
||||
}
|
||||
|
||||
fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
|
||||
// the lookup in empty_hashes could fail only if empty_hashes were not built correctly
|
||||
// by the constructor as we check the depth of the lookup above.
|
||||
let leaf_pos = key.value();
|
||||
|
||||
match self.leaves.get(&leaf_pos) {
|
||||
Some(word) => *word,
|
||||
None => Word::from(*EmptySubtreeRoots::entry(DEPTH, DEPTH)),
|
||||
None => Self::EMPTY_VALUE,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -295,6 +387,15 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
|
||||
leaf.into()
|
||||
}
|
||||
|
||||
fn construct_prospective_leaf(
|
||||
&self,
|
||||
_existing_leaf: Word,
|
||||
_key: &LeafIndex<DEPTH>,
|
||||
value: &Word,
|
||||
) -> Word {
|
||||
*value
|
||||
}
|
||||
|
||||
fn key_to_leaf_index(key: &LeafIndex<DEPTH>) -> LeafIndex<DEPTH> {
|
||||
*key
|
||||
}
|
||||
@@ -302,4 +403,11 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
|
||||
fn path_and_leaf_to_opening(path: MerklePath, leaf: Word) -> ValuePath {
|
||||
(path, leaf).into()
|
||||
}
|
||||
|
||||
fn pairs_to_leaf(mut pairs: Vec<(LeafIndex<DEPTH>, Word)>) -> Word {
|
||||
// SimpleSmt can't have more than one value per key.
|
||||
assert_eq!(pairs.len(), 1);
|
||||
let (_key, value) = pairs.pop().unwrap();
|
||||
value
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
|
||||
use super::{
|
||||
super::{MerkleError, RpoDigest, SimpleSmt},
|
||||
NodeIndex,
|
||||
@@ -8,7 +12,6 @@ use crate::{
|
||||
digests_to_words, int_to_leaf, int_to_node, smt::SparseMerkleTree, EmptySubtreeRoots,
|
||||
InnerNodeInfo, LeafIndex, MerkleTree,
|
||||
},
|
||||
utils::collections::*,
|
||||
Word, EMPTY_WORD,
|
||||
};
|
||||
|
||||
@@ -50,6 +53,8 @@ fn build_sparse_tree() {
|
||||
let mut smt = SimpleSmt::<DEPTH>::new().unwrap();
|
||||
let mut values = ZERO_VALUES8.to_vec();
|
||||
|
||||
assert_eq!(smt.num_leaves(), 0);
|
||||
|
||||
// insert single value
|
||||
let key = 6;
|
||||
let new_node = int_to_leaf(7);
|
||||
@@ -62,6 +67,7 @@ fn build_sparse_tree() {
|
||||
smt.open(&LeafIndex::<3>::new(6).unwrap()).path
|
||||
);
|
||||
assert_eq!(old_value, EMPTY_WORD);
|
||||
assert_eq!(smt.num_leaves(), 1);
|
||||
|
||||
// insert second value at distinct leaf branch
|
||||
let key = 2;
|
||||
@@ -75,6 +81,7 @@ fn build_sparse_tree() {
|
||||
smt.open(&LeafIndex::<3>::new(2).unwrap()).path
|
||||
);
|
||||
assert_eq!(old_value, EMPTY_WORD);
|
||||
assert_eq!(smt.num_leaves(), 2);
|
||||
}
|
||||
|
||||
/// Tests that [`SimpleSmt::with_contiguous_leaves`] works as expected
|
||||
@@ -146,10 +153,11 @@ fn test_inner_node_iterator() -> Result<(), MerkleError> {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_leaf() {
|
||||
fn test_insert() {
|
||||
const DEPTH: u8 = 3;
|
||||
let mut tree =
|
||||
SimpleSmt::<DEPTH>::with_leaves(KEYS8.into_iter().zip(digests_to_words(&VALUES8))).unwrap();
|
||||
assert_eq!(tree.num_leaves(), 8);
|
||||
|
||||
// update one value
|
||||
let key = 3;
|
||||
@@ -161,6 +169,7 @@ fn update_leaf() {
|
||||
let old_leaf = tree.insert(LeafIndex::<DEPTH>::new(key as u64).unwrap(), new_node);
|
||||
assert_eq!(expected_tree.root(), tree.root);
|
||||
assert_eq!(old_leaf, *VALUES8[key]);
|
||||
assert_eq!(tree.num_leaves(), 8);
|
||||
|
||||
// update another value
|
||||
let key = 6;
|
||||
@@ -171,6 +180,18 @@ fn update_leaf() {
|
||||
let old_leaf = tree.insert(LeafIndex::<DEPTH>::new(key as u64).unwrap(), new_node);
|
||||
assert_eq!(expected_tree.root(), tree.root);
|
||||
assert_eq!(old_leaf, *VALUES8[key]);
|
||||
assert_eq!(tree.num_leaves(), 8);
|
||||
|
||||
// set a leaf to empty value
|
||||
let key = 5;
|
||||
let new_node = EMPTY_WORD;
|
||||
expected_values[key] = new_node;
|
||||
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
|
||||
|
||||
let old_leaf = tree.insert(LeafIndex::<DEPTH>::new(key as u64).unwrap(), new_node);
|
||||
assert_eq!(expected_tree.root(), tree.root);
|
||||
assert_eq!(old_leaf, *VALUES8[key]);
|
||||
assert_eq!(tree.num_leaves(), 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -238,12 +259,12 @@ fn test_simplesmt_fail_on_duplicates() {
|
||||
// consecutive
|
||||
let entries = [(1, *first), (1, *second)];
|
||||
let smt = SimpleSmt::<64>::with_leaves(entries);
|
||||
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||
assert_matches!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||
|
||||
// not consecutive
|
||||
let entries = [(1, *first), (5, int_to_leaf(5)), (1, *second)];
|
||||
let smt = SimpleSmt::<64>::with_leaves(entries);
|
||||
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||
assert_matches!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -425,6 +446,23 @@ fn test_simplesmt_set_subtree_entire_tree() {
|
||||
assert_eq!(tree.root(), *EmptySubtreeRoots::entry(DEPTH, 0));
|
||||
}
|
||||
|
||||
/// Tests that `EMPTY_ROOT` constant generated in the `SimpleSmt` equals to the root of the empty
|
||||
/// tree of depth 64
|
||||
#[test]
|
||||
fn test_simplesmt_check_empty_root_constant() {
|
||||
// get the root of the empty tree of depth 64
|
||||
let empty_root_64_depth = EmptySubtreeRoots::empty_hashes(64)[0];
|
||||
assert_eq!(empty_root_64_depth, SimpleSmt::<64>::EMPTY_ROOT);
|
||||
|
||||
// get the root of the empty tree of depth 32
|
||||
let empty_root_32_depth = EmptySubtreeRoots::empty_hashes(32)[0];
|
||||
assert_eq!(empty_root_32_depth, SimpleSmt::<32>::EMPTY_ROOT);
|
||||
|
||||
// get the root of the empty tree of depth 0
|
||||
let empty_root_1_depth = EmptySubtreeRoots::empty_hashes(1)[0];
|
||||
assert_eq!(empty_root_1_depth, SimpleSmt::<1>::EMPTY_ROOT);
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
417
src/merkle/smt/tests.rs
Normal file
417
src/merkle/smt/tests.rs
Normal file
@@ -0,0 +1,417 @@
|
||||
use alloc::{collections::BTreeMap, vec::Vec};
|
||||
|
||||
use super::{
|
||||
build_subtree, InnerNode, LeafIndex, NodeIndex, PairComputations, SmtLeaf, SparseMerkleTree,
|
||||
SubtreeLeaf, SubtreeLeavesIter, COLS_PER_SUBTREE, SUBTREE_DEPTH,
|
||||
};
|
||||
use crate::{
|
||||
hash::rpo::RpoDigest,
|
||||
merkle::{Smt, SMT_DEPTH},
|
||||
Felt, Word, ONE,
|
||||
};
|
||||
|
||||
fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf {
|
||||
SubtreeLeaf {
|
||||
col: leaf.index().index.value(),
|
||||
hash: leaf.hash(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sorted_pairs_to_leaves() {
|
||||
let entries: Vec<(RpoDigest, Word)> = vec![
|
||||
// Subtree 0.
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]),
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(17)]), [ONE; 4]),
|
||||
// Leaf index collision.
|
||||
(RpoDigest::new([ONE, ONE, Felt::new(10), Felt::new(20)]), [ONE; 4]),
|
||||
(RpoDigest::new([ONE, ONE, Felt::new(20), Felt::new(20)]), [ONE; 4]),
|
||||
// Subtree 1. Normal single leaf again.
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), // Subtree boundary.
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(401)]), [ONE; 4]),
|
||||
// Subtree 2. Another normal leaf.
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]),
|
||||
];
|
||||
|
||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||
|
||||
let control_leaves: Vec<SmtLeaf> = {
|
||||
let mut entries_iter = entries.iter().cloned();
|
||||
let mut next_entry = || entries_iter.next().unwrap();
|
||||
let control_leaves = vec![
|
||||
// Subtree 0.
|
||||
SmtLeaf::Single(next_entry()),
|
||||
SmtLeaf::Single(next_entry()),
|
||||
SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(),
|
||||
// Subtree 1.
|
||||
SmtLeaf::Single(next_entry()),
|
||||
SmtLeaf::Single(next_entry()),
|
||||
// Subtree 2.
|
||||
SmtLeaf::Single(next_entry()),
|
||||
];
|
||||
assert_eq!(entries_iter.next(), None);
|
||||
control_leaves
|
||||
};
|
||||
|
||||
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = {
|
||||
let mut control_leaves_iter = control_leaves.iter();
|
||||
let mut next_leaf = || control_leaves_iter.next().unwrap();
|
||||
|
||||
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = [
|
||||
// Subtree 0.
|
||||
vec![next_leaf(), next_leaf(), next_leaf()],
|
||||
// Subtree 1.
|
||||
vec![next_leaf(), next_leaf()],
|
||||
// Subtree 2.
|
||||
vec![next_leaf()],
|
||||
]
|
||||
.map(|subtree| subtree.into_iter().map(smtleaf_to_subtree_leaf).collect())
|
||||
.to_vec();
|
||||
assert_eq!(control_leaves_iter.next(), None);
|
||||
control_subtree_leaves
|
||||
};
|
||||
|
||||
let subtrees: PairComputations<u64, SmtLeaf> = Smt::sorted_pairs_to_leaves(entries);
|
||||
// This will check that the hashes, columns, and subtree assignments all match.
|
||||
assert_eq!(subtrees.leaves, control_subtree_leaves);
|
||||
|
||||
// Flattening and re-separating out the leaves into subtrees should have the same result.
|
||||
let mut all_leaves: Vec<SubtreeLeaf> = subtrees.leaves.clone().into_iter().flatten().collect();
|
||||
let re_grouped: Vec<Vec<_>> = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect();
|
||||
assert_eq!(subtrees.leaves, re_grouped);
|
||||
|
||||
// Then finally we might as well check the computed leaf nodes too.
|
||||
let control_leaves: BTreeMap<u64, SmtLeaf> = control
|
||||
.leaves()
|
||||
.map(|(index, value)| (index.index.value(), value.clone()))
|
||||
.collect();
|
||||
|
||||
for (column, test_leaf) in subtrees.nodes {
|
||||
if test_leaf.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let control_leaf = control_leaves
|
||||
.get(&column)
|
||||
.unwrap_or_else(|| panic!("no leaf node found for column {column}"));
|
||||
assert_eq!(control_leaf, &test_leaf);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper for the below tests.
|
||||
fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> {
|
||||
(0..pair_count)
|
||||
.map(|i| {
|
||||
let leaf_index = ((i as f64 / pair_count as f64) * (pair_count as f64)) as u64;
|
||||
let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]);
|
||||
let value = [ONE, ONE, ONE, Felt::new(i)];
|
||||
(key, value)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_subtree() {
|
||||
// A single subtree's worth of leaves.
|
||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE;
|
||||
|
||||
let entries = generate_entries(PAIR_COUNT);
|
||||
|
||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||
|
||||
// `entries` should already be sorted by nature of how we constructed it.
|
||||
let leaves = Smt::sorted_pairs_to_leaves(entries).leaves;
|
||||
let leaves = leaves.into_iter().next().unwrap();
|
||||
|
||||
let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH);
|
||||
assert!(!first_subtree.is_empty());
|
||||
|
||||
// The inner nodes computed from that subtree should match the nodes in our control tree.
|
||||
for (index, node) in first_subtree.into_iter() {
|
||||
let control = control.get_inner_node(index);
|
||||
assert_eq!(
|
||||
control, node,
|
||||
"subtree-computed node at index {index:?} does not match control",
|
||||
);
|
||||
}
|
||||
|
||||
// The root returned should also match the equivalent node in the control tree.
|
||||
let control_root_index =
|
||||
NodeIndex::new(SMT_DEPTH - SUBTREE_DEPTH, subtree_root.col).expect("Valid root index");
|
||||
let control_root_node = control.get_inner_node(control_root_index);
|
||||
let control_hash = control_root_node.hash();
|
||||
assert_eq!(
|
||||
control_hash, subtree_root.hash,
|
||||
"Subtree-computed root at index {control_root_index:?} does not match control"
|
||||
);
|
||||
}
|
||||
|
||||
// Test that not just can we compute a subtree correctly, but we can feed the results of one
|
||||
// subtree into computing another. In other words, test that `build_subtree()` is correctly
|
||||
// composable.
|
||||
#[test]
|
||||
fn test_two_subtrees() {
|
||||
// Two subtrees' worth of leaves.
|
||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2;
|
||||
|
||||
let entries = generate_entries(PAIR_COUNT);
|
||||
|
||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||
|
||||
let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries);
|
||||
// With two subtrees' worth of leaves, we should have exactly two subtrees.
|
||||
let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap();
|
||||
assert_eq!(first.len() as u64, PAIR_COUNT / 2);
|
||||
assert_eq!(first.len(), second.len());
|
||||
|
||||
let mut current_depth = SMT_DEPTH;
|
||||
let mut next_leaves: Vec<SubtreeLeaf> = Default::default();
|
||||
|
||||
let (first_nodes, first_root) = build_subtree(first, SMT_DEPTH, current_depth);
|
||||
next_leaves.push(first_root);
|
||||
|
||||
let (second_nodes, second_root) = build_subtree(second, SMT_DEPTH, current_depth);
|
||||
next_leaves.push(second_root);
|
||||
|
||||
// All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle.
|
||||
let total_computed = first_nodes.len() + second_nodes.len() + next_leaves.len();
|
||||
assert_eq!(total_computed as u64, PAIR_COUNT);
|
||||
|
||||
// Verify the computed nodes of both subtrees.
|
||||
let computed_nodes = first_nodes.clone().into_iter().chain(second_nodes);
|
||||
for (index, test_node) in computed_nodes {
|
||||
let control_node = control.get_inner_node(index);
|
||||
assert_eq!(
|
||||
control_node, test_node,
|
||||
"subtree-computed node at index {index:?} does not match control",
|
||||
);
|
||||
}
|
||||
|
||||
current_depth -= SUBTREE_DEPTH;
|
||||
|
||||
let (nodes, root_leaf) = build_subtree(next_leaves, SMT_DEPTH, current_depth);
|
||||
assert_eq!(nodes.len(), SUBTREE_DEPTH as usize);
|
||||
assert_eq!(root_leaf.col, 0);
|
||||
|
||||
for (index, test_node) in nodes {
|
||||
let control_node = control.get_inner_node(index);
|
||||
assert_eq!(
|
||||
control_node, test_node,
|
||||
"subtree-computed node at index {index:?} does not match control",
|
||||
);
|
||||
}
|
||||
|
||||
let index = NodeIndex::new(current_depth - SUBTREE_DEPTH, root_leaf.col).unwrap();
|
||||
let control_root = control.get_inner_node(index).hash();
|
||||
assert_eq!(control_root, root_leaf.hash, "Root mismatch");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_singlethreaded_subtrees() {
|
||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
||||
|
||||
let entries = generate_entries(PAIR_COUNT);
|
||||
|
||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||
|
||||
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
|
||||
|
||||
let PairComputations {
|
||||
leaves: mut leaf_subtrees,
|
||||
nodes: test_leaves,
|
||||
} = Smt::sorted_pairs_to_leaves(entries);
|
||||
|
||||
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||
// There's no flat_map_unzip(), so this is the best we can do.
|
||||
let (nodes, mut subtree_roots): (Vec<BTreeMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, subtree)| {
|
||||
// Pre-assertions.
|
||||
assert!(
|
||||
subtree.is_sorted(),
|
||||
"subtree {i} at bottom-depth {current_depth} is not sorted",
|
||||
);
|
||||
assert!(
|
||||
!subtree.is_empty(),
|
||||
"subtree {i} at bottom-depth {current_depth} is empty!",
|
||||
);
|
||||
|
||||
// Do actual things.
|
||||
let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth);
|
||||
|
||||
// Post-assertions.
|
||||
for (&index, test_node) in nodes.iter() {
|
||||
let control_node = control.get_inner_node(index);
|
||||
assert_eq!(
|
||||
test_node, &control_node,
|
||||
"depth {} subtree {}: test node does not match control at index {:?}",
|
||||
current_depth, i, index,
|
||||
);
|
||||
}
|
||||
|
||||
(nodes, subtree_root)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
// Update state between each depth iteration.
|
||||
|
||||
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
||||
accumulated_nodes.extend(nodes.into_iter().flatten());
|
||||
|
||||
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
|
||||
}
|
||||
|
||||
// Make sure the true leaves match, first checking length and then checking each individual
|
||||
// leaf.
|
||||
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
|
||||
let control_leaves_len = control_leaves.len();
|
||||
let test_leaves_len = test_leaves.len();
|
||||
assert_eq!(test_leaves_len, control_leaves_len);
|
||||
for (col, ref test_leaf) in test_leaves {
|
||||
let index = LeafIndex::new_max_depth(col);
|
||||
let &control_leaf = control_leaves.get(&index).unwrap();
|
||||
assert_eq!(test_leaf, control_leaf, "test leaf at column {col} does not match control");
|
||||
}
|
||||
|
||||
// Make sure the inner nodes match, checking length first and then each individual leaf.
|
||||
let control_nodes_len = control.inner_nodes().count();
|
||||
let test_nodes_len = accumulated_nodes.len();
|
||||
assert_eq!(test_nodes_len, control_nodes_len);
|
||||
for (index, test_node) in accumulated_nodes.clone() {
|
||||
let control_node = control.get_inner_node(index);
|
||||
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
|
||||
}
|
||||
|
||||
// After the last iteration of the above for loop, we should have the new root node actually
|
||||
// in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from
|
||||
// `build_subtree()`. So let's check both!
|
||||
|
||||
let control_root = control.get_inner_node(NodeIndex::root());
|
||||
|
||||
// That for loop should have left us with only one leaf subtree...
|
||||
let [leaf_subtree]: [Vec<_>; 1] = leaf_subtrees.try_into().unwrap();
|
||||
// which itself contains only one 'leaf'...
|
||||
let [root_leaf]: [SubtreeLeaf; 1] = leaf_subtree.try_into().unwrap();
|
||||
// which matches the expected root.
|
||||
assert_eq!(control.root(), root_leaf.hash);
|
||||
|
||||
// Likewise `accumulated_nodes` should contain a node at the root index...
|
||||
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
|
||||
// and it should match our actual root.
|
||||
let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap();
|
||||
assert_eq!(control_root, *test_root);
|
||||
// And of course the root we got from each place should match.
|
||||
assert_eq!(control.root(), root_leaf.hash);
|
||||
}
|
||||
|
||||
/// The parallel version of `test_singlethreaded_subtree()`.
|
||||
#[test]
|
||||
#[cfg(feature = "concurrent")]
|
||||
fn test_multithreaded_subtrees() {
|
||||
use rayon::prelude::*;
|
||||
|
||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
||||
|
||||
let entries = generate_entries(PAIR_COUNT);
|
||||
|
||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||
|
||||
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
|
||||
|
||||
let PairComputations {
|
||||
leaves: mut leaf_subtrees,
|
||||
nodes: test_leaves,
|
||||
} = Smt::sorted_pairs_to_leaves(entries);
|
||||
|
||||
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||
let (nodes, mut subtree_roots): (Vec<BTreeMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
|
||||
.into_par_iter()
|
||||
.enumerate()
|
||||
.map(|(i, subtree)| {
|
||||
// Pre-assertions.
|
||||
assert!(
|
||||
subtree.is_sorted(),
|
||||
"subtree {i} at bottom-depth {current_depth} is not sorted",
|
||||
);
|
||||
assert!(
|
||||
!subtree.is_empty(),
|
||||
"subtree {i} at bottom-depth {current_depth} is empty!",
|
||||
);
|
||||
|
||||
let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth);
|
||||
|
||||
// Post-assertions.
|
||||
for (&index, test_node) in nodes.iter() {
|
||||
let control_node = control.get_inner_node(index);
|
||||
assert_eq!(
|
||||
test_node, &control_node,
|
||||
"depth {} subtree {}: test node does not match control at index {:?}",
|
||||
current_depth, i, index,
|
||||
);
|
||||
}
|
||||
|
||||
(nodes, subtree_root)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
||||
accumulated_nodes.extend(nodes.into_iter().flatten());
|
||||
|
||||
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
|
||||
}
|
||||
|
||||
// Make sure the true leaves match, checking length first and then each individual leaf.
|
||||
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
|
||||
let control_leaves_len = control_leaves.len();
|
||||
let test_leaves_len = test_leaves.len();
|
||||
assert_eq!(test_leaves_len, control_leaves_len);
|
||||
for (col, ref test_leaf) in test_leaves {
|
||||
let index = LeafIndex::new_max_depth(col);
|
||||
let &control_leaf = control_leaves.get(&index).unwrap();
|
||||
assert_eq!(test_leaf, control_leaf);
|
||||
}
|
||||
|
||||
// Make sure the inner nodes match, checking length first and then each individual leaf.
|
||||
let control_nodes_len = control.inner_nodes().count();
|
||||
let test_nodes_len = accumulated_nodes.len();
|
||||
assert_eq!(test_nodes_len, control_nodes_len);
|
||||
for (index, test_node) in accumulated_nodes.clone() {
|
||||
let control_node = control.get_inner_node(index);
|
||||
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
|
||||
}
|
||||
|
||||
// After the last iteration of the above for loop, we should have the new root node actually
|
||||
// in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from
|
||||
// `build_subtree()`. So let's check both!
|
||||
|
||||
let control_root = control.get_inner_node(NodeIndex::root());
|
||||
|
||||
// That for loop should have left us with only one leaf subtree...
|
||||
let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap();
|
||||
// which itself contains only one 'leaf'...
|
||||
let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap();
|
||||
// which matches the expected root.
|
||||
assert_eq!(control.root(), root_leaf.hash);
|
||||
|
||||
// Likewise `accumulated_nodes` should contain a node at the root index...
|
||||
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
|
||||
// and it should match our actual root.
|
||||
let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap();
|
||||
assert_eq!(control_root, *test_root);
|
||||
// And of course the root we got from each place should match.
|
||||
assert_eq!(control.root(), root_leaf.hash);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "concurrent")]
|
||||
fn test_with_entries_parallel() {
|
||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
||||
|
||||
let entries = generate_entries(PAIR_COUNT);
|
||||
|
||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||
|
||||
let smt = Smt::with_entries(entries.clone()).unwrap();
|
||||
assert_eq!(smt.root(), control.root());
|
||||
assert_eq!(smt, control);
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
use alloc::{collections::BTreeMap, vec::Vec};
|
||||
use core::borrow::Borrow;
|
||||
|
||||
use super::{
|
||||
@@ -5,7 +6,8 @@ use super::{
|
||||
PartialMerkleTree, RootPath, Rpo256, RpoDigest, SimpleSmt, Smt, ValuePath,
|
||||
};
|
||||
use crate::utils::{
|
||||
collections::*, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
|
||||
collections::{KvMap, RecordingMap},
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -125,8 +127,8 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
|
||||
/// # Errors
|
||||
/// This method can return the following errors:
|
||||
/// - `RootNotInStore` if the `root` is not present in the store.
|
||||
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in
|
||||
/// the store.
|
||||
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the
|
||||
/// store.
|
||||
pub fn get_node(&self, root: RpoDigest, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
let mut hash = root;
|
||||
|
||||
@@ -134,7 +136,10 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
|
||||
self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash))?;
|
||||
|
||||
for i in (0..index.depth()).rev() {
|
||||
let node = self.nodes.get(&hash).ok_or(MerkleError::NodeNotInStore(hash, index))?;
|
||||
let node = self
|
||||
.nodes
|
||||
.get(&hash)
|
||||
.ok_or(MerkleError::NodeIndexNotFoundInStore(hash, index))?;
|
||||
|
||||
let bit = (index.value() >> i) & 1;
|
||||
hash = if bit == 0 { node.left } else { node.right }
|
||||
@@ -150,8 +155,8 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
|
||||
/// # Errors
|
||||
/// This method can return the following errors:
|
||||
/// - `RootNotInStore` if the `root` is not present in the store.
|
||||
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in
|
||||
/// the store.
|
||||
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the
|
||||
/// store.
|
||||
pub fn get_path(&self, root: RpoDigest, index: NodeIndex) -> Result<ValuePath, MerkleError> {
|
||||
let mut hash = root;
|
||||
let mut path = Vec::with_capacity(index.depth().into());
|
||||
@@ -160,7 +165,10 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
|
||||
self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash))?;
|
||||
|
||||
for i in (0..index.depth()).rev() {
|
||||
let node = self.nodes.get(&hash).ok_or(MerkleError::NodeNotInStore(hash, index))?;
|
||||
let node = self
|
||||
.nodes
|
||||
.get(&hash)
|
||||
.ok_or(MerkleError::NodeIndexNotFoundInStore(hash, index))?;
|
||||
|
||||
let bit = (index.value() >> i) & 1;
|
||||
hash = if bit == 0 {
|
||||
@@ -419,8 +427,8 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
|
||||
/// # Errors
|
||||
/// This method can return the following errors:
|
||||
/// - `RootNotInStore` if the `root` is not present in the store.
|
||||
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in
|
||||
/// the store.
|
||||
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the
|
||||
/// store.
|
||||
pub fn set_node(
|
||||
&mut self,
|
||||
mut root: RpoDigest,
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
use assert_matches::assert_matches;
|
||||
use seq_macro::seq;
|
||||
#[cfg(feature = "std")]
|
||||
use {
|
||||
super::{Deserializable, Serializable},
|
||||
alloc::boxed::Box,
|
||||
std::error::Error,
|
||||
};
|
||||
|
||||
use super::{
|
||||
DefaultMerkleStore as MerkleStore, EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex,
|
||||
@@ -11,12 +18,6 @@ use crate::{
|
||||
Felt, Word, ONE, WORD_SIZE, ZERO,
|
||||
};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
use {
|
||||
super::{Deserializable, Serializable},
|
||||
std::error::Error,
|
||||
};
|
||||
|
||||
// TEST DATA
|
||||
// ================================================================================================
|
||||
|
||||
@@ -42,14 +43,14 @@ const VALUES8: [RpoDigest; 8] = [
|
||||
fn test_root_not_in_store() -> Result<(), MerkleError> {
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
|
||||
let store = MerkleStore::from(&mtree);
|
||||
assert_eq!(
|
||||
assert_matches!(
|
||||
store.get_node(VALUES4[0], NodeIndex::make(mtree.depth(), 0)),
|
||||
Err(MerkleError::RootNotInStore(VALUES4[0])),
|
||||
Err(MerkleError::RootNotInStore(root)) if root == VALUES4[0],
|
||||
"Leaf 0 is not a root"
|
||||
);
|
||||
assert_eq!(
|
||||
assert_matches!(
|
||||
store.get_path(VALUES4[0], NodeIndex::make(mtree.depth(), 0)),
|
||||
Err(MerkleError::RootNotInStore(VALUES4[0])),
|
||||
Err(MerkleError::RootNotInStore(root)) if root == VALUES4[0],
|
||||
"Leaf 0 is not a root"
|
||||
);
|
||||
|
||||
@@ -64,46 +65,46 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
|
||||
// STORE LEAVES ARE CORRECT -------------------------------------------------------------------
|
||||
// checks the leaves in the store corresponds to the expected values
|
||||
assert_eq!(
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 0)),
|
||||
Ok(VALUES4[0]),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 0)).unwrap(),
|
||||
VALUES4[0],
|
||||
"node 0 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 1)),
|
||||
Ok(VALUES4[1]),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 1)).unwrap(),
|
||||
VALUES4[1],
|
||||
"node 1 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 2)),
|
||||
Ok(VALUES4[2]),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 2)).unwrap(),
|
||||
VALUES4[2],
|
||||
"node 2 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3)),
|
||||
Ok(VALUES4[3]),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3)).unwrap(),
|
||||
VALUES4[3],
|
||||
"node 3 must be in the tree"
|
||||
);
|
||||
|
||||
// STORE LEAVES MATCH TREE --------------------------------------------------------------------
|
||||
// sanity check the values returned by the store and the tree
|
||||
assert_eq!(
|
||||
mtree.get_node(NodeIndex::make(mtree.depth(), 0)),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 0)),
|
||||
mtree.get_node(NodeIndex::make(mtree.depth(), 0)).unwrap(),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 0)).unwrap(),
|
||||
"node 0 must be the same for both MerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
mtree.get_node(NodeIndex::make(mtree.depth(), 1)),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 1)),
|
||||
mtree.get_node(NodeIndex::make(mtree.depth(), 1)).unwrap(),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 1)).unwrap(),
|
||||
"node 1 must be the same for both MerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
mtree.get_node(NodeIndex::make(mtree.depth(), 2)),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 2)),
|
||||
mtree.get_node(NodeIndex::make(mtree.depth(), 2)).unwrap(),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 2)).unwrap(),
|
||||
"node 2 must be the same for both MerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
mtree.get_node(NodeIndex::make(mtree.depth(), 3)),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3)),
|
||||
mtree.get_node(NodeIndex::make(mtree.depth(), 3)).unwrap(),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3)).unwrap(),
|
||||
"node 3 must be the same for both MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
@@ -115,8 +116,8 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
mtree.get_path(NodeIndex::make(mtree.depth(), 0)),
|
||||
Ok(result.path),
|
||||
mtree.get_path(NodeIndex::make(mtree.depth(), 0)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
@@ -126,8 +127,8 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
mtree.get_path(NodeIndex::make(mtree.depth(), 1)),
|
||||
Ok(result.path),
|
||||
mtree.get_path(NodeIndex::make(mtree.depth(), 1)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 1 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
@@ -137,8 +138,8 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
mtree.get_path(NodeIndex::make(mtree.depth(), 2)),
|
||||
Ok(result.path),
|
||||
mtree.get_path(NodeIndex::make(mtree.depth(), 2)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
@@ -148,8 +149,8 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
mtree.get_path(NodeIndex::make(mtree.depth(), 3)),
|
||||
Ok(result.path),
|
||||
mtree.get_path(NodeIndex::make(mtree.depth(), 3)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
@@ -240,56 +241,56 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
|
||||
// STORE LEAVES ARE CORRECT ==============================================================
|
||||
// checks the leaves in the store corresponds to the expected values
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)),
|
||||
Ok(VALUES4[0]),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap(),
|
||||
VALUES4[0],
|
||||
"node 0 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)),
|
||||
Ok(VALUES4[1]),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap(),
|
||||
VALUES4[1],
|
||||
"node 1 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)),
|
||||
Ok(VALUES4[2]),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap(),
|
||||
VALUES4[2],
|
||||
"node 2 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)),
|
||||
Ok(VALUES4[3]),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap(),
|
||||
VALUES4[3],
|
||||
"node 3 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)),
|
||||
Ok(RpoDigest::default()),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap(),
|
||||
RpoDigest::default(),
|
||||
"unmodified node 4 must be ZERO"
|
||||
);
|
||||
|
||||
// STORE LEAVES MATCH TREE ===============================================================
|
||||
// sanity check the values returned by the store and the tree
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 0)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)),
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap(),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap(),
|
||||
"node 0 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 1)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)),
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap(),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap(),
|
||||
"node 1 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 2)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)),
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap(),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap(),
|
||||
"node 2 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 3)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)),
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap(),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap(),
|
||||
"node 3 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 4)),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)),
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap(),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap(),
|
||||
"node 4 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
@@ -385,46 +386,46 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> {
|
||||
// STORE LEAVES ARE CORRECT ==============================================================
|
||||
// checks the leaves in the store corresponds to the expected values
|
||||
assert_eq!(
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)),
|
||||
Ok(VALUES4[0]),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
|
||||
VALUES4[0],
|
||||
"node 0 must be in the pmt"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)),
|
||||
Ok(VALUES4[1]),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
|
||||
VALUES4[1],
|
||||
"node 1 must be in the pmt"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)),
|
||||
Ok(VALUES4[2]),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
|
||||
VALUES4[2],
|
||||
"node 2 must be in the pmt"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)),
|
||||
Ok(VALUES4[3]),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
|
||||
VALUES4[3],
|
||||
"node 3 must be in the pmt"
|
||||
);
|
||||
|
||||
// STORE LEAVES MATCH PMT ================================================================
|
||||
// sanity check the values returned by the store and the pmt
|
||||
assert_eq!(
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 0)),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)),
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
|
||||
"node 0 must be the same for both PartialMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 1)),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)),
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
|
||||
"node 1 must be the same for both PartialMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 2)),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)),
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
|
||||
"node 2 must be the same for both PartialMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 3)),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)),
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
|
||||
"node 3 must be the same for both PartialMerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
@@ -436,8 +437,8 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> {
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 0)),
|
||||
Ok(result.path),
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
@@ -447,8 +448,8 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> {
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 1)),
|
||||
Ok(result.path),
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 1 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
@@ -458,8 +459,8 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> {
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 2)),
|
||||
Ok(result.path),
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
@@ -469,8 +470,8 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> {
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 3)),
|
||||
Ok(result.path),
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
@@ -498,7 +499,7 @@ fn wont_open_to_different_depth_root() {
|
||||
let store = MerkleStore::from(&mtree);
|
||||
let index = NodeIndex::root();
|
||||
let err = store.get_node(root, index).err().unwrap();
|
||||
assert_eq!(err, MerkleError::RootNotInStore(root));
|
||||
assert_matches!(err, MerkleError::RootNotInStore(err_root) if err_root == root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -537,7 +538,7 @@ fn test_set_node() -> Result<(), MerkleError> {
|
||||
let value = int_to_node(42);
|
||||
let index = NodeIndex::make(mtree.depth(), 0);
|
||||
let new_root = store.set_node(mtree.root(), index, value)?.root;
|
||||
assert_eq!(store.get_node(new_root, index), Ok(value), "Value must have changed");
|
||||
assert_eq!(store.get_node(new_root, index).unwrap(), value, "value must have changed");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -613,7 +614,7 @@ fn node_path_should_be_truncated_by_midtier_insert() {
|
||||
let path = store.get_path(root, index).unwrap().path;
|
||||
assert_eq!(node, result);
|
||||
assert_eq!(path.depth(), depth);
|
||||
assert!(path.verify(index.value(), result, &root));
|
||||
assert!(path.verify(index.value(), result, &root).is_ok());
|
||||
|
||||
// flip the first bit of the key and insert the second node on a different depth
|
||||
let key = key ^ (1 << 63);
|
||||
@@ -626,7 +627,7 @@ fn node_path_should_be_truncated_by_midtier_insert() {
|
||||
let path = store.get_path(root, index).unwrap().path;
|
||||
assert_eq!(node, result);
|
||||
assert_eq!(path.depth(), depth);
|
||||
assert!(path.verify(index.value(), result, &root));
|
||||
assert!(path.verify(index.value(), result, &root).is_ok());
|
||||
|
||||
// attempt to fetch a path of the second node to depth 64
|
||||
// should fail because the previously inserted node will remove its sub-tree from the set
|
||||
@@ -745,7 +746,7 @@ fn get_leaf_depth_works_with_depth_8() {
|
||||
// duplicate the tree on `a` and assert the depth is short-circuited by such sub-tree
|
||||
let index = NodeIndex::new(8, a).unwrap();
|
||||
root = store.set_node(root, index, root).unwrap().root;
|
||||
assert_eq!(Err(MerkleError::DepthTooBig(9)), store.get_leaf_depth(root, 8, a));
|
||||
assert_matches!(store.get_leaf_depth(root, 8, a).unwrap_err(), MerkleError::DepthTooBig(9));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
//! Pseudo-random element generation.
|
||||
|
||||
use rand::RngCore;
|
||||
pub use winter_crypto::{DefaultRandomCoin as WinterRandomCoin, RandomCoin, RandomCoinError};
|
||||
pub use winter_utils::Randomizable;
|
||||
|
||||
use crate::{Felt, FieldElement, Word, ZERO};
|
||||
|
||||
mod rpo;
|
||||
mod rpx;
|
||||
pub use rpo::RpoRandomCoin;
|
||||
pub use rpx::RpxRandomCoin;
|
||||
|
||||
/// Pseudo-random element generator.
|
||||
///
|
||||
/// An instance can be used to draw, uniformly at random, base field elements as well as [Word]s.
|
||||
pub trait FeltRng {
|
||||
pub trait FeltRng: RngCore {
|
||||
/// Draw, uniformly at random, a base field element.
|
||||
fn draw_element(&mut self) -> Felt;
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use super::{Felt, FeltRng, FieldElement, RandomCoin, RandomCoinError, Word, ZERO};
|
||||
use alloc::{string::ToString, vec::Vec};
|
||||
|
||||
use rand_core::impls;
|
||||
|
||||
use super::{Felt, FeltRng, FieldElement, RandomCoin, RandomCoinError, RngCore, Word, ZERO};
|
||||
use crate::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
utils::{
|
||||
collections::*, string::*, vec, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, Serializable,
|
||||
},
|
||||
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
@@ -21,8 +22,8 @@ const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.star
|
||||
/// described in <https://eprint.iacr.org/2011/499.pdf>.
|
||||
///
|
||||
/// The simplification is related to the following facts:
|
||||
/// 1. A call to the reseed method implies one and only one call to the permutation function.
|
||||
/// This is possible because in our case we never reseed with more than 4 field elements.
|
||||
/// 1. A call to the reseed method implies one and only one call to the permutation function. This
|
||||
/// is possible because in our case we never reseed with more than 4 field elements.
|
||||
/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
|
||||
/// material.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@@ -63,6 +64,11 @@ impl RpoRandomCoin {
|
||||
(self.state, self.current)
|
||||
}
|
||||
|
||||
/// Fills `dest` with random data.
|
||||
pub fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
<Self as RngCore>::fill_bytes(self, dest)
|
||||
}
|
||||
|
||||
fn draw_basefield(&mut self) -> Felt {
|
||||
if self.current == RATE_END {
|
||||
Rpo256::apply_permutation(&mut self.state);
|
||||
@@ -139,8 +145,10 @@ impl RandomCoin for RpoRandomCoin {
|
||||
self.state[RATE_START] += nonce;
|
||||
Rpo256::apply_permutation(&mut self.state);
|
||||
|
||||
// reset the buffer
|
||||
self.current = RATE_START;
|
||||
// reset the buffer and move the next random element pointer to the second rate element.
|
||||
// this is done as the first rate element will be "biased" via the provided `nonce` to
|
||||
// contain some number of leading zeros.
|
||||
self.current = RATE_START + 1;
|
||||
|
||||
// determine how many bits are needed to represent valid values in the domain
|
||||
let v_mask = (domain_size - 1) as u64;
|
||||
@@ -166,6 +174,36 @@ impl RandomCoin for RpoRandomCoin {
|
||||
|
||||
Ok(values)
|
||||
}
|
||||
|
||||
fn reseed_with_salt(
|
||||
&mut self,
|
||||
data: <Self::Hasher as winter_crypto::Hasher>::Digest,
|
||||
salt: Option<<Self::Hasher as winter_crypto::Hasher>::Digest>,
|
||||
) {
|
||||
// Reset buffer
|
||||
self.current = RATE_START;
|
||||
|
||||
// Add the new seed material to the first half of the rate portion of the RPO state
|
||||
let data: Word = data.into();
|
||||
|
||||
self.state[RATE_START] += data[0];
|
||||
self.state[RATE_START + 1] += data[1];
|
||||
self.state[RATE_START + 2] += data[2];
|
||||
self.state[RATE_START + 3] += data[3];
|
||||
|
||||
if let Some(salt) = salt {
|
||||
// Add the salt to the second half of the rate portion of the RPO state
|
||||
let data: Word = salt.into();
|
||||
|
||||
self.state[RATE_START + 4] += data[0];
|
||||
self.state[RATE_START + 5] += data[1];
|
||||
self.state[RATE_START + 6] += data[2];
|
||||
self.state[RATE_START + 7] += data[3];
|
||||
}
|
||||
|
||||
// Absorb
|
||||
Rpo256::apply_permutation(&mut self.state);
|
||||
}
|
||||
}
|
||||
|
||||
// FELT RNG IMPLEMENTATION
|
||||
@@ -185,6 +223,28 @@ impl FeltRng for RpoRandomCoin {
|
||||
}
|
||||
}
|
||||
|
||||
// RNGCORE IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl RngCore for RpoRandomCoin {
|
||||
fn next_u32(&mut self) -> u32 {
|
||||
self.draw_basefield().as_int() as u32
|
||||
}
|
||||
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
impls::next_u64_via_u32(self)
|
||||
}
|
||||
|
||||
fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
impls::fill_bytes_via_next(self, dest)
|
||||
}
|
||||
|
||||
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
|
||||
self.fill_bytes(dest);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
324
src/rand/rpx.rs
Normal file
324
src/rand/rpx.rs
Normal file
@@ -0,0 +1,324 @@
|
||||
use alloc::{string::ToString, vec::Vec};
|
||||
|
||||
use rand_core::impls;
|
||||
|
||||
use super::{Felt, FeltRng, FieldElement, RandomCoin, RandomCoinError, RngCore, Word, ZERO};
|
||||
use crate::{
|
||||
hash::rpx::{Rpx256, RpxDigest},
|
||||
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
const STATE_WIDTH: usize = Rpx256::STATE_WIDTH;
|
||||
const RATE_START: usize = Rpx256::RATE_RANGE.start;
|
||||
const RATE_END: usize = Rpx256::RATE_RANGE.end;
|
||||
const HALF_RATE_WIDTH: usize = (Rpx256::RATE_RANGE.end - Rpx256::RATE_RANGE.start) / 2;
|
||||
|
||||
// RPX RANDOM COIN
|
||||
// ================================================================================================
|
||||
/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
|
||||
/// described in <https://eprint.iacr.org/2011/499.pdf>.
|
||||
///
|
||||
/// The simplification is related to the following facts:
|
||||
/// 1. A call to the reseed method implies one and only one call to the permutation function. This
|
||||
/// is possible because in our case we never reseed with more than 4 field elements.
|
||||
/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
|
||||
/// material.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct RpxRandomCoin {
|
||||
state: [Felt; STATE_WIDTH],
|
||||
current: usize,
|
||||
}
|
||||
|
||||
impl RpxRandomCoin {
|
||||
/// Returns a new [RpxRandomCoin] initialize with the specified seed.
|
||||
pub fn new(seed: Word) -> Self {
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
for i in 0..HALF_RATE_WIDTH {
|
||||
state[RATE_START + i] += seed[i];
|
||||
}
|
||||
|
||||
// Absorb
|
||||
Rpx256::apply_permutation(&mut state);
|
||||
|
||||
RpxRandomCoin { state, current: RATE_START }
|
||||
}
|
||||
|
||||
/// Returns an [RpxRandomCoin] instantiated from the provided components.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if `current` is smaller than 4 or greater than or equal to 12.
|
||||
pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
|
||||
assert!(
|
||||
(RATE_START..RATE_END).contains(¤t),
|
||||
"current value outside of valid range"
|
||||
);
|
||||
Self { state, current }
|
||||
}
|
||||
|
||||
/// Returns components of this random coin.
|
||||
pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
|
||||
(self.state, self.current)
|
||||
}
|
||||
|
||||
/// Fills `dest` with random data.
|
||||
pub fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
<Self as RngCore>::fill_bytes(self, dest)
|
||||
}
|
||||
|
||||
fn draw_basefield(&mut self) -> Felt {
|
||||
if self.current == RATE_END {
|
||||
Rpx256::apply_permutation(&mut self.state);
|
||||
self.current = RATE_START;
|
||||
}
|
||||
|
||||
self.current += 1;
|
||||
self.state[self.current - 1]
|
||||
}
|
||||
}
|
||||
|
||||
// RANDOM COIN IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl RandomCoin for RpxRandomCoin {
|
||||
type BaseField = Felt;
|
||||
type Hasher = Rpx256;
|
||||
|
||||
fn new(seed: &[Self::BaseField]) -> Self {
|
||||
let digest: Word = Rpx256::hash_elements(seed).into();
|
||||
Self::new(digest)
|
||||
}
|
||||
|
||||
fn reseed(&mut self, data: RpxDigest) {
|
||||
// Reset buffer
|
||||
self.current = RATE_START;
|
||||
|
||||
// Add the new seed material to the first half of the rate portion of the RPX state
|
||||
let data: Word = data.into();
|
||||
|
||||
self.state[RATE_START] += data[0];
|
||||
self.state[RATE_START + 1] += data[1];
|
||||
self.state[RATE_START + 2] += data[2];
|
||||
self.state[RATE_START + 3] += data[3];
|
||||
|
||||
// Absorb
|
||||
Rpx256::apply_permutation(&mut self.state);
|
||||
}
|
||||
|
||||
fn check_leading_zeros(&self, value: u64) -> u32 {
|
||||
let value = Felt::new(value);
|
||||
let mut state_tmp = self.state;
|
||||
|
||||
state_tmp[RATE_START] += value;
|
||||
|
||||
Rpx256::apply_permutation(&mut state_tmp);
|
||||
|
||||
let first_rate_element = state_tmp[RATE_START].as_int();
|
||||
first_rate_element.trailing_zeros()
|
||||
}
|
||||
|
||||
fn draw<E: FieldElement<BaseField = Felt>>(&mut self) -> Result<E, RandomCoinError> {
|
||||
let ext_degree = E::EXTENSION_DEGREE;
|
||||
let mut result = vec![ZERO; ext_degree];
|
||||
for r in result.iter_mut().take(ext_degree) {
|
||||
*r = self.draw_basefield();
|
||||
}
|
||||
|
||||
let result = E::slice_from_base_elements(&result);
|
||||
Ok(result[0])
|
||||
}
|
||||
|
||||
fn draw_integers(
|
||||
&mut self,
|
||||
num_values: usize,
|
||||
domain_size: usize,
|
||||
nonce: u64,
|
||||
) -> Result<Vec<usize>, RandomCoinError> {
|
||||
assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
|
||||
assert!(num_values < domain_size, "number of values must be smaller than domain size");
|
||||
|
||||
// absorb the nonce
|
||||
let nonce = Felt::new(nonce);
|
||||
self.state[RATE_START] += nonce;
|
||||
Rpx256::apply_permutation(&mut self.state);
|
||||
|
||||
// reset the buffer
|
||||
self.current = RATE_START;
|
||||
|
||||
// determine how many bits are needed to represent valid values in the domain
|
||||
let v_mask = (domain_size - 1) as u64;
|
||||
|
||||
// draw values from PRNG until we get as many unique values as specified by num_queries
|
||||
let mut values = Vec::new();
|
||||
for _ in 0..1000 {
|
||||
// get the next pseudo-random field element
|
||||
let value = self.draw_basefield().as_int();
|
||||
|
||||
// use the mask to get a value within the range
|
||||
let value = (value & v_mask) as usize;
|
||||
|
||||
values.push(value);
|
||||
if values.len() == num_values {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if values.len() < num_values {
|
||||
return Err(RandomCoinError::FailedToDrawIntegers(num_values, values.len(), 1000));
|
||||
}
|
||||
|
||||
Ok(values)
|
||||
}
|
||||
|
||||
fn reseed_with_salt(
|
||||
&mut self,
|
||||
data: <Self::Hasher as winter_crypto::Hasher>::Digest,
|
||||
salt: Option<<Self::Hasher as winter_crypto::Hasher>::Digest>,
|
||||
) {
|
||||
// Reset buffer
|
||||
self.current = RATE_START;
|
||||
|
||||
// Add the new seed material to the first half of the rate portion of the RPO state
|
||||
let data: Word = data.into();
|
||||
|
||||
self.state[RATE_START] += data[0];
|
||||
self.state[RATE_START + 1] += data[1];
|
||||
self.state[RATE_START + 2] += data[2];
|
||||
self.state[RATE_START + 3] += data[3];
|
||||
|
||||
if let Some(salt) = salt {
|
||||
// Add the salt to the second half of the rate portion of the RPO state
|
||||
let data: Word = salt.into();
|
||||
|
||||
self.state[RATE_START + 4] += data[0];
|
||||
self.state[RATE_START + 5] += data[1];
|
||||
self.state[RATE_START + 6] += data[2];
|
||||
self.state[RATE_START + 7] += data[3];
|
||||
}
|
||||
|
||||
// Absorb
|
||||
Rpx256::apply_permutation(&mut self.state);
|
||||
}
|
||||
}
|
||||
|
||||
// FELT RNG IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl FeltRng for RpxRandomCoin {
|
||||
fn draw_element(&mut self) -> Felt {
|
||||
self.draw_basefield()
|
||||
}
|
||||
|
||||
fn draw_word(&mut self) -> Word {
|
||||
let mut output = [ZERO; 4];
|
||||
for o in output.iter_mut() {
|
||||
*o = self.draw_basefield();
|
||||
}
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
// RNGCORE IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl RngCore for RpxRandomCoin {
|
||||
fn next_u32(&mut self) -> u32 {
|
||||
self.draw_basefield().as_int() as u32
|
||||
}
|
||||
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
impls::next_u64_via_u32(self)
|
||||
}
|
||||
|
||||
fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
impls::fill_bytes_via_next(self, dest)
|
||||
}
|
||||
|
||||
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
|
||||
self.fill_bytes(dest);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl Serializable for RpxRandomCoin {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
self.state.iter().for_each(|v| v.write_into(target));
|
||||
// casting to u8 is OK because `current` is always between 4 and 12.
|
||||
target.write_u8(self.current as u8);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RpxRandomCoin {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let state = [
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
];
|
||||
let current = source.read_u8()? as usize;
|
||||
if !(RATE_START..RATE_END).contains(¤t) {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"current value outside of valid range".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(Self { state, current })
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Deserializable, FeltRng, RpxRandomCoin, Serializable, ZERO};
|
||||
use crate::ONE;
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_felt() {
|
||||
let mut rpxcoin = RpxRandomCoin::new([ZERO; 4]);
|
||||
let output = rpxcoin.draw_element();
|
||||
|
||||
let mut rpxcoin = RpxRandomCoin::new([ZERO; 4]);
|
||||
let expected = rpxcoin.draw_basefield();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_word() {
|
||||
let mut rpxcoin = RpxRandomCoin::new([ZERO; 4]);
|
||||
let output = rpxcoin.draw_word();
|
||||
|
||||
let mut rpocoin = RpxRandomCoin::new([ZERO; 4]);
|
||||
let mut expected = [ZERO; 4];
|
||||
for o in expected.iter_mut() {
|
||||
*o = rpocoin.draw_basefield();
|
||||
}
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_serialization() {
|
||||
let coin1 = RpxRandomCoin::from_parts([ONE; 12], 5);
|
||||
|
||||
let bytes = coin1.to_bytes();
|
||||
let coin2 = RpxRandomCoin::read_from_bytes(&bytes).unwrap();
|
||||
assert_eq!(coin1, coin2);
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,8 @@
|
||||
use core::cell::RefCell;
|
||||
|
||||
use super::{
|
||||
boxed::*,
|
||||
collections::{btree_map::*, *},
|
||||
use alloc::{
|
||||
boxed::Box,
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
};
|
||||
use core::cell::RefCell;
|
||||
|
||||
// KEY-VALUE MAP TRAIT
|
||||
// ================================================================================================
|
||||
@@ -127,11 +126,10 @@ impl<K: Ord + Clone, V: Clone> KvMap<K, V> for RecordingMap<K, V> {
|
||||
///
|
||||
/// If the key is part of the initial data set, the key access is recorded.
|
||||
fn get(&self, key: &K) -> Option<&V> {
|
||||
self.data.get(key).map(|value| {
|
||||
self.data.get(key).inspect(|&value| {
|
||||
if !self.updates.contains(key) {
|
||||
self.trace.borrow_mut().insert(key.clone(), value.clone());
|
||||
}
|
||||
value
|
||||
})
|
||||
}
|
||||
|
||||
@@ -156,11 +154,10 @@ impl<K: Ord + Clone, V: Clone> KvMap<K, V> for RecordingMap<K, V> {
|
||||
/// returned.
|
||||
fn insert(&mut self, key: K, value: V) -> Option<V> {
|
||||
let new_update = self.updates.insert(key.clone());
|
||||
self.data.insert(key.clone(), value).map(|old_value| {
|
||||
self.data.insert(key.clone(), value).inspect(|old_value| {
|
||||
if new_update {
|
||||
self.trace.borrow_mut().insert(key, old_value.clone());
|
||||
}
|
||||
old_value
|
||||
})
|
||||
}
|
||||
|
||||
@@ -168,12 +165,11 @@ impl<K: Ord + Clone, V: Clone> KvMap<K, V> for RecordingMap<K, V> {
|
||||
///
|
||||
/// If the key exists in the data set, the old value is returned.
|
||||
fn remove(&mut self, key: &K) -> Option<V> {
|
||||
self.data.remove(key).map(|old_value| {
|
||||
self.data.remove(key).inspect(|old_value| {
|
||||
let new_update = self.updates.insert(key.clone());
|
||||
if new_update {
|
||||
self.trace.borrow_mut().insert(key.clone(), old_value.clone());
|
||||
}
|
||||
old_value
|
||||
})
|
||||
}
|
||||
|
||||
@@ -202,7 +198,7 @@ impl<K: Clone + Ord, V: Clone> FromIterator<(K, V)> for RecordingMap<K, V> {
|
||||
|
||||
impl<K: Clone + Ord, V: Clone> IntoIterator for RecordingMap<K, V> {
|
||||
type Item = (K, V);
|
||||
type IntoIter = IntoIter<K, V>;
|
||||
type IntoIter = alloc::collections::btree_map::IntoIter<K, V>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.data.into_iter()
|
||||
@@ -329,7 +325,8 @@ mod tests {
|
||||
let mut map = RecordingMap::new(ITEMS.to_vec());
|
||||
assert!(map.iter().all(|(x, y)| ITEMS.contains(&(*x, *y))));
|
||||
|
||||
// when inserting entry with key that already exists the iterator should return the new value
|
||||
// when inserting entry with key that already exists the iterator should return the new
|
||||
// value
|
||||
let new_value = 5;
|
||||
map.insert(4, new_value);
|
||||
assert_eq!(map.iter().count(), ITEMS.len());
|
||||
|
||||
@@ -1,28 +1,23 @@
|
||||
//! Utilities used in this crate which can also be generally useful downstream.
|
||||
|
||||
use core::fmt::{self, Display, Write};
|
||||
use alloc::string::String;
|
||||
use core::fmt::{self, Write};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
pub use std::{format, vec};
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
pub use alloc::{format, vec};
|
||||
use thiserror::Error;
|
||||
|
||||
use super::Word;
|
||||
use crate::utils::string::*;
|
||||
|
||||
mod kv_map;
|
||||
|
||||
// RE-EXPORTS
|
||||
// ================================================================================================
|
||||
|
||||
pub use winter_utils::{
|
||||
boxed, string, uninit_vector, Box, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, Serializable, SliceReader,
|
||||
uninit_vector, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
|
||||
SliceReader,
|
||||
};
|
||||
|
||||
pub mod collections {
|
||||
pub use winter_utils::collections::*;
|
||||
|
||||
pub use super::kv_map::*;
|
||||
}
|
||||
|
||||
@@ -53,36 +48,20 @@ pub fn bytes_to_hex_string<const N: usize>(data: [u8; N]) -> String {
|
||||
}
|
||||
|
||||
/// Defines errors which can occur during parsing of hexadecimal strings.
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Error)]
|
||||
pub enum HexParseError {
|
||||
#[error(
|
||||
"expected hex data to have length {expected}, including the 0x prefix, found {actual}"
|
||||
)]
|
||||
InvalidLength { expected: usize, actual: usize },
|
||||
#[error("hex encoded data must start with 0x prefix")]
|
||||
MissingPrefix,
|
||||
#[error("hex encoded data must contain only characters [a-zA-Z0-9]")]
|
||||
InvalidChar,
|
||||
#[error("hex encoded values of a Digest must be inside the field modulus")]
|
||||
OutOfRange,
|
||||
}
|
||||
|
||||
impl Display for HexParseError {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
match self {
|
||||
HexParseError::InvalidLength { expected, actual } => {
|
||||
write!(f, "Hex encoded RpoDigest must have length 66, including the 0x prefix. expected {expected} got {actual}")
|
||||
}
|
||||
HexParseError::MissingPrefix => {
|
||||
write!(f, "Hex encoded RpoDigest must start with 0x prefix")
|
||||
}
|
||||
HexParseError::InvalidChar => {
|
||||
write!(f, "Hex encoded RpoDigest must contain characters [a-zA-Z0-9]")
|
||||
}
|
||||
HexParseError::OutOfRange => {
|
||||
write!(f, "Hex encoded values of an RpoDigest must be inside the field modulus")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for HexParseError {}
|
||||
|
||||
/// Parses a hex string into an array of bytes of known size.
|
||||
pub fn hex_to_bytes<const N: usize>(value: &str) -> Result<[u8; N], HexParseError> {
|
||||
let expected: usize = (N * 2) + 2;
|
||||
@@ -102,12 +81,11 @@ pub fn hex_to_bytes<const N: usize>(value: &str) -> Result<[u8; N], HexParseErro
|
||||
});
|
||||
|
||||
let mut decoded = [0u8; N];
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for pos in 0..N {
|
||||
for byte in decoded.iter_mut() {
|
||||
// These `unwrap` calls are okay because the length was checked above
|
||||
let high: u8 = data.next().unwrap()?;
|
||||
let low: u8 = data.next().unwrap()?;
|
||||
decoded[pos] = (high << 4) + low;
|
||||
*byte = (high << 4) + low;
|
||||
}
|
||||
|
||||
Ok(decoded)
|
||||
|
||||
Reference in New Issue
Block a user